PyTorch实现Transformer:核心代码结构与自注意力机制详解

PyTorch实现Transformer:核心代码结构与自注意力机制详解

Transformer模型凭借自注意力机制在自然语言处理领域取得了突破性成果,其核心思想通过并行计算捕捉序列中元素间的全局依赖关系。本文将基于PyTorch框架,从代码实现角度深入解析Transformer的完整结构,重点拆解自注意力机制的实现细节,并提供可复用的代码框架与优化建议。

一、Transformer整体架构设计

Transformer的编码器-解码器结构由N个相同层堆叠而成,每层包含两个核心子模块:

  1. 多头自注意力层:并行计算多个注意力头,捕捉不同维度的依赖关系
  2. 前馈神经网络层:对每个位置的表示进行非线性变换
  1. import torch
  2. import torch.nn as nn
  3. import math
  4. class TransformerEncoderLayer(nn.Module):
  5. def __init__(self, d_model, nhead, dim_feedforward=2048, dropout=0.1):
  6. super().__init__()
  7. self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout)
  8. self.linear1 = nn.Linear(d_model, dim_feedforward)
  9. self.dropout = nn.Dropout(dropout)
  10. self.linear2 = nn.Linear(dim_feedforward, d_model)
  11. self.norm1 = nn.LayerNorm(d_model)
  12. self.norm2 = nn.LayerNorm(d_model)
  13. self.dropout1 = nn.Dropout(dropout)
  14. self.dropout2 = nn.Dropout(dropout)
  15. self.activation = nn.ReLU()
  16. def forward(self, src, src_mask=None):
  17. # 自注意力子层
  18. src2, attn_weights = self.self_attn(src, src, src, attn_mask=src_mask)
  19. src = src + self.dropout1(src2)
  20. src = self.norm1(src)
  21. # 前馈子层
  22. src2 = self.linear2(self.dropout(self.activation(self.linear1(src))))
  23. src = src + self.dropout2(src2)
  24. src = self.norm2(src)
  25. return src, attn_weights

关键设计要点:

  1. 残差连接:通过src + self.dropout(src2)实现,缓解梯度消失问题
  2. 层归一化:在每个子层后应用,稳定训练过程
  3. 参数共享:同一层的不同注意力头共享输入/输出投影矩阵

二、自注意力机制实现解析

自注意力机制的核心是计算查询(Q)、键(K)、值(V)三者间的相似度,其数学表达式为:

[ \text{Attention}(Q,K,V) = \text{softmax}\left(\frac{QK^T}{\sqrt{d_k}}\right)V ]

1. 单头注意力实现

  1. class SingleHeadAttention(nn.Module):
  2. def __init__(self, d_model):
  3. super().__init__()
  4. self.q_proj = nn.Linear(d_model, d_model)
  5. self.k_proj = nn.Linear(d_model, d_model)
  6. self.v_proj = nn.Linear(d_model, d_model)
  7. self.out_proj = nn.Linear(d_model, d_model)
  8. self.scale = math.sqrt(d_model) # 缩放因子
  9. def forward(self, x):
  10. Q = self.q_proj(x) # (batch, seq_len, d_model)
  11. K = self.k_proj(x)
  12. V = self.v_proj(x)
  13. # 计算注意力分数
  14. scores = torch.bmm(Q, K.transpose(1,2)) / self.scale # (batch, seq_len, seq_len)
  15. attn_weights = torch.softmax(scores, dim=-1)
  16. # 加权求和
  17. output = torch.bmm(attn_weights, V) # (batch, seq_len, d_model)
  18. return self.out_proj(output), attn_weights

2. 多头注意力优化实现

实际实现中采用矩阵并行计算优化性能:

  1. class MultiHeadAttention(nn.Module):
  2. def __init__(self, d_model, nhead):
  3. super().__init__()
  4. assert d_model % nhead == 0
  5. self.d_model = d_model
  6. self.nhead = nhead
  7. self.d_head = d_model // nhead
  8. # 共享参数的投影矩阵
  9. self.in_proj = nn.Linear(d_model, 3 * d_model)
  10. self.out_proj = nn.Linear(d_model, d_model)
  11. self.scale = math.sqrt(self.d_head)
  12. def forward(self, x):
  13. batch_size, seq_len, _ = x.size()
  14. # 线性投影生成QKV
  15. qkv = self.in_proj(x) # (batch, seq_len, 3*d_model)
  16. qkv = qkv.view(batch_size, seq_len, 3, self.nhead, self.d_head)
  17. qkv = qkv.permute(2, 0, 3, 1, 4) # [3, batch, nhead, seq_len, d_head]
  18. Q, K, V = qkv[0], qkv[1], qkv[2]
  19. # 计算注意力
  20. attn_scores = torch.einsum('bhld,bhsd->bhls', Q, K) / self.scale
  21. attn_weights = torch.softmax(attn_scores, dim=-1)
  22. # 加权求和
  23. output = torch.einsum('bhls,bhsd->bhld', attn_weights, V)
  24. output = output.permute(0, 2, 1, 3).contiguous() # [batch, seq_len, nhead, d_head]
  25. output = output.view(batch_size, seq_len, -1)
  26. return self.out_proj(output), attn_weights

