ResNet-18在Cifar-10图像分类中的PyTorch实践

ResNet-18在Cifar-10图像分类中的PyTorch实践

一、引言

Cifar-10数据集作为计算机视觉领域的经典基准,包含10类共6万张32x32彩色图像,常用于验证模型在有限数据下的泛化能力。ResNet-18作为残差网络的轻量级版本,通过跳跃连接缓解梯度消失问题,在保持较高精度的同时降低计算复杂度。本文将结合PyTorch框架,系统介绍从数据加载到模型部署的全流程实现,并提供优化建议。

二、环境准备与数据加载

2.1 环境配置

建议使用PyTorch 1.8+版本,配合CUDA 10.2+以支持GPU加速。通过pip install torch torchvision安装基础库,并确保NumPy、Matplotlib等辅助工具已就绪。

2.2 数据预处理

Cifar-10原始数据需进行标准化与增强:

  1. import torchvision.transforms as transforms
  2. transform_train = transforms.Compose([
  3. transforms.RandomCrop(32, padding=4), # 随机裁剪增强
  4. transforms.RandomHorizontalFlip(), # 水平翻转
  5. transforms.ToTensor(), # 转为Tensor
  6. transforms.Normalize((0.4914, 0.4822, 0.4465), # 均值
  7. (0.2470, 0.2435, 0.2616)) # 标准差
  8. ])
  9. transform_test = transforms.Compose([
  10. transforms.ToTensor(),
  11. transforms.Normalize((0.4914, 0.4822, 0.4465),
  12. (0.2470, 0.2435, 0.2616))
  13. ])

通过torchvision.datasets.CIFAR10加载数据集,并使用DataLoader实现批量读取:

  1. trainset = torchvision.datasets.CIFAR10(
  2. root='./data', train=True, download=True, transform=transform_train)
  3. trainloader = torch.utils.data.DataLoader(
  4. trainset, batch_size=128, shuffle=True, num_workers=2)

三、ResNet-18模型构建

3.1 残差块实现

核心在于BasicBlock类,通过跳跃连接实现特征传递:

  1. import torch.nn as nn
  2. import torch.nn.functional as F
  3. class BasicBlock(nn.Module):
  4. expansion = 1
  5. def __init__(self, in_channels, out_channels, stride=1):
  6. super(BasicBlock, self).__init__()
  7. self.conv1 = nn.Conv2d(in_channels, out_channels,
  8. kernel_size=3, stride=stride, padding=1, bias=False)
  9. self.bn1 = nn.BatchNorm2d(out_channels)
  10. self.conv2 = nn.Conv2d(out_channels, out_channels*self.expansion,
  11. kernel_size=3, stride=1, padding=1, bias=False)
  12. self.bn2 = nn.BatchNorm2d(out_channels*self.expansion)
  13. self.shortcut = nn.Sequential()
  14. if stride != 1 or in_channels != out_channels*self.expansion:
  15. self.shortcut = nn.Sequential(
  16. nn.Conv2d(in_channels, out_channels*self.expansion,
  17. kernel_size=1, stride=stride, bias=False),
  18. nn.BatchNorm2d(out_channels*self.expansion)
  19. )
  20. def forward(self, x):
  21. residual = x
  22. out = F.relu(self.bn1(self.conv1(x)))
  23. out = self.bn2(self.conv2(out))
  24. out += self.shortcut(residual)
  25. out = F.relu(out)
  26. return out

3.2 完整网络架构

