📜  在 pytorch 中保存模型 (1)

📅  最后修改于: 2023-12-03 15:37:27.700000             🧑  作者: Mango

在PyTorch中保存模型

在机器学习项目中,训练模型是一个大工程。然而,成功地训练好一个模型只是完成机器学习项目的一小部分。将训练好的模型部署到生产环境中也是非常重要的。在PyTorch中,我们可以保存训练好的模型参数和相关的信息,以便在需要时重新载入模型。以下是使用PyTorch保存模型的方法。

保存模型

使用PyTorch保存模型的最简单的方法是将模型的参数保存到文件中。我们可以使用torch.save()函数将模型的参数保存到文件中。以下是将模型参数保存到文件的示例代码。

import torch

# 定义模型
model = torch.nn.Sequential(
    torch.nn.Linear(10, 100),
    torch.nn.ReLU(),
    torch.nn.Linear(100, 1),
)

# 保存模型参数
torch.save(model.state_dict(), 'model.pth')

在此示例中,我们定义了一个包含两个线性层和一个ReLU激活函数的神经网络模型。我们使用torch.save()函数将模型的参数保存到了一个名为model.pth的文件中。这个文件可以通过反序列化重建模型,或者在另一个项目中使用。

加载模型

我们可以使用torch.load()函数从文件中加载模型参数。以下是从文件中加载模型参数的示例代码。

import torch

# 定义模型
model = torch.nn.Sequential(
    torch.nn.Linear(10, 100),
    torch.nn.ReLU(),
    torch.nn.Linear(100, 1),
)

# 保存模型参数
torch.save(model.state_dict(), 'model.pth')

# 加载模型参数
model.load_state_dict(torch.load('model.pth'))

在此示例中,我们首先定义了一个神经网络模型,然后使用torch.save()函数将模型参数保存到文件中。我们使用torch.load()函数从文件中加载模型参数,并使用model.load_state_dict()方法将模型参数加载到模型中。

保存完整模型

有时候,我们可能需要将整个模型,包括结构和参数信息都保存下来,以便以后能够完全重建模型。在PyTorch中,我们可以使用torch.save()函数将整个模型保存到文件中。以下是将整个模型保存到文件的示例代码。

import torch

# 定义模型
model = torch.nn.Sequential(
    torch.nn.Linear(10, 100),
    torch.nn.ReLU(),
    torch.nn.Linear(100, 1),
)

# 保存整个模型
torch.save(model, 'model.pth')

在此示例中,我们使用torch.save()函数将整个模型保存到了名为model.pth的文件中。这个文件包含了模型结构以及参数信息。我们可以使用torch.load()函数将保存的模型加载到内存中。

加载完整模型

我们可以使用torch.load()函数从文件中加载整个模型。以下是从文件中加载整个模型的示例代码。

import torch

# 加载整个模型
model = torch.load('model.pth')

在此示例中,我们使用torch.load()函数从文件中加载整个模型。由于文件中包含了模型的结构和参数信息,我们可以直接将加载的模型用于推理或进一步的训练。

总结

在PyTorch中,我们可以使用torch.save()torch.load()函数保存和加载模型。我们可以选择只保存模型参数或整个模型。保存模型的目的是在需要时可以重新载入模型进行推理或训练。这些函数为模型的开发和部署提供了极大的便利。