📜  将 torch 转换为 numpy - Python (1)

📅  最后修改于: 2023-12-03 15:25:15.087000             🧑  作者: Mango

将 torch 转换为 numpy - Python

NumPy是Python的一个基础科学计算库,而PyTorch是一个专门为深度学习任务设计的机器学习库。但是,在某些情况下,我们可能需要从PyTorch张量转换为NumPy数组或从NumPy数组转换为PyTorch张量。

PyTorch到NumPy

我们可以使用PyTorch张量的numpy方法将其转换为NumPy数组。下面是一个示例:

import torch
import numpy as np

# 创建一个PyTorch张量
a = torch.tensor([1,2,3])
print("PyTorch张量:", a)

# 转换为NumPy数组
b = a.numpy()
print("NumPy数组:", b)

输出:

PyTorch张量: tensor([1, 2, 3])
NumPy数组: [1 2 3]

我们也可以在转换时保留张量的计算图。为此,我们可以使用torch.Tensor.detach方法:

import torch
import numpy as np

# 创建一个PyTorch张量
a = torch.tensor([1,2,3], requires_grad=True)
print("PyTorch张量:", a)

# 转换为NumPy数组且保留计算图
b = a.detach().numpy()
print("NumPy数组:", b)

输出:

PyTorch张量: tensor([1, 2, 3], requires_grad=True)
NumPy数组: [1 2 3]

注意,我们必须使用detach方法来分离张量的计算图,以便在转换后不干扰任何梯度计算。

NumPy到PyTorch

我们可以使用torch.from_numpy方法将NumPy数组转换回PyTorch张量:

import torch
import numpy as np

# 创建一个NumPy数组
a = np.array([1,2,3])
print("NumPy数组:", a)

# 转换为PyTorch张量
b = torch.from_numpy(a)
print("PyTorch张量:", b)

输出:

NumPy数组: [1 2 3]
PyTorch张量: tensor([1, 2, 3], dtype=torch.int32)

注意,torch.from_numpy方法返回的张量默认使用与NumPy数组相同的数据类型,可以使用dtype参数指定不同的数据类型。