📜  pytorch 加载模型 - Python (1)

📅  最后修改于: 2023-12-03 15:04:42.774000             🧑  作者: Mango

PyTorch 加载模型

简介

在机器学习中,训练模型只是整个过程的一部分,另一个重要的部分是将已经训练好的模型加载到代码中以进行预测或继续训练。在 PyTorch 中,加载模型非常简单,本文将向您展示如何加载训练好的 PyTorch 模型。

步骤
1. 安装 PyTorch

要使用 PyTorch 加载模型,首先需要确保已安装 PyTorch。可以通过以下命令安装 PyTorch:

pip install torch torchvision
2. 定义模型结构

假设您已经在 PyTorch 中训练了一个模型。在加载模型之前,需要定义模型结构。例如,以下是一个简单的神经网络:

import torch.nn as nn

class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.fc1 = nn.Linear(10, 5)
        self.fc2 = nn.Linear(5, 1)
    
    def forward(self, x):
        x = nn.functional.relu(self.fc1(x))
        x = self.fc2(x)
        return x

model = Net()
3. 保存训练好的模型

在训练模型完成后,需要将模型保存为文件。可以通过以下方式将模型保存起来:

torch.save(model.state_dict(), "model.pth")
4. 加载模型

要加载模型,只需使用 torch.load() 函数并传递模型文件的路径即可。例如,以下代码将加载名为 model.pth 的模型:

model = Net()
model.load_state_dict(torch.load("model.pth"))

完成后,您可以使用模型进行预测或继续训练。

结论

本文向您介绍了如何加载训练好的 PyTorch 模型。通过了解这些知识,您可以更好地处理机器学习的预测或继续训练任务。