NLP学习笔记(六):Transformer架构与核心机制解析

一、Transformer架构诞生的背景与意义

传统NLP模型(如RNN、LSTM)存在两大核心缺陷:序列依赖导致的并行计算困难长距离依赖捕捉能力不足。以LSTM为例,其时间复杂度为O(n²),当处理长文本时(如超过1000个token),计算效率急剧下降。而Transformer通过自注意力机制(Self-Attention)彻底解决了这一问题,将时间复杂度降至O(n),同时通过多头注意力增强了对不同位置信息的捕捉能力。

2017年《Attention Is All You Need》论文提出的Transformer架构,不仅成为BERT、GPT等预训练模型的基石,更推动了NLP领域从”序列建模”向”注意力建模”的范式转变。其核心思想可概括为:用注意力权重替代循环结构,通过并行计算实现高效信息融合

二、自注意力机制:从数学原理到代码实现

1. 注意力分数计算

自注意力机制的核心是计算查询(Q)、键(K)、值(V)三者间的相似度。以单头注意力为例,其计算流程如下:

  1. import torch
  2. import torch.nn.functional as F
  3. def scaled_dot_product_attention(Q, K, V, mask=None):
  4. # Q/K/V shape: (batch_size, num_heads, seq_len, d_k)
  5. d_k = Q.size(-1)
  6. scores = torch.matmul(Q, K.transpose(-2, -1)) / torch.sqrt(torch.tensor(d_k))
  7. if mask is not None:
  8. scores = scores.masked_fill(mask == 0, float('-inf'))
  9. attention_weights = F.softmax(scores, dim=-1)
  10. output = torch.matmul(attention_weights, V)
  11. return output

关键点

  • 缩放因子1/√d_k防止点积结果过大导致softmax梯度消失
  • 可选mask机制用于屏蔽无效位置(如解码器的未来信息)

2. 多头注意力:并行信息捕捉

通过将Q/K/V拆分为多个子空间(如8头),每个头独立计算注意力,最后拼接结果:

  1. class MultiHeadAttention(torch.nn.Module):
  2. def __init__(self, d_model, num_heads):
  3. super().__init__()
  4. self.d_model = d_model
  5. self.num_heads = num_heads
  6. self.d_k = d_model // num_heads
  7. self.W_q = torch.nn.Linear(d_model, d_model)
  8. self.W_k = torch.nn.Linear(d_model, d_model)
  9. self.W_v = torch.nn.Linear(d_model, d_model)
  10. self.W_o = torch.nn.Linear(d_model, d_model)
  11. def forward(self, x, mask=None):
  12. batch_size = x.size(0)
  13. # 线性变换并拆分多头
  14. Q = self.W_q(x).view(batch_size, -1, self.num_heads, self.d_k).transpose(1, 2)
  15. K = self.W_k(x).view(batch_size, -1, self.num_heads, self.d_k).transpose(1, 2)
  16. V = self.W_v(x).view(batch_size, -1, self.num_heads, self.d_k).transpose(1, 2)
  17. # 计算多头注意力
  18. attn_output = scaled_dot_product_attention(Q, K, V, mask)
  19. # 拼接并输出
  20. attn_output = attn_output.transpose(1, 2).contiguous()
  21. attn_output = attn_output.view(batch_size, -1, self.d_model)
  22. return self.W_o(attn_output)

工程价值:多头机制使模型能同时关注不同语义维度(如语法、语义、指代),实验表明8头注意力在多数任务中达到性能饱和。

三、位置编码:弥补自注意力的位置缺陷

自注意力机制本身是位置无关的,需通过位置编码(Positional Encoding)注入序列顺序信息。论文采用正弦/余弦函数生成固定位置编码:

  1. def positional_encoding(max_len, d_model):
  2. position = torch.arange(max_len).unsqueeze(1)
  3. div_term = torch.exp(torch.arange(0, d_model, 2) * -(math.log(10000.0) / d_model))
  4. pe = torch.zeros(max_len, d_model)
  5. pe[:, 0::2] = torch.sin(position * div_term)
  6. pe[:, 1::2] = torch.cos(position * div_term)
  7. return pe.unsqueeze(0) # shape: (1, max_len, d_model)

