一、Transformer架构的核心机制
Transformer架构自2017年提出以来,凭借自注意力机制(Self-Attention)和并行计算能力,迅速成为自然语言处理(NLP)的主流架构。其核心由多头注意力(Multi-Head Attention)、位置编码(Positional Encoding)、前馈神经网络(Feed-Forward Network)和残差连接(Residual Connection)组成。
1.1 自注意力机制的实现
自注意力机制通过计算输入序列中每个元素与其他元素的关联权重,捕捉长距离依赖关系。其核心公式为:
Attention(Q, K, V) = softmax(QK^T/√d_k)V
其中,Q(Query)、K(Key)、V(Value)通过线性变换从输入序列生成,d_k为Key的维度。多头注意力通过并行计算多个注意力头,增强模型对不同语义特征的捕捉能力。
1.2 位置编码的必要性
由于Transformer缺乏卷积或循环结构的隐式位置信息,需通过位置编码显式注入序列顺序。原始论文采用正弦/余弦函数生成位置编码:
PE(pos, 2i) = sin(pos/10000^(2i/d_model))PE(pos, 2i+1) = cos(pos/10000^(2i/d_model))
其中,pos为位置索引,i为维度索引,d_model为嵌入维度。
1.3 层归一化与残差连接
Transformer在每个子层(注意力层和前馈层)后应用层归一化(Layer Normalization),并通过残差连接缓解梯度消失问题。其结构可表示为:
SublayerOutput = LayerNorm(x + Sublayer(x))
二、ViT架构的诞生:Transformer迁移至视觉任务
2020年,Google提出的Vision Transformer(ViT)首次将纯Transformer架构应用于图像分类任务,通过将图像分割为固定大小的patch(如16×16),将每个patch视为序列中的一个token,从而将2D图像转换为1D序列输入。
2.1 ViT的核心架构设计
ViT的架构可分为三个阶段:
- Patch Embedding:将图像分割为N个patch(如224×224图像分割为14×14=196个16×16 patch),每个patch通过线性投影生成固定维度的嵌入向量(如768维)。
- Transformer Encoder:由L个相同的Transformer层堆叠而成,每层包含多头注意力、层归一化和前馈网络。
- Classification Head:使用第一个token([CLS] token)的输出作为分类特征,通过MLP层输出类别概率。
2.2 关键实现细节
- Patch分割策略:ViT默认采用非重叠patch分割,但后续研究(如Swin Transformer)引入重叠patch和窗口注意力,提升局部特征捕捉能力。
- 位置编码扩展:ViT沿用Transformer的正弦位置编码,但针对2D图像特性,可改用相对位置编码或2D位置嵌入。
- 预训练与微调:ViT依赖大规模预训练(如JFT-300M数据集),在小规模数据集上需谨慎调整学习率。
三、ViT与Transformer的异同对比
| 维度 | Transformer(NLP) | ViT(CV) |
|---|---|---|
| 输入表示 | 离散token序列(如单词) | 图像patch序列 |
| 位置编码 | 1D正弦/余弦编码 | 1D编码(可扩展为2D) |
| 任务适配 | 文本生成、分类、翻译等 | 图像分类、检测、分割等 |
| 数据需求 | 中等规模(如WMT数据集) | 大规模(如JFT-300M) |
| 计算复杂度 | O(n²)(n为序列长度) | O(n²)(n为patch数量) |
四、ViT的实现优化策略
4.1 混合架构设计
为缓解ViT对局部特征捕捉的不足,可引入卷积操作:
- 前馈卷积:在Transformer层前添加卷积块,增强局部特征提取。
- 注意力卷积混合:如CvT架构,在注意力计算中引入深度可分离卷积。
4.2 层次化Transformer
借鉴CNN的层次化设计,通过逐步下采样减少patch数量:
- Pyramid ViT:如PVT、Swin Transformer,采用多阶段架构,每个阶段输出不同尺度的特征图。
- 窗口注意力:Swin Transformer将图像划分为非重叠窗口,在窗口内计算自注意力,降低计算复杂度。
4.3 轻量化设计
针对边缘设备部署,需优化ViT的参数量和计算量:
- 参数共享:如ALiBi架构,共享注意力权重。
- 线性注意力:用线性复杂度近似自注意力,如Performer。
五、ViT的代码实现示例(PyTorch)
以下为ViT的简化实现代码:
import torchimport torch.nn as nnclass PatchEmbedding(nn.Module):def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768):super().__init__()self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)self.num_patches = (img_size // patch_size) ** 2def forward(self, x):x = self.proj(x) # [B, embed_dim, num_patches^0.5, num_patches^0.5]x = x.flatten(2).transpose(1, 2) # [B, num_patches, embed_dim]return xclass ViT(nn.Module):def __init__(self, img_size=224, patch_size=16, in_chans=3, num_classes=1000, embed_dim=768):super().__init__()self.patch_embed = PatchEmbedding(img_size, patch_size, in_chans, embed_dim)self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))self.pos_embed = nn.Parameter(torch.randn(1, self.patch_embed.num_patches + 1, embed_dim))self.blocks = nn.ModuleList([nn.TransformerEncoderLayer(d_model=embed_dim, nhead=12)for _ in range(12)])self.head = nn.Linear(embed_dim, num_classes)def forward(self, x):x = self.patch_embed(x) # [B, num_patches, embed_dim]cls_tokens = self.cls_token.expand(x.size(0), -1, -1)x = torch.cat((cls_tokens, x), dim=1)x = x + self.pos_embedfor block in self.blocks:x = block(x)return self.head(x[:, 0])
六、ViT的适用场景与挑战
6.1 适用场景
- 大规模数据集:ViT在JFT-300M等超大规模数据集上表现优异。
- 高分辨率图像:结合层次化设计(如Swin Transformer)可处理高分辨率输入。
- 多模态任务:ViT可与文本Transformer结合,用于图像-文本跨模态任务。
6.2 挑战与解决方案
- 小样本问题:通过知识蒸馏(如DeiT)或预训练-微调策略缓解。
- 计算复杂度:采用线性注意力或窗口注意力降低计算量。
- 局部特征缺失:引入卷积或层次化设计增强局部建模能力。
七、总结与展望
ViT的成功证明了Transformer架构在视觉领域的普适性,但其高效应用仍需结合任务特性进行优化。未来发展方向包括:
- 更高效的注意力机制:如稀疏注意力、低秩注意力。
- 统一的多模态架构:构建支持文本、图像、视频的通用Transformer。
- 硬件友好设计:针对GPU/TPU优化计算图,提升推理速度。
开发者可根据任务需求选择基础ViT或改进架构(如Swin、PVT),并结合预训练模型和微调策略,快速构建高性能视觉应用。