Swin Transformer PyTorch实战指南:从理论到代码的完整实现

Swin Transformer PyTorch实战指南:从理论到代码的完整实现

Swin Transformer作为视觉Transformer领域的里程碑式模型,通过引入层级化特征提取和窗口多头自注意力机制,在图像分类、目标检测等任务中取得了显著效果。本文将从PyTorch实现角度,深入解析其核心模块的代码实现,并提供完整的训练优化策略。

一、Swin Transformer核心架构解析

1.1 层级化特征提取设计

与ViT的全局注意力不同,Swin Transformer采用类似CNN的4阶段层级结构,特征图尺寸从H/4×W/4逐步下采样至H/32×W/32。这种设计带来三大优势:

  • 多尺度特征融合能力
  • 减少全局注意力计算量
  • 与下游任务(如检测)的无缝衔接

实现时需注意每个阶段的patch合并操作:

  1. class PatchMerging(nn.Module):
  2. def __init__(self, dim):
  3. super().__init__()
  4. self.reduction = nn.Linear(4*dim, 2*dim) # 4个patch合并为1个
  5. self.norm = nn.LayerNorm(4*dim)
  6. def forward(self, x):
  7. B, H, W, C = x.shape
  8. # 窗口划分与展平
  9. x = x.reshape(B, H//2, 2, W//2, 2, C)
  10. x = x.permute(0, 1, 3, 2, 4, 5).reshape(B, -1, 4*C)
  11. return self.reduction(self.norm(x))

1.2 窗口多头自注意力(W-MSA)

窗口注意力将全局计算分解为多个局部窗口计算,显著降低计算复杂度。关键实现步骤:

  1. 特征图划分窗口(如7×7)
  2. 每个窗口内独立计算QKV
  3. 使用相对位置编码
  1. class WindowAttention(nn.Module):
  2. def __init__(self, dim, num_heads, window_size):
  3. super().__init__()
  4. self.dim = dim
  5. self.window_size = window_size
  6. self.num_heads = num_heads
  7. # 相对位置编码表
  8. coords_h = torch.arange(window_size)
  9. coords_w = torch.arange(window_size)
  10. coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww
  11. coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww
  12. relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Wh*Ww, Wh*Ww
  13. relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2
  14. relative_coords[:, :, 0] += window_size - 1 # 归一化到[-M+1, M-1]
  15. relative_coords[:, :, 1] += window_size - 1
  16. relative_coords *= 2 # 缩放到[-2M+2, 2M-2]
  17. # 生成相对位置索引
  18. self.register_buffer("relative_position_index",
  19. relative_coords.sum(-1).long())
  20. def forward(self, x, mask=None):
  21. B, N, C = x.shape
  22. qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C//self.num_heads).permute(2, 0, 3, 1, 4)
  23. q, k, v = qkv[0], qkv[1], qkv[2]
  24. # 计算注意力分数
  25. attn = (q @ k.transpose(-2, -1)) * self.scale
  26. # 添加相对位置编码
  27. relative_position_bias = self.relative_position_bias_table[self.relative_position_index.view(-1)].view(
  28. self.window_size * self.window_size, self.window_size * self.window_size, -1)
  29. relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous()
  30. attn = attn + relative_position_bias.unsqueeze(0)
  31. # 后续softmax和v相乘...

1.3 移位窗口机制(SW-MSA)

为解决窗口间信息隔离问题,SW-MSA通过循环移位实现跨窗口交互。实现关键点:

  • 特征图循环移位(上移h/2,左移w/2)
  • 注意力计算后反向移位恢复
  • 特殊mask处理边界问题
  1. def get_shifted_window_mask(self, H, W, window_size, shift_size):
  2. # 计算移位后的窗口坐标
  3. img_mask = torch.zeros((1, H, W, 1)) # 1Hw1
  4. cnt = 0
  5. for h in (slice(0, -window_size), slice(-window_size, -shift_size), slice(-shift_size, None)):
  6. for w in (slice(0, -window_size), slice(-window_size, -shift_size), slice(-shift_size, None)):
  7. img_mask[:, h, w, :] = cnt
  8. cnt += 1
  9. # 生成mask矩阵
  10. mask_windows = window_partition(img_mask, window_size) # nW, window_size, window_size, 1
  11. mask_windows = mask_windows.view(-1, window_size * window_size)
  12. attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2)
  13. attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0))
  14. return attn_mask

