从零实现Self-Attention模块:原理、代码与优化实践

一、Self-Attention核心机制解析

Self-Attention机制通过计算序列中每个元素与其他所有元素的关联性,动态生成权重分布。其核心公式可分解为三个关键步骤:

1.1 线性变换与矩阵拆分

输入序列X∈ℝ^(n×d)首先经过三个独立的线性变换:

  1. import torch
  2. import torch.nn as nn
  3. class SelfAttention(nn.Module):
  4. def __init__(self, embed_dim, num_heads):
  5. super().__init__()
  6. self.embed_dim = embed_dim
  7. self.num_heads = num_heads
  8. self.head_dim = embed_dim // num_heads
  9. # 确保embed_dim能被num_heads整除
  10. assert self.head_dim * num_heads == embed_dim, \
  11. "embed_dim must be divisible by num_heads"
  12. # 定义QKV变换矩阵
  13. self.q_proj = nn.Linear(embed_dim, embed_dim)
  14. self.k_proj = nn.Linear(embed_dim, embed_dim)
  15. self.v_proj = nn.Linear(embed_dim, embed_dim)

其中W^Q,W^K,W^V∈ℝ^(d×d)将输入映射到查询(Query)、键(Key)、值(Value)空间。多头注意力通过将d维空间拆分为h个d/h维子空间实现并行计算。

1.2 注意力分数计算

注意力分数矩阵S∈ℝ^(n×n)通过QK^T计算得到:
S = QK^T / √d_k
其中√d_k是缩放因子,防止点积结果过大导致softmax梯度消失。实现时需注意矩阵乘法的维度对齐:

  1. def forward(self, x):
  2. batch_size = x.size(0)
  3. # 生成QKV矩阵 (batch_size, seq_len, embed_dim)
  4. Q = self.q_proj(x)
  5. K = self.k_proj(x)
  6. V = self.v_proj(x)
  7. # 多头拆分 (batch_size, num_heads, seq_len, head_dim)
  8. Q = Q.view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2)
  9. K = K.view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2)
  10. V = V.view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2)

1.3 权重归一化与加权求和

通过softmax将注意力分数转换为概率分布,再与V矩阵相乘得到最终输出:
Attention(Q,K,V) = softmax(S/√d_k)V
实现时需使用mask机制处理变长序列:

  1. # 计算注意力分数 (batch_size, num_heads, seq_len, seq_len)
  2. attn_scores = torch.matmul(Q, K.transpose(-2, -1)) / (self.head_dim ** 0.5)
  3. # 可选:添加注意力mask(如处理padding位置)
  4. if mask is not None:
  5. attn_scores = attn_scores.masked_fill(mask == 0, float('-inf'))
  6. # 计算注意力权重
  7. attn_weights = torch.softmax(attn_scores, dim=-1)
  8. # 加权求和 (batch_size, num_heads, seq_len, head_dim)
  9. output = torch.matmul(attn_weights, V)

二、完整实现与关键优化

2.1 多头注意力整合

将h个头的输出拼接后通过线性变换恢复d维空间:

  1. # 拼接多头输出 (batch_size, seq_len, num_heads, head_dim)
  2. output = output.transpose(1, 2).contiguous()
  3. output = output.view(batch_size, -1, self.embed_dim)
  4. # 最终线性变换
  5. output = self.out_proj(output)
  6. return output

完整类定义需包含输出投影层:

  1. class MultiHeadAttention(nn.Module):
  2. def __init__(self, embed_dim, num_heads):
  3. super().__init__()
  4. self.embed_dim = embed_dim
  5. self.num_heads = num_heads
  6. self.head_dim = embed_dim // num_heads
  7. # ...(前述QKV投影定义)...
  8. # 输出投影层
  9. self.out_proj = nn.Linear(embed_dim, embed_dim)
  10. def forward(self, x, mask=None):
  11. # ...(前述forward实现)...
  12. return output

2.2 性能优化技巧

  1. 矩阵运算优化:使用einsum简化张量运算
    1. # 使用einsum替代matmul+transpose组合
    2. attn_scores = torch.einsum('bhid,bhjd->bhij', Q, K) / (self.head_dim ** 0.5)
  2. 内存效率提升:通过contiguous()view()避免显式transpose
  3. 数值稳定性:在softmax前添加极小值防止log(0)
    1. attn_scores = attn_scores - attn_scores.max(dim=-1, keepdim=True)[0]
    2. attn_weights = torch.softmax(attn_scores, dim=-1)

三、实际应用与扩展

3.1 模型集成示例

在Transformer编码器层中的集成方式:

  1. class TransformerEncoderLayer(nn.Module):
  2. def __init__(self, embed_dim, num_heads, ff_dim):
  3. super().__init__()
  4. self.self_attn = MultiHeadAttention(embed_dim, num_heads)
  5. self.ffn = nn.Sequential(
  6. nn.Linear(embed_dim, ff_dim),
  7. nn.ReLU(),
  8. nn.Linear(ff_dim, embed_dim)
  9. )
  10. self.norm1 = nn.LayerNorm(embed_dim)
  11. self.norm2 = nn.LayerNorm(embed_dim)
  12. def forward(self, x, mask=None):
  13. # 自注意力子层
  14. attn_output = self.self_attn(x, mask)
  15. x = x + attn_output
  16. x = self.norm1(x)
  17. # 前馈子层
  18. ffn_output = self.ffn(x)
  19. x = x + ffn_output
  20. x = self.norm2(x)
  21. return x

3.2 常见问题处理

  1. 维度不匹配:确保embed_dim % num_heads == 0
  2. 梯度消失:检查缩放因子√d_k是否正确应用
  3. 内存爆炸:长序列处理时考虑局部注意力或稀疏注意力

3.3 扩展方向

  1. 相对位置编码:通过相对距离改进位置表示
  2. 线性注意力:使用核方法降低时间复杂度至O(n)
  3. 自适应注意力跨度:动态调整感受野大小

四、验证与调试指南

4.1 单元测试要点

  1. 验证输出维度是否符合预期:

    1. def test_shape():
    2. batch_size, seq_len, embed_dim = 2, 10, 64
    3. num_heads = 8
    4. x = torch.randn(batch_size, seq_len, embed_dim)
    5. mha = MultiHeadAttention(embed_dim, num_heads)
    6. output = mha(x)
    7. assert output.shape == (batch_size, seq_len, embed_dim)
  2. 检查注意力权重和是否为1:
    1. def test_attention_weights():
    2. # ...(准备输入)...
    3. attn_weights = torch.softmax(attn_scores, dim=-1)
    4. # 检查每个query位置的权重和是否≈1
    5. assert torch.allclose(attn_weights.sum(dim=-1), torch.ones_like(attn_weights.sum(dim=-1)), atol=1e-5)

4.2 可视化调试

使用matplotlib绘制注意力权重热力图:

  1. import matplotlib.pyplot as plt
  2. def plot_attention(attn_weights):
  3. plt.figure(figsize=(10, 8))
  4. plt.imshow(attn_weights[0, 0].cpu().detach().numpy(), cmap='hot')
  5. plt.colorbar()
  6. plt.title("Attention Weights Heatmap")
  7. plt.show()

通过本文的实现,开发者可以深入理解Self-Attention的底层计算逻辑,掌握从数学原理到工程优化的完整过程。实际开发中,建议先在小型数据集上验证模块正确性,再逐步扩展到大规模应用场景。对于生产环境,可考虑使用优化过的库实现(如百度智能云提供的深度学习框架中的优化算子),在保持可解释性的同时提升计算效率。