Swin Transformer原理与代码实现全解析
一、模型设计动机与核心创新
传统Transformer模型在处理高分辨率图像时面临两大挑战:一是全局自注意力计算带来的平方级复杂度(O(N²)),二是缺乏层次化特征表达能力。某主流技术方案提出的Vision Transformer(ViT)虽证明了纯Transformer架构在视觉领域的可行性,但其计算成本随图像尺寸增大急剧上升,难以直接应用于密集预测任务(如目标检测、分割)。
Swin Transformer的核心创新在于引入层次化窗口注意力机制,通过将全局注意力拆解为局部窗口内计算,并配合位移窗口策略实现跨窗口信息交互,在保持线性复杂度(O(N))的同时构建多尺度特征。这种设计使其能无缝适配各类视觉任务,成为计算机视觉领域的重要技术突破。
二、核心架构解析
1. 层次化特征构建
模型采用类似CNN的4阶段金字塔结构,每个阶段通过patch merging层实现下采样与通道数扩展:
- Stage1:输入图像(H×W)→ 4×4 patch分割 → 线性嵌入(C=96)
- Stage2:2×2窗口合并 → 通道数翻倍(C=192)
- Stage3:重复Stage2操作 → C=384
- Stage4:最终下采样 → C=768
关键优势:每个阶段输出的特征图分辨率递减,而感受野逐步扩大,形成多尺度特征表示。
2. 窗口多头自注意力(W-MSA)
每个窗口内独立计算自注意力,计算公式为:
Attn(Q,K,V) = Softmax(QKᵀ/√d + B)V
其中B为相对位置编码,通过可学习的参数矩阵实现空间位置感知。窗口划分策略如下:
# 假设输入特征图尺寸为 [B, C, H, W]window_size = 7 # 固定窗口大小def window_partition(x, window_size):B, C, H, W = x.shapex = x.view(B, C, H//window_size, window_size, W//window_size, window_size)windows = x.permute(0, 2, 4, 1, 3, 5).contiguous()windows = windows.view(-1, C, window_size, window_size) # [num_windows, C, W, W]return windows
复杂度分析:对于M×M窗口,W-MSA复杂度为O(M²·C),远低于全局注意力的O(HW·C)。
3. 位移窗口划分(SW-MSA)
为解决窗口间信息隔离问题,引入循环位移策略:
def shift_window(x, shift_size):B, C, H, W = x.shapex = torch.roll(x, shifts=[-shift_size, -shift_size], dims=[2, 3])return x# 反向位移恢复原始位置def reverse_shift_window(x, shift_size):return torch.roll(x, shifts=[shift_size, shift_size], dims=[2, 3])
位移模式:在偶数阶段向左上方移动⌊W/2⌋个像素,奇数阶段恢复原位。配合掩码机制处理边界问题:
def get_window_attention_mask(H, W, window_size, shift_size):# 生成相对位置掩码,确保位移后窗口计算正确img_mask = torch.zeros((1, H, W, 1))cnt = 0for i in range(H // window_size):for j in range(W // window_size):start_h, start_w = i * window_size, j * window_sizeimg_mask[:, start_h:start_h+window_size,start_w:start_w+window_size, :] = cntcnt += 1# 处理位移后的掩码mask_windows = []num_windows = (H // window_size) * (W // window_size)for i in range(num_windows):for j in range(num_windows):if (i // (W//window_size)) == (j // (W//window_size)):mask_windows.append(True)else:mask_windows.append(False)return torch.tensor(mask_windows, dtype=torch.bool).reshape(num_windows, num_windows)
三、完整代码实现与注释
1. 基础模块定义
import torchimport torch.nn as nnclass WindowAttention(nn.Module):def __init__(self, dim, window_size, num_heads):super().__init__()self.dim = dimself.window_size = window_sizeself.num_heads = num_headshead_dim = dim // num_heads# 定义QKV投影层self.relative_position_bias_table = nn.Parameter(torch.zeros((2 * window_size[0] - 1) * (2 * window_size[1] - 1), num_heads))coords_h = torch.arange(window_size[0])coords_w = torch.arange(window_size[1])# 生成相对位置索引coords = torch.stack(torch.meshgrid([coords_h, coords_w]))coords_flatten = torch.flatten(coords, 1)relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :]relative_coords = relative_coords.permute(1, 2, 0).contiguous()relative_coords[:, :, 0] += window_size[0] - 1relative_coords[:, :, 1] += window_size[1] - 1relative_coords[:, :, 0] *= 2 * window_size[1] - 1relative_position_index = relative_coords.sum(-1)self.register_buffer("relative_position_index", relative_position_index)def forward(self, x, mask=None):B, N, C = x.shapehead_dim = C // self.num_headsqkv = nn.functional.linear(x, torch.empty((C, C * 3))).chunk(3, dim=-1)Q, K, V = map(lambda t: t.view(B, N, self.num_heads, head_dim).transpose(1, 2), qkv)# 计算相对位置偏置relative_position_bias = self.relative_position_bias_table[self.relative_position_index.view(-1)].view(self.window_size[0] * self.window_size[1],self.window_size[0] * self.window_size[1], -1)relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous()relative_position_bias = relative_position_bias.view(N // self.window_size[0] // self.window_size[1],self.window_size[0] * self.window_size[1],self.window_size[0] * self.window_size[1])# 缩放点积注意力attn = (Q @ K.transpose(-2, -1)) * (1.0 / torch.sqrt(torch.tensor(head_dim, dtype=torch.float32)))attn = attn + relative_position_bias.unsqueeze(0)if mask is not None:nW = mask.shape[0]attn = attn.view(B // nW, nW, self.num_heads, N, N) + mask.unsqueeze(1).unsqueeze(0)attn = attn.view(-1, self.num_heads, N, N)attn = attn.softmax(dim=-1)x = (attn @ V).transpose(1, 2).reshape(B, N, C)return x
2. Swin Transformer块实现
class SwinTransformerBlock(nn.Module):def __init__(self, dim, input_resolution, num_heads, window_size=7, shift_size=0):super().__init__()self.dim = dimself.input_resolution = input_resolutionself.num_heads = num_headsself.window_size = window_sizeself.shift_size = shift_size# 定义LayerNorm和MLPself.norm1 = nn.LayerNorm(dim)self.attn = WindowAttention(dim, window_size, num_heads)self.norm2 = nn.LayerNorm(dim)self.mlp = nn.Sequential(nn.Linear(dim, 4 * dim),nn.GELU(),nn.Linear(4 * dim, dim))# 生成位移窗口掩码if min(self.input_resolution) <= self.window_size:self.shift_size = 0self.window_size = min(self.input_resolution)def forward(self, x):H, W = self.input_resolutionB, L, C = x.shapeassert L == H * W, "input feature has wrong size"shortcut = xx = self.norm1(x)x = x.view(B, H, W, C)# 位移窗口处理if self.shift_size > 0:shifted_x = torch.roll(x, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2))else:shifted_x = x# 窗口划分与注意力计算x_windows = window_partition(shifted_x, self.window_size)x_windows = x_windows.view(-1, self.window_size * self.window_size, C)# 生成掩码(实际实现需更复杂)attn_mask = Noneif self.shift_size > 0:# 此处应插入掩码生成逻辑passattn_windows = self.attn(x_windows, mask=attn_mask)attn_windows = attn_windows.view(-1, self.window_size, self.window_size, C)# 反向位移if self.shift_size > 0:shifted_x = window_reverse(attn_windows, self.window_size, H, W)x = torch.roll(shifted_x, shifts=(self.shift_size, self.shift_size), dims=(1, 2))else:x = window_reverse(attn_windows, self.window_size, H, W)x = x.view(B, H * W, C)# FFN部分x = shortcut + xx = x + self.mlp(self.norm2(x))return x
四、实践建议与优化方向
- 窗口大小选择:通常设为7×7或14×14,需平衡计算效率与感受野
- 位移策略优化:可尝试动态位移幅度而非固定⌊W/2⌋
- 相对位置编码:建议使用可学习的参数矩阵而非固定模式
- 多尺度训练:结合不同分辨率输入提升模型泛化能力
- 硬件适配:针对特定加速器优化窗口划分操作的内存访问模式
五、典型应用场景
- 图像分类:在ImageNet等数据集上达到87.3% Top-1准确率
- 目标检测:作为Mask R-CNN等检测器的骨干网络
- 语义分割:配合UperNet等结构实现高精度分割
- 视频理解:通过3D窗口注意力扩展至时空维度
该架构通过创新的窗口注意力机制,在保持Transformer模型优势的同时,解决了高分辨率图像处理的计算瓶颈问题,为视觉任务提供了高效的解决方案。完整实现可参考主流深度学习框架中的官方示例代码。