基于PyTorch的ResNet18实现MNIST手写数字识别

基于PyTorch的ResNet18实现MNIST手写数字识别

MNIST数据集作为计算机视觉领域的经典基准,常用于验证深度学习模型的图像分类能力。传统CNN模型(如LeNet)虽能高效处理该任务,但现代深度学习更倾向于使用具有更强特征提取能力的架构。本文将展示如何利用PyTorch框架下的ResNet18模型实现MNIST手写数字识别,并探讨其技术实现细节与优化策略。

一、技术背景与模型选择

MNIST数据集包含60,000张训练图像和10,000张测试图像,每张图像为28×28像素的单通道灰度图,标注为0-9的数字类别。传统CNN通过堆叠卷积层和池化层实现特征提取,但存在梯度消失问题。ResNet(残差网络)通过引入残差连接(Residual Block),允许梯度直接反向传播至浅层,解决了深层网络训练困难的问题。

ResNet18作为轻量级版本,包含17个卷积层和1个全连接层,通过4个残差块(每个块含2个卷积层)构建。其优势在于:

  1. 梯度流畅性:残差连接确保深层网络仍可有效训练;
  2. 参数效率:相比VGG等模型,参数量更少;
  3. 泛化能力:在ImageNet等大规模数据集上验证的性能可迁移至小规模任务。

二、数据预处理与加载

1. 数据标准化

MNIST图像像素值范围为[0,1],需标准化至[-1,1]以匹配ResNet的预训练权重输入范围:

  1. import torchvision.transforms as transforms
  2. transform = transforms.Compose([
  3. transforms.ToTensor(),
  4. transforms.Normalize((0.5,), (0.5,)) # 均值0.5,标准差0.5
  5. ])

2. 数据集加载

使用PyTorch的DataLoader实现批量加载与多线程加速:

  1. from torchvision.datasets import MNIST
  2. from torch.utils.data import DataLoader
  3. train_dataset = MNIST(root='./data', train=True, transform=transform, download=True)
  4. test_dataset = MNIST(root='./data', train=False, transform=transform, download=True)
  5. train_loader = DataLoader(train_dataset, batch_size=128, shuffle=True, num_workers=4)
  6. test_loader = DataLoader(test_dataset, batch_size=128, shuffle=False, num_workers=4)

三、模型构建与修改

1. 加载预训练ResNet18

直接加载ImageNet预训练模型时,需调整全连接层以匹配MNIST的10个类别:

  1. import torch.nn as nn
  2. from torchvision.models import resnet18
  3. model = resnet18(pretrained=True)
  4. # 修改全连接层:原输出1000类 → 10类
  5. model.fc = nn.Linear(model.fc.in_features, 10)

2. 输入通道适配

ResNet18默认输入为3通道RGB图像,而MNIST为单通道。可通过以下方式处理:

  • 方案1:复制单通道至3通道(简单但非最优):

    1. class MNISTAdapter(nn.Module):
    2. def __init__(self, model):
    3. super().__init__()
    4. self.model = model
    5. def forward(self, x):
    6. # x.shape = [B,1,28,28] → [B,3,28,28]
    7. x = x.repeat(1, 3, 1, 1)
    8. return self.model(x)
    9. model = MNISTAdapter(model)
  • 方案2:修改第一层卷积(推荐):
    1. # 手动构建ResNet18并修改第一层
    2. class CustomResNet18(nn.Module):
    3. def __init__(self):
    4. super().__init__()
    5. self.conv1 = nn.Conv2d(1, 64, kernel_size=7, stride=2, padding=3, bias=False)
    6. # 后续层与标准ResNet18一致...

四、训练流程与优化

1. 损失函数与优化器

使用交叉熵损失和Adam优化器:

  1. import torch.optim as optim
  2. criterion = nn.CrossEntropyLoss()
  3. optimizer = optim.Adam(model.parameters(), lr=0.001)

2. 训练循环

  1. def train(model, train_loader, criterion, optimizer, epochs=10):
  2. model.train()
  3. for epoch in range(epochs):
  4. running_loss = 0.0
  5. for images, labels in train_loader:
  6. optimizer.zero_grad()
  7. outputs = model(images)
  8. loss = criterion(outputs, labels)
  9. loss.backward()
  10. optimizer.step()
  11. running_loss += loss.item()
  12. print(f'Epoch {epoch+1}, Loss: {running_loss/len(train_loader):.4f}')

