📜  PyTorch 中的张量运算

📅  最后修改于: 2022-05-13 01:55:09.156000             🧑  作者: Mango

PyTorch 中的张量运算

在本文中,我们将讨论 PyTorch 中的张量运算。

PyTorch 是一个科学包,用于对给定数据(如Python中的张量)执行操作。张量是数据的集合,如 numpy 数组。我们可以使用张量函数创建张量:

PyTorch 中应用于张量的操作是:

扩张()

该操作用于将张量展开为张量数、张量中的行数和张量中的列数。

示例:在本例中,我们将张量展开为 4 个张量,每个张量中 2 行 3 列

Python3
# import module
import torch
  
# create a tensor with 2 data
# in 3 three elements each
data = torch.tensor([[10, 20, 30], 
                     [45, 67, 89]])
  
# display
print(data)
  
# expand the tensor into 4 tensors , 2
# rows and 3 columns in each tensor
print(data.expand(4, 2, 3))


Python3
# import module
import torch
  
# create a tensor with 2 data
# in 3 three elements each
data = torch.tensor([[[10, 20, 30], 
                      [45, 67, 89]]])
  
# display
print(data)
  
# permute the tensor first by row
print(data.permute(1, 2, 0))
  
# permute the tensor first by column
print(data.permute(2, 1, 0))


Python3
# import module
import torch
  
# create a tensor with 2 data in
# 3 three elements each
data = torch.tensor([[[10, 20, 30], 
                      [45, 67, 89]]])
  
# display
print(data)
  
# convert the tensor to list
print(data.tolist())


Python3
# import module
import torch
  
# create a tensor with 2 data in
# 3 three elements each
data = torch.tensor([[10, 20, 30], 
                     [45, 67, 89], 
                     [23, 45, 67]])
  
# display
print(data)
  
# narrow the tensor
# with 1 dimension
# starting from 1 st index
# length of each dimension is 2
print(torch.narrow(data, 1, 1, 2))
  
# narrow the tensor
# with 1 dimension
# starting from 0 th  index
# length of each dimension is 2
print(torch.narrow(data, 1, 0, 2))


Python3
# import module
import torch
  
# create a tensor with 3 data in
# 3 three elements each
data = torch.tensor([[[10, 20, 30], 
                      [45, 67, 89],
                      [23, 45, 67]]])
  
# display
print(data)
  
# set the number 100 when the
# number in greater than 45
# otherwise 50
print(torch.where(data > 45, 100, 50))
  
# set the number 100 when the
# number in less than 45
# otherwise 50
print(torch.where(data < 45, 100, 50))
  
# set the number 100 when the number in 
# equal to 23 otherwise 50
print(torch.where(data == 23, 100, 50))


输出:

tensor([[10, 20, 30],
        [45, 67, 89]])
tensor([[[10, 20, 30],
         [45, 67, 89]],

        [[10, 20, 30],
         [45, 67, 89]],

        [[10, 20, 30],
         [45, 67, 89]],

        [[10, 20, 30],
         [45, 67, 89]]])

置换()

这用于使用行和列重新排序张量

示例:在此示例中,我们将首先按行和按列排列张量。

Python3

# import module
import torch
  
# create a tensor with 2 data
# in 3 three elements each
data = torch.tensor([[[10, 20, 30], 
                      [45, 67, 89]]])
  
# display
print(data)
  
# permute the tensor first by row
print(data.permute(1, 2, 0))
  
# permute the tensor first by column
print(data.permute(2, 1, 0))

输出:

tensor([[[10, 20, 30],
         [45, 67, 89]]])
tensor([[[10],
         [20],
         [30]],

        [[45],
         [67],
         [89]]])
tensor([[[10],
         [45]],

        [[20],
         [67]],

        [[30],
         [89]]])

列表()

此方法用于从给定张量返回列表或嵌套列表。

示例:在此示例中,我们将把给定的张量转换为列表。

Python3

# import module
import torch
  
# create a tensor with 2 data in
# 3 three elements each
data = torch.tensor([[[10, 20, 30], 
                      [45, 67, 89]]])
  
# display
print(data)
  
# convert the tensor to list
print(data.tolist())

输出:

tensor([[[10, 20, 30],
         [45, 67, 89]]])
[[[10, 20, 30], [45, 67, 89]]]

狭窄的()

此函数用于缩小张量。换句话说,它将根据输入维度扩展张量。

示例:在本例中,我们将从第 1 个索引开始的 1 维张量进行缩小,每个维度的长度为 2,我们将从第 0索引和长度开始的 1 维张量进行缩小每个维度为 2

Python3

# import module
import torch
  
# create a tensor with 2 data in
# 3 three elements each
data = torch.tensor([[10, 20, 30], 
                     [45, 67, 89], 
                     [23, 45, 67]])
  
# display
print(data)
  
# narrow the tensor
# with 1 dimension
# starting from 1 st index
# length of each dimension is 2
print(torch.narrow(data, 1, 1, 2))
  
# narrow the tensor
# with 1 dimension
# starting from 0 th  index
# length of each dimension is 2
print(torch.narrow(data, 1, 0, 2))

输出:

tensor([[10, 20, 30],
        [45, 67, 89],
        [23, 45, 67]])
tensor([[20, 30],
        [67, 89],
        [45, 67]])
tensor([[10, 20],
        [45, 67],
        [23, 45]])

在哪里()

此函数用于通过有条件地检查现有张量来返回新张量。

示例:我们将使用不同的关系运算符来检查功能

Python3

# import module
import torch
  
# create a tensor with 3 data in
# 3 three elements each
data = torch.tensor([[[10, 20, 30], 
                      [45, 67, 89],
                      [23, 45, 67]]])
  
# display
print(data)
  
# set the number 100 when the
# number in greater than 45
# otherwise 50
print(torch.where(data > 45, 100, 50))
  
# set the number 100 when the
# number in less than 45
# otherwise 50
print(torch.where(data < 45, 100, 50))
  
# set the number 100 when the number in 
# equal to 23 otherwise 50
print(torch.where(data == 23, 100, 50))

输出:

tensor([[[10, 20, 30],
         [45, 67, 89],
         [23, 45, 67]]])
tensor([[[ 50,  50,  50],
         [ 50, 100, 100],
         [ 50,  50, 100]]])
tensor([[[100, 100, 100],
         [ 50,  50,  50],
         [100,  50,  50]]])
tensor([[[ 50,  50,  50],
         [ 50,  50,  50],
         [100,  50,  50]]])