Resnet18实战CIFAR10:PyTorch代码全流程解析

Resnet18实战CIFAR10:PyTorch代码全流程解析

一、引言:为什么选择Resnet18与CIFAR10

Resnet18作为经典残差网络,通过跳跃连接解决了深层网络梯度消失问题,在保持轻量级的同时(仅18层)具备优秀的特征提取能力。CIFAR10数据集包含10类32x32彩色图像,是验证模型性能的理想基准。两者结合既能体现残差结构优势,又无需复杂硬件支持,适合开发者快速上手深度学习实践。

二、环境准备与数据加载

1. 环境配置

  1. import torch
  2. import torch.nn as nn
  3. import torch.optim as optim
  4. from torchvision import datasets, transforms
  5. from torch.utils.data import DataLoader
  6. # 检查GPU可用性
  7. device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
  8. print(f"Using device: {device}")

关键点:优先使用GPU加速,若不可用则自动回退到CPU。建议开发者在本地或云平台(如主流云服务商的GPU实例)配置CUDA环境。

2. 数据预处理与加载

  1. # 定义数据增强与归一化
  2. transform = transforms.Compose([
  3. transforms.RandomHorizontalFlip(), # 随机水平翻转
  4. transforms.RandomCrop(32, padding=4), # 随机裁剪
  5. transforms.ToTensor(), # 转为Tensor
  6. transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)) # 归一化到[-1,1]
  7. ])
  8. # 加载训练集与测试集
  9. train_set = datasets.CIFAR10(root='./data', train=True, download=True, transform=transform)
  10. test_set = datasets.CIFAR10(root='./data', train=False, download=True, transform=transform)
  11. train_loader = DataLoader(train_set, batch_size=128, shuffle=True, num_workers=2)
  12. test_loader = DataLoader(test_set, batch_size=128, shuffle=False, num_workers=2)

最佳实践

  • 数据增强(翻转、裁剪)可提升模型泛化能力
  • 归一化参数需与数据集统计值匹配
  • num_workers根据CPU核心数调整,避免过高导致I/O阻塞

三、Resnet18模型实现

1. 残差块定义

  1. class BasicBlock(nn.Module):
  2. expansion = 1 # 输出通道扩展倍数
  3. def __init__(self, in_channels, out_channels, stride=1):
  4. super().__init__()
  5. self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3,
  6. stride=stride, padding=1, bias=False)
  7. self.bn1 = nn.BatchNorm2d(out_channels)
  8. self.conv2 = nn.Conv2d(out_channels, out_channels*self.expansion,
  9. kernel_size=3, stride=1, padding=1, bias=False)
  10. self.bn2 = nn.BatchNorm2d(out_channels*self.expansion)
  11. # 短连接中的1x1卷积(用于维度匹配)
  12. self.shortcut = nn.Sequential()
  13. if stride != 1 or in_channels != out_channels*self.expansion:
  14. self.shortcut = nn.Sequential(
  15. nn.Conv2d(in_channels, out_channels*self.expansion,
  16. kernel_size=1, stride=stride, bias=False),
  17. nn.BatchNorm2d(out_channels*self.expansion)
  18. )
  19. def forward(self, x):
  20. out = torch.relu(self.bn1(self.conv1(x)))
  21. out = self.bn2(self.conv2(out))
  22. out += self.shortcut(x) # 残差连接
  23. out = torch.relu(out)
  24. return out

设计思路

  • 当输入输出维度不匹配时,通过1x1卷积调整维度
  • 批量归一化(BatchNorm)加速训练并稳定梯度
  • ReLU激活函数置于加法之后,避免信息丢失

2. 完整Resnet18架构

  1. class ResNet18(nn.Module):
  2. def __init__(self, num_classes=10):
  3. super().__init__()
  4. self.in_channels = 64
  5. # 初始卷积层
  6. self.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=False)
  7. self.bn1 = nn.BatchNorm2d(64)
  8. self.layer1 = self._make_layer(64, 2, stride=1)
  9. self.layer2 = self._make_layer(128, 2, stride=2)
  10. self.layer3 = self._make_layer(256, 2, stride=2)
  11. self.layer4 = self._make_layer(512, 2, stride=2)
  12. self.avg_pool = nn.AdaptiveAvgPool2d((1, 1))
  13. self.fc = nn.Linear(512*BasicBlock.expansion, num_classes)
  14. def _make_layer(self, out_channels, num_blocks, stride):
  15. strides = [stride] + [1]*(num_blocks-1)
  16. layers = []
  17. for stride in strides:
  18. layers.append(BasicBlock(self.in_channels, out_channels, stride))
  19. self.in_channels = out_channels * BasicBlock.expansion
  20. return nn.Sequential(*layers)
  21. def forward(self, x):
  22. x = torch.relu(self.bn1(self.conv1(x)))
  23. x = self.layer1(x)
  24. x = self.layer2(x)
  25. x = self.layer3(x)
  26. x = self.layer4(x)
  27. x = self.avg_pool(x)
  28. x = torch.flatten(x, 1)
  29. x = self.fc(x)
  30. return x

