Transformer完整代码实现:从理论到实践的深度解析

Transformer完整代码实现:提供Transformer模型的完整代码示例及解释

引言

Transformer模型自2017年《Attention is All You Need》论文提出以来,已成为自然语言处理(NLP)领域的基石架构。其核心创新在于完全摒弃循环神经网络(RNN)和卷积神经网络(CNN),仅通过自注意力机制(Self-Attention)实现序列建模。本文将提供完整的Transformer实现代码(基于PyTorch),并深入解释每个组件的设计原理与实现细节,帮助开发者从理论到实践全面掌握这一经典模型。

代码实现:Transformer模型完整示例

1. 环境准备与依赖安装

首先确保环境配置正确,推荐使用Python 3.8+和PyTorch 1.10+。通过以下命令安装依赖:

  1. pip install torch numpy

2. 核心组件实现

(1)多头注意力机制(Multi-Head Attention)

  1. import torch
  2. import torch.nn as nn
  3. import torch.nn.functional as F
  4. class MultiHeadAttention(nn.Module):
  5. def __init__(self, embed_dim, num_heads):
  6. super().__init__()
  7. self.embed_dim = embed_dim
  8. self.num_heads = num_heads
  9. self.head_dim = embed_dim // num_heads
  10. assert self.head_dim * num_heads == embed_dim, "embed_dim must be divisible by num_heads"
  11. self.q_linear = nn.Linear(embed_dim, embed_dim)
  12. self.k_linear = nn.Linear(embed_dim, embed_dim)
  13. self.v_linear = nn.Linear(embed_dim, embed_dim)
  14. self.out_linear = nn.Linear(embed_dim, embed_dim)
  15. def forward(self, query, key, value, mask=None):
  16. batch_size = query.size(0)
  17. # 线性变换并分割多头
  18. Q = self.q_linear(query).view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2)
  19. K = self.k_linear(key).view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2)
  20. V = self.v_linear(value).view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2)
  21. # 计算缩放点积注意力
  22. scores = torch.matmul(Q, K.transpose(-2, -1)) / torch.sqrt(torch.tensor(self.head_dim, dtype=torch.float32))
  23. if mask is not None:
  24. scores = scores.masked_fill(mask == 0, float("-1e20"))
  25. attention = F.softmax(scores, dim=-1)
  26. out = torch.matmul(attention, V)
  27. # 合并多头并输出
  28. out = out.transpose(1, 2).contiguous().view(batch_size, -1, self.embed_dim)
  29. return self.out_linear(out)

关键点解释

  • 多头分割:将输入嵌入维度embed_dim均分为num_heads个头,每个头独立计算注意力,增强模型对不同位置关系的捕捉能力。
  • 缩放点积:通过1/sqrt(d_k)缩放避免点积结果过大导致梯度消失。
  • 掩码机制:可选的mask参数用于屏蔽无效位置(如填充位或未来信息)。

(2)位置编码(Positional Encoding)

  1. class PositionalEncoding(nn.Module):
  2. def __init__(self, embed_dim, max_len=5000):
  3. super().__init__()
  4. position = torch.arange(max_len).unsqueeze(1)
  5. div_term = torch.exp(torch.arange(0, embed_dim, 2).float() * (-math.log(10000.0) / embed_dim))
  6. pe = torch.zeros(max_len, embed_dim)
  7. pe[:, 0::2] = torch.sin(position * div_term)
  8. pe[:, 1::2] = torch.cos(position * div_term)
  9. self.register_buffer("pe", pe)
  10. def forward(self, x):
  11. # x shape: (batch_size, seq_len, embed_dim)
  12. x = x + self.pe[:x.size(1)]
  13. return x

设计原理

  • 使用正弦和余弦函数的不同频率生成位置编码,使模型能通过相对位置推断序列顺序。
  • 直接与输入嵌入相加,而非拼接,保持维度一致。

(3)前馈网络(Feed Forward Network)

  1. class PositionwiseFeedForward(nn.Module):
  2. def __init__(self, embed_dim, hidden_dim):
  3. super().__init__()
  4. self.fc1 = nn.Linear(embed_dim, hidden_dim)
  5. self.fc2 = nn.Linear(hidden_dim, embed_dim)
  6. def forward(self, x):
  7. return self.fc2(F.relu(self.fc1(x)))

