ResNet-18在Cifar-10图像分类中的PyTorch实践
一、引言
Cifar-10数据集作为计算机视觉领域的经典基准,包含10类共6万张32x32彩色图像,常用于验证模型在有限数据下的泛化能力。ResNet-18作为残差网络的轻量级版本,通过跳跃连接缓解梯度消失问题,在保持较高精度的同时降低计算复杂度。本文将结合PyTorch框架,系统介绍从数据加载到模型部署的全流程实现,并提供优化建议。
二、环境准备与数据加载
2.1 环境配置
建议使用PyTorch 1.8+版本,配合CUDA 10.2+以支持GPU加速。通过pip install torch torchvision安装基础库,并确保NumPy、Matplotlib等辅助工具已就绪。
2.2 数据预处理
Cifar-10原始数据需进行标准化与增强:
import torchvision.transforms as transformstransform_train = transforms.Compose([transforms.RandomCrop(32, padding=4), # 随机裁剪增强transforms.RandomHorizontalFlip(), # 水平翻转transforms.ToTensor(), # 转为Tensortransforms.Normalize((0.4914, 0.4822, 0.4465), # 均值(0.2470, 0.2435, 0.2616)) # 标准差])transform_test = transforms.Compose([transforms.ToTensor(),transforms.Normalize((0.4914, 0.4822, 0.4465),(0.2470, 0.2435, 0.2616))])
通过torchvision.datasets.CIFAR10加载数据集,并使用DataLoader实现批量读取:
trainset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transform_train)trainloader = torch.utils.data.DataLoader(trainset, batch_size=128, shuffle=True, num_workers=2)
三、ResNet-18模型构建
3.1 残差块实现
核心在于BasicBlock类,通过跳跃连接实现特征传递:
import torch.nn as nnimport torch.nn.functional as Fclass BasicBlock(nn.Module):expansion = 1def __init__(self, in_channels, out_channels, stride=1):super(BasicBlock, self).__init__()self.conv1 = nn.Conv2d(in_channels, out_channels,kernel_size=3, stride=stride, padding=1, bias=False)self.bn1 = nn.BatchNorm2d(out_channels)self.conv2 = nn.Conv2d(out_channels, out_channels*self.expansion,kernel_size=3, stride=1, padding=1, bias=False)self.bn2 = nn.BatchNorm2d(out_channels*self.expansion)self.shortcut = nn.Sequential()if stride != 1 or in_channels != out_channels*self.expansion:self.shortcut = nn.Sequential(nn.Conv2d(in_channels, out_channels*self.expansion,kernel_size=1, stride=stride, bias=False),nn.BatchNorm2d(out_channels*self.expansion))def forward(self, x):residual = xout = F.relu(self.bn1(self.conv1(x)))out = self.bn2(self.conv2(out))out += self.shortcut(residual)out = F.relu(out)return out
3.2 完整网络架构
通过堆叠残差块构建18层网络:
class ResNet(nn.Module):def __init__(self, block, num_blocks, num_classes=10):super(ResNet, self).__init__()self.in_channels = 64self.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=False)self.bn1 = nn.BatchNorm2d(64)self.layer1 = self._make_layer(block, 64, num_blocks[0], stride=1)self.layer2 = self._make_layer(block, 128, num_blocks[1], stride=2)self.layer3 = self._make_layer(block, 256, num_blocks[2], stride=2)self.layer4 = self._make_layer(block, 512, num_blocks[3], stride=2)self.linear = nn.Linear(512*block.expansion, num_classes)def _make_layer(self, block, out_channels, num_blocks, stride):strides = [stride] + [1]*(num_blocks-1)layers = []for stride in strides:layers.append(block(self.in_channels, out_channels, stride))self.in_channels = out_channels * block.expansionreturn nn.Sequential(*layers)def forward(self, x):out = F.relu(self.bn1(self.conv1(x)))out = self.layer1(out)out = self.layer2(out)out = self.layer3(out)out = self.layer4(out)out = F.avg_pool2d(out, 4)out = out.view(out.size(0), -1)out = self.linear(out)return outdef ResNet18():return ResNet(BasicBlock, [2, 2, 2, 2])
四、模型训练与优化
4.1 训练配置
使用交叉熵损失与Adam优化器:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")net = ResNet18().to(device)criterion = nn.CrossEntropyLoss()optimizer = torch.optim.Adam(net.parameters(), lr=0.001)
4.2 训练循环
实现100轮训练,每轮输出损失与准确率:
for epoch in range(100):running_loss = 0.0correct = 0total = 0for i, (inputs, labels) in enumerate(trainloader, 0):inputs, labels = inputs.to(device), labels.to(device)optimizer.zero_grad()outputs = net(inputs)loss = criterion(outputs, labels)loss.backward()optimizer.step()running_loss += loss.item()_, predicted = torch.max(outputs.data, 1)total += labels.size(0)correct += (predicted == labels).sum().item()print(f'Epoch {epoch+1}, Loss: {running_loss/(i+1):.3f}, 'f'Acc: {100*correct/total:.2f}%')
4.3 性能优化技巧
- 学习率调度:采用
torch.optim.lr_scheduler.StepLR动态调整学习率 - 混合精度训练:使用
torch.cuda.amp加速计算 - 梯度裁剪:防止梯度爆炸,设置
nn.utils.clip_grad_norm_ - 早停机制:监控验证集损失,提前终止无效训练
五、模型评估与部署
5.1 测试集评估
在测试集上验证模型泛化能力:
testset = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=transform_test)testloader = torch.utils.data.DataLoader(testset, batch_size=128, shuffle=False, num_workers=2)net.eval()correct = 0total = 0with torch.no_grad():for inputs, labels in testloader:inputs, labels = inputs.to(device), labels.to(device)outputs = net(inputs)_, predicted = torch.max(outputs.data, 1)total += labels.size(0)correct += (predicted == labels).sum().item()print(f'Test Accuracy: {100*correct/total:.2f}%')
5.2 模型导出
将训练好的模型导出为ONNX格式,便于跨平台部署:
dummy_input = torch.randn(1, 3, 32, 32).to(device)torch.onnx.export(net, dummy_input, "resnet18_cifar10.onnx",input_names=["input"], output_names=["output"])
六、总结与展望
本文通过完整的代码实现,展示了ResNet-18在Cifar-10分类任务中的PyTorch实践。实验表明,在标准数据增强下,该模型可达93%以上的测试准确率。未来工作可探索:
- 引入注意力机制提升特征表达能力
- 结合知识蒸馏技术压缩模型规模
- 扩展至更复杂的数据集(如ImageNet)
开发者可通过调整残差块数量、优化超参数等方式进一步改进性能,同时利用PyTorch的动态计算图特性快速迭代实验。