📅  最后修改于: 2023-12-03 15:34:33.052000             🧑  作者: Mango
在PyTorch中,通过张量(Tensor)可以表示各种形状的多维数组。有时候需要更改张量的维度顺序,以便操作数据或将张量用于特定任务时更方便。
本文将介绍如何使用PyTorch API更改张量的维度顺序。
转置是一种常见的操作,可以更改张量的维度顺序。在PyTorch中,使用torch.transpose()
函数来执行转置操作,它的语法如下:
torch.transpose(input, dim0, dim1)
其中,
input
:输入张量。dim0
:要交换的维度。dim1
:要交换的维度。示例代码:
import torch
x = torch.randn(2, 3)
print("原始张量:", x)
print("转置张量:", torch.transpose(x, 0, 1))
输出结果:
原始张量: tensor([[ 1.6755, -0.7950, -0.1255],
[-1.4309, -0.4506, 0.7618]])
转置张量: tensor([[ 1.6755, -1.4309],
[-0.7950, -0.4506],
[-0.1255, 0.7618]])
有时候需要改变张量的形状而不改变其元素数量。在PyTorch中,可以使用torch.reshape()
函数来执行这个操作,它的语法如下:
torch.reshape(input, shape)
其中,
input
:输入张量。shape
:新的张量形状。示例代码:
import torch
x = torch.randn(2, 3, 4)
print("原始张量:", x)
print("重塑张量:", torch.reshape(x, (3, 8)))
输出结果:
原始张量: tensor([[[-1.4154, 0.6704, 0.7255, -0.5188],
[-1.4339, -0.9647, -0.9109, -1.4621],
[ 0.2178, 0.4743, 0.0822, -0.1233]],
[[-1.2858, -0.0610, 0.0333, 1.1043],
[ 0.6891, 0.0268, -0.6805, -0.1237],
[-0.2461, 1.3932, -0.0678, 1.6401]]])
重塑张量: tensor([[-1.4154, 0.6704, 0.7255, -0.5188, -1.4339, -0.9647, -0.9109, -1.4621],
[ 0.2178, 0.4743, 0.0822, -0.1233, -1.2858, -0.0610, 0.0333, 1.1043],
[ 0.6891, 0.0268, -0.6805, -0.1237, -0.2461, 1.3932, -0.0678, 1.6401]])
视图重塑不会改变张量的元素数量,但可以改变张量的维度顺序。在PyTorch中,可以使用torch.view()
函数来执行这个操作,它的语法如下:
torch.view(input, shape)
其中,
input
:输入张量。shape
:新的张量形状。示例代码:
import torch
x = torch.randn(2, 3, 4)
print("原始张量:", x)
print("视图重塑张量:", x.view(-1, 2, 6))
输出结果:
原始张量: tensor([[[-1.3247, 0.1183, -0.9083, 0.0192],
[-1.5420, -0.5240, 0.5576, -0.2223],
[-0.9068, 1.1326, 0.3798, 0.4644]],
[[ 0.0894, 0.0400, -1.1517, -0.4748],
[-0.6091, 0.5464, 0.1776, -0.0931],
[-0.1285, -0.8465, 0.0268, 0.0973]]])
视图重塑张量: tensor([[[-1.3247, 0.1183, -0.9083, 0.0192, -1.5420, -0.5240],
[ 0.5576, -0.2223, -0.9068, 1.1326, 0.3798, 0.4644]],
[[ 0.0894, 0.0400, -1.1517, -0.4748, -0.6091, 0.5464],
[ 0.1776, -0.0931, -0.1285, -0.8465, 0.0268, 0.0973]]])
permute
函数在PyTorch中用于交换张量的维度顺序。与transpose
函数不同,permute
函数可以一次交换多个维度。它的语法如下:
torch.permute(input, *dims)
其中,
input
:输入张量。*dims
:要交换的维度的序列。示例代码:
import torch
x = torch.randn(2, 3, 4)
print("原始张量:", x)
print("permute张量:", torch.permute(x, (0, 2, 1)))
输出结果:
原始张量: tensor([[[ 0.3501, -1.6539, -0.3327, -0.9970],
[-0.5971, -1.5002, -0.9659, -0.7533],
[-2.0233, 0.1796, -1.2937, -0.5366]],
[[ 1.0905, -1.3325, -0.0355, 0.8574],
[-1.7158, 0.8785, -1.8715, -0.2025],
[-0.9074, 2.0618, 1.1326, 0.4023]]])
permute张量: tensor([[[ 0.3501, -0.5971, -2.0233],
[-1.6539, -1.5002, 0.1796],
[-0.3327, -0.9659, -1.2937],
[-0.9970, -0.7533, -0.5366]],
[[ 1.0905, -1.7158, -0.9074],
[-1.3325, 0.8785, 2.0618],
[-0.0355, -1.8715, 1.1326],
[ 0.8574, -0.2025, 0.4023]]])
以上是PyTorch中张量更改维度顺序的介绍。使用这些函数,可以方便地操作和处理多维数组数据。