从Transformer到ViT:视觉领域的自注意力架构演进与实现

一、Transformer架构的核心机制

Transformer架构自2017年提出以来,凭借自注意力机制(Self-Attention)和并行计算能力,迅速成为自然语言处理(NLP)的主流架构。其核心由多头注意力(Multi-Head Attention)、位置编码(Positional Encoding)、前馈神经网络(Feed-Forward Network)和残差连接(Residual Connection)组成。

1.1 自注意力机制的实现

自注意力机制通过计算输入序列中每个元素与其他元素的关联权重,捕捉长距离依赖关系。其核心公式为:

  1. Attention(Q, K, V) = softmax(QK^T/√d_k)V

其中,Q(Query)、K(Key)、V(Value)通过线性变换从输入序列生成,d_k为Key的维度。多头注意力通过并行计算多个注意力头,增强模型对不同语义特征的捕捉能力。

1.2 位置编码的必要性

由于Transformer缺乏卷积或循环结构的隐式位置信息,需通过位置编码显式注入序列顺序。原始论文采用正弦/余弦函数生成位置编码:

  1. PE(pos, 2i) = sin(pos/10000^(2i/d_model))
  2. PE(pos, 2i+1) = cos(pos/10000^(2i/d_model))

其中,pos为位置索引,i为维度索引,d_model为嵌入维度。

1.3 层归一化与残差连接

Transformer在每个子层(注意力层和前馈层)后应用层归一化(Layer Normalization),并通过残差连接缓解梯度消失问题。其结构可表示为:

  1. SublayerOutput = LayerNorm(x + Sublayer(x))

二、ViT架构的诞生:Transformer迁移至视觉任务

2020年,Google提出的Vision Transformer(ViT)首次将纯Transformer架构应用于图像分类任务,通过将图像分割为固定大小的patch(如16×16),将每个patch视为序列中的一个token,从而将2D图像转换为1D序列输入。

2.1 ViT的核心架构设计

ViT的架构可分为三个阶段:

  1. Patch Embedding:将图像分割为N个patch(如224×224图像分割为14×14=196个16×16 patch),每个patch通过线性投影生成固定维度的嵌入向量(如768维)。
  2. Transformer Encoder:由L个相同的Transformer层堆叠而成,每层包含多头注意力、层归一化和前馈网络。
  3. Classification Head:使用第一个token([CLS] token)的输出作为分类特征,通过MLP层输出类别概率。

2.2 关键实现细节

  • Patch分割策略:ViT默认采用非重叠patch分割,但后续研究(如Swin Transformer)引入重叠patch和窗口注意力,提升局部特征捕捉能力。
  • 位置编码扩展:ViT沿用Transformer的正弦位置编码,但针对2D图像特性,可改用相对位置编码或2D位置嵌入。
  • 预训练与微调:ViT依赖大规模预训练(如JFT-300M数据集),在小规模数据集上需谨慎调整学习率。

三、ViT与Transformer的异同对比

维度 Transformer(NLP) ViT(CV)
输入表示 离散token序列(如单词) 图像patch序列
位置编码 1D正弦/余弦编码 1D编码(可扩展为2D)
任务适配 文本生成、分类、翻译等 图像分类、检测、分割等
数据需求 中等规模(如WMT数据集) 大规模(如JFT-300M)
计算复杂度 O(n²)(n为序列长度) O(n²)(n为patch数量)

四、ViT的实现优化策略

4.1 混合架构设计

为缓解ViT对局部特征捕捉的不足,可引入卷积操作:

  • 前馈卷积:在Transformer层前添加卷积块,增强局部特征提取。
  • 注意力卷积混合:如CvT架构,在注意力计算中引入深度可分离卷积。

4.2 层次化Transformer

借鉴CNN的层次化设计,通过逐步下采样减少patch数量:

  • Pyramid ViT:如PVT、Swin Transformer,采用多阶段架构,每个阶段输出不同尺度的特征图。
  • 窗口注意力:Swin Transformer将图像划分为非重叠窗口,在窗口内计算自注意力,降低计算复杂度。

4.3 轻量化设计

针对边缘设备部署,需优化ViT的参数量和计算量:

  • 参数共享:如ALiBi架构,共享注意力权重。
  • 线性注意力:用线性复杂度近似自注意力,如Performer。

五、ViT的代码实现示例(PyTorch)

以下为ViT的简化实现代码:

  1. import torch
  2. import torch.nn as nn
  3. class PatchEmbedding(nn.Module):
  4. def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768):
  5. super().__init__()
  6. self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)
  7. self.num_patches = (img_size // patch_size) ** 2
  8. def forward(self, x):
  9. x = self.proj(x) # [B, embed_dim, num_patches^0.5, num_patches^0.5]
  10. x = x.flatten(2).transpose(1, 2) # [B, num_patches, embed_dim]
  11. return x
  12. class ViT(nn.Module):
  13. def __init__(self, img_size=224, patch_size=16, in_chans=3, num_classes=1000, embed_dim=768):
  14. super().__init__()
  15. self.patch_embed = PatchEmbedding(img_size, patch_size, in_chans, embed_dim)
  16. self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
  17. self.pos_embed = nn.Parameter(torch.randn(1, self.patch_embed.num_patches + 1, embed_dim))
  18. self.blocks = nn.ModuleList([
  19. nn.TransformerEncoderLayer(d_model=embed_dim, nhead=12)
  20. for _ in range(12)
  21. ])
  22. self.head = nn.Linear(embed_dim, num_classes)
  23. def forward(self, x):
  24. x = self.patch_embed(x) # [B, num_patches, embed_dim]
  25. cls_tokens = self.cls_token.expand(x.size(0), -1, -1)
  26. x = torch.cat((cls_tokens, x), dim=1)
  27. x = x + self.pos_embed
  28. for block in self.blocks:
  29. x = block(x)
  30. return self.head(x[:, 0])

六、ViT的适用场景与挑战

6.1 适用场景

  • 大规模数据集:ViT在JFT-300M等超大规模数据集上表现优异。
  • 高分辨率图像:结合层次化设计(如Swin Transformer)可处理高分辨率输入。
  • 多模态任务:ViT可与文本Transformer结合,用于图像-文本跨模态任务。

6.2 挑战与解决方案

  • 小样本问题:通过知识蒸馏(如DeiT)或预训练-微调策略缓解。
  • 计算复杂度:采用线性注意力或窗口注意力降低计算量。
  • 局部特征缺失:引入卷积或层次化设计增强局部建模能力。

七、总结与展望

ViT的成功证明了Transformer架构在视觉领域的普适性,但其高效应用仍需结合任务特性进行优化。未来发展方向包括:

  1. 更高效的注意力机制:如稀疏注意力、低秩注意力。
  2. 统一的多模态架构:构建支持文本、图像、视频的通用Transformer。
  3. 硬件友好设计:针对GPU/TPU优化计算图,提升推理速度。

开发者可根据任务需求选择基础ViT或改进架构(如Swin、PVT),并结合预训练模型和微调策略,快速构建高性能视觉应用。