从代码视角拆解:Self-Attention与Cross-Attention的核心差异

一、核心概念:两种注意力机制的定位

Self-Attention和Cross-Attention均属于Transformer架构的核心组件,但目标不同:

  • Self-Attention:处理单一序列内部元素的关系,例如将句子中的每个词与所有其他词关联,捕捉上下文依赖。典型应用包括文本生成、分类任务中的序列编码。
  • Cross-Attention:处理两个不同序列之间的关系,例如将图像特征与文本描述对齐,捕捉跨模态交互。常见于图文检索、多模态生成任务。

两者的核心差异体现在输入数据的组织方式上:Self-Attention的输入是单一序列(Q=K=V),而Cross-Attention的输入是两个独立序列(Q来自一个序列,K/V来自另一个序列)。

二、代码实现对比:从矩阵运算看本质

1. Self-Attention的代码实现

  1. import torch
  2. import torch.nn as nn
  3. class SelfAttention(nn.Module):
  4. def __init__(self, embed_dim):
  5. super().__init__()
  6. self.embed_dim = embed_dim
  7. self.qkv_proj = nn.Linear(embed_dim, embed_dim * 3) # 输出Q,K,V
  8. self.out_proj = nn.Linear(embed_dim, embed_dim)
  9. def forward(self, x):
  10. # x: [batch_size, seq_len, embed_dim]
  11. qkv = self.qkv_proj(x) # [batch_size, seq_len, 3*embed_dim]
  12. q, k, v = torch.split(qkv, self.embed_dim, dim=-1) # 拆分Q,K,V
  13. # 计算注意力分数
  14. attn_scores = torch.bmm(q, k.transpose(1, 2)) # [batch_size, seq_len, seq_len]
  15. attn_weights = torch.softmax(attn_scores / (self.embed_dim ** 0.5), dim=-1)
  16. # 加权求和
  17. output = torch.bmm(attn_weights, v) # [batch_size, seq_len, embed_dim]
  18. return self.out_proj(output)

关键点

  • Q、K、V均来自同一输入序列x
  • 注意力分数矩阵attn_scores的形状为[seq_len, seq_len],表示序列内每个位置与其他位置的关联强度。
  • 输出维度与输入序列长度相同,保留了原始序列结构。

2. Cross-Attention的代码实现

  1. class CrossAttention(nn.Module):
  2. def __init__(self, embed_dim):
  3. super().__init__()
  4. self.embed_dim = embed_dim
  5. self.q_proj = nn.Linear(embed_dim, embed_dim) # 仅投影Q
  6. self.kv_proj = nn.Linear(embed_dim, embed_dim * 2) # 投影K,V
  7. self.out_proj = nn.Linear(embed_dim, embed_dim)
  8. def forward(self, query_seq, key_value_seq):
  9. # query_seq: [batch_size, query_len, embed_dim]
  10. # key_value_seq: [batch_size, kv_len, embed_dim]
  11. q = self.q_proj(query_seq) # [batch_size, query_len, embed_dim]
  12. kv = self.kv_proj(key_value_seq) # [batch_size, kv_len, 2*embed_dim]
  13. k, v = torch.split(kv, self.embed_dim, dim=-1) # 拆分K,V
  14. # 计算跨序列注意力分数
  15. attn_scores = torch.bmm(q, k.transpose(1, 2)) # [batch_size, query_len, kv_len]
  16. attn_weights = torch.softmax(attn_scores / (self.embed_dim ** 0.5), dim=-1)
  17. # 加权求和(使用key_value_seq的V)
  18. output = torch.bmm(attn_weights, v) # [batch_size, query_len, embed_dim]
  19. return self.out_proj(output)

关键点

  • Q来自query_seq,K和V来自key_value_seq(两个独立序列)。
  • 注意力分数矩阵attn_scores的形状为[query_len, kv_len],表示查询序列中每个位置与键值序列中每个位置的关联强度。
  • 输出长度与查询序列query_seq相同,但内容基于键值序列key_value_seq的信息聚合。

三、核心差异解析:输入、计算与输出

1. 输入结构差异

