在使用 PyTorch 进行深度学习模型训练时,模型的保存与加载是非常重要的操作。它不仅可以帮助我们保存训练好的模型,以便后续使用,还能在训练过程中出现意外时恢复训练进度。在 PyTorch 中,常见的模型保存文件格式是 .pth
和 .pt
,本文将详细介绍这两种格式,并通过实际例子展示如何保存和加载模型。
.pth
和 .pt
文件格式概述其实在 PyTorch 里,.pth
和 .pt
并没有本质区别,它们都可以用来保存 PyTorch 模型的状态信息。.pth
是 “PyTorch” 的缩写,而 .pt
则代表 “PyTorch” 或 “Python Torch”。这两种扩展名都是社区约定俗成的,PyTorch 本身并不强制使用特定的扩展名。
在 PyTorch 中,模型的状态主要由两部分组成:模型的结构和模型的参数。模型的参数存储在 state_dict
中,它是一个 Python 字典,将每一层的参数名称映射到对应的张量。保存和加载模型时,我们通常保存和加载 state_dict
。
state_dict
以下是一个简单的例子,展示如何保存模型的 state_dict
到 .pth
文件:
import torch
import torch.nn as nn
# 定义一个简单的神经网络模型
class SimpleNet(nn.Module):
def __init__(self):
super(SimpleNet, self).__init__()
self.fc = nn.Linear(10, 1)
def forward(self, x):
return self.fc(x)
# 初始化模型
model = SimpleNet()
# 保存模型的 state_dict
torch.save(model.state_dict(), 'model.pth')
state_dict
加载模型的 state_dict
时,需要先创建一个相同结构的模型实例,然后使用 load_state_dict
方法加载参数:
# 创建一个新的模型实例
new_model = SimpleNet()
# 加载保存的 state_dict
new_model.load_state_dict(torch.load('model.pth'))
# 将模型设置为评估模式
new_model.eval()
除了保存 state_dict
,我们还可以保存整个模型,包括模型的结构和参数:
# 保存整个模型
torch.save(model, 'whole_model.pth')
# 加载整个模型
loaded_model = torch.load('whole_model.pth')
loaded_model.eval()
不过,保存整个模型可能会导致一些兼容性问题,因为它依赖于模型定义的代码。因此,建议只保存 state_dict
。
在实际应用中,我们可能还需要保存和加载其他组件,如优化器的状态、训练的轮数等。以下是一个例子:
import torch.optim as optim
# 定义优化器
optimizer = optim.SGD(model.parameters(), lr=0.01)
# 训练一些轮数
epochs = 10
# 保存模型、优化器和训练轮数
checkpoint = {
'epoch': epochs,
'model_state_dict': model.state_dict(),
'optimizer_state_dict': optimizer.state_dict()
}
torch.save(checkpoint, 'checkpoint.pth')
# 加载保存的组件
checkpoint = torch.load('checkpoint.pth')
new_model = SimpleNet()
new_optimizer = optim.SGD(new_model.parameters(), lr=0.01)
new_model.load_state_dict(checkpoint['model_state_dict'])
new_optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
start_epoch = checkpoint['epoch']
操作 | 代码示例 | 说明 |
---|---|---|
保存模型的 state_dict |
torch.save(model.state_dict(), 'model.pth') |
只保存模型的参数,推荐使用 |
加载模型的 state_dict |
model.load_state_dict(torch.load('model.pth')) |
先创建相同结构的模型实例,再加载参数 |
保存整个模型 | torch.save(model, 'whole_model.pth') |
保存模型的结构和参数,可能存在兼容性问题 |
加载整个模型 | loaded_model = torch.load('whole_model.pth') |
直接加载整个模型 |
保存多个组件 | checkpoint = {'epoch': epochs, 'model_state_dict': model.state_dict(), 'optimizer_state_dict': optimizer.state_dict()}; torch.save(checkpoint, 'checkpoint.pth') |
保存模型、优化器和训练轮数等信息 |
加载多个组件 | checkpoint = torch.load('checkpoint.pth'); model.load_state_dict(checkpoint['model_state_dict']); optimizer.load_state_dict(checkpoint['optimizer_state_dict']); start_epoch = checkpoint['epoch'] |
加载保存的多个组件 |
在 PyTorch 中,.pth
和 .pt
文件格式都可以用来保存模型。通过保存和加载 state_dict
,我们可以方便地管理模型的参数,并在需要时恢复训练进度。同时,保存多个组件的功能也为我们提供了更多的灵活性。希望本文能帮助你更好地理解和使用 PyTorch 中的模型保存和加载功能。