📜  PyTorch 训练(1)

📅  最后修改于: 2023-12-03 15:19:37.127000             🧑  作者: Mango

PyTorch 训练指南

PyTorch 是由 Facebook 开发的深度学习框架,它提供了易于使用的张量操作符,以及自动求导功能。PyTorch 的设计旨在让开发者能够轻松地构建和训练神经网络。

在这篇文章中,我将向你介绍如何使用 PyTorch 来训练你的深度学习模型。我们将涵盖以下主题:

  • 准备数据
  • 构建模型
  • 定义损失函数
  • 训练模型
  • 保存模型
  • 加载模型并预测
准备数据

在 PyTorch 中,数据通常被表示为张量(tensor)。有两种方式可以为模型准备数据:使用现成的数据集(如 MNIST、CIFAR-10),或者自己创建数据集。

加载现成的数据集

对于现成的数据集,PyTorch 提供了便捷的数据加载器,可自动下载并预处理数据集。以下是使用 DataLoader 加载 MNIST 数据集的示例代码:

import torch
from torchvision.datasets import MNIST
from torch.utils.data import DataLoader

train_dataset = MNIST(root='data', train=True, download=True, transform=transforms.ToTensor())
test_dataset = MNIST(root='data', train=False, download=True, transform=transforms.ToTensor())

train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=32, shuffle=False)
创建自己的数据集

如果需要自己创建数据集,则需要创建一个自定义数据集类,继承 PyTorch 中的 Dataset 类,并实现 lengetitem 方法。以下是创建自定义数据集的示例代码:

import torch
from torch.utils.data import Dataset

class CustomDataset(Dataset):
    def __init__(self, data):
        self.data = data

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        item = self.data[idx]
        return item["input"], item["target"]

data = [{"input": torch.randn(3, 224, 224), "target": torch.tensor(0)}, 
        {"input": torch.randn(3, 224, 224), "target": torch.tensor(1)}]

dataset = CustomDataset(data)
构建模型

PyTorch 允许你以类的形式构建神经网络模型,模型类需要继承 PyTorch 中的 nn.Module 类,并通过实现 forward 方法来定义模型的计算流程。以下是构建简单的卷积神经网络模型的示例代码:

import torch.nn as nn
import torch.nn.functional as F

class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.conv1 = nn.Conv2d(3, 6, 5)
        self.pool = nn.MaxPool2d(2, 2)
        self.conv2 = nn.Conv2d(6, 16, 5)
        self.fc1 = nn.Linear(16 * 5 * 5, 120)
        self.fc2 = nn.Linear(120, 84)
        self.fc3 = nn.Linear(84, 10)

    def forward(self, x):
        x = self.pool(F.relu(self.conv1(x)))
        x = self.pool(F.relu(self.conv2(x)))
        x = x.view(-1, 16 * 5 * 5)
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x

model = Net()
定义损失函数

在 PyTorch 中,常见的损失函数包括交叉熵损失和均方差损失。以下是定义交叉熵损失函数的示例代码:

import torch.nn as nn

criterion = nn.CrossEntropyLoss()
训练模型

在训练过程中,我们需要将数据通过模型进行前向传播,计算损失函数值并反向传播,更新模型参数。以下是训练模型的示例代码:

import torch.optim as optim

optimizer = optim.SGD(model.parameters(), lr=0.001, momentum=0.9)

for epoch in range(10):
    running_loss = 0.0
    for i, data in enumerate(train_loader, 0):
        inputs, labels = data
        optimizer.zero_grad()
        outputs = model(inputs)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()
        running_loss += loss.item()
        if i % 100 == 99:
            print('[%d, %5d] loss: %.3f' %
                  (epoch + 1, i + 1, running_loss / 100))
            running_loss = 0.0
保存模型

在训练完成后,我们需要将模型保存到磁盘上。以下是保存模型的示例代码:

torch.save(model.state_dict(), 'model.pt')
加载模型并预测

当我们需要使用已经训练好的模型进行预测时,可以使用以下代码加载模型:

model = Net()
model.load_state_dict(torch.load('model.pt'))
model.eval()

其中,eval 方法将模型设置为评估模式,以便在预测时能够准确计算。

以上就是使用 PyTorch 训练深度学习模型的完整指南。希望这篇文章对你有所帮助!