机制 Q来源 K/V来源 典型场景
Self-Attention 同一序列 同一序列 文本编码、序列分类
Cross-Attention 查询序列 键值序列 图文匹配、多模态生成

代码验证:在Self-Attention中,x同时作为Q、K、V的输入;而在Cross-Attention中,query_seqkey_value_seq是分离的。

2. 计算流程差异

  • Self-Attention:计算序列内所有位置对的相似度,生成对称的注意力矩阵(如句子中“猫”与“狗”的关联)。
  • Cross-Attention:计算两个序列间的非对称关联,生成非对称矩阵(如图像区域与文本描述的匹配度)。

性能优化建议

  • Self-Attention可通过稀疏注意力(如局部窗口)减少计算量。
  • Cross-Attention需优化K/V的投影层,避免因序列长度差异导致的内存爆炸。

3. 输出语义差异

  • Self-Attention的输出是输入序列的“上下文增强版”,保留原始位置信息。
  • Cross-Attention的输出是查询序列对键值序列的“信息聚合”,可能丢失键值序列的原始结构。

应用场景示例

  • 在机器翻译中,Self-Attention用于编码源语言句子,Cross-Attention用于将目标语言生成与源语言对齐。
  • 在图像描述生成中,Cross-Attention将文本查询与图像区域特征匹配,生成更准确的描述。

四、实践中的选择策略

1. 如何选择机制?

  • 单一序列任务(如文本分类):优先使用Self-Attention,捕捉序列内部依赖。
  • 跨序列任务(如图文检索):必须使用Cross-Attention,建立序列间关联。
  • 混合任务(如视频描述生成):可组合两种机制,例如用Self-Attention处理视频帧序列,用Cross-Attention对齐文本与视频。

2. 代码复用技巧

  1. class UnifiedAttention(nn.Module):
  2. def __init__(self, embed_dim, is_cross=False):
  3. super().__init__()
  4. self.is_cross = is_cross
  5. self.embed_dim = embed_dim
  6. if is_cross:
  7. self.q_proj = nn.Linear(embed_dim, embed_dim)
  8. self.kv_proj = nn.Linear(embed_dim, embed_dim * 2)
  9. else:
  10. self.qkv_proj = nn.Linear(embed_dim, embed_dim * 3)
  11. self.out_proj = nn.Linear(embed_dim, embed_dim)
  12. def forward(self, *inputs):
  13. if self.is_cross:
  14. query_seq, key_value_seq = inputs
  15. q = self.q_proj(query_seq)
  16. kv = self.kv_proj(key_value_seq)
  17. k, v = torch.split(kv, self.embed_dim, dim=-1)
  18. else:
  19. x = inputs[0]
  20. qkv = self.qkv_proj(x)
  21. q, k, v = torch.split(qkv, self.embed_dim, dim=-1)
  22. attn_scores = torch.bmm(q, k.transpose(1, 2)) / (self.embed_dim ** 0.5)
  23. attn_weights = torch.softmax(attn_scores, dim=-1)
  24. output = torch.bmm(attn_weights, v)
  25. return self.out_proj(output)

使用方式

  1. # Self-Attention模式
  2. self_attn = UnifiedAttention(embed_dim=512, is_cross=False)
  3. output = self_attn(x)
  4. # Cross-Attention模式
  5. cross_attn = UnifiedAttention(embed_dim=512, is_cross=True)
  6. output = cross_attn(query_seq, key_value_seq)

五、总结与建议

  1. 本质区别:Self-Attention是“序列内自关联”,Cross-Attention是“序列间交互”。
  2. 代码实现:核心差异在于Q/K/V的来源和注意力矩阵的形状。
  3. 实践建议
    • 优先使用成熟的Transformer库(如百度智能云提供的NLP工具包),避免重复造轮子。
    • 在自定义注意力机制时,注意序列长度的匹配(如Cross-Attention中query_lenkv_len的差异)。
    • 通过可视化注意力矩阵(如使用torchviz)调试模型,验证关联是否符合预期。

通过理解这两种机制的底层逻辑,开发者可以更灵活地设计多模态、跨序列任务模型,提升任务效果与计算效率。