设计原理

  • 奇偶维度分别使用sin/cos函数,使不同位置编码具有唯一性
  • 相对位置通过线性组合保留,满足平移不变性需求

四、编码器-解码器架构详解

1. 编码器堆叠

6层编码器结构包含两大子层:

  • 多头注意力层:捕捉输入序列内部关系
  • 前馈网络层:对每个位置独立进行非线性变换

    1. class EncoderLayer(torch.nn.Module):
    2. def __init__(self, d_model, num_heads, d_ff):
    3. super().__init__()
    4. self.self_attn = MultiHeadAttention(d_model, num_heads)
    5. self.feed_forward = torch.nn.Sequential(
    6. torch.nn.Linear(d_model, d_ff),
    7. torch.nn.ReLU(),
    8. torch.nn.Linear(d_ff, d_model)
    9. )
    10. self.norm1 = torch.nn.LayerNorm(d_model)
    11. self.norm2 = torch.nn.LayerNorm(d_model)
    12. def forward(self, x):
    13. # 自注意力子层
    14. attn_output = self.self_attn(x)
    15. x = x + attn_output
    16. x = self.norm1(x)
    17. # 前馈子层
    18. ff_output = self.feed_forward(x)
    19. x = x + ff_output
    20. x = self.norm2(x)
    21. return x

    残差连接与层归一化:解决深层网络梯度消失问题,加速训练收敛。

2. 解码器特殊设计

解码器包含两类注意力:

  • 掩码自注意力:防止解码时看到未来信息
  • 编码器-解码器注意力:Q来自解码器,K/V来自编码器输出

    1. class DecoderLayer(torch.nn.Module):
    2. def __init__(self, d_model, num_heads, d_ff):
    3. super().__init__()
    4. self.self_attn = MultiHeadAttention(d_model, num_heads)
    5. self.cross_attn = MultiHeadAttention(d_model, num_heads)
    6. self.feed_forward = torch.nn.Sequential(...)
    7. # 省略归一化层定义...
    8. def forward(self, x, encoder_output, src_mask, tgt_mask):
    9. # 掩码自注意力
    10. attn_output = self.self_attn(x, mask=tgt_mask)
    11. x = x + attn_output
    12. x = self.norm1(x)
    13. # 编码器-解码器注意力
    14. cross_attn = self.cross_attn(x, encoder_output, mask=src_mask)
    15. x = x + cross_attn
    16. x = self.norm2(x)
    17. # 前馈子层...
    18. return x

五、Transformer的工程化实践建议

  1. 超参数选择

    • 模型维度d_model通常设为512/768
    • 头数num_heads建议为8/16
    • 层数num_layers在6-12层间平衡性能与效率
  2. 训练技巧

    • 使用Adam优化器(β1=0.9, β2=0.98)
    • 学习率采用线性预热+余弦衰减策略
    • 标签平滑(label smoothing=0.1)提升泛化性
  3. 部署优化

    • 通过量化(INT8)减少模型体积
    • 采用知识蒸馏训练小模型(如DistilBERT)
    • 使用TensorRT加速推理

六、Transformer的演进方向

当前研究热点包括:

  • 稀疏注意力:降低计算复杂度(如Longformer)
  • 线性化注意力:通过核方法近似计算(如Performer)
  • 模块化设计:解耦注意力与前馈网络(如GLU变体)

Transformer架构的成功证明,通过设计合理的归纳偏置,可以构建出兼具表达能力和计算效率的通用模型。对于开发者而言,深入理解其机制不仅能优化现有模型,更能为创新架构设计提供灵感。