Swin Transformer原理与代码实现全解析

Swin Transformer原理与代码实现全解析

一、模型设计动机与核心创新

传统Transformer模型在处理高分辨率图像时面临两大挑战:一是全局自注意力计算带来的平方级复杂度(O(N²)),二是缺乏层次化特征表达能力。某主流技术方案提出的Vision Transformer(ViT)虽证明了纯Transformer架构在视觉领域的可行性,但其计算成本随图像尺寸增大急剧上升,难以直接应用于密集预测任务(如目标检测、分割)。

Swin Transformer的核心创新在于引入层次化窗口注意力机制,通过将全局注意力拆解为局部窗口内计算,并配合位移窗口策略实现跨窗口信息交互,在保持线性复杂度(O(N))的同时构建多尺度特征。这种设计使其能无缝适配各类视觉任务,成为计算机视觉领域的重要技术突破。

二、核心架构解析

1. 层次化特征构建

模型采用类似CNN的4阶段金字塔结构,每个阶段通过patch merging层实现下采样与通道数扩展:

  • Stage1:输入图像(H×W)→ 4×4 patch分割 → 线性嵌入(C=96)
  • Stage2:2×2窗口合并 → 通道数翻倍(C=192)
  • Stage3:重复Stage2操作 → C=384
  • Stage4:最终下采样 → C=768

关键优势:每个阶段输出的特征图分辨率递减,而感受野逐步扩大,形成多尺度特征表示。

2. 窗口多头自注意力(W-MSA)

每个窗口内独立计算自注意力,计算公式为:

  1. Attn(Q,K,V) = Softmax(QKᵀ/√d + B)V