二、完整模型实现代码

2.1 Swin Transformer块实现

  1. class SwinTransformerBlock(nn.Module):
  2. def __init__(self, dim, num_heads, window_size=7, shift_size=0):
  3. super().__init__()
  4. self.dim = dim
  5. self.window_size = window_size
  6. self.shift_size = shift_size
  7. # W-MSA或SW-MSA
  8. if shift_size > 0:
  9. attn = ShiftedWindowAttention(dim, num_heads, window_size, shift_size)
  10. else:
  11. attn = WindowAttention(dim, num_heads, window_size)
  12. self.norm1 = nn.LayerNorm(dim)
  13. self.attn = attn
  14. self.drop_path = DropPath() # 随机深度
  15. # MLP部分
  16. self.norm2 = nn.LayerNorm(dim)
  17. self.mlp = MLP(in_features=dim, hidden_features=int(dim*4), out_features=dim)
  18. def forward(self, x):
  19. H, W = self.get_hw(x)
  20. shortcut = x
  21. x = self.norm1(x)
  22. # 窗口划分与注意力计算
  23. x_windows = window_partition(x, self.window_size) # nW*B, window_size, window_size, C
  24. x_windows = x_windows.view(-1, self.window_size*self.window_size, self.dim) # nW*B, window_size*window_size, C
  25. # 计算注意力
  26. attn_windows = self.attn(x_windows, mask=self.attn_mask)
  27. # 合并窗口
  28. attn_windows = attn_windows.view(-1, self.window_size, self.window_size, self.dim)
  29. x = window_reverse(attn_windows, self.window_size, H, W) # B H' W' C
  30. # 残差连接
  31. x = shortcut + self.drop_path(x)
  32. # MLP部分
  33. x = x + self.drop_path(self.mlp(self.norm2(x)))
  34. return x

2.2 完整模型架构

  1. class SwinTransformer(nn.Module):
  2. def __init__(self, img_size=224, patch_size=4, in_chans=3,
  3. embed_dim=96, depths=[2, 2, 6, 2], num_heads=[3, 6, 12, 24]):
  4. super().__init__()
  5. self.patch_embed = PatchEmbed(img_size, patch_size, in_chans, embed_dim)
  6. # 层级特征提取
  7. self.stages = nn.ModuleList()
  8. dpr = [x.item() for x in torch.linspace(0, 0.1, sum(depths))] # 随机深度概率
  9. cur = 0
  10. for i in range(4):
  11. stage = nn.ModuleList([
  12. SwinTransformerBlock(
  13. dim=embed_dim*(2**i),
  14. num_heads=num_heads[i],
  15. window_size=7 if i<2 else 14, # 深层使用更大窗口
  16. shift_size=3 if i<2 else 7,
  17. drop_path=dpr[cur+j]
  18. ) for j in range(depths[i])
  19. ])
  20. self.stages.append(stage)
  21. cur += depths[i]
  22. if i < 3: # 除最后一层外都进行下采样
  23. self.stages.append(PatchMerging(dim=embed_dim*(2**i)))
  24. def forward_features(self, x):
  25. x = self.patch_embed(x)
  26. for stage in self.stages:
  27. if isinstance(stage, PatchMerging):
  28. x = stage(x)
  29. else:
  30. for blk in stage:
  31. x = blk(x)
  32. return x.mean(dim=[1,2]) # 全局平均池化

三、训练优化最佳实践

3.1 数据增强策略

推荐使用AutoAugment+RandAugment组合:

  1. from torchvision import transforms
  2. train_transform = transforms.Compose([
  3. transforms.RandomResizedCrop(224, scale=(0.08, 1.0)),
  4. transforms.RandomApply([
  5. transforms.ColorJitter(0.4, 0.4, 0.4, 0.1)
  6. ], p=0.8),
  7. transforms.RandomGrayscale(p=0.2),
  8. transforms.RandomApply([
  9. transforms.GaussianBlur((3, 3), (0.1, 2.0))
  10. ], p=0.5),
  11. transforms.ToTensor(),
  12. transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
  13. ])

3.2 优化器配置

采用AdamW优化器配合余弦退火学习率:

  1. def get_optimizer(model, lr=5e-4, weight_decay=0.05):
  2. param_groups = [
  3. {"params": [p for n, p in model.named_parameters()
  4. if not any(nd in n for nd in ["norm", "bias"])],
  5. "weight_decay": weight_decay},
  6. {"params": [p for n, p in model.named_parameters()
  7. if any(nd in n for nd in ["norm", "bias"])],
  8. "weight_decay": 0.}
  9. ]
  10. optimizer = torch.optim.AdamW(param_groups, lr=lr)
  11. return optimizer
  12. def get_scheduler(optimizer, num_epochs=300):
  13. lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
  14. optimizer, T_max=num_epochs, eta_min=1e-6)
  15. return lr_scheduler

3.3 性能优化技巧

  1. 混合精度训练

    1. scaler = torch.cuda.amp.GradScaler()
    2. with torch.cuda.amp.autocast():
    3. outputs = model(inputs)
    4. loss = criterion(outputs, targets)
    5. scaler.scale(loss).backward()
    6. scaler.step(optimizer)
    7. scaler.update()
  2. 梯度累积:当显存不足时,可累积多个batch的梯度再更新

    1. accum_steps = 4
    2. optimizer.zero_grad()
    3. for i, (inputs, targets) in enumerate(train_loader):
    4. outputs = model(inputs)
    5. loss = criterion(outputs, targets)/accum_steps
    6. loss.backward()
    7. if (i+1)%accum_steps == 0:
    8. optimizer.step()
    9. optimizer.zero_grad()

四、部署优化建议

4.1 模型量化

使用动态量化可减少模型体积并加速推理:

  1. quantized_model = torch.quantization.quantize_dynamic(
  2. model, {nn.Linear}, dtype=torch.qint8)

4.2 TensorRT加速

通过ONNX导出后使用TensorRT优化:

  1. dummy_input = torch.randn(1, 3, 224, 224)
  2. torch.onnx.export(model, dummy_input, "swin_tiny.onnx",
  3. input_names=["input"], output_names=["output"],
  4. dynamic_axes={"input": {0: "batch"}, "output": {0: "batch"}})

五、常见问题解决方案

5.1 训练不稳定问题

  • 现象:Loss突然增大或NaN
  • 解决方案
    • 减小初始学习率(如从5e-4降至1e-4)
    • 增加梯度裁剪(torch.nn.utils.clip_grad_norm_
    • 检查数据是否存在异常样本

5.2 显存不足问题

  • 解决方案
    • 使用梯度累积(如上文所述)
    • 减小batch size(推荐从64开始尝试)
    • 启用混合精度训练
    • 使用torch.cuda.empty_cache()清理缓存

六、总结与展望

Swin Transformer通过创新的窗口注意力机制和层级化设计,成功将Transformer架构应用于密集预测任务。本文提供的PyTorch实现完整覆盖了从基础模块到完整模型的构建过程,并给出了实用的训练优化策略。在实际应用中,开发者可根据具体任务调整模型深度、窗口大小等超参数,以获得最佳的性能-效率平衡。

随着视觉Transformer研究的深入,未来可探索的方向包括:

  1. 更高效的窗口划分策略
  2. 动态窗口大小调整机制
  3. 与CNN的混合架构设计
  4. 轻量化版本在移动端的应用

通过持续优化和工程实践,Swin Transformer系列模型将在更多计算机视觉任务中展现其强大潜力。