📜  在Python中将图像转换为 Torch 张量

📅  最后修改于: 2022-05-13 01:55:46.271000             🧑  作者: Mango

在Python中将图像转换为 Torch 张量

在本文中,我们将了解如何将图像转换为 PyTorch 张量。 PyTorch 中的张量就像一个 NumPy 数组,包含相同 dtype 的元素。

张量可以是标量类型、一维或多维的。为了在 PyTorch 中将图像转换为张量,我们使用PILToTensor()ToTensor()转换。这些转换在torchvision.transforms包中提供。使用这些转换,我们可以转换 PIL 图像或numpy.ndarraynumpy.ndarray必须是 [H, W, C] 格式,其中 H、W 和 C 是图像的高度、宽度和通道数。

此转换将PIL 图像转换为数据类型为torch.uint8的张量,范围在0 到 255之间。这里的img是一个 PIL 图像。

此转换将任何numpy.ndarray转换为范围为 0 和 1的数据类型torch.float32的火炬张量。这里img是一个numpy.ndarray

方法:

  • 导入所需的库。
  • 读取输入图像。输入图像是 PIL 图像或 NumPy N 维数组。
  • 定义将图像转换为 Torch 张量的变换。我们使用transforms.Compose()定义一个变换。您可以直接使用transforms.PILToTensor()transforms.ToTensor()
  • 使用上面定义的变换将图像转换为张量。
  • 打印张量值。

下图在两个示例中都用作输入图像:

示例 1:

在下面的示例中,我们将 PIL 图像转换为 Torch 张量。

Python3
# Import necessary libraries
import torch
from PIL import Image
import torchvision.transforms as transforms
  
# Read a PIL image
image = Image.open('iceland.jpg')
  
# Define a transform to convert PIL 
# image to a Torch tensor
transform = transforms.Compose([
    transforms.PILToTensor()
])
  
# transform = transforms.PILToTensor()
# Convert the PIL image to Torch tensor
img_tensor = transform(image)
  
# print the converted Torch tensor
print(img_tensor)


Python3
# Import required libraries
import torch
import cv2
import torchvision.transforms as transforms
  
# Read the image
image = cv2.imread('iceland.jpg')
  
# Convert BGR image to RGB image
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
  
# Define a transform to convert
# the image to torch tensor
transform = transforms.Compose([
    transforms.ToTensor()
])
  
# Convert the image to Torch tensor
tensor = transform(image)
  
# print the converted image tensor
print(tensor)


输出:

请注意,输出张量的数据类型是torch.uint8并且值在[0,255]范围内。

示例 2:

在此示例中,我们使用OpenCV读取 RGB 图像。使用 OpenCV 读取的图像类型是numpy.ndarray 。我们使用变换ToTensor()将其转换为火炬张量。

Python3

# Import required libraries
import torch
import cv2
import torchvision.transforms as transforms
  
# Read the image
image = cv2.imread('iceland.jpg')
  
# Convert BGR image to RGB image
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
  
# Define a transform to convert
# the image to torch tensor
transform = transforms.Compose([
    transforms.ToTensor()
])
  
# Convert the image to Torch tensor
tensor = transform(image)
  
# print the converted image tensor
print(tensor)

输出:

请注意,输出张量的数据类型是torch.float32 ,值在[0, 1] 范围内。