📜  pytorch 张量更改维度顺序 - Python (1)

📅  最后修改于: 2023-12-03 15:34:33.052000             🧑  作者: Mango

PyTorch 张量更改维度顺序 - Python

在PyTorch中,通过张量(Tensor)可以表示各种形状的多维数组。有时候需要更改张量的维度顺序,以便操作数据或将张量用于特定任务时更方便。

本文将介绍如何使用PyTorch API更改张量的维度顺序。

转置

转置是一种常见的操作,可以更改张量的维度顺序。在PyTorch中,使用torch.transpose()函数来执行转置操作,它的语法如下:

torch.transpose(input, dim0, dim1)

其中,

  • input:输入张量。
  • dim0:要交换的维度。
  • dim1:要交换的维度。

示例代码:

import torch

x = torch.randn(2, 3)
print("原始张量:", x)
print("转置张量:", torch.transpose(x, 0, 1))

输出结果:

原始张量: tensor([[ 1.6755, -0.7950, -0.1255],
        [-1.4309, -0.4506,  0.7618]])
转置张量: tensor([[ 1.6755, -1.4309],
        [-0.7950, -0.4506],
        [-0.1255,  0.7618]])
维度重塑

有时候需要改变张量的形状而不改变其元素数量。在PyTorch中,可以使用torch.reshape()函数来执行这个操作,它的语法如下:

torch.reshape(input, shape)

其中,

  • input:输入张量。
  • shape:新的张量形状。

示例代码:

import torch

x = torch.randn(2, 3, 4)
print("原始张量:", x)
print("重塑张量:", torch.reshape(x, (3, 8)))

输出结果:

原始张量: tensor([[[-1.4154,  0.6704,  0.7255, -0.5188],
         [-1.4339, -0.9647, -0.9109, -1.4621],
         [ 0.2178,  0.4743,  0.0822, -0.1233]],

        [[-1.2858, -0.0610,  0.0333,  1.1043],
         [ 0.6891,  0.0268, -0.6805, -0.1237],
         [-0.2461,  1.3932, -0.0678,  1.6401]]])
重塑张量: tensor([[-1.4154,  0.6704,  0.7255, -0.5188, -1.4339, -0.9647, -0.9109, -1.4621],
        [ 0.2178,  0.4743,  0.0822, -0.1233, -1.2858, -0.0610,  0.0333,  1.1043],
        [ 0.6891,  0.0268, -0.6805, -0.1237, -0.2461,  1.3932, -0.0678,  1.6401]])
视图重塑

视图重塑不会改变张量的元素数量,但可以改变张量的维度顺序。在PyTorch中,可以使用torch.view()函数来执行这个操作,它的语法如下:

torch.view(input, shape)

其中,

  • input:输入张量。
  • shape:新的张量形状。

示例代码:

import torch

x = torch.randn(2, 3, 4)
print("原始张量:", x)
print("视图重塑张量:", x.view(-1, 2, 6))

输出结果:

原始张量: tensor([[[-1.3247,  0.1183, -0.9083,  0.0192],
         [-1.5420, -0.5240,  0.5576, -0.2223],
         [-0.9068,  1.1326,  0.3798,  0.4644]],

        [[ 0.0894,  0.0400, -1.1517, -0.4748],
         [-0.6091,  0.5464,  0.1776, -0.0931],
         [-0.1285, -0.8465,  0.0268,  0.0973]]])
视图重塑张量: tensor([[[-1.3247,  0.1183, -0.9083,  0.0192, -1.5420, -0.5240],
         [ 0.5576, -0.2223, -0.9068,  1.1326,  0.3798,  0.4644]],

        [[ 0.0894,  0.0400, -1.1517, -0.4748, -0.6091,  0.5464],
         [ 0.1776, -0.0931, -0.1285, -0.8465,  0.0268,  0.0973]]])
permute

permute函数在PyTorch中用于交换张量的维度顺序。与transpose函数不同,permute函数可以一次交换多个维度。它的语法如下:

torch.permute(input, *dims)

其中,

  • input:输入张量。
  • *dims:要交换的维度的序列。

示例代码:

import torch

x = torch.randn(2, 3, 4)
print("原始张量:", x)
print("permute张量:", torch.permute(x, (0, 2, 1)))

输出结果:

原始张量: tensor([[[ 0.3501, -1.6539, -0.3327, -0.9970],
         [-0.5971, -1.5002, -0.9659, -0.7533],
         [-2.0233,  0.1796, -1.2937, -0.5366]],

        [[ 1.0905, -1.3325, -0.0355,  0.8574],
         [-1.7158,  0.8785, -1.8715, -0.2025],
         [-0.9074,  2.0618,  1.1326,  0.4023]]])
permute张量: tensor([[[ 0.3501, -0.5971, -2.0233],
         [-1.6539, -1.5002,  0.1796],
         [-0.3327, -0.9659, -1.2937],
         [-0.9970, -0.7533, -0.5366]],

        [[ 1.0905, -1.7158, -0.9074],
         [-1.3325,  0.8785,  2.0618],
         [-0.0355, -1.8715,  1.1326],
         [ 0.8574, -0.2025,  0.4023]]])

以上是PyTorch中张量更改维度顺序的介绍。使用这些函数,可以方便地操作和处理多维数组数据。