📜  在 Pytorch 中加载数据

📅  最后修改于: 2022-05-13 01:54:44.536000             🧑  作者: Mango

在 Pytorch 中加载数据

在本文中,我们将讨论如何在 PyTorch 中加载不同类型的数据。

出于演示目的,Pytorch 附带了 3 个数据集部分,即 torchaudio、torchvision 和 torchtext。我们可以利用这些演示数据集来了解如何使用 Pytorch 加载声音、图像和文本数据。

Torchaudio 数据集

使用 Pytorch 在 torchaudio 中加载演示 yes_no 音频数据集。

Yes_No 数据集是一个音频波形数据集,它的值以 3 个值的元组形式存储,即波形、采样率、标签,其中波形表示音频信号,采样率表示频率,标签表示是或否。

  • 导入 torch 和 torchaudio 包。 (如有必要,使用 pip install torchaudio 安装)
  • 使用带有数据集访问器的 torchaudio函数,后跟数据集名称。
  • 现在,传递必须存储数据集的路径并指定 download = True 以下载数据集。这里,'./' 指定根目录。
  • 现在,使用 for 循环遍历加载的数据集,并访问存储在元组中的 3 个值以查看数据集的样本。

要加载您的自定义数据:

Python3
# import the torch and torchaudio dataset packages.
import torch
import torchaudio
  
# access the dataset in torchaudio package using
# datasets followed by dataset name.
# './' makes sure that the dataset is stored
# in a root directory.
# download = True ensures that the
# data gets downloaded
yesno_data = torchaudio.datasets.YESNO('./', 
                                       download=True)
  
# loading the first 5 data from yesno_data
for i in range(5):
    waveform, sample_rate, labels = yesno_data[i]
    print("Waveform: {}\nSample rate: {}\nLabels: {}".format(
        waveform, sample_rate, labels))


Python3
# import the torch and
# torchvision dataset packages.
import torch
import torchvision
  
# access the dataset in torchvision package using
# .datasets followed by dataset name.
imagenet_data = torchvision.datasets.ImageNet('path/to/imagenet_root/')


Python3
# import necessary function
# from torchvision package
from torchvision import transforms, datasets
import matplotlib.pyplot as plt
  
# specify the image dataset folder
data_dir = r'path to dataset\train'
  
# perform some transformations like resizing,
# centring and tensorconversion
# using transforms function
transform = transforms.Compose(
    [transforms.Resize(255),
     transforms.CenterCrop(224),
     transforms.ToTensor()])
  
# pass the image data folder and
# transform function to the datasets
# .imagefolder function
dataset = datasets.ImageFolder(data_dir, 
                               transform=transform)
  
# now use dataloder function load the
# dataset in the specified transformation.
dataloader = torch.utils.data.DataLoader(dataset,
                                         batch_size=32,
                                         shuffle=True)
  
# iter function iterates through all the
# images and labels and stores in two variables
images, labels = next(iter(dataloader))
  
# print the total no of samples
print('Number of samples: ', len(images))
image = images[2][0]  # load 3rd sample
  
# visualize the image
plt.imshow(image, cmap='gray')
  
# print the size of image
print("Image Size: ", image.size())
  
# print the label
print(label)


Python3
# import the torch and torchtext dataset packages.
import torch
import torchtext
  
# access the dataset in torchtext package
# using .datasets followed by dataset name.
text_data = torchtext.datasets.IMDB(split='train')
  
# define a function to tokenize
# the words in the corpus
def tokenize(label, line):
    return line.split()
  
  
# define a empty list to store
# the tokenized words
tokens = []
  
# iterate over the text_data and
# tokenize each line and store
# it in the list tokens
for label, line in text_data:
    tokens += tokenize(label, line)
  
print('The total no. of tokens in imdb dataset is',
      len(tokens))


输出:

Torchvision 数据集

使用 Pytorch 在 torchvision 中加载演示 ImageNet 视觉数据集。单击此处通过注册下载数据集。

Python3

# import the torch and
# torchvision dataset packages.
import torch
import torchvision
  
# access the dataset in torchvision package using
# .datasets followed by dataset name.
imagenet_data = torchvision.datasets.ImageNet('path/to/imagenet_root/')

代码说明:

  • 该过程与加载音频数据几乎相同。
  • 在这里,必须导入 torchvision 而不是 torchaudio。
  • 使用带有数据集访问器的 torchvision函数,后跟数据集名称。
  • 现在,传递数据集所在的路径。由于 ImageNet 数据集不再可公开访问,因此请在本地系统中下载根数据并将路径传递给此函数。这将轻松加载视觉数据。

要加载您的自定义图像数据,请使用上面提到的 torch.utils.data.DataLoader(data, batch_size, shuffle)。

Python3

# import necessary function
# from torchvision package
from torchvision import transforms, datasets
import matplotlib.pyplot as plt
  
# specify the image dataset folder
data_dir = r'path to dataset\train'
  
# perform some transformations like resizing,
# centring and tensorconversion
# using transforms function
transform = transforms.Compose(
    [transforms.Resize(255),
     transforms.CenterCrop(224),
     transforms.ToTensor()])
  
# pass the image data folder and
# transform function to the datasets
# .imagefolder function
dataset = datasets.ImageFolder(data_dir, 
                               transform=transform)
  
# now use dataloder function load the
# dataset in the specified transformation.
dataloader = torch.utils.data.DataLoader(dataset,
                                         batch_size=32,
                                         shuffle=True)
  
# iter function iterates through all the
# images and labels and stores in two variables
images, labels = next(iter(dataloader))
  
# print the total no of samples
print('Number of samples: ', len(images))
image = images[2][0]  # load 3rd sample
  
# visualize the image
plt.imshow(image, cmap='gray')
  
# print the size of image
print("Image Size: ", image.size())
  
# print the label
print(label)

输出:

Image size: torch.Size([224,224])
tensor([0, 0, 0, 1, 1, 1])

Torchtext 数据集

使用 Pytorch 在 torchtext 中加载演示 IMDB 文本数据集。要加载您的自定义文本数据,我们使用 torch.utils.data.DataLoader() 方法。

代码说明:

  • 该过程与加载图像和音频数据几乎相同。
  • 在这里,必须导入torchtext 而不是torchvision。
  • 将 torchtext函数与数据集访问器一起使用,后跟数据集名称 (IMDB)。
  • 现在,将 split函数传递给 torchtext函数以拆分数据集以训练和测试数据。
  • 现在定义一个函数,通过迭代语料库中的每一行,将语料库中的每一行拆分为单独的标记,如图所示。这样,我们就可以轻松地使用 Pytorch 加载文本数据。

Python3

# import the torch and torchtext dataset packages.
import torch
import torchtext
  
# access the dataset in torchtext package
# using .datasets followed by dataset name.
text_data = torchtext.datasets.IMDB(split='train')
  
# define a function to tokenize
# the words in the corpus
def tokenize(label, line):
    return line.split()
  
  
# define a empty list to store
# the tokenized words
tokens = []
  
# iterate over the text_data and
# tokenize each line and store
# it in the list tokens
for label, line in text_data:
    tokens += tokenize(label, line)
  
print('The total no. of tokens in imdb dataset is',
      len(tokens))

输出: