📜  Pytorch 中的数据集和数据加载器(1)

📅  最后修改于: 2023-12-03 15:34:33.010000             🧑  作者: Mango

Pytorch 中的数据集和数据加载器

在 Pytorch 中,我们可以使用内置的数据集和数据加载器来加载内置数据集或自己的数据集。数据集用于存储和访问数据,数据加载器用于将数据集中的数据加载到模型中进行训练或测试。

数据集

Pytorch 中的数据集是一个抽象类,用于定义和读取数据集。如果要使用 Pytorch 提供的数据集,可以直接使用 Pytorch 提供的数据集类。如果要使用自己的数据集,则需要继承 Pytorch 的数据集类,并重写 __len____getitem__ 方法。

以下是一个使用 Pytorch 自带 MNIST 数据集的例子:

import torch
from torchvision.datasets import MNIST
from torchvision import transforms

transform = transforms.Compose([transforms.ToTensor(), 
                                transforms.Normalize((0.1307,), (0.3081,))])
                                
train_set = MNIST('/data/', train=True, download=True, transform=transform)
test_set = MNIST('/data/', train=False, download=True, transform=transform)

在上述例子中,我们使用了 MNIST 类来定义 MNIST 数据集,transform 参数用于对数据进行预处理,train=Truetrain=False 分别表示使用训练集和测试集。__len__ 方法用于返回数据集的大小,__getitem__ 方法用于返回数据集中的某个数据。

如果要使用自己的数据集,可以继承 Pytorch 中的 Dataset 类,并在其中实现 __len____getitem__ 方法:

import torch
from torch.utils.data import Dataset

class MyDataset(Dataset):
    def __init__(self, data_path):
        self.data = # 加载数据集

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        return self.data[idx]
数据加载器

数据加载器是将数据集中的数据加载到模型中进行训练或测试的工具。在 Pytorch 中,我们可以使用内置的 DataLoader 类来构建数据加载器。数据加载器可以实现批量加载、乱序加载、多进程加载等功能。

以下是一个使用 MNIST 数据集构建数据加载器的例子:

import torch
from torch.utils.data import DataLoader

train_loader = DataLoader(train_set, batch_size=64, shuffle=True)
test_loader = DataLoader(test_set, batch_size=64, shuffle=True)

在上述例子中,我们使用 DataLoader 类来构建两个数据加载器,train_settest_set 分别表示训练集和测试集,batch_size 参数定义了批量大小,shuffle 参数用于对数据进行乱序加载。

在训练模型时,我们可以使用 DataLoader 加载器加载数据,例如:

for idx, (data, label) in enumerate(train_loader):
    # 训练模型

在上述例子中,我们使用 DataLoader 加载器加载数据,并进行模型训练。enumerate 函数用于获取迭代的索引和数据。

总结

Pytorch 中的数据集和数据加载器是进行深度学习的重要工具。使用内置的数据集和数据加载器可以快速地构建深度学习模型并进行训练。自定义数据集和数据加载器可以适应更加复杂的数据集和数据加载需求。