Swin Transformer PyTorch实战指南:从理论到代码的完整实现
Swin Transformer作为视觉Transformer领域的里程碑式模型,通过引入层级化特征提取和窗口多头自注意力机制,在图像分类、目标检测等任务中取得了显著效果。本文将从PyTorch实现角度,深入解析其核心模块的代码实现,并提供完整的训练优化策略。
一、Swin Transformer核心架构解析
1.1 层级化特征提取设计
与ViT的全局注意力不同,Swin Transformer采用类似CNN的4阶段层级结构,特征图尺寸从H/4×W/4逐步下采样至H/32×W/32。这种设计带来三大优势:
- 多尺度特征融合能力
- 减少全局注意力计算量
- 与下游任务(如检测)的无缝衔接
实现时需注意每个阶段的patch合并操作:
class PatchMerging(nn.Module):def __init__(self, dim):super().__init__()self.reduction = nn.Linear(4*dim, 2*dim) # 4个patch合并为1个self.norm = nn.LayerNorm(4*dim)def forward(self, x):B, H, W, C = x.shape# 窗口划分与展平x = x.reshape(B, H//2, 2, W//2, 2, C)x = x.permute(0, 1, 3, 2, 4, 5).reshape(B, -1, 4*C)return self.reduction(self.norm(x))
1.2 窗口多头自注意力(W-MSA)
窗口注意力将全局计算分解为多个局部窗口计算,显著降低计算复杂度。关键实现步骤:
- 特征图划分窗口(如7×7)
- 每个窗口内独立计算QKV
- 使用相对位置编码
class WindowAttention(nn.Module):def __init__(self, dim, num_heads, window_size):super().__init__()self.dim = dimself.window_size = window_sizeself.num_heads = num_heads# 相对位置编码表coords_h = torch.arange(window_size)coords_w = torch.arange(window_size)coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Wwcoords_flatten = torch.flatten(coords, 1) # 2, Wh*Wwrelative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Wh*Ww, Wh*Wwrelative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2relative_coords[:, :, 0] += window_size - 1 # 归一化到[-M+1, M-1]relative_coords[:, :, 1] += window_size - 1relative_coords *= 2 # 缩放到[-2M+2, 2M-2]# 生成相对位置索引self.register_buffer("relative_position_index",relative_coords.sum(-1).long())def forward(self, x, mask=None):B, N, C = x.shapeqkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C//self.num_heads).permute(2, 0, 3, 1, 4)q, k, v = qkv[0], qkv[1], qkv[2]# 计算注意力分数attn = (q @ k.transpose(-2, -1)) * self.scale# 添加相对位置编码relative_position_bias = self.relative_position_bias_table[self.relative_position_index.view(-1)].view(self.window_size * self.window_size, self.window_size * self.window_size, -1)relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous()attn = attn + relative_position_bias.unsqueeze(0)# 后续softmax和v相乘...
1.3 移位窗口机制(SW-MSA)
为解决窗口间信息隔离问题,SW-MSA通过循环移位实现跨窗口交互。实现关键点:
- 特征图循环移位(上移h/2,左移w/2)
- 注意力计算后反向移位恢复
- 特殊mask处理边界问题
def get_shifted_window_mask(self, H, W, window_size, shift_size):# 计算移位后的窗口坐标img_mask = torch.zeros((1, H, W, 1)) # 1Hw1cnt = 0for h in (slice(0, -window_size), slice(-window_size, -shift_size), slice(-shift_size, None)):for w in (slice(0, -window_size), slice(-window_size, -shift_size), slice(-shift_size, None)):img_mask[:, h, w, :] = cntcnt += 1# 生成mask矩阵mask_windows = window_partition(img_mask, window_size) # nW, window_size, window_size, 1mask_windows = mask_windows.view(-1, window_size * window_size)attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2)attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0))return attn_mask
二、完整模型实现代码
2.1 Swin Transformer块实现
class SwinTransformerBlock(nn.Module):def __init__(self, dim, num_heads, window_size=7, shift_size=0):super().__init__()self.dim = dimself.window_size = window_sizeself.shift_size = shift_size# W-MSA或SW-MSAif shift_size > 0:attn = ShiftedWindowAttention(dim, num_heads, window_size, shift_size)else:attn = WindowAttention(dim, num_heads, window_size)self.norm1 = nn.LayerNorm(dim)self.attn = attnself.drop_path = DropPath() # 随机深度# MLP部分self.norm2 = nn.LayerNorm(dim)self.mlp = MLP(in_features=dim, hidden_features=int(dim*4), out_features=dim)def forward(self, x):H, W = self.get_hw(x)shortcut = xx = self.norm1(x)# 窗口划分与注意力计算x_windows = window_partition(x, self.window_size) # nW*B, window_size, window_size, Cx_windows = x_windows.view(-1, self.window_size*self.window_size, self.dim) # nW*B, window_size*window_size, C# 计算注意力attn_windows = self.attn(x_windows, mask=self.attn_mask)# 合并窗口attn_windows = attn_windows.view(-1, self.window_size, self.window_size, self.dim)x = window_reverse(attn_windows, self.window_size, H, W) # B H' W' C# 残差连接x = shortcut + self.drop_path(x)# MLP部分x = x + self.drop_path(self.mlp(self.norm2(x)))return x
2.2 完整模型架构
class SwinTransformer(nn.Module):def __init__(self, img_size=224, patch_size=4, in_chans=3,embed_dim=96, depths=[2, 2, 6, 2], num_heads=[3, 6, 12, 24]):super().__init__()self.patch_embed = PatchEmbed(img_size, patch_size, in_chans, embed_dim)# 层级特征提取self.stages = nn.ModuleList()dpr = [x.item() for x in torch.linspace(0, 0.1, sum(depths))] # 随机深度概率cur = 0for i in range(4):stage = nn.ModuleList([SwinTransformerBlock(dim=embed_dim*(2**i),num_heads=num_heads[i],window_size=7 if i<2 else 14, # 深层使用更大窗口shift_size=3 if i<2 else 7,drop_path=dpr[cur+j]) for j in range(depths[i])])self.stages.append(stage)cur += depths[i]if i < 3: # 除最后一层外都进行下采样self.stages.append(PatchMerging(dim=embed_dim*(2**i)))def forward_features(self, x):x = self.patch_embed(x)for stage in self.stages:if isinstance(stage, PatchMerging):x = stage(x)else:for blk in stage:x = blk(x)return x.mean(dim=[1,2]) # 全局平均池化
三、训练优化最佳实践
3.1 数据增强策略
推荐使用AutoAugment+RandAugment组合:
from torchvision import transformstrain_transform = transforms.Compose([transforms.RandomResizedCrop(224, scale=(0.08, 1.0)),transforms.RandomApply([transforms.ColorJitter(0.4, 0.4, 0.4, 0.1)], p=0.8),transforms.RandomGrayscale(p=0.2),transforms.RandomApply([transforms.GaussianBlur((3, 3), (0.1, 2.0))], p=0.5),transforms.ToTensor(),transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])])
3.2 优化器配置
采用AdamW优化器配合余弦退火学习率:
def get_optimizer(model, lr=5e-4, weight_decay=0.05):param_groups = [{"params": [p for n, p in model.named_parameters()if not any(nd in n for nd in ["norm", "bias"])],"weight_decay": weight_decay},{"params": [p for n, p in model.named_parameters()if any(nd in n for nd in ["norm", "bias"])],"weight_decay": 0.}]optimizer = torch.optim.AdamW(param_groups, lr=lr)return optimizerdef get_scheduler(optimizer, num_epochs=300):lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=num_epochs, eta_min=1e-6)return lr_scheduler
3.3 性能优化技巧
-
混合精度训练:
scaler = torch.cuda.amp.GradScaler()with torch.cuda.amp.autocast():outputs = model(inputs)loss = criterion(outputs, targets)scaler.scale(loss).backward()scaler.step(optimizer)scaler.update()
-
梯度累积:当显存不足时,可累积多个batch的梯度再更新
accum_steps = 4optimizer.zero_grad()for i, (inputs, targets) in enumerate(train_loader):outputs = model(inputs)loss = criterion(outputs, targets)/accum_stepsloss.backward()if (i+1)%accum_steps == 0:optimizer.step()optimizer.zero_grad()
四、部署优化建议
4.1 模型量化
使用动态量化可减少模型体积并加速推理:
quantized_model = torch.quantization.quantize_dynamic(model, {nn.Linear}, dtype=torch.qint8)
4.2 TensorRT加速
通过ONNX导出后使用TensorRT优化:
dummy_input = torch.randn(1, 3, 224, 224)torch.onnx.export(model, dummy_input, "swin_tiny.onnx",input_names=["input"], output_names=["output"],dynamic_axes={"input": {0: "batch"}, "output": {0: "batch"}})
五、常见问题解决方案
5.1 训练不稳定问题
- 现象:Loss突然增大或NaN
- 解决方案:
- 减小初始学习率(如从5e-4降至1e-4)
- 增加梯度裁剪(
torch.nn.utils.clip_grad_norm_) - 检查数据是否存在异常样本
5.2 显存不足问题
- 解决方案:
- 使用梯度累积(如上文所述)
- 减小batch size(推荐从64开始尝试)
- 启用混合精度训练
- 使用
torch.cuda.empty_cache()清理缓存
六、总结与展望
Swin Transformer通过创新的窗口注意力机制和层级化设计,成功将Transformer架构应用于密集预测任务。本文提供的PyTorch实现完整覆盖了从基础模块到完整模型的构建过程,并给出了实用的训练优化策略。在实际应用中,开发者可根据具体任务调整模型深度、窗口大小等超参数,以获得最佳的性能-效率平衡。
随着视觉Transformer研究的深入,未来可探索的方向包括:
- 更高效的窗口划分策略
- 动态窗口大小调整机制
- 与CNN的混合架构设计
- 轻量化版本在移动端的应用
通过持续优化和工程实践,Swin Transformer系列模型将在更多计算机视觉任务中展现其强大潜力。