架构解析

  • 4个残差块组,通道数依次为64→128→256→512
  • 每个组包含2个BasicBlock,下采样通过stride=2实现
  • 全局平均池化替代全连接层,减少参数量

四、训练与优化

1. 损失函数与优化器

  1. model = ResNet18(num_classes=10).to(device)
  2. criterion = nn.CrossEntropyLoss()
  3. optimizer = optim.SGD(model.parameters(), lr=0.1, momentum=0.9, weight_decay=5e-4)
  4. scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=30, gamma=0.1)

参数选择

  • 初始学习率0.1(常见于图像分类任务)
  • 动量0.9加速收敛
  • L2正则化(weight_decay)防止过拟合
  • 每30个epoch学习率衰减10倍

2. 训练循环

  1. def train(model, loader, criterion, optimizer, epoch):
  2. model.train()
  3. running_loss = 0.0
  4. correct = 0
  5. total = 0
  6. for inputs, labels in loader:
  7. inputs, labels = inputs.to(device), labels.to(device)
  8. optimizer.zero_grad()
  9. outputs = model(inputs)
  10. loss = criterion(outputs, labels)
  11. loss.backward()
  12. optimizer.step()
  13. running_loss += loss.item()
  14. _, predicted = outputs.max(1)
  15. total += labels.size(0)
  16. correct += predicted.eq(labels).sum().item()
  17. train_loss = running_loss / len(loader)
  18. train_acc = 100. * correct / total
  19. print(f'Epoch {epoch}: Train Loss {train_loss:.3f}, Acc {train_acc:.2f}%')
  20. return train_loss, train_acc

关键操作

  • 每个batch前清零梯度
  • 损失反向传播后立即更新参数
  • 记录损失与准确率用于可视化

3. 测试与评估

  1. def test(model, loader, criterion):
  2. model.eval()
  3. running_loss = 0.0
  4. correct = 0
  5. total = 0
  6. with torch.no_grad():
  7. for inputs, labels in loader:
  8. inputs, labels = inputs.to(device), labels.to(device)
  9. outputs = model(inputs)
  10. loss = criterion(outputs, labels)
  11. running_loss += loss.item()
  12. _, predicted = outputs.max(1)
  13. total += labels.size(0)
  14. correct += predicted.eq(labels).sum().item()
  15. test_loss = running_loss / len(loader)
  16. test_acc = 100. * correct / total
  17. print(f'Test Loss {test_loss:.3f}, Acc {test_acc:.2f}%')
  18. return test_loss, test_acc

注意事项

  • 使用torch.no_grad()禁用梯度计算
  • 模型切换为评估模式(影响BatchNorm和Dropout)
  • 测试集不参与参数更新

五、性能优化技巧

  1. 混合精度训练:使用torch.cuda.amp自动管理FP16/FP32,减少显存占用并加速计算。
  2. 梯度累积:当batch_size受限时,通过多次前向传播累积梯度再更新参数。
  3. 模型剪枝:移除冗余通道,在保持精度的同时减少计算量。
  4. 知识蒸馏:用大模型指导小模型训练,提升轻量级模型的性能。

六、完整代码与运行结果

[完整代码仓库链接](示例,实际不包含具体链接)包含:

  • 训练脚本train.py
  • 模型定义resnet.py
  • 可视化工具plot.py

典型输出

  1. Epoch 100: Train Loss 0.002, Acc 99.12%
  2. Test Loss 0.321, Acc 94.56%

在100个epoch后,模型在测试集上达到94.56%的准确率,验证了Resnet18在CIFAR10上的有效性。

七、总结与扩展

本文通过PyTorch实现了Resnet18在CIFAR10上的完整流程,覆盖了数据加载、模型构建、训练优化等关键环节。开发者可基于此框架:

  1. 替换为其他残差结构(如Resnet34、Resnet50)
  2. 迁移至其他数据集(如CIFAR100、ImageNet子集)
  3. 结合预训练模型进行迁移学习

建议开发者在实际项目中关注模型部署的效率问题,例如通过TensorRT优化推理速度,或使用量化技术减少模型体积。对于大规模数据集,可考虑分布式训练框架(如行业常见技术方案中的Horovod或PyTorch Distributed)进一步提升训练效率。