其中B为相对位置编码,通过可学习的参数矩阵实现空间位置感知。窗口划分策略如下:

  1. # 假设输入特征图尺寸为 [B, C, H, W]
  2. window_size = 7 # 固定窗口大小
  3. def window_partition(x, window_size):
  4. B, C, H, W = x.shape
  5. x = x.view(B, C, H//window_size, window_size, W//window_size, window_size)
  6. windows = x.permute(0, 2, 4, 1, 3, 5).contiguous()
  7. windows = windows.view(-1, C, window_size, window_size) # [num_windows, C, W, W]
  8. return windows

复杂度分析:对于M×M窗口,W-MSA复杂度为O(M²·C),远低于全局注意力的O(HW·C)。

3. 位移窗口划分(SW-MSA)

为解决窗口间信息隔离问题,引入循环位移策略

  1. def shift_window(x, shift_size):
  2. B, C, H, W = x.shape
  3. x = torch.roll(x, shifts=[-shift_size, -shift_size], dims=[2, 3])
  4. return x
  5. # 反向位移恢复原始位置
  6. def reverse_shift_window(x, shift_size):
  7. return torch.roll(x, shifts=[shift_size, shift_size], dims=[2, 3])

位移模式:在偶数阶段向左上方移动⌊W/2⌋个像素,奇数阶段恢复原位。配合掩码机制处理边界问题:

  1. def get_window_attention_mask(H, W, window_size, shift_size):
  2. # 生成相对位置掩码,确保位移后窗口计算正确
  3. img_mask = torch.zeros((1, H, W, 1))
  4. cnt = 0
  5. for i in range(H // window_size):
  6. for j in range(W // window_size):
  7. start_h, start_w = i * window_size, j * window_size
  8. img_mask[:, start_h:start_h+window_size,
  9. start_w:start_w+window_size, :] = cnt
  10. cnt += 1
  11. # 处理位移后的掩码
  12. mask_windows = []
  13. num_windows = (H // window_size) * (W // window_size)
  14. for i in range(num_windows):
  15. for j in range(num_windows):
  16. if (i // (W//window_size)) == (j // (W//window_size)):
  17. mask_windows.append(True)
  18. else:
  19. mask_windows.append(False)
  20. return torch.tensor(mask_windows, dtype=torch.bool).reshape(num_windows, num_windows)

三、完整代码实现与注释

1. 基础模块定义

  1. import torch
  2. import torch.nn as nn
  3. class WindowAttention(nn.Module):
  4. def __init__(self, dim, window_size, num_heads):
  5. super().__init__()
  6. self.dim = dim
  7. self.window_size = window_size
  8. self.num_heads = num_heads
  9. head_dim = dim // num_heads
  10. # 定义QKV投影层
  11. self.relative_position_bias_table = nn.Parameter(
  12. torch.zeros((2 * window_size[0] - 1) * (2 * window_size[1] - 1), num_heads))
  13. coords_h = torch.arange(window_size[0])
  14. coords_w = torch.arange(window_size[1])
  15. # 生成相对位置索引
  16. coords = torch.stack(torch.meshgrid([coords_h, coords_w]))
  17. coords_flatten = torch.flatten(coords, 1)
  18. relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :]
  19. relative_coords = relative_coords.permute(1, 2, 0).contiguous()
  20. relative_coords[:, :, 0] += window_size[0] - 1
  21. relative_coords[:, :, 1] += window_size[1] - 1
  22. relative_coords[:, :, 0] *= 2 * window_size[1] - 1
  23. relative_position_index = relative_coords.sum(-1)
  24. self.register_buffer("relative_position_index", relative_position_index)
  25. def forward(self, x, mask=None):
  26. B, N, C = x.shape
  27. head_dim = C // self.num_heads
  28. qkv = nn.functional.linear(x, torch.empty((C, C * 3))).chunk(3, dim=-1)
  29. Q, K, V = map(lambda t: t.view(B, N, self.num_heads, head_dim).transpose(1, 2), qkv)
  30. # 计算相对位置偏置
  31. relative_position_bias = self.relative_position_bias_table[
  32. self.relative_position_index.view(-1)].view(
  33. self.window_size[0] * self.window_size[1],
  34. self.window_size[0] * self.window_size[1], -1)
  35. relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous()
  36. relative_position_bias = relative_position_bias.view(
  37. N // self.window_size[0] // self.window_size[1],
  38. self.window_size[0] * self.window_size[1],
  39. self.window_size[0] * self.window_size[1])
  40. # 缩放点积注意力
  41. attn = (Q @ K.transpose(-2, -1)) * (1.0 / torch.sqrt(torch.tensor(head_dim, dtype=torch.float32)))
  42. attn = attn + relative_position_bias.unsqueeze(0)
  43. if mask is not None:
  44. nW = mask.shape[0]
  45. attn = attn.view(B // nW, nW, self.num_heads, N, N) + mask.unsqueeze(1).unsqueeze(0)
  46. attn = attn.view(-1, self.num_heads, N, N)
  47. attn = attn.softmax(dim=-1)
  48. x = (attn @ V).transpose(1, 2).reshape(B, N, C)
  49. return x

2. Swin Transformer块实现

  1. class SwinTransformerBlock(nn.Module):
  2. def __init__(self, dim, input_resolution, num_heads, window_size=7, shift_size=0):
  3. super().__init__()
  4. self.dim = dim
  5. self.input_resolution = input_resolution
  6. self.num_heads = num_heads
  7. self.window_size = window_size
  8. self.shift_size = shift_size
  9. # 定义LayerNorm和MLP
  10. self.norm1 = nn.LayerNorm(dim)
  11. self.attn = WindowAttention(dim, window_size, num_heads)
  12. self.norm2 = nn.LayerNorm(dim)
  13. self.mlp = nn.Sequential(
  14. nn.Linear(dim, 4 * dim),
  15. nn.GELU(),
  16. nn.Linear(4 * dim, dim)
  17. )
  18. # 生成位移窗口掩码
  19. if min(self.input_resolution) <= self.window_size:
  20. self.shift_size = 0
  21. self.window_size = min(self.input_resolution)
  22. def forward(self, x):
  23. H, W = self.input_resolution
  24. B, L, C = x.shape
  25. assert L == H * W, "input feature has wrong size"
  26. shortcut = x
  27. x = self.norm1(x)
  28. x = x.view(B, H, W, C)
  29. # 位移窗口处理
  30. if self.shift_size > 0:
  31. shifted_x = torch.roll(x, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2))
  32. else:
  33. shifted_x = x
  34. # 窗口划分与注意力计算
  35. x_windows = window_partition(shifted_x, self.window_size)
  36. x_windows = x_windows.view(-1, self.window_size * self.window_size, C)
  37. # 生成掩码(实际实现需更复杂)
  38. attn_mask = None
  39. if self.shift_size > 0:
  40. # 此处应插入掩码生成逻辑
  41. pass
  42. attn_windows = self.attn(x_windows, mask=attn_mask)
  43. attn_windows = attn_windows.view(-1, self.window_size, self.window_size, C)
  44. # 反向位移
  45. if self.shift_size > 0:
  46. shifted_x = window_reverse(attn_windows, self.window_size, H, W)
  47. x = torch.roll(shifted_x, shifts=(self.shift_size, self.shift_size), dims=(1, 2))
  48. else:
  49. x = window_reverse(attn_windows, self.window_size, H, W)
  50. x = x.view(B, H * W, C)
  51. # FFN部分
  52. x = shortcut + x
  53. x = x + self.mlp(self.norm2(x))
  54. return x

四、实践建议与优化方向

  1. 窗口大小选择:通常设为7×7或14×14,需平衡计算效率与感受野
  2. 位移策略优化:可尝试动态位移幅度而非固定⌊W/2⌋
  3. 相对位置编码:建议使用可学习的参数矩阵而非固定模式
  4. 多尺度训练:结合不同分辨率输入提升模型泛化能力
  5. 硬件适配:针对特定加速器优化窗口划分操作的内存访问模式

五、典型应用场景

  1. 图像分类:在ImageNet等数据集上达到87.3% Top-1准确率
  2. 目标检测:作为Mask R-CNN等检测器的骨干网络
  3. 语义分割:配合UperNet等结构实现高精度分割
  4. 视频理解:通过3D窗口注意力扩展至时空维度

该架构通过创新的窗口注意力机制,在保持Transformer模型优势的同时,解决了高分辨率图像处理的计算瓶颈问题,为视觉任务提供了高效的解决方案。完整实现可参考主流深度学习框架中的官方示例代码。