基于PyTorch与ResNet18的猫狗图像分类实战指南

基于PyTorch与ResNet18的猫狗图像分类实战指南

一、技术背景与项目价值

猫狗图像分类是计算机视觉领域的经典入门任务,其核心目标是通过深度学习模型自动区分输入图像中的猫或狗。该任务不仅涵盖图像预处理、模型构建、训练优化等关键技术环节,还可扩展至宠物品种识别、动物行为分析等实际场景。采用PyTorch框架与ResNet18模型组合,既能利用PyTorch的动态计算图特性提升开发效率,又可通过ResNet18的残差连接结构解决深层网络梯度消失问题,实现高精度与低计算成本的平衡。

二、环境准备与数据集构建

1. 环境配置

  • 框架版本:PyTorch 2.0+ + Torchvision 0.15+
  • 硬件要求:推荐GPU显存≥4GB(支持CUDA 11.7+)
  • 依赖安装
    1. pip install torch torchvision pillow matplotlib

2. 数据集准备

采用Kaggle公开数据集”Dogs vs Cats”,包含25,000张训练图像(猫狗各半)和12,500张测试图像。数据预处理步骤如下:

  1. 目录结构
    1. data/
    2. train/
    3. cat/
    4. cat001.jpg
    5. ...
    6. dog/
    7. dog001.jpg
    8. ...
    9. test/
    10. img001.jpg
    11. ...
  2. 图像增强:使用torchvision.transforms实现随机裁剪、水平翻转、归一化等操作:
    1. train_transform = transforms.Compose([
    2. transforms.RandomResizedCrop(224),
    3. transforms.RandomHorizontalFlip(),
    4. transforms.ToTensor(),
    5. transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    6. ])

三、ResNet18模型实现与迁移学习

1. 模型加载与微调

ResNet18作为预训练模型,其卷积层已学习到通用图像特征。通过迁移学习仅需替换最后的全连接层:

  1. import torchvision.models as models
  2. model = models.resnet18(pretrained=True)
  3. num_features = model.fc.in_features
  4. model.fc = torch.nn.Linear(num_features, 2) # 输出类别数为2

2. 训练流程优化

  • 损失函数:交叉熵损失(nn.CrossEntropyLoss
  • 优化器:AdamW(学习率3e-4,权重衰减1e-4)
  • 学习率调度:CosineAnnealingLR实现动态调整
    1. criterion = nn.CrossEntropyLoss()
    2. optimizer = torch.optim.AdamW(model.parameters(), lr=3e-4, weight_decay=1e-4)
    3. scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=20)

3. 训练循环实现

完整训练代码示例:

  1. def train_model(model, dataloader, criterion, optimizer, scheduler, num_epochs=25):
  2. device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
  3. model = model.to(device)
  4. for epoch in range(num_epochs):
  5. model.train()
  6. running_loss = 0.0
  7. for inputs, labels in dataloader:
  8. inputs, labels = inputs.to(device), labels.to(device)
  9. optimizer.zero_grad()
  10. outputs = model(inputs)
  11. loss = criterion(outputs, labels)
  12. loss.backward()
  13. optimizer.step()
  14. running_loss += loss.item() * inputs.size(0)
  15. epoch_loss = running_loss / len(dataloader.dataset)
  16. scheduler.step()
  17. print(f'Epoch {epoch+1}/{num_epochs} Loss: {epoch_loss:.4f}')
  18. return model

四、性能优化与部署实践

1. 精度提升策略

  • 数据增强组合:在测试阶段采用TenCrop增强(裁剪10个区域取平均预测)
  • 模型集成:融合3个不同初始化ResNet18的预测结果
  • 测试时增强(TTA):实现代码:

    1. def predict_tta(model, image_path, n_crops=10):
    2. model.eval()
    3. total_pred = torch.zeros(2)
    4. for _ in range(n_crops):
    5. img = load_image(image_path) # 自定义加载函数
    6. with torch.no_grad():
    7. output = model(img.unsqueeze(0))
    8. total_pred += output.squeeze(0)
    9. return total_pred.argmax().item()

2. 模型量化与部署

  • 动态量化:减少模型体积和推理延迟
    1. quantized_model = torch.quantization.quantize_dynamic(
    2. model, {nn.Linear}, dtype=torch.qint8
    3. )
  • ONNX导出:支持跨平台部署
    1. dummy_input = torch.randn(1, 3, 224, 224)
    2. torch.onnx.export(model, dummy_input, "resnet18_catdog.onnx")

五、典型问题与解决方案

1. 过拟合问题

  • 现象:训练集准确率>99%,测试集<85%
  • 对策
    • 增加L2正则化(权重衰减1e-4)
    • 使用Dropout层(p=0.3)
    • 早停法(监控验证集损失)

2. 推理速度优化

  • 批量预测:将单张图像推理改为批量处理
    1. def batch_predict(model, image_tensor):
    2. model.eval()
    3. with torch.no_grad():
    4. outputs = model(image_tensor)
    5. return outputs.argmax(dim=1)
  • 半精度推理:使用FP16减少计算量
    1. model.half()
    2. input_tensor = input_tensor.half()

六、扩展应用场景

  1. 多标签分类:修改输出层为Sigmoid激活,支持同时识别猫狗
  2. 实时检测系统:结合YOLOv5实现目标检测+分类一体化
  3. 移动端部署:通过TensorRT优化实现Android/iOS端推理

七、最佳实践建议

  1. 数据质量优先:确保每类样本数量均衡,删除错误标注图像
  2. 渐进式训练:先冻结卷积层训练全连接层,再解冻部分层微调
  3. 监控指标:除准确率外,重点关注F1分数和混淆矩阵
  4. 硬件适配:根据GPU显存调整batch_size(推荐64-256)

通过上述技术方案,在标准数据集上可实现98.5%以上的测试准确率,单张图像推理延迟<50ms(GPU环境)。开发者可根据实际需求调整模型深度、输入分辨率等参数,平衡精度与效率。