一、核心概念:两种注意力机制的定位
Self-Attention和Cross-Attention均属于Transformer架构的核心组件,但目标不同:
- Self-Attention:处理单一序列内部元素的关系,例如将句子中的每个词与所有其他词关联,捕捉上下文依赖。典型应用包括文本生成、分类任务中的序列编码。
- Cross-Attention:处理两个不同序列之间的关系,例如将图像特征与文本描述对齐,捕捉跨模态交互。常见于图文检索、多模态生成任务。
两者的核心差异体现在输入数据的组织方式上:Self-Attention的输入是单一序列(Q=K=V),而Cross-Attention的输入是两个独立序列(Q来自一个序列,K/V来自另一个序列)。
二、代码实现对比:从矩阵运算看本质
1. Self-Attention的代码实现
import torchimport torch.nn as nnclass SelfAttention(nn.Module):def __init__(self, embed_dim):super().__init__()self.embed_dim = embed_dimself.qkv_proj = nn.Linear(embed_dim, embed_dim * 3) # 输出Q,K,Vself.out_proj = nn.Linear(embed_dim, embed_dim)def forward(self, x):# x: [batch_size, seq_len, embed_dim]qkv = self.qkv_proj(x) # [batch_size, seq_len, 3*embed_dim]q, k, v = torch.split(qkv, self.embed_dim, dim=-1) # 拆分Q,K,V# 计算注意力分数attn_scores = torch.bmm(q, k.transpose(1, 2)) # [batch_size, seq_len, seq_len]attn_weights = torch.softmax(attn_scores / (self.embed_dim ** 0.5), dim=-1)# 加权求和output = torch.bmm(attn_weights, v) # [batch_size, seq_len, embed_dim]return self.out_proj(output)
关键点:
- Q、K、V均来自同一输入序列
x。 - 注意力分数矩阵
attn_scores的形状为[seq_len, seq_len],表示序列内每个位置与其他位置的关联强度。 - 输出维度与输入序列长度相同,保留了原始序列结构。
2. Cross-Attention的代码实现
class CrossAttention(nn.Module):def __init__(self, embed_dim):super().__init__()self.embed_dim = embed_dimself.q_proj = nn.Linear(embed_dim, embed_dim) # 仅投影Qself.kv_proj = nn.Linear(embed_dim, embed_dim * 2) # 投影K,Vself.out_proj = nn.Linear(embed_dim, embed_dim)def forward(self, query_seq, key_value_seq):# query_seq: [batch_size, query_len, embed_dim]# key_value_seq: [batch_size, kv_len, embed_dim]q = self.q_proj(query_seq) # [batch_size, query_len, embed_dim]kv = self.kv_proj(key_value_seq) # [batch_size, kv_len, 2*embed_dim]k, v = torch.split(kv, self.embed_dim, dim=-1) # 拆分K,V# 计算跨序列注意力分数attn_scores = torch.bmm(q, k.transpose(1, 2)) # [batch_size, query_len, kv_len]attn_weights = torch.softmax(attn_scores / (self.embed_dim ** 0.5), dim=-1)# 加权求和(使用key_value_seq的V)output = torch.bmm(attn_weights, v) # [batch_size, query_len, embed_dim]return self.out_proj(output)
关键点:
- Q来自
query_seq,K和V来自key_value_seq(两个独立序列)。 - 注意力分数矩阵
attn_scores的形状为[query_len, kv_len],表示查询序列中每个位置与键值序列中每个位置的关联强度。 - 输出长度与查询序列
query_seq相同,但内容基于键值序列key_value_seq的信息聚合。
三、核心差异解析:输入、计算与输出
1. 输入结构差异
| 机制 | Q来源 | K/V来源 | 典型场景 |
|---|---|---|---|
| Self-Attention | 同一序列 | 同一序列 | 文本编码、序列分类 |
| Cross-Attention | 查询序列 | 键值序列 | 图文匹配、多模态生成 |
代码验证:在Self-Attention中,x同时作为Q、K、V的输入;而在Cross-Attention中,query_seq和key_value_seq是分离的。
2. 计算流程差异
- Self-Attention:计算序列内所有位置对的相似度,生成对称的注意力矩阵(如句子中“猫”与“狗”的关联)。
- Cross-Attention:计算两个序列间的非对称关联,生成非对称矩阵(如图像区域与文本描述的匹配度)。
性能优化建议:
- Self-Attention可通过稀疏注意力(如局部窗口)减少计算量。
- Cross-Attention需优化K/V的投影层,避免因序列长度差异导致的内存爆炸。
3. 输出语义差异
- Self-Attention的输出是输入序列的“上下文增强版”,保留原始位置信息。
- Cross-Attention的输出是查询序列对键值序列的“信息聚合”,可能丢失键值序列的原始结构。
应用场景示例:
- 在机器翻译中,Self-Attention用于编码源语言句子,Cross-Attention用于将目标语言生成与源语言对齐。
- 在图像描述生成中,Cross-Attention将文本查询与图像区域特征匹配,生成更准确的描述。
四、实践中的选择策略
1. 如何选择机制?
- 单一序列任务(如文本分类):优先使用Self-Attention,捕捉序列内部依赖。
- 跨序列任务(如图文检索):必须使用Cross-Attention,建立序列间关联。
- 混合任务(如视频描述生成):可组合两种机制,例如用Self-Attention处理视频帧序列,用Cross-Attention对齐文本与视频。
2. 代码复用技巧
class UnifiedAttention(nn.Module):def __init__(self, embed_dim, is_cross=False):super().__init__()self.is_cross = is_crossself.embed_dim = embed_dimif is_cross:self.q_proj = nn.Linear(embed_dim, embed_dim)self.kv_proj = nn.Linear(embed_dim, embed_dim * 2)else:self.qkv_proj = nn.Linear(embed_dim, embed_dim * 3)self.out_proj = nn.Linear(embed_dim, embed_dim)def forward(self, *inputs):if self.is_cross:query_seq, key_value_seq = inputsq = self.q_proj(query_seq)kv = self.kv_proj(key_value_seq)k, v = torch.split(kv, self.embed_dim, dim=-1)else:x = inputs[0]qkv = self.qkv_proj(x)q, k, v = torch.split(qkv, self.embed_dim, dim=-1)attn_scores = torch.bmm(q, k.transpose(1, 2)) / (self.embed_dim ** 0.5)attn_weights = torch.softmax(attn_scores, dim=-1)output = torch.bmm(attn_weights, v)return self.out_proj(output)
使用方式:
# Self-Attention模式self_attn = UnifiedAttention(embed_dim=512, is_cross=False)output = self_attn(x)# Cross-Attention模式cross_attn = UnifiedAttention(embed_dim=512, is_cross=True)output = cross_attn(query_seq, key_value_seq)
五、总结与建议
- 本质区别:Self-Attention是“序列内自关联”,Cross-Attention是“序列间交互”。
- 代码实现:核心差异在于Q/K/V的来源和注意力矩阵的形状。
- 实践建议:
- 优先使用成熟的Transformer库(如百度智能云提供的NLP工具包),避免重复造轮子。
- 在自定义注意力机制时,注意序列长度的匹配(如Cross-Attention中
query_len与kv_len的差异)。 - 通过可视化注意力矩阵(如使用
torchviz)调试模型,验证关联是否符合预期。
通过理解这两种机制的底层逻辑,开发者可以更灵活地设计多模态、跨序列任务模型,提升任务效果与计算效率。