微信登录

模型加载 - 加载已保存模型 - 恢复训练状态

模型加载 - 加载已保存模型 - 恢复训练状态

在深度学习领域,模型训练往往是一个耗时且资源密集的过程。有时候,由于各种原因(如计算资源限制、意外中断等),我们可能无法一次性完成整个训练过程。这时,能够保存模型的状态并在后续恢复训练就显得尤为重要。PyTorch 为我们提供了强大而灵活的工具来实现模型的保存和加载,下面将详细介绍如何在 PyTorch 中加载已保存的模型并恢复训练状态。

1. 保存模型和训练状态

在 PyTorch 中,我们可以使用 torch.save() 函数来保存模型的参数、优化器的状态以及其他训练相关的信息。以下是一个简单的示例:

  1. import torch
  2. import torch.nn as nn
  3. import torch.optim as optim
  4. # 定义一个简单的模型
  5. class SimpleModel(nn.Module):
  6. def __init__(self):
  7. super(SimpleModel, self).__init__()
  8. self.fc = nn.Linear(10, 1)
  9. def forward(self, x):
  10. return self.fc(x)
  11. # 初始化模型、优化器和损失函数
  12. model = SimpleModel()
  13. optimizer = optim.SGD(model.parameters(), lr=0.01)
  14. criterion = nn.MSELoss()
  15. # 假设我们已经训练了一些轮次
  16. epochs = 10
  17. for epoch in range(epochs):
  18. # 这里省略具体的训练步骤
  19. pass
  20. # 保存模型和训练状态
  21. checkpoint = {
  22. 'epoch': epochs,
  23. 'model_state_dict': model.state_dict(),
  24. 'optimizer_state_dict': optimizer.state_dict(),
  25. 'loss': criterion
  26. }
  27. torch.save(checkpoint, 'checkpoint.pth')

在上述代码中,我们首先定义了一个简单的线性模型 SimpleModel,然后初始化了优化器和损失函数。接着,我们进行了一些轮次的训练(这里省略了具体的训练步骤)。最后,我们创建了一个字典 checkpoint,其中包含了当前的训练轮次、模型的参数、优化器的状态和损失函数,并使用 torch.save() 函数将其保存到文件 checkpoint.pth 中。

2. 加载已保存的模型和训练状态

要恢复训练状态,我们可以使用 torch.load() 函数加载保存的检查点文件,并将模型和优化器的状态恢复到之前保存的状态。以下是加载并恢复训练状态的示例代码:

  1. # 重新定义模型、优化器和损失函数
  2. model = SimpleModel()
  3. optimizer = optim.SGD(model.parameters(), lr=0.01)
  4. criterion = nn.MSELoss()
  5. # 加载检查点
  6. checkpoint = torch.load('checkpoint.pth')
  7. # 恢复模型和优化器的状态
  8. model.load_state_dict(checkpoint['model_state_dict'])
  9. optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
  10. start_epoch = checkpoint['epoch']
  11. # 继续训练
  12. remaining_epochs = 10
  13. for epoch in range(start_epoch, start_epoch + remaining_epochs):
  14. # 这里省略具体的训练步骤
  15. print(f'Epoch {epoch + 1}/{start_epoch + remaining_epochs}')

在上述代码中,我们首先重新定义了模型、优化器和损失函数。然后,使用 torch.load() 函数加载保存的检查点文件。接着,我们使用 load_state_dict() 方法将模型和优化器的状态恢复到之前保存的状态,并获取之前训练的轮次 start_epoch。最后,我们继续进行训练,从之前保存的轮次开始,再训练 remaining_epochs 个轮次。

3. 注意事项

3.1 设备一致性

在加载模型时,需要注意设备的一致性。如果保存模型时使用的是 GPU,而加载时使用的是 CPU,需要进行相应的处理。可以在加载时指定 map_location 参数,将模型加载到指定的设备上。例如:

  1. # 加载到 CPU
  2. checkpoint = torch.load('checkpoint.pth', map_location=torch.device('cpu'))
  3. # 加载到 GPU
  4. device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
  5. checkpoint = torch.load('checkpoint.pth', map_location=device)

3.2 模型定义一致性

加载模型时,模型的定义必须与保存时的模型定义一致。如果模型的结构发生了变化,可能会导致加载失败。

4. 总结

操作 代码示例 说明
保存模型和训练状态 checkpoint = {'epoch': epochs, 'model_state_dict': model.state_dict(), 'optimizer_state_dict': optimizer.state_dict(), 'loss': criterion}; torch.save(checkpoint, 'checkpoint.pth') 将模型的参数、优化器的状态和训练轮次等信息保存到文件中
加载模型和训练状态 checkpoint = torch.load('checkpoint.pth'); model.load_state_dict(checkpoint['model_state_dict']); optimizer.load_state_dict(checkpoint['optimizer_state_dict']); start_epoch = checkpoint['epoch'] 从文件中加载模型和优化器的状态,并恢复训练轮次
处理设备一致性 checkpoint = torch.load('checkpoint.pth', map_location=torch.device('cpu'))device = torch.device("cuda" if torch.cuda.is_available() else "cpu"); checkpoint = torch.load('checkpoint.pth', map_location=device) 确保模型加载到指定的设备上

通过以上步骤,我们可以在 PyTorch 中方便地保存和加载模型的状态,实现训练过程的中断和恢复。这不仅提高了训练的灵活性,还能有效地利用计算资源。