关键优化技术:

  1. 矩阵分块:通过einsum操作实现高效矩阵乘法
  2. 内存连续:使用contiguous()保证张量内存布局
  3. 并行计算:同时处理所有注意力头

三、位置编码实现方案

Transformer通过位置编码注入序列顺序信息,常见实现包括:

1. 正弦位置编码

  1. class PositionalEncoding(nn.Module):
  2. def __init__(self, d_model, max_len=5000):
  3. super().__init__()
  4. position = torch.arange(max_len).unsqueeze(1)
  5. div_term = torch.exp(torch.arange(0, d_model, 2) * (-math.log(10000.0) / d_model))
  6. pe = torch.zeros(max_len, d_model)
  7. pe[:, 0::2] = torch.sin(position * div_term)
  8. pe[:, 1::2] = torch.cos(position * div_term)
  9. self.register_buffer('pe', pe)
  10. def forward(self, x):
  11. # x: (batch, seq_len, d_model)
  12. x = x + self.pe[:x.size(1)].unsqueeze(0)
  13. return x

2. 可学习位置编码

  1. class LearnablePositionalEncoding(nn.Module):
  2. def __init__(self, d_model, max_len=5000):
  3. super().__init__()
  4. self.pe = nn.Parameter(torch.zeros(max_len, d_model))
  5. nn.init.normal_(self.pe, mean=0, std=0.02)
  6. def forward(self, x):
  7. return x + self.pe[:x.size(1)].unsqueeze(0)

选择建议:

  • 正弦编码:适用于任意长度序列,无需训练参数
  • 可学习编码:在小规模数据集上可能获得更好效果,但需要固定最大长度

四、完整Transformer实现框架

  1. class TransformerModel(nn.Module):
  2. def __init__(self, ntoken, d_model=512, nhead=8, num_layers=6):
  3. super().__init__()
  4. self.d_model = d_model
  5. self.embedding = nn.Embedding(ntoken, d_model)
  6. self.pos_encoder = PositionalEncoding(d_model)
  7. encoder_layers = [
  8. TransformerEncoderLayer(d_model, nhead)
  9. for _ in range(num_layers)
  10. ]
  11. self.encoder = nn.Sequential(*encoder_layers)
  12. self.decoder = nn.Linear(d_model, ntoken)
  13. def forward(self, src, src_mask=None):
  14. # src: (seq_len, batch)
  15. src = self.embedding(src) * math.sqrt(self.d_model) # (seq_len, batch, d_model)
  16. src = src.permute(1, 0, 2) # (batch, seq_len, d_model)
  17. src = self.pos_encoder(src)
  18. memory = self.encoder(src, src_mask=src_mask)
  19. output = self.decoder(memory)
  20. return output

五、性能优化实践

  1. 混合精度训练

    1. scaler = torch.cuda.amp.GradScaler()
    2. with torch.cuda.amp.autocast():
    3. outputs = model(inputs)
    4. loss = criterion(outputs, targets)
    5. scaler.scale(loss).backward()
    6. scaler.step(optimizer)
    7. scaler.update()
  2. 注意力掩码实现

    1. def generate_square_subsequent_mask(sz):
    2. mask = (torch.triu(torch.ones(sz, sz)) == 1).transpose(0, 1)
    3. mask = mask.float().masked_fill(mask == 0, float('-inf')).masked_fill(mask == 1, float(0.0))
    4. return mask
  3. 批处理优化

  • 固定序列长度或使用填充掩码
  • 采用梯度累积处理大batch

六、典型应用场景建议

  1. 文本分类:取最后一个位置的输出作为序列表示
  2. 序列标注:对每个位置的输出进行分类
  3. 文本生成:结合解码器结构实现自回归生成

实际部署时需注意:

  • 输入长度限制(通常512/1024)
  • 显存占用优化(FP16混合精度)
  • 推理速度优化(量化/蒸馏)

通过模块化设计,开发者可以基于上述代码框架快速构建适用于不同任务的Transformer模型,并根据具体需求调整模型深度、注意力头数等超参数。