📌  相关文章
📜  如何在 PyTorch 中对张量的元素进行排序?(1)

📅  最后修改于: 2023-12-03 15:08:46.098000             🧑  作者: Mango

如何在 PyTorch 中对张量的元素进行排序?

在 PyTorch 中,我们可以使用 torch.sort 方法对张量元素进行排序。该方法可以对张量进行升序或降序排序,并返回排序后的结果和对应的索引。

调用方法
sorted_tensor, indices = torch.sort(input, dim=None, descending=False, stable=False)

参数说明:

  • input:需要进行排序的张量。
  • dim:指定进行排序的维度,如果为 None,则默认对整个张量排序。
  • descending:是否采用降序排序,默认为 False
  • stable:当排序中有相等的元素时,是否保持它们在原序列中的相对位置关系不变,默认为 False
代码示例
import torch

# 生成一个大小为5x3的张量
x = torch.randn(5, 3)
print(f"x:\n{x}\n")

# 对整个张量进行升序排序
sorted_tensor, indices = torch.sort(x)
print(f"sorted_tensor:\n{sorted_tensor}\n")
print(f"indices:\n{indices}\n")

# 对第0维度进行降序排序
sorted_tensor, indices = torch.sort(x, dim=0, descending=True)
print(f"sorted_tensor:\n{sorted_tensor}\n")
print(f"indices:\n{indices}\n")

输出结果:

x:
tensor([[ 0.6488, -0.6001, -1.7620],
        [-2.8995,  0.2005, -0.6887],
        [ 0.4121, -0.2223, -0.1305],
        [ 1.1422,  0.2251, -0.6016],
        [ 1.1444, -0.7302,  0.7982]])

sorted_tensor:
tensor([[-1.7620, -0.6001,  0.6488],
        [-2.8995, -0.6887,  0.2005],
        [-0.2223, -0.1305,  0.4121],
        [-0.6016,  0.2251,  1.1422],
        [-0.7302,  0.7982,  1.1444]])

indices:
tensor([[2, 1, 0],
        [0, 2, 1],
        [1, 2, 0],
        [2, 1, 0],
        [1, 2, 0]])

sorted_tensor:
tensor([[ 1.1444,  0.2251,  0.7982],
        [ 1.1422,  0.2005, -0.1305],
        [ 0.6488, -0.2223, -0.6016],
        [ 0.4121, -0.6001, -0.6887],
        [-2.8995, -0.7302, -1.7620]])

indices:
tensor([[4, 3, 4],
        [3, 1, 2],
        [0, 2, 3],
        [2, 0, 1],
        [1, 4, 0]])

从输出结果中可以看出,升序排序之后,结果以升序排列,且每一列的元素在原序列中的相对位置都不变;降序排序之后,结果以降序排列,且每一行的元素在原序列中的相对位置都不变。

以上就是在 PyTorch 中对张量元素进行排序的方法,希望对大家有所帮助。