Ubuntu PyTorch如何进行模型保存与加载
在Ubuntu系统中使用PyTorch进行模型的保存与加载是一个常见的操作。以下是详细的步骤和示例代码:
保存模型
- 定义模型:首先,你需要定义一个模型。
- 训练模型:在训练过程中,模型会不断更新其权重。
- 保存模型:使用
torch.save()
函数将整个模型或仅保存模型的状态字典(state_dict)。
示例代码
import torch
import torch.nn as nn
import torch.optim as optim
# 定义一个简单的神经网络模型
class SimpleNet(nn.Module):
def __init__(self):
super(SimpleNet, self).__init__()
self.fc = nn.Linear(784, 10) # 假设输入是784维,输出是10类
def forward(self, x):
x = x.view(x.size(0), -1) # 将输入展平
x = self.fc(x)
return x
# 创建模型实例
model = SimpleNet()
# 定义损失函数和优化器
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(model.parameters(), lr=0.01)
# 假设我们有一些训练数据
inputs = torch.randn(64, 1, 28, 28) # 示例输入
labels = torch.randint(0, 10, (64,)) # 示例标签
# 训练模型(这里省略了训练循环)
for epoch in range(5):
optimizer.zero_grad()
outputs = model(inputs)
loss = criterion(outputs, labels)
loss.backward()
optimizer.step()
# 保存整个模型
torch.save(model, 'model.pth')
# 或者仅保存模型的状态字典
torch.save(model.state_dict(), 'model_state_dict.pth')
加载模型
- 加载模型:使用
torch.load()
函数加载模型或模型的状态字典。 - 加载状态字典:如果之前保存的是状态字典,需要先创建一个模型实例,然后加载状态字典。
示例代码
# 加载整个模型
loaded_model = torch.load('model.pth')
# 或者加载模型的状态字典
model = SimpleNet() # 创建一个新的模型实例
model.load_state_dict(torch.load('model_state_dict.pth'))
# 确保模型在评估模式
model.eval()
# 使用加载的模型进行预测
with torch.no_grad():
test_inputs = torch.randn(1, 1, 28, 28) # 示例测试输入
predictions = loaded_model(test_inputs)
print(predictions)
注意事项
- 设备兼容性:如果模型是在GPU上训练的,保存的模型文件默认也会包含GPU信息。在加载到CPU上进行推理时,需要将模型移动到CPU:
model = model.to('cpu')
- 版本兼容性:不同版本的PyTorch可能会有不同的模型保存格式,确保加载模型的PyTorch版本与保存模型的版本兼容。
通过以上步骤,你可以在Ubuntu系统中轻松地进行PyTorch模型的保存与加载。
本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若内容造成侵权请联系我们,一经查实立即删除!