3. 测试评估

  1. def test(model, test_loader):
  2. model.eval()
  3. correct = 0
  4. total = 0
  5. with torch.no_grad():
  6. for images, labels in test_loader:
  7. outputs = model(images)
  8. _, predicted = torch.max(outputs.data, 1)
  9. total += labels.size(0)
  10. correct += (predicted == labels).sum().item()
  11. print(f'Accuracy: {100 * correct / total:.2f}%')

五、性能优化策略

1. 学习率调度

使用ReduceLROnPlateau动态调整学习率:

  1. scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'min', patience=2)
  2. # 在训练循环中添加:
  3. scheduler.step(running_loss/len(train_loader))

2. 数据增强

对MNIST添加轻微旋转和缩放:

  1. transform = transforms.Compose([
  2. transforms.RandomRotation(10),
  3. transforms.RandomResizedCrop(28, scale=(0.9, 1.1)),
  4. transforms.ToTensor(),
  5. transforms.Normalize((0.5,), (0.5,))
  6. ])

3. 模型微调

冻结浅层参数,仅训练最后几个残差块:

  1. for name, param in model.named_parameters():
  2. if 'layer4' not in name and 'fc' not in name: # 冻结前3个残差块
  3. param.requires_grad = False

六、完整代码示例

  1. import torch
  2. import torch.nn as nn
  3. import torch.optim as optim
  4. from torchvision.datasets import MNIST
  5. from torchvision.models import resnet18
  6. from torch.utils.data import DataLoader
  7. import torchvision.transforms as transforms
  8. # 数据预处理
  9. transform = transforms.Compose([
  10. transforms.ToTensor(),
  11. transforms.Normalize((0.5,), (0.5,))
  12. ])
  13. # 加载数据集
  14. train_dataset = MNIST(root='./data', train=True, transform=transform, download=True)
  15. test_dataset = MNIST(root='./data', train=False, transform=transform, download=True)
  16. train_loader = DataLoader(train_dataset, batch_size=128, shuffle=True, num_workers=4)
  17. test_loader = DataLoader(test_dataset, batch_size=128, shuffle=False, num_workers=4)
  18. # 修改模型
  19. model = resnet18(pretrained=True)
  20. model.conv1 = nn.Conv2d(1, 64, kernel_size=7, stride=2, padding=3, bias=False) # 修改输入通道
  21. model.fc = nn.Linear(model.fc.in_features, 10) # 修改输出层
  22. # 训练配置
  23. device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
  24. model = model.to(device)
  25. criterion = nn.CrossEntropyLoss()
  26. optimizer = optim.Adam(model.parameters(), lr=0.001)
  27. # 训练循环
  28. def train(model, train_loader, criterion, optimizer, epochs=10):
  29. model.train()
  30. for epoch in range(epochs):
  31. running_loss = 0.0
  32. for images, labels in train_loader:
  33. images, labels = images.to(device), labels.to(device)
  34. optimizer.zero_grad()
  35. outputs = model(images)
  36. loss = criterion(outputs, labels)
  37. loss.backward()
  38. optimizer.step()
  39. running_loss += loss.item()
  40. print(f'Epoch {epoch+1}, Loss: {running_loss/len(train_loader):.4f}')
  41. # 测试函数
  42. def test(model, test_loader):
  43. model.eval()
  44. correct = 0
  45. total = 0
  46. with torch.no_grad():
  47. for images, labels in test_loader:
  48. images, labels = images.to(device), labels.to(device)
  49. outputs = model(images)
  50. _, predicted = torch.max(outputs.data, 1)
  51. total += labels.size(0)
  52. correct += (predicted == labels).sum().item()
  53. print(f'Accuracy: {100 * correct / total:.2f}%')
  54. # 执行训练与测试
  55. train(model, train_loader, criterion, optimizer, epochs=10)
  56. test(model, test_loader)

七、总结与展望

本文通过PyTorch实现了基于ResNet18的MNIST手写数字识别,验证了残差网络在小规模数据集上的有效性。实际应用中,可进一步探索:

  1. 轻量化模型:使用MobileNet等更高效的架构;
  2. 自监督学习:通过对比学习预训练提升特征提取能力;
  3. 部署优化:将模型转换为ONNX或TensorRT格式以提升推理速度。

对于企业级应用,可结合百度智能云的AI开发平台,实现模型训练、调优与部署的一站式管理,显著降低深度学习应用的落地成本。