PyTorch实现Transformer:核心代码结构与自注意力机制详解
Transformer模型凭借自注意力机制在自然语言处理领域取得了突破性成果,其核心思想通过并行计算捕捉序列中元素间的全局依赖关系。本文将基于PyTorch框架,从代码实现角度深入解析Transformer的完整结构,重点拆解自注意力机制的实现细节,并提供可复用的代码框架与优化建议。
一、Transformer整体架构设计
Transformer的编码器-解码器结构由N个相同层堆叠而成,每层包含两个核心子模块:
- 多头自注意力层:并行计算多个注意力头,捕捉不同维度的依赖关系
- 前馈神经网络层:对每个位置的表示进行非线性变换
import torchimport torch.nn as nnimport mathclass TransformerEncoderLayer(nn.Module):def __init__(self, d_model, nhead, dim_feedforward=2048, dropout=0.1):super().__init__()self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout)self.linear1 = nn.Linear(d_model, dim_feedforward)self.dropout = nn.Dropout(dropout)self.linear2 = nn.Linear(dim_feedforward, d_model)self.norm1 = nn.LayerNorm(d_model)self.norm2 = nn.LayerNorm(d_model)self.dropout1 = nn.Dropout(dropout)self.dropout2 = nn.Dropout(dropout)self.activation = nn.ReLU()def forward(self, src, src_mask=None):# 自注意力子层src2, attn_weights = self.self_attn(src, src, src, attn_mask=src_mask)src = src + self.dropout1(src2)src = self.norm1(src)# 前馈子层src2 = self.linear2(self.dropout(self.activation(self.linear1(src))))src = src + self.dropout2(src2)src = self.norm2(src)return src, attn_weights
关键设计要点:
- 残差连接:通过
src + self.dropout(src2)实现,缓解梯度消失问题 - 层归一化:在每个子层后应用,稳定训练过程
- 参数共享:同一层的不同注意力头共享输入/输出投影矩阵
二、自注意力机制实现解析
自注意力机制的核心是计算查询(Q)、键(K)、值(V)三者间的相似度,其数学表达式为:
[ \text{Attention}(Q,K,V) = \text{softmax}\left(\frac{QK^T}{\sqrt{d_k}}\right)V ]
1. 单头注意力实现
class SingleHeadAttention(nn.Module):def __init__(self, d_model):super().__init__()self.q_proj = nn.Linear(d_model, d_model)self.k_proj = nn.Linear(d_model, d_model)self.v_proj = nn.Linear(d_model, d_model)self.out_proj = nn.Linear(d_model, d_model)self.scale = math.sqrt(d_model) # 缩放因子def forward(self, x):Q = self.q_proj(x) # (batch, seq_len, d_model)K = self.k_proj(x)V = self.v_proj(x)# 计算注意力分数scores = torch.bmm(Q, K.transpose(1,2)) / self.scale # (batch, seq_len, seq_len)attn_weights = torch.softmax(scores, dim=-1)# 加权求和output = torch.bmm(attn_weights, V) # (batch, seq_len, d_model)return self.out_proj(output), attn_weights
2. 多头注意力优化实现
实际实现中采用矩阵并行计算优化性能:
class MultiHeadAttention(nn.Module):def __init__(self, d_model, nhead):super().__init__()assert d_model % nhead == 0self.d_model = d_modelself.nhead = nheadself.d_head = d_model // nhead# 共享参数的投影矩阵self.in_proj = nn.Linear(d_model, 3 * d_model)self.out_proj = nn.Linear(d_model, d_model)self.scale = math.sqrt(self.d_head)def forward(self, x):batch_size, seq_len, _ = x.size()# 线性投影生成QKVqkv = self.in_proj(x) # (batch, seq_len, 3*d_model)qkv = qkv.view(batch_size, seq_len, 3, self.nhead, self.d_head)qkv = qkv.permute(2, 0, 3, 1, 4) # [3, batch, nhead, seq_len, d_head]Q, K, V = qkv[0], qkv[1], qkv[2]# 计算注意力attn_scores = torch.einsum('bhld,bhsd->bhls', Q, K) / self.scaleattn_weights = torch.softmax(attn_scores, dim=-1)# 加权求和output = torch.einsum('bhls,bhsd->bhld', attn_weights, V)output = output.permute(0, 2, 1, 3).contiguous() # [batch, seq_len, nhead, d_head]output = output.view(batch_size, seq_len, -1)return self.out_proj(output), attn_weights
关键优化技术:
- 矩阵分块:通过
einsum操作实现高效矩阵乘法 - 内存连续:使用
contiguous()保证张量内存布局 - 并行计算:同时处理所有注意力头
三、位置编码实现方案
Transformer通过位置编码注入序列顺序信息,常见实现包括:
1. 正弦位置编码
class PositionalEncoding(nn.Module):def __init__(self, d_model, max_len=5000):super().__init__()position = torch.arange(max_len).unsqueeze(1)div_term = torch.exp(torch.arange(0, d_model, 2) * (-math.log(10000.0) / d_model))pe = torch.zeros(max_len, d_model)pe[:, 0::2] = torch.sin(position * div_term)pe[:, 1::2] = torch.cos(position * div_term)self.register_buffer('pe', pe)def forward(self, x):# x: (batch, seq_len, d_model)x = x + self.pe[:x.size(1)].unsqueeze(0)return x
2. 可学习位置编码
class LearnablePositionalEncoding(nn.Module):def __init__(self, d_model, max_len=5000):super().__init__()self.pe = nn.Parameter(torch.zeros(max_len, d_model))nn.init.normal_(self.pe, mean=0, std=0.02)def forward(self, x):return x + self.pe[:x.size(1)].unsqueeze(0)
选择建议:
- 正弦编码:适用于任意长度序列,无需训练参数
- 可学习编码:在小规模数据集上可能获得更好效果,但需要固定最大长度
四、完整Transformer实现框架
class TransformerModel(nn.Module):def __init__(self, ntoken, d_model=512, nhead=8, num_layers=6):super().__init__()self.d_model = d_modelself.embedding = nn.Embedding(ntoken, d_model)self.pos_encoder = PositionalEncoding(d_model)encoder_layers = [TransformerEncoderLayer(d_model, nhead)for _ in range(num_layers)]self.encoder = nn.Sequential(*encoder_layers)self.decoder = nn.Linear(d_model, ntoken)def forward(self, src, src_mask=None):# src: (seq_len, batch)src = self.embedding(src) * math.sqrt(self.d_model) # (seq_len, batch, d_model)src = src.permute(1, 0, 2) # (batch, seq_len, d_model)src = self.pos_encoder(src)memory = self.encoder(src, src_mask=src_mask)output = self.decoder(memory)return output
五、性能优化实践
-
混合精度训练:
scaler = torch.cuda.amp.GradScaler()with torch.cuda.amp.autocast():outputs = model(inputs)loss = criterion(outputs, targets)scaler.scale(loss).backward()scaler.step(optimizer)scaler.update()
-
注意力掩码实现:
def generate_square_subsequent_mask(sz):mask = (torch.triu(torch.ones(sz, sz)) == 1).transpose(0, 1)mask = mask.float().masked_fill(mask == 0, float('-inf')).masked_fill(mask == 1, float(0.0))return mask
-
批处理优化:
- 固定序列长度或使用填充掩码
- 采用梯度累积处理大batch
六、典型应用场景建议
- 文本分类:取最后一个位置的输出作为序列表示
- 序列标注:对每个位置的输出进行分类
- 文本生成:结合解码器结构实现自回归生成
实际部署时需注意:
- 输入长度限制(通常512/1024)
- 显存占用优化(FP16混合精度)
- 推理速度优化(量化/蒸馏)
通过模块化设计,开发者可以基于上述代码框架快速构建适用于不同任务的Transformer模型,并根据具体需求调整模型深度、注意力头数等超参数。