通过堆叠残差块构建18层网络:

  1. class ResNet(nn.Module):
  2. def __init__(self, block, num_blocks, num_classes=10):
  3. super(ResNet, self).__init__()
  4. self.in_channels = 64
  5. self.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=False)
  6. self.bn1 = nn.BatchNorm2d(64)
  7. self.layer1 = self._make_layer(block, 64, num_blocks[0], stride=1)
  8. self.layer2 = self._make_layer(block, 128, num_blocks[1], stride=2)
  9. self.layer3 = self._make_layer(block, 256, num_blocks[2], stride=2)
  10. self.layer4 = self._make_layer(block, 512, num_blocks[3], stride=2)
  11. self.linear = nn.Linear(512*block.expansion, num_classes)
  12. def _make_layer(self, block, out_channels, num_blocks, stride):
  13. strides = [stride] + [1]*(num_blocks-1)
  14. layers = []
  15. for stride in strides:
  16. layers.append(block(self.in_channels, out_channels, stride))
  17. self.in_channels = out_channels * block.expansion
  18. return nn.Sequential(*layers)
  19. def forward(self, x):
  20. out = F.relu(self.bn1(self.conv1(x)))
  21. out = self.layer1(out)
  22. out = self.layer2(out)
  23. out = self.layer3(out)
  24. out = self.layer4(out)
  25. out = F.avg_pool2d(out, 4)
  26. out = out.view(out.size(0), -1)
  27. out = self.linear(out)
  28. return out
  29. def ResNet18():
  30. return ResNet(BasicBlock, [2, 2, 2, 2])

四、模型训练与优化

4.1 训练配置

使用交叉熵损失与Adam优化器:

  1. device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
  2. net = ResNet18().to(device)
  3. criterion = nn.CrossEntropyLoss()
  4. optimizer = torch.optim.Adam(net.parameters(), lr=0.001)

4.2 训练循环

实现100轮训练,每轮输出损失与准确率:

  1. for epoch in range(100):
  2. running_loss = 0.0
  3. correct = 0
  4. total = 0
  5. for i, (inputs, labels) in enumerate(trainloader, 0):
  6. inputs, labels = inputs.to(device), labels.to(device)
  7. optimizer.zero_grad()
  8. outputs = net(inputs)
  9. loss = criterion(outputs, labels)
  10. loss.backward()
  11. optimizer.step()
  12. running_loss += loss.item()
  13. _, predicted = torch.max(outputs.data, 1)
  14. total += labels.size(0)
  15. correct += (predicted == labels).sum().item()
  16. print(f'Epoch {epoch+1}, Loss: {running_loss/(i+1):.3f}, '
  17. f'Acc: {100*correct/total:.2f}%')

4.3 性能优化技巧

  1. 学习率调度:采用torch.optim.lr_scheduler.StepLR动态调整学习率
  2. 混合精度训练:使用torch.cuda.amp加速计算
  3. 梯度裁剪:防止梯度爆炸,设置nn.utils.clip_grad_norm_
  4. 早停机制:监控验证集损失,提前终止无效训练

五、模型评估与部署

5.1 测试集评估

在测试集上验证模型泛化能力:

  1. testset = torchvision.datasets.CIFAR10(
  2. root='./data', train=False, download=True, transform=transform_test)
  3. testloader = torch.utils.data.DataLoader(
  4. testset, batch_size=128, shuffle=False, num_workers=2)
  5. net.eval()
  6. correct = 0
  7. total = 0
  8. with torch.no_grad():
  9. for inputs, labels in testloader:
  10. inputs, labels = inputs.to(device), labels.to(device)
  11. outputs = net(inputs)
  12. _, predicted = torch.max(outputs.data, 1)
  13. total += labels.size(0)
  14. correct += (predicted == labels).sum().item()
  15. print(f'Test Accuracy: {100*correct/total:.2f}%')

5.2 模型导出

将训练好的模型导出为ONNX格式,便于跨平台部署:

  1. dummy_input = torch.randn(1, 3, 32, 32).to(device)
  2. torch.onnx.export(net, dummy_input, "resnet18_cifar10.onnx",
  3. input_names=["input"], output_names=["output"])

六、总结与展望

本文通过完整的代码实现,展示了ResNet-18在Cifar-10分类任务中的PyTorch实践。实验表明,在标准数据增强下,该模型可达93%以上的测试准确率。未来工作可探索:

  1. 引入注意力机制提升特征表达能力
  2. 结合知识蒸馏技术压缩模型规模
  3. 扩展至更复杂的数据集(如ImageNet)

开发者可通过调整残差块数量、优化超参数等方式进一步改进性能,同时利用PyTorch的动态计算图特性快速迭代实验。