一、为什么需要Multi-Head Attention?
在序列建模任务中,传统注意力机制通过计算查询(Query)与键(Key)的相似度,加权聚合值(Value)信息。然而,单次注意力计算仅能捕捉一种依赖模式(如局部或全局关系),难以同时处理多种语义关联。例如,在自然语言处理中,”bank”一词可能同时关联”河流”(地理含义)和”金融”(机构含义),单头注意力难以区分这两种上下文。
Multi-Head Attention的核心思想是将输入投影到多个子空间,每个子空间独立计算注意力权重,最终合并结果。这种设计使模型能够并行捕捉不同位置的多种依赖关系,显著提升表达能力。以Transformer论文中的实验为例,6层编码器+6层解码器的架构中,Multi-Head Attention(8头)比单头注意力在机器翻译任务上提升了2.3 BLEU分数。
二、从缩放点积注意力到多头并行
1. 缩放点积注意力(Scaled Dot-Product Attention)
单头注意力的计算分为三步:
- 相似度计算:Query与Key的转置做点积,得到原始注意力分数。
- 缩放处理:除以$\sqrt{d_k}$($d_k$为Key维度),避免点积数值过大导致Softmax梯度消失。
- 加权聚合:Softmax归一化后,与Value相乘得到输出。
数学表达式:
2. 多头并行化的实现
Multi-Head Attention通过线性变换将Q、K、V投影到$h$个子空间($h$为头数),每个子空间独立计算注意力,最后拼接结果并通过线性变换输出。具体步骤如下:
-
线性投影:
- $Qi = QW_i^Q$, $K_i = KW_i^K$, $V_i = VW_i^V$,其中$W_i^Q \in \mathbb{R}^{d{model} \times dk}$, $W_i^K \in \mathbb{R}^{d{model} \times dk}$, $W_i^V \in \mathbb{R}^{d{model} \times d_v}$。
- 通常$dk = d_v = d{model}/h$(例如$d_{model}=512$, $h=8$时,$d_k=64$)。
-
并行注意力计算:
- 对每个头$i$,计算$\text{head}_i = \text{Attention}(Q_i, K_i, V_i)$。
-
结果拼接与输出:
- 拼接所有头的输出:$\text{MultiHead}(Q, K, V) = \text{Concat}(\text{head}1, …, \text{head}_h)W^O$,其中$W^O \in \mathbb{R}^{hd_v \times d{model}}$。
3. 代码示例(PyTorch实现)
import torchimport torch.nn as nnclass MultiHeadAttention(nn.Module):def __init__(self, d_model, num_heads):super().__init__()self.d_model = d_modelself.num_heads = num_headsself.d_k = d_model // num_heads# 线性投影层self.W_Q = nn.Linear(d_model, d_model)self.W_K = nn.Linear(d_model, d_model)self.W_V = nn.Linear(d_model, d_model)self.W_O = nn.Linear(d_model, d_model)def scaled_dot_product(self, Q, K, V):# Q, K, V形状: [batch_size, seq_len, d_model]scores = torch.bmm(Q, K.transpose(1, 2)) / (self.d_k ** 0.5)attn_weights = torch.softmax(scores, dim=-1)return torch.bmm(attn_weights, V)def split_heads(self, x):# 分割为多头: [batch_size, seq_len, d_model] -> [batch_size, num_heads, seq_len, d_k]batch_size, seq_len, _ = x.size()return x.view(batch_size, seq_len, self.num_heads, self.d_k).transpose(1, 2)def combine_heads(self, x):# 合并多头: [batch_size, num_heads, seq_len, d_k] -> [batch_size, seq_len, d_model]batch_size, _, seq_len, _ = x.size()return x.transpose(1, 2).contiguous().view(batch_size, seq_len, -1)def forward(self, Q, K, V):# 线性投影Q = self.W_Q(Q)K = self.W_K(K)V = self.W_V(V)# 分割多头Q = self.split_heads(Q)K = self.split_heads(K)V = self.split_heads(V)# 并行计算注意力attn_outputs = []for i in range(self.num_heads):attn_output = self.scaled_dot_product(Q[:, i], K[:, i], V[:, i])attn_outputs.append(attn_output)# 拼接结果concat_output = torch.stack(attn_outputs, dim=1)concat_output = concat_output.view(batch_size, -1, self.d_model) # 简化版,实际可用combine_heads# 输出投影return self.W_O(concat_output)
三、关键设计细节与优化
1. 头数与模型容量的平衡
头数$h$的选择直接影响模型性能:
- 过少(如$h=1$):无法捕捉多种依赖关系,表达能力受限。
- 过多(如$h>32$):增加计算开销,且可能因每个头维度过小($d_k$过小)导致信息丢失。
- 经验值:在$d_{model}=512$时,$h=8$是常见选择(如BERT、GPT等模型)。
2. 缩放因子$\sqrt{d_k}$的作用
点积结果的方差随$d_k$增大而增加,若不缩放,Softmax输入可能落入梯度极小的饱和区。缩放后,无论$d_k$如何变化,点积结果的方差稳定在1附近,保证训练稳定性。
3. 多头注意力的可视化解释
以机器翻译为例,输入序列”The cat sat on the mat”:
- 头1:可能捕捉”cat-sat”的主谓关系。
- 头2:可能捕捉”on-mat”的方位关系。
- 头3:可能忽略无关词(如”the”),聚焦关键依赖。
通过可视化注意力权重矩阵,可观察到不同头确实关注输入序列的不同部分(参考《Attention Is All You Need》论文中的热力图)。
四、应用场景与扩展
1. 自然语言处理
- 机器翻译:编码器-解码器架构中,多头注意力同时处理源语言与目标语言的依赖。
- 文本分类:通过自注意力捕捉句子内长距离依赖(如BERT)。
2. 计算机视觉
- Vision Transformer:将图像分块后作为序列输入,多头注意力捕捉空间关系。
- 目标检测:通过注意力机制融合全局与局部特征。
3. 扩展变体
- 稀疏多头注意力:仅计算部分Query-Key对的注意力,降低计算复杂度(如Longformer)。
- 动态头数:根据输入动态调整头数,平衡效率与性能(如研究中的自适应注意力)。
五、总结与最佳实践
Multi-Head Attention通过并行化注意力计算,显著提升了模型捕捉复杂依赖关系的能力。实际应用中需注意:
- 头数选择:根据任务复杂度与计算资源权衡,常见值为4~16。
- 维度分配:确保$dk = d_v = d{model}/h$,避免维度不匹配。
- 初始化策略:线性层权重可使用Xavier初始化,稳定训练过程。
- 性能优化:通过矩阵运算并行化(如PyTorch的
bmm)加速多头计算。
掌握Multi-Head Attention的设计逻辑后,开发者可更灵活地调整Transformer架构,适应不同场景的需求。