结构特点

  • 两层全连接网络,中间使用ReLU激活。
  • 隐藏层维度通常大于输入维度(如embed_dim=512hidden_dim=2048),增强非线性表达能力。

3. 完整Transformer编码器实现

  1. class TransformerEncoderLayer(nn.Module):
  2. def __init__(self, embed_dim, num_heads, hidden_dim, dropout=0.1):
  3. super().__init__()
  4. self.self_attn = MultiHeadAttention(embed_dim, num_heads)
  5. self.feed_forward = PositionwiseFeedForward(embed_dim, hidden_dim)
  6. self.norm1 = nn.LayerNorm(embed_dim)
  7. self.norm2 = nn.LayerNorm(embed_dim)
  8. self.dropout1 = nn.Dropout(dropout)
  9. self.dropout2 = nn.Dropout(dropout)
  10. def forward(self, x, src_mask=None):
  11. # 自注意力子层
  12. attn_output = self.self_attn(x, x, x, src_mask)
  13. x = x + self.dropout1(attn_output)
  14. x = self.norm1(x)
  15. # 前馈子层
  16. ff_output = self.feed_forward(x)
  17. x = x + self.dropout2(ff_output)
  18. x = self.norm2(x)
  19. return x
  20. class TransformerEncoder(nn.Module):
  21. def __init__(self, num_layers, embed_dim, num_heads, hidden_dim, dropout=0.1):
  22. super().__init__()
  23. self.layers = nn.ModuleList(
  24. [TransformerEncoderLayer(embed_dim, num_heads, hidden_dim, dropout) for _ in range(num_layers)]
  25. )
  26. self.norm = nn.LayerNorm(embed_dim)
  27. def forward(self, x, src_mask=None):
  28. for layer in self.layers:
  29. x = layer(x, src_mask)
  30. return self.norm(x)

关键设计

  • 残差连接:每个子层的输出与输入相加(x + sublayer(x)),缓解梯度消失。
  • 层归一化:在残差连接后应用,稳定训练过程。
  • 堆叠多层:通过num_layers控制模型深度,典型值为6层。

训练与优化建议

  1. 学习率调度:使用Noam调度器动态调整学习率:

    1. class NoamOpt:
    2. def __init__(self, model_size, factor, warmup_steps, optimizer):
    3. self.optimizer = optimizer
    4. self.warmup_steps = warmup_steps
    5. self.factor = factor
    6. self.model_size = model_size
    7. self._step = 0
    8. def step(self):
    9. self._step += 1
    10. rate = self.rate()
    11. for p in self.optimizer.param_groups:
    12. p["lr"] = rate
    13. self.optimizer.step()
    14. def rate(self):
    15. return self.factor * (self.model_size ** (-0.5) * min(self._step ** (-0.5), self._step * self.warmup_steps ** (-1.5)))
  2. 标签平滑:在分类任务中应用标签平滑(如epsilon=0.1)提升泛化性:

    1. def label_smoothing(y, epsilon, num_classes):
    2. y_smooth = torch.full_like(y, epsilon / (num_classes - 1))
    3. y_smooth.scatter_(1, y.unsqueeze(1), 1 - epsilon)
    4. return y_smooth
  3. 批量处理:使用DataLoader分批加载数据,设置batch_size=64,并启用pin_memory=True加速GPU传输。

总结与扩展

本文提供了Transformer模型的完整PyTorch实现,覆盖多头注意力、位置编码、残差连接等核心组件。开发者可通过调整以下参数定制模型:

  • embed_dim:嵌入维度(通常512/1024)
  • num_heads:注意力头数(8/16)
  • num_layers:编码器层数(6/12)
  • hidden_dim:前馈网络隐藏层维度(2048/4096)

进一步优化方向包括:

  1. 集成torch.compile加速训练
  2. 尝试混合精度训练(fp16
  3. 扩展为解码器结构实现完整Seq2Seq模型

通过理解代码实现细节,开发者能更高效地调试模型、改进架构,并应用于机器翻译、文本生成等实际任务。