PyTorch中ResNet18模型保存与再训练全流程解析

PyTorch中ResNet18模型保存与再训练全流程解析

在深度学习项目开发中,模型保存与再训练是两个核心环节。本文以PyTorch框架下的ResNet18模型为例,系统讲解如何实现模型的有效保存,以及如何在不同场景下进行再训练。这种能力对于模型迭代、迁移学习和生产部署具有重要意义。

一、模型保存的两种核心方式

1. 仅保存模型参数(推荐)

这是最常用的保存方式,仅存储模型的学习参数(state_dict),不包含模型结构信息。

  1. import torch
  2. from torchvision.models import resnet18
  3. # 创建并训练模型(示例)
  4. model = resnet18(pretrained=False)
  5. # ... 假设这里进行了训练 ...
  6. # 保存参数
  7. torch.save(model.state_dict(), 'resnet18_weights.pth')

优势

  • 文件体积小(通常几MB到几十MB)
  • 跨平台兼容性好
  • 灵活性强,可配合不同模型结构使用

注意事项

  • 需单独保存模型结构代码
  • 加载时需确保模型结构与保存时一致

2. 保存完整模型

这种方式同时保存模型结构和参数,但存在跨版本兼容性问题。

  1. # 保存完整模型
  2. torch.save(model, 'resnet18_full.pth')

适用场景

  • 快速加载使用
  • 确定不会改变模型结构的场景

潜在问题

  • PyTorch版本升级可能导致加载失败
  • 文件体积较大(包含完整结构定义)

二、模型再训练的完整流程

1. 加载已保存模型

加载参数方式(推荐)

  1. # 定义与保存时相同的模型结构
  2. new_model = resnet18()
  3. # 加载参数
  4. new_model.load_state_dict(torch.load('resnet18_weights.pth'))
  5. new_model.eval() # 切换为评估模式

加载完整模型方式

  1. # 直接加载(需确保环境一致)
  2. loaded_model = torch.load('resnet18_full.pth')
  3. loaded_model.eval()

2. 再训练的数据准备

再训练通常需要准备新的数据集,推荐使用DataLoader进行高效加载:

  1. from torchvision import transforms
  2. from torch.utils.data import DataLoader
  3. from torchvision.datasets import ImageFolder
  4. # 数据预处理
  5. transform = transforms.Compose([
  6. transforms.Resize(256),
  7. transforms.CenterCrop(224),
  8. transforms.ToTensor(),
  9. transforms.Normalize(mean=[0.485, 0.456, 0.406],
  10. std=[0.229, 0.224, 0.225])
  11. ])
  12. # 加载数据集
  13. train_dataset = ImageFolder('path/to/train_data', transform=transform)
  14. train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)

3. 迁移学习场景下的再训练

对于迁移学习,通常需要冻结部分层,只训练最后的全连接层:

  1. # 冻结所有卷积层
  2. for param in new_model.parameters():
  3. param.requires_grad = False
  4. # 替换最后的全连接层(假设新类别数为10)
  5. num_ftrs = new_model.fc.in_features
  6. new_model.fc = torch.nn.Linear(num_ftrs, 10)
  7. # 只训练新添加的层
  8. params_to_update = []
  9. for name, param in new_model.named_parameters():
  10. if 'fc' in name:
  11. params_to_update.append(param)
  12. # 定义优化器(仅优化可训练参数)
  13. optimizer = torch.optim.SGD(params_to_update, lr=0.001, momentum=0.9)

4. 完整再训练循环示例

  1. import torch.nn as nn
  2. # 定义损失函数
  3. criterion = nn.CrossEntropyLoss()
  4. # 训练循环
  5. num_epochs = 10
  6. for epoch in range(num_epochs):
  7. new_model.train() # 切换为训练模式
  8. running_loss = 0.0
  9. for inputs, labels in train_loader:
  10. optimizer.zero_grad()
  11. outputs = new_model(inputs)
  12. loss = criterion(outputs, labels)
  13. loss.backward()
  14. optimizer.step()
  15. running_loss += loss.item()
  16. print(f'Epoch {epoch+1}, Loss: {running_loss/len(train_loader):.4f}')
  17. # 保存再训练后的模型
  18. torch.save(new_model.state_dict(), 'resnet18_finetuned.pth')

三、最佳实践与注意事项

1. 模型保存策略

  • 版本控制:为每个保存的模型添加版本号和日期标记
  • 多阶段保存:在训练过程中定期保存检查点
  • 元数据记录:保存时记录训练参数、数据集信息等元数据

2. 再训练优化技巧

  • 学习率调整:再训练时通常需要比初始训练更低的学习率
  • 分阶段解冻:可逐步解冻更多层进行微调
  • 数据增强:根据新任务调整数据增强策略

3. 跨平台兼容性处理

  • 保存时明确指定map_location参数:
    1. model.load_state_dict(torch.load('model.pth', map_location=torch.device('cpu')))
  • 对于GPU训练的模型,加载到CPU环境时需要特别处理

4. 生产环境部署建议

  • 保存优化后的模型(使用torch.jit.trace或torch.jit.script)
  • 考虑使用ONNX格式实现跨框架部署
  • 对于百度智能云等平台,可将其转换为平台支持的推理格式

四、性能优化方向

  1. 混合精度训练:使用torch.cuda.amp提升再训练速度
  2. 分布式训练:对于大数据集,可采用分布式数据并行
  3. 模型剪枝:在再训练前进行参数剪枝,减少计算量
  4. 量化感知训练:为后续部署量化模型做准备

五、常见问题解决方案

  1. 加载模型时的尺寸不匹配错误

    • 检查输入数据的预处理是否与原始模型一致
    • 确认最后一层的输出类别数是否正确修改
  2. 再训练时损失不下降

    • 检查学习率是否设置合理
    • 确认数据标签是否正确
    • 尝试使用不同的初始化方法
  3. 跨设备加载失败

    • 明确指定map_location参数
    • 确保PyTorch版本兼容

通过系统掌握这些技术要点,开发者可以高效实现ResNet18模型的保存与再训练,为各种计算机视觉任务提供灵活的模型迭代方案。这种能力在快速原型开发、模型优化和生产部署等场景中都具有重要价值。