Transformer整体架构解析:从编码器到解码器的全流程设计

Transformer整体架构解析:从编码器到解码器的全流程设计

Transformer模型自2017年提出以来,凭借其并行计算能力和长序列处理优势,迅速成为自然语言处理(NLP)领域的核心架构。本文将从整体架构出发,系统拆解编码器、解码器、注意力机制等关键组件,结合实现细节与优化策略,为开发者提供可落地的技术指南。

一、Transformer整体架构概览

Transformer采用编码器-解码器(Encoder-Decoder)结构,由N个相同的编码器层和N个相同的解码器层堆叠而成。输入序列通过编码器转换为隐藏表示,再由解码器生成目标序列。其核心设计摒弃了传统的循环神经网络(RNN),转而依赖自注意力机制(Self-Attention)和位置编码(Positional Encoding)实现序列建模。

1.1 架构分层与数据流

  • 输入层:接收词嵌入(Word Embedding)与位置编码的和作为输入。
  • 编码器层:每个编码器层包含多头自注意力子层和前馈神经网络子层,通过残差连接(Residual Connection)和层归一化(Layer Normalization)稳定训练。
  • 解码器层:每个解码器层包含掩码多头自注意力子层、编码器-解码器多头注意力子层和前馈神经网络子层,额外引入掩码机制防止信息泄露。
  • 输出层:通过线性变换和Softmax函数生成概率分布。

二、编码器:自注意力驱动的序列压缩

编码器的核心任务是将输入序列映射为上下文感知的隐藏表示。每个编码器层包含两个子层:

2.1 多头自注意力机制

自注意力机制通过计算序列中每个词与其他词的关联权重,动态捕捉上下文依赖。多头注意力将输入分割为多个子空间,并行计算注意力分数,增强模型表达能力。

  1. import torch
  2. import torch.nn as nn
  3. class MultiHeadAttention(nn.Module):
  4. def __init__(self, embed_dim, num_heads):
  5. super().__init__()
  6. self.embed_dim = embed_dim
  7. self.num_heads = num_heads
  8. self.head_dim = embed_dim // num_heads
  9. # 线性变换矩阵
  10. self.q_linear = nn.Linear(embed_dim, embed_dim)
  11. self.k_linear = nn.Linear(embed_dim, embed_dim)
  12. self.v_linear = nn.Linear(embed_dim, embed_dim)
  13. self.out_linear = nn.Linear(embed_dim, embed_dim)
  14. def forward(self, query, key, value, mask=None):
  15. # 线性变换并分割多头
  16. Q = self.q_linear(query).view(-1, self.num_heads, self.head_dim).transpose(0, 1)
  17. K = self.k_linear(key).view(-1, self.num_heads, self.head_dim).transpose(0, 1)
  18. V = self.v_linear(value).view(-1, self.num_heads, self.head_dim).transpose(0, 1)
  19. # 计算注意力分数
  20. scores = torch.matmul(Q, K.transpose(-2, -1)) / torch.sqrt(torch.tensor(self.head_dim, dtype=torch.float32))
  21. # 应用掩码(可选)
  22. if mask is not None:
  23. scores = scores.masked_fill(mask == 0, float('-inf'))
  24. # 计算权重并聚合值
  25. attention = torch.softmax(scores, dim=-1)
  26. context = torch.matmul(attention, V)
  27. # 合并多头并输出
  28. context = context.transpose(0, 1).contiguous().view(-1, self.embed_dim)
  29. return self.out_linear(context)

关键点

  • 缩放点积注意力:通过1/sqrt(d_k)缩放分数,防止梯度消失。
  • 多头并行:每个头独立学习不同的注意力模式,最后拼接结果。

2.2 前馈神经网络与残差连接

编码器层的第二个子层是全连接前馈网络,包含两个线性变换和一个ReLU激活函数:

  1. FFN(x) = max(0, xW1 + b1)W2 + b2

残差连接和层归一化确保梯度稳定传播:

  1. class EncoderLayer(nn.Module):
  2. def __init__(self, embed_dim, num_heads, ffn_dim):
  3. super().__init__()
  4. self.self_attn = MultiHeadAttention(embed_dim, num_heads)
  5. self.ffn = nn.Sequential(
  6. nn.Linear(embed_dim, ffn_dim),
  7. nn.ReLU(),
  8. nn.Linear(ffn_dim, embed_dim)
  9. )
  10. self.norm1 = nn.LayerNorm(embed_dim)
  11. self.norm2 = nn.LayerNorm(embed_dim)
  12. def forward(self, x, mask=None):
  13. # 自注意力子层
  14. attn_output = self.self_attn(x, x, x, mask)
  15. x = x + attn_output
  16. x = self.norm1(x)
  17. # 前馈子层
  18. ffn_output = self.ffn(x)
  19. x = x + ffn_output
  20. x = self.norm2(x)
  21. return x

三、解码器:自回归生成与条件约束

解码器通过掩码多头注意力和编码器-解码器注意力,实现自回归生成(即逐词预测)。

3.1 掩码多头注意力

解码器在训练时需防止未来信息泄露,通过上三角掩码屏蔽后续位置的注意力:

  1. def create_mask(seq_len):
  2. mask = torch.triu(torch.ones(seq_len, seq_len), diagonal=1)
  3. return mask == 0 # True表示可访问

3.2 编码器-解码器注意力

解码器的第二个子层将编码器的输出作为keyvalue,解码器自身的输出作为query,实现跨模态信息融合。

3.3 解码器层实现

  1. class DecoderLayer(nn.Module):
  2. def __init__(self, embed_dim, num_heads, ffn_dim):
  3. super().__init__()
  4. self.self_attn = MultiHeadAttention(embed_dim, num_heads)
  5. self.enc_dec_attn = MultiHeadAttention(embed_dim, num_heads)
  6. self.ffn = nn.Sequential(
  7. nn.Linear(embed_dim, ffn_dim),
  8. nn.ReLU(),
  9. nn.Linear(ffn_dim, embed_dim)
  10. )
  11. self.norm1 = nn.LayerNorm(embed_dim)
  12. self.norm2 = nn.LayerNorm(embed_dim)
  13. self.norm3 = nn.LayerNorm(embed_dim)
  14. def forward(self, x, enc_output, src_mask=None, tgt_mask=None):
  15. # 掩码自注意力
  16. self_attn_output = self.self_attn(x, x, x, tgt_mask)
  17. x = x + self_attn_output
  18. x = self.norm1(x)
  19. # 编码器-解码器注意力
  20. enc_dec_attn_output = self.enc_dec_attn(x, enc_output, enc_output, src_mask)
  21. x = x + enc_dec_attn_output
  22. x = self.norm2(x)
  23. # 前馈子层
  24. ffn_output = self.ffn(x)
  25. x = x + ffn_output
  26. x = self.norm3(x)
  27. return x

四、关键实现细节与优化策略

4.1 位置编码

由于Transformer缺乏递归结构,需通过正弦/余弦函数注入位置信息:

  1. def positional_encoding(max_len, embed_dim):
  2. position = torch.arange(max_len).unsqueeze(1)
  3. div_term = torch.exp(torch.arange(0, embed_dim, 2) * (-math.log(10000.0) / embed_dim))
  4. pe = torch.zeros(max_len, embed_dim)
  5. pe[:, 0::2] = torch.sin(position * div_term)
  6. pe[:, 1::2] = torch.cos(position * div_term)
  7. return pe

4.2 性能优化建议

  1. 混合精度训练:使用FP16/FP32混合精度加速训练,减少显存占用。
  2. 梯度累积:模拟大batch效果,稳定梯度更新。
  3. 注意力下采样:在长序列场景中,可采用稀疏注意力或局部注意力降低计算复杂度。
  4. 模型并行:将编码器/解码器层分配到不同设备,突破单机显存限制。

五、工程化实践与扩展应用

5.1 部署优化

  • 量化:将模型权重从FP32转为INT8,减少推理延迟。
  • 蒸馏:用大模型指导小模型训练,平衡精度与速度。
  • ONNX转换:导出为标准化格式,兼容多硬件后端。

5.2 扩展架构

  • BERT:仅用编码器,通过掩码语言模型(MLM)预训练。
  • GPT:仅用解码器,通过自回归生成预训练。
  • T5:统一编码器-解码器框架,支持多种NLP任务。

六、总结与未来方向

Transformer架构通过自注意力机制革新了序列建模方式,其模块化设计支持灵活扩展。开发者在实现时需重点关注:

  1. 多头注意力的头数与维度权衡。
  2. 层归一化与残差连接的稳定性作用。
  3. 位置编码的注入方式选择。

未来,Transformer将进一步向高效长序列处理(如Linear Attention)、多模态融合(如Vision Transformer)和轻量化部署方向发展。掌握其核心架构后,开发者可快速适配到机器翻译、文本生成、语音识别等场景,构建高性能AI应用。