自注意力机制:从基础到进阶的神奇变体

引言:自注意力为何成为核心组件?

自注意力机制(Self-attention)是Transformer架构的核心,通过计算序列中元素间的相互关系,捕捉长距离依赖和上下文信息。其核心优势在于并行计算能力动态权重分配,相比传统RNN或CNN,能更高效地处理序列数据。然而,原始自注意力机制存在计算复杂度(O(n²))和空间占用高的痛点,尤其在长序列场景下。为此,研究者提出了多种变体,针对不同需求优化性能与效果。本文将系统梳理自注意力机制的经典变体及其应用场景,为开发者提供架构设计参考。

一、基础自注意力机制:核心原理与实现

1.1 原始自注意力公式

原始自注意力通过查询(Q)、键(K)、值(V)三个矩阵的交互计算权重:

  1. import torch
  2. import torch.nn as nn
  3. class SelfAttention(nn.Module):
  4. def __init__(self, embed_dim):
  5. super().__init__()
  6. self.q_proj = nn.Linear(embed_dim, embed_dim)
  7. self.k_proj = nn.Linear(embed_dim, embed_dim)
  8. self.v_proj = nn.Linear(embed_dim, embed_dim)
  9. self.scale = embed_dim ** -0.5
  10. def forward(self, x):
  11. Q = self.q_proj(x) # (batch, seq_len, embed_dim)
  12. K = self.k_proj(x)
  13. V = self.v_proj(x)
  14. # 计算注意力分数
  15. scores = torch.bmm(Q, K.transpose(1, 2)) * self.scale # (batch, seq_len, seq_len)
  16. weights = torch.softmax(scores, dim=-1)
  17. output = torch.bmm(weights, V) # (batch, seq_len, embed_dim)
  18. return output

关键点

  • 缩放因子scale):防止点积结果过大导致梯度消失。
  • 权重归一化:通过softmax确保每行的权重和为1。

1.2 多头自注意力(Multi-head Attention)

通过将Q、K、V拆分为多个头(head),并行计算不同子空间的注意力,增强模型表达能力:

  1. class MultiHeadAttention(nn.Module):
  2. def __init__(self, embed_dim, num_heads):
  3. super().__init__()
  4. self.head_dim = embed_dim // num_heads
  5. self.num_heads = num_heads
  6. self.q_proj = nn.Linear(embed_dim, embed_dim)
  7. self.k_proj = nn.Linear(embed_dim, embed_dim)
  8. self.v_proj = nn.Linear(embed_dim, embed_dim)
  9. self.out_proj = nn.Linear(embed_dim, embed_dim)
  10. self.scale = (self.head_dim) ** -0.5
  11. def forward(self, x):
  12. batch_size, seq_len, _ = x.shape
  13. Q = self.q_proj(x).view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)
  14. K = self.k_proj(x).view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)
  15. V = self.v_proj(x).view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)
  16. scores = torch.matmul(Q, K.transpose(-2, -1)) * self.scale
  17. weights = torch.softmax(scores, dim=-1)
  18. output = torch.matmul(weights, V).transpose(1, 2).contiguous()
  19. output = output.view(batch_size, seq_len, -1)
  20. return self.out_proj(output)

优势

  • 每个头关注不同的特征子空间(如语法、语义)。
  • 参数总量与单头相当(embed_dim不变)。

二、自注意力机制的变体与优化

2.1 稀疏自注意力(Sparse Attention)

问题:原始自注意力计算复杂度为O(n²),长序列下内存消耗大。
解决方案:限制注意力范围,仅计算部分元素对的交互。
典型变体

  • 局部窗口注意力(Local Window Attention):将序列划分为固定大小的窗口,每个token仅关注窗口内的其他token。例如,Swin Transformer中的滑动窗口机制。
  • 随机注意力(Random Attention):随机选择部分token进行交互,平衡计算效率与信息覆盖。
  • 轴向注意力(Axial Attention):分别在行和列方向计算注意力,适用于图像等二维数据。

代码示例(局部窗口注意力)

  1. def local_window_attention(Q, K, V, window_size):
  2. batch_size, seq_len, _ = Q.shape
  3. windows = []
  4. for i in range(0, seq_len, window_size):
  5. start, end = i, min(i + window_size, seq_len)
  6. Q_window = Q[:, start:end, :]
  7. K_window = K[:, start:end, :]
  8. V_window = V[:, start:end, :]
  9. scores = torch.bmm(Q_window, K_window.transpose(1, 2)) * (Q_window.shape[-1] ** -0.5)
  10. weights = torch.softmax(scores, dim=-1)
  11. window_output = torch.bmm(weights, V_window)
  12. windows.append(window_output)
  13. return torch.cat(windows, dim=1)

2.2 相对位置编码(Relative Position Encoding)

问题:原始绝对位置编码无法直接建模token间的相对距离。
解决方案:引入相对位置偏置(Relative Position Bias),计算注意力时考虑位置差。
公式
[ \text{Attention}(Q, K, V) = \text{softmax}\left(\frac{QK^T}{\sqrt{d}} + B\right)V ]
其中,( B \in \mathbb{R}^{2L-1} )为可学习的相对位置偏置矩阵,( L )为最大相对距离。

代码示例

  1. class RelativePositionAttention(nn.Module):
  2. def __init__(self, embed_dim, max_pos=512):
  3. super().__init__()
  4. self.max_pos = max_pos
  5. self.rel_pos_bias = nn.Parameter(torch.randn(2 * max_pos - 1, embed_dim))
  6. def forward(self, Q, K, pos_ids): # pos_ids: (batch, seq_len)
  7. batch_size, seq_len, _ = Q.shape
  8. scores = torch.bmm(Q, K.transpose(1, 2)) * (Q.shape[-1] ** -0.5)
  9. # 计算相对位置偏置
  10. rel_pos = pos_ids.unsqueeze(2) - pos_ids.unsqueeze(1) # (batch, seq_len, seq_len)
  11. rel_pos = rel_pos.clamp(-self.max_pos + 1, self.max_pos - 1)
  12. rel_pos_indices = rel_pos + self.max_pos - 1 # 转换为0~2L-2
  13. rel_pos_bias = self.rel_pos_bias[rel_pos_indices.long()] # (batch, seq_len, seq_len, embed_dim)
  14. rel_pos_bias = rel_pos_bias.mean(dim=-1) # (batch, seq_len, seq_len)
  15. scores = scores + rel_pos_bias
  16. weights = torch.softmax(scores, dim=-1)
  17. output = torch.bmm(weights, V)
  18. return output

2.3 线性注意力(Linear Attention)

问题:自注意力的点积计算依赖矩阵乘法,复杂度高。
解决方案:通过核方法(Kernel Method)将点积转换为线性运算,降低复杂度至O(n)。
公式
[ \text{Attention}(Q, K, V) = \text{softmax}(QK^T)V \approx \phi(Q)(\phi(K)^TV) ]
其中,( \phi )为非线性函数(如elu(x)+1)。

代码示例

  1. class LinearAttention(nn.Module):
  2. def __init__(self, embed_dim):
  3. super().__init__()
  4. self.q_proj = nn.Linear(embed_dim, embed_dim)
  5. self.k_proj = nn.Linear(embed_dim, embed_dim)
  6. self.v_proj = nn.Linear(embed_dim, embed_dim)
  7. self.phi = lambda x: torch.nn.functional.elu(x) + 1 # 核函数
  8. def forward(self, x):
  9. Q = self.q_proj(x)
  10. K = self.k_proj(x)
  11. V = self.v_proj(x)
  12. # 核化计算
  13. Q_phi = self.phi(Q)
  14. K_phi = self.phi(K)
  15. D = torch.sum(K_phi, dim=-1, keepdim=True) # (batch, seq_len, 1)
  16. weights = torch.bmm(Q_phi, K_phi.transpose(1, 2)) / D # (batch, seq_len, seq_len)
  17. output = torch.bmm(weights, V)
  18. return output

三、应用场景与最佳实践

3.1 长序列处理

  • 选择稀疏注意力:如局部窗口或轴向注意力,平衡效率与效果。
  • 分块计算:将长序列拆分为多个块,分别计算注意力后合并。

3.2 计算资源受限场景

  • 使用线性注意力:降低内存占用,适合移动端或边缘设备。
  • 量化与剪枝:对Q、K、V矩阵进行低比特量化,减少计算量。

3.3 多模态任务

  • 结合相对位置编码:在图像或视频任务中,建模空间或时间上的相对关系。
  • 跨模态注意力:设计共享的Q矩阵,分别与文本K/V和图像K/V交互。

总结

自注意力机制的变体通过稀疏化、位置编码优化和线性化等手段,解决了原始机制在长序列、计算效率和多模态场景下的痛点。开发者可根据任务需求选择合适的变体,例如:

  • 长文本生成:稀疏窗口注意力 + 相对位置编码。
  • 实时语音识别:线性注意力 + 多头并行。
  • 图像描述:轴向注意力 + 跨模态交互。

未来,自注意力机制将进一步与图神经网络、动态路由等技术融合,拓展其在非欧几里得数据和复杂推理任务中的应用。