📜  Pytorch——基于索引的操作(1)

📅  最后修改于: 2023-12-03 15:04:42.888000             🧑  作者: Mango

Pytorch——基于索引的操作

在Pytorch中,可以使用索引对Tensor进行设置或取出部分数据。索引操作方法类似于NumPy中的数组索引,但有一些Pytorch独有的细节需要注意。

1. 基本索引方法

首先我们创建一个5x5的Tensor:

import torch

x = torch.arange(25).reshape(5, 5)
print(x)
# >>> tensor([[ 0,  1,  2,  3,  4],
#             [ 5,  6,  7,  8,  9],
#             [10, 11, 12, 13, 14],
#             [15, 16, 17, 18, 19],
#             [20, 21, 22, 23, 24]])

其中第i行第j列的元素可以通过x[i, j]进行访问,例如:

print(x[0, 1])  # >>> 1

除了使用单个索引,也可以使用范围索引:

print(x[:, 1:3])  # >>> tensor([[ 1,  2],
                  #             [ 6,  7],
                  #             [11, 12],
                  #             [16, 17],
                  #             [21, 22]])

除此之外,Pytorch还支持使用掩码数组进行索引,例如:

mask = x > 10
print(mask)
# >>> tensor([[False, False, False, False, False],
#             [False, False, False, False, False],
#             [False,  True,  True,  True,  True],
#             [ True,  True,  True,  True,  True],
#             [ True,  True,  True,  True,  True]])

print(x[mask])  # >>> tensor([11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24])
2. 数据类型索引

在Pytorch中,采取类似于NumPy的方式对数据类型进行索引。例如:

x = torch.tensor([1, 2, 3], dtype=torch.float)
print(x)
# >>> tensor([1., 2., 3.])

print(x[0].item())
# >>> 1.0
3. tensor类型索引

可以使用LongTensor类型的索引数组对tensor进行索引。例如:

x = torch.randn(3, 4)
print(x)
# >>> tensor([[-0.5474, -0.7157, -0.0070, -1.6918],
#             [ 2.3551, -0.5569,  0.9610,  0.3642],
#             [ 0.1221, -0.6849, -0.4440, -0.3760]])

indices = torch.tensor([0, 2])
print(x[indices])
# >>> tensor([[-0.5474, -0.7157, -0.0070, -1.6918],
#             [ 0.1221, -0.6849, -0.4440, -0.3760]])
4. scatter_操作

scatter_操作是一个重要、灵活的操作,该操作能够根据索引将一些值写入Tensor,例如:

x = torch.zeros(2, 4)
indices = torch.tensor([[0, 1, 1], [1, 0, 1]])
values = torch.tensor([1.0, 2.0, 3.0])
x.scatter_(1, indices, values)
print(x)
# >>> tensor([[1., 3., 0., 0.],
#             [2., 1., 3., 0.]])

上述代码中,我们将值为[1.0, 2.0, 3.0]分别写入x的第0行第1列、第1行第0列和第1行第1列。

5. gather操作

gather操作是另一个非常有用的操作,它与scatter_操作相对应,可以根据索引提取Tensor中的一些值,例如:

x = torch.tensor([[1.0, 2.0], [3.0, 4.0]])
indices = torch.tensor([1, 0])
y = torch.gather(x, 1, indices.unsqueeze(0).t())
print(y)
# >>> tensor([[2.],
#             [3.]])

上述代码中,我们将x的第0列和第1列互换,并提取出第0行和第1行,最终得到一个列向量。