微信登录

模型保存 - 保存格式 - 保存为.pth 或.pt 文件

模型保存 - 保存格式 - 保存为.pth 或.pt 文件

在使用 PyTorch 进行深度学习模型训练时,模型的保存与加载是非常重要的操作。它不仅可以帮助我们保存训练好的模型,以便后续使用,还能在训练过程中出现意外时恢复训练进度。在 PyTorch 中,常见的模型保存文件格式是 .pth.pt,本文将详细介绍这两种格式,并通过实际例子展示如何保存和加载模型。

1. .pth.pt 文件格式概述

其实在 PyTorch 里,.pth.pt 并没有本质区别,它们都可以用来保存 PyTorch 模型的状态信息。.pth 是 “PyTorch” 的缩写,而 .pt 则代表 “PyTorch” 或 “Python Torch”。这两种扩展名都是社区约定俗成的,PyTorch 本身并不强制使用特定的扩展名。

2. 保存和加载模型的基本原理

在 PyTorch 中,模型的状态主要由两部分组成:模型的结构和模型的参数。模型的参数存储在 state_dict 中,它是一个 Python 字典,将每一层的参数名称映射到对应的张量。保存和加载模型时,我们通常保存和加载 state_dict

2.1 保存模型的 state_dict

以下是一个简单的例子,展示如何保存模型的 state_dict.pth 文件:

  1. import torch
  2. import torch.nn as nn
  3. # 定义一个简单的神经网络模型
  4. class SimpleNet(nn.Module):
  5. def __init__(self):
  6. super(SimpleNet, self).__init__()
  7. self.fc = nn.Linear(10, 1)
  8. def forward(self, x):
  9. return self.fc(x)
  10. # 初始化模型
  11. model = SimpleNet()
  12. # 保存模型的 state_dict
  13. torch.save(model.state_dict(), 'model.pth')

2.2 加载模型的 state_dict

加载模型的 state_dict 时,需要先创建一个相同结构的模型实例,然后使用 load_state_dict 方法加载参数:

  1. # 创建一个新的模型实例
  2. new_model = SimpleNet()
  3. # 加载保存的 state_dict
  4. new_model.load_state_dict(torch.load('model.pth'))
  5. # 将模型设置为评估模式
  6. new_model.eval()

2.3 保存和加载整个模型

除了保存 state_dict,我们还可以保存整个模型,包括模型的结构和参数:

  1. # 保存整个模型
  2. torch.save(model, 'whole_model.pth')
  3. # 加载整个模型
  4. loaded_model = torch.load('whole_model.pth')
  5. loaded_model.eval()

不过,保存整个模型可能会导致一些兼容性问题,因为它依赖于模型定义的代码。因此,建议只保存 state_dict

3. 保存和加载多个组件

在实际应用中,我们可能还需要保存和加载其他组件,如优化器的状态、训练的轮数等。以下是一个例子:

  1. import torch.optim as optim
  2. # 定义优化器
  3. optimizer = optim.SGD(model.parameters(), lr=0.01)
  4. # 训练一些轮数
  5. epochs = 10
  6. # 保存模型、优化器和训练轮数
  7. checkpoint = {
  8. 'epoch': epochs,
  9. 'model_state_dict': model.state_dict(),
  10. 'optimizer_state_dict': optimizer.state_dict()
  11. }
  12. torch.save(checkpoint, 'checkpoint.pth')
  13. # 加载保存的组件
  14. checkpoint = torch.load('checkpoint.pth')
  15. new_model = SimpleNet()
  16. new_optimizer = optim.SGD(new_model.parameters(), lr=0.01)
  17. new_model.load_state_dict(checkpoint['model_state_dict'])
  18. new_optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
  19. start_epoch = checkpoint['epoch']

4. 总结

操作 代码示例 说明
保存模型的 state_dict torch.save(model.state_dict(), 'model.pth') 只保存模型的参数,推荐使用
加载模型的 state_dict model.load_state_dict(torch.load('model.pth')) 先创建相同结构的模型实例,再加载参数
保存整个模型 torch.save(model, 'whole_model.pth') 保存模型的结构和参数,可能存在兼容性问题
加载整个模型 loaded_model = torch.load('whole_model.pth') 直接加载整个模型
保存多个组件 checkpoint = {'epoch': epochs, 'model_state_dict': model.state_dict(), 'optimizer_state_dict': optimizer.state_dict()}; torch.save(checkpoint, 'checkpoint.pth') 保存模型、优化器和训练轮数等信息
加载多个组件 checkpoint = torch.load('checkpoint.pth'); model.load_state_dict(checkpoint['model_state_dict']); optimizer.load_state_dict(checkpoint['optimizer_state_dict']); start_epoch = checkpoint['epoch'] 加载保存的多个组件

5. 结论

在 PyTorch 中,.pth.pt 文件格式都可以用来保存模型。通过保存和加载 state_dict,我们可以方便地管理模型的参数,并在需要时恢复训练进度。同时,保存多个组件的功能也为我们提供了更多的灵活性。希望本文能帮助你更好地理解和使用 PyTorch 中的模型保存和加载功能。