Transformer完整代码实现:提供Transformer模型的完整代码示例及解释
引言
Transformer模型自2017年《Attention is All You Need》论文提出以来,已成为自然语言处理(NLP)领域的基石架构。其核心创新在于完全摒弃循环神经网络(RNN)和卷积神经网络(CNN),仅通过自注意力机制(Self-Attention)实现序列建模。本文将提供完整的Transformer实现代码(基于PyTorch),并深入解释每个组件的设计原理与实现细节,帮助开发者从理论到实践全面掌握这一经典模型。
代码实现:Transformer模型完整示例
1. 环境准备与依赖安装
首先确保环境配置正确,推荐使用Python 3.8+和PyTorch 1.10+。通过以下命令安装依赖:
pip install torch numpy
2. 核心组件实现
(1)多头注意力机制(Multi-Head Attention)
import torchimport torch.nn as nnimport torch.nn.functional as Fclass MultiHeadAttention(nn.Module):def __init__(self, embed_dim, num_heads):super().__init__()self.embed_dim = embed_dimself.num_heads = num_headsself.head_dim = embed_dim // num_headsassert self.head_dim * num_heads == embed_dim, "embed_dim must be divisible by num_heads"self.q_linear = nn.Linear(embed_dim, embed_dim)self.k_linear = nn.Linear(embed_dim, embed_dim)self.v_linear = nn.Linear(embed_dim, embed_dim)self.out_linear = nn.Linear(embed_dim, embed_dim)def forward(self, query, key, value, mask=None):batch_size = query.size(0)# 线性变换并分割多头Q = self.q_linear(query).view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2)K = self.k_linear(key).view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2)V = self.v_linear(value).view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2)# 计算缩放点积注意力scores = torch.matmul(Q, K.transpose(-2, -1)) / torch.sqrt(torch.tensor(self.head_dim, dtype=torch.float32))if mask is not None:scores = scores.masked_fill(mask == 0, float("-1e20"))attention = F.softmax(scores, dim=-1)out = torch.matmul(attention, V)# 合并多头并输出out = out.transpose(1, 2).contiguous().view(batch_size, -1, self.embed_dim)return self.out_linear(out)
关键点解释:
- 多头分割:将输入嵌入维度
embed_dim均分为num_heads个头,每个头独立计算注意力,增强模型对不同位置关系的捕捉能力。 - 缩放点积:通过
1/sqrt(d_k)缩放避免点积结果过大导致梯度消失。 - 掩码机制:可选的
mask参数用于屏蔽无效位置(如填充位或未来信息)。
(2)位置编码(Positional Encoding)
class PositionalEncoding(nn.Module):def __init__(self, embed_dim, max_len=5000):super().__init__()position = torch.arange(max_len).unsqueeze(1)div_term = torch.exp(torch.arange(0, embed_dim, 2).float() * (-math.log(10000.0) / embed_dim))pe = torch.zeros(max_len, embed_dim)pe[:, 0::2] = torch.sin(position * div_term)pe[:, 1::2] = torch.cos(position * div_term)self.register_buffer("pe", pe)def forward(self, x):# x shape: (batch_size, seq_len, embed_dim)x = x + self.pe[:x.size(1)]return x
设计原理:
- 使用正弦和余弦函数的不同频率生成位置编码,使模型能通过相对位置推断序列顺序。
- 直接与输入嵌入相加,而非拼接,保持维度一致。
(3)前馈网络(Feed Forward Network)
class PositionwiseFeedForward(nn.Module):def __init__(self, embed_dim, hidden_dim):super().__init__()self.fc1 = nn.Linear(embed_dim, hidden_dim)self.fc2 = nn.Linear(hidden_dim, embed_dim)def forward(self, x):return self.fc2(F.relu(self.fc1(x)))
结构特点:
- 两层全连接网络,中间使用ReLU激活。
- 隐藏层维度通常大于输入维度(如
embed_dim=512,hidden_dim=2048),增强非线性表达能力。
3. 完整Transformer编码器实现
class TransformerEncoderLayer(nn.Module):def __init__(self, embed_dim, num_heads, hidden_dim, dropout=0.1):super().__init__()self.self_attn = MultiHeadAttention(embed_dim, num_heads)self.feed_forward = PositionwiseFeedForward(embed_dim, hidden_dim)self.norm1 = nn.LayerNorm(embed_dim)self.norm2 = nn.LayerNorm(embed_dim)self.dropout1 = nn.Dropout(dropout)self.dropout2 = nn.Dropout(dropout)def forward(self, x, src_mask=None):# 自注意力子层attn_output = self.self_attn(x, x, x, src_mask)x = x + self.dropout1(attn_output)x = self.norm1(x)# 前馈子层ff_output = self.feed_forward(x)x = x + self.dropout2(ff_output)x = self.norm2(x)return xclass TransformerEncoder(nn.Module):def __init__(self, num_layers, embed_dim, num_heads, hidden_dim, dropout=0.1):super().__init__()self.layers = nn.ModuleList([TransformerEncoderLayer(embed_dim, num_heads, hidden_dim, dropout) for _ in range(num_layers)])self.norm = nn.LayerNorm(embed_dim)def forward(self, x, src_mask=None):for layer in self.layers:x = layer(x, src_mask)return self.norm(x)
关键设计:
- 残差连接:每个子层的输出与输入相加(
x + sublayer(x)),缓解梯度消失。 - 层归一化:在残差连接后应用,稳定训练过程。
- 堆叠多层:通过
num_layers控制模型深度,典型值为6层。
训练与优化建议
-
学习率调度:使用
Noam调度器动态调整学习率:class NoamOpt:def __init__(self, model_size, factor, warmup_steps, optimizer):self.optimizer = optimizerself.warmup_steps = warmup_stepsself.factor = factorself.model_size = model_sizeself._step = 0def step(self):self._step += 1rate = self.rate()for p in self.optimizer.param_groups:p["lr"] = rateself.optimizer.step()def rate(self):return self.factor * (self.model_size ** (-0.5) * min(self._step ** (-0.5), self._step * self.warmup_steps ** (-1.5)))
-
标签平滑:在分类任务中应用标签平滑(如
epsilon=0.1)提升泛化性:def label_smoothing(y, epsilon, num_classes):y_smooth = torch.full_like(y, epsilon / (num_classes - 1))y_smooth.scatter_(1, y.unsqueeze(1), 1 - epsilon)return y_smooth
-
批量处理:使用
DataLoader分批加载数据,设置batch_size=64,并启用pin_memory=True加速GPU传输。
总结与扩展
本文提供了Transformer模型的完整PyTorch实现,覆盖多头注意力、位置编码、残差连接等核心组件。开发者可通过调整以下参数定制模型:
embed_dim:嵌入维度(通常512/1024)num_heads:注意力头数(8/16)num_layers:编码器层数(6/12)hidden_dim:前馈网络隐藏层维度(2048/4096)
进一步优化方向包括:
- 集成
torch.compile加速训练 - 尝试混合精度训练(
fp16) - 扩展为解码器结构实现完整Seq2Seq模型
通过理解代码实现细节,开发者能更高效地调试模型、改进架构,并应用于机器翻译、文本生成等实际任务。