引言:自注意力为何成为核心组件?
自注意力机制(Self-attention)是Transformer架构的核心,通过计算序列中元素间的相互关系,捕捉长距离依赖和上下文信息。其核心优势在于并行计算能力和动态权重分配,相比传统RNN或CNN,能更高效地处理序列数据。然而,原始自注意力机制存在计算复杂度(O(n²))和空间占用高的痛点,尤其在长序列场景下。为此,研究者提出了多种变体,针对不同需求优化性能与效果。本文将系统梳理自注意力机制的经典变体及其应用场景,为开发者提供架构设计参考。
一、基础自注意力机制:核心原理与实现
1.1 原始自注意力公式
原始自注意力通过查询(Q)、键(K)、值(V)三个矩阵的交互计算权重:
import torchimport torch.nn as nnclass SelfAttention(nn.Module):def __init__(self, embed_dim):super().__init__()self.q_proj = nn.Linear(embed_dim, embed_dim)self.k_proj = nn.Linear(embed_dim, embed_dim)self.v_proj = nn.Linear(embed_dim, embed_dim)self.scale = embed_dim ** -0.5def forward(self, x):Q = self.q_proj(x) # (batch, seq_len, embed_dim)K = self.k_proj(x)V = self.v_proj(x)# 计算注意力分数scores = torch.bmm(Q, K.transpose(1, 2)) * self.scale # (batch, seq_len, seq_len)weights = torch.softmax(scores, dim=-1)output = torch.bmm(weights, V) # (batch, seq_len, embed_dim)return output
关键点:
- 缩放因子(
scale):防止点积结果过大导致梯度消失。 - 权重归一化:通过
softmax确保每行的权重和为1。
1.2 多头自注意力(Multi-head Attention)
通过将Q、K、V拆分为多个头(head),并行计算不同子空间的注意力,增强模型表达能力:
class MultiHeadAttention(nn.Module):def __init__(self, embed_dim, num_heads):super().__init__()self.head_dim = embed_dim // num_headsself.num_heads = num_headsself.q_proj = nn.Linear(embed_dim, embed_dim)self.k_proj = nn.Linear(embed_dim, embed_dim)self.v_proj = nn.Linear(embed_dim, embed_dim)self.out_proj = nn.Linear(embed_dim, embed_dim)self.scale = (self.head_dim) ** -0.5def forward(self, x):batch_size, seq_len, _ = x.shapeQ = self.q_proj(x).view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)K = self.k_proj(x).view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)V = self.v_proj(x).view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)scores = torch.matmul(Q, K.transpose(-2, -1)) * self.scaleweights = torch.softmax(scores, dim=-1)output = torch.matmul(weights, V).transpose(1, 2).contiguous()output = output.view(batch_size, seq_len, -1)return self.out_proj(output)
优势:
- 每个头关注不同的特征子空间(如语法、语义)。
- 参数总量与单头相当(
embed_dim不变)。
二、自注意力机制的变体与优化
2.1 稀疏自注意力(Sparse Attention)
问题:原始自注意力计算复杂度为O(n²),长序列下内存消耗大。
解决方案:限制注意力范围,仅计算部分元素对的交互。
典型变体:
- 局部窗口注意力(Local Window Attention):将序列划分为固定大小的窗口,每个token仅关注窗口内的其他token。例如,Swin Transformer中的滑动窗口机制。
- 随机注意力(Random Attention):随机选择部分token进行交互,平衡计算效率与信息覆盖。
- 轴向注意力(Axial Attention):分别在行和列方向计算注意力,适用于图像等二维数据。
代码示例(局部窗口注意力):
def local_window_attention(Q, K, V, window_size):batch_size, seq_len, _ = Q.shapewindows = []for i in range(0, seq_len, window_size):start, end = i, min(i + window_size, seq_len)Q_window = Q[:, start:end, :]K_window = K[:, start:end, :]V_window = V[:, start:end, :]scores = torch.bmm(Q_window, K_window.transpose(1, 2)) * (Q_window.shape[-1] ** -0.5)weights = torch.softmax(scores, dim=-1)window_output = torch.bmm(weights, V_window)windows.append(window_output)return torch.cat(windows, dim=1)
2.2 相对位置编码(Relative Position Encoding)
问题:原始绝对位置编码无法直接建模token间的相对距离。
解决方案:引入相对位置偏置(Relative Position Bias),计算注意力时考虑位置差。
公式:
[ \text{Attention}(Q, K, V) = \text{softmax}\left(\frac{QK^T}{\sqrt{d}} + B\right)V ]
其中,( B \in \mathbb{R}^{2L-1} )为可学习的相对位置偏置矩阵,( L )为最大相对距离。
代码示例:
class RelativePositionAttention(nn.Module):def __init__(self, embed_dim, max_pos=512):super().__init__()self.max_pos = max_posself.rel_pos_bias = nn.Parameter(torch.randn(2 * max_pos - 1, embed_dim))def forward(self, Q, K, pos_ids): # pos_ids: (batch, seq_len)batch_size, seq_len, _ = Q.shapescores = torch.bmm(Q, K.transpose(1, 2)) * (Q.shape[-1] ** -0.5)# 计算相对位置偏置rel_pos = pos_ids.unsqueeze(2) - pos_ids.unsqueeze(1) # (batch, seq_len, seq_len)rel_pos = rel_pos.clamp(-self.max_pos + 1, self.max_pos - 1)rel_pos_indices = rel_pos + self.max_pos - 1 # 转换为0~2L-2rel_pos_bias = self.rel_pos_bias[rel_pos_indices.long()] # (batch, seq_len, seq_len, embed_dim)rel_pos_bias = rel_pos_bias.mean(dim=-1) # (batch, seq_len, seq_len)scores = scores + rel_pos_biasweights = torch.softmax(scores, dim=-1)output = torch.bmm(weights, V)return output
2.3 线性注意力(Linear Attention)
问题:自注意力的点积计算依赖矩阵乘法,复杂度高。
解决方案:通过核方法(Kernel Method)将点积转换为线性运算,降低复杂度至O(n)。
公式:
[ \text{Attention}(Q, K, V) = \text{softmax}(QK^T)V \approx \phi(Q)(\phi(K)^TV) ]
其中,( \phi )为非线性函数(如elu(x)+1)。
代码示例:
class LinearAttention(nn.Module):def __init__(self, embed_dim):super().__init__()self.q_proj = nn.Linear(embed_dim, embed_dim)self.k_proj = nn.Linear(embed_dim, embed_dim)self.v_proj = nn.Linear(embed_dim, embed_dim)self.phi = lambda x: torch.nn.functional.elu(x) + 1 # 核函数def forward(self, x):Q = self.q_proj(x)K = self.k_proj(x)V = self.v_proj(x)# 核化计算Q_phi = self.phi(Q)K_phi = self.phi(K)D = torch.sum(K_phi, dim=-1, keepdim=True) # (batch, seq_len, 1)weights = torch.bmm(Q_phi, K_phi.transpose(1, 2)) / D # (batch, seq_len, seq_len)output = torch.bmm(weights, V)return output
三、应用场景与最佳实践
3.1 长序列处理
- 选择稀疏注意力:如局部窗口或轴向注意力,平衡效率与效果。
- 分块计算:将长序列拆分为多个块,分别计算注意力后合并。
3.2 计算资源受限场景
- 使用线性注意力:降低内存占用,适合移动端或边缘设备。
- 量化与剪枝:对Q、K、V矩阵进行低比特量化,减少计算量。
3.3 多模态任务
- 结合相对位置编码:在图像或视频任务中,建模空间或时间上的相对关系。
- 跨模态注意力:设计共享的Q矩阵,分别与文本K/V和图像K/V交互。
总结
自注意力机制的变体通过稀疏化、位置编码优化和线性化等手段,解决了原始机制在长序列、计算效率和多模态场景下的痛点。开发者可根据任务需求选择合适的变体,例如:
- 长文本生成:稀疏窗口注意力 + 相对位置编码。
- 实时语音识别:线性注意力 + 多头并行。
- 图像描述:轴向注意力 + 跨模态交互。
未来,自注意力机制将进一步与图神经网络、动态路由等技术融合,拓展其在非欧几里得数据和复杂推理任务中的应用。