📌  相关文章
📜  如何在 PyTorch 中对张量的元素进行排序?

📅  最后修改于: 2022-05-13 01:55:47.822000             🧑  作者: Mango

如何在 PyTorch 中对张量的元素进行排序?

在本文中,我们将了解如何在Python中对 PyTorch 张量的元素进行排序。

为了对 PyTorch 张量的元素进行排序,我们使用了 torch.sort()方法。当张量是二维的时,我们可以将元素与列或行一起排序。

示例 1:

在下面的示例中,我们按升序和降序对一维张量的元素进行排序。按升序或降序对张量进行排序。我们应用torch.sort()方法对输入张量的元素进行排序。要按降序排序,请将descending=True传递给该方法。

Python3
# importing required library
import torch
  
# defining a PyTorch Tensor
tensor = torch.tensor([-12, -23, 0.0, 32,
                       1.32, 201, 5.02])
print("Tensor:\n", tensor)
  
# sorting the tensor in ascending order
print("Sorting tensor in ascending order:")
values, indices = torch.sort(tensor)
  
# printing values of sorted tensor
print("Sorted values:\n", values)
  
# printing indices of sorted value
print("Indices:\n", indices)
  
# sorting the tensor in  descending order
print("Sorting tensor in  descending order:")
values, indices = torch.sort(tensor, descending=True)
  
# printing values of sorted tensor
print("Sorted values:\n", values)
  
# printing indices of sorted value
print("Indices:\n", indices)


Python3
# importing the library
import torch
  
# define a 2D torch tensor
tensor = torch.tensor([[43,31,-92],
                       [3,-4.3,53], 
                       [-4.2,7,-6.2]])
print("Tensor:\n", tensor)
  
# sorting the tensor in  ascending order
print("Sorting tensor in \
ascending order along the column:")
values, indices = torch.sort(tensor, dim = 0)
  
# printing  values in sorted tensor
print("Sorted values:\n", values)
  
# print indices of values in sorted tensor
print("Indices:\n", indices)
  
# sorting the tensor in  descending order
print("Sorting tensor in \
descending order along the column:")
values, indices = torch.sort(tensor, dim = 0,
                             descending=True)
  
# printing  values in sorted tensor
print("Sorted values:\n", values)
  
# print indices of values in sorted tensor
print("Indices:\n", indices)


Python3
# importing the library
import torch
  
# define a 2D torch tensor
tensor = torch.tensor([[43, 31, -92], 
                       [3, -4.3, 53], 
                       [-4.2, 7, -6.2]])
print("Tensor:\n", tensor)
  
# sorting the tensor in  ascending order
print("Sorting tensor in \
ascending order along the row:")
values, indices = torch.sort(tensor, dim=1)
  
print("Sorted values:\n", values)
  
# print indices of values in sorted tensor
print("Indices:\n", indices)
  
# sorting the tensor in  descending order
print("Sorting tensor in \
descending order along the row:")
values, indices = torch.sort(tensor,
                             dim=1,
                             descending=True)
  
# printing  values in sorted tensor
print("Sorted values:\n", values)
  
# printing indices of values in sorted tensor
print("Indices:\n", indices)


输出:

示例 2:

在此示例中,我们将二维张量的元素与列一起按升序和降序排序。

Python3

# importing the library
import torch
  
# define a 2D torch tensor
tensor = torch.tensor([[43,31,-92],
                       [3,-4.3,53], 
                       [-4.2,7,-6.2]])
print("Tensor:\n", tensor)
  
# sorting the tensor in  ascending order
print("Sorting tensor in \
ascending order along the column:")
values, indices = torch.sort(tensor, dim = 0)
  
# printing  values in sorted tensor
print("Sorted values:\n", values)
  
# print indices of values in sorted tensor
print("Indices:\n", indices)
  
# sorting the tensor in  descending order
print("Sorting tensor in \
descending order along the column:")
values, indices = torch.sort(tensor, dim = 0,
                             descending=True)
  
# printing  values in sorted tensor
print("Sorted values:\n", values)
  
# print indices of values in sorted tensor
print("Indices:\n", indices)

输出:

示例 3:

在这个例子中,我们按照行的升序和降序对二维张量的元素进行排序。

Python3

# importing the library
import torch
  
# define a 2D torch tensor
tensor = torch.tensor([[43, 31, -92], 
                       [3, -4.3, 53], 
                       [-4.2, 7, -6.2]])
print("Tensor:\n", tensor)
  
# sorting the tensor in  ascending order
print("Sorting tensor in \
ascending order along the row:")
values, indices = torch.sort(tensor, dim=1)
  
print("Sorted values:\n", values)
  
# print indices of values in sorted tensor
print("Indices:\n", indices)
  
# sorting the tensor in  descending order
print("Sorting tensor in \
descending order along the row:")
values, indices = torch.sort(tensor,
                             dim=1,
                             descending=True)
  
# printing  values in sorted tensor
print("Sorted values:\n", values)
  
# printing indices of values in sorted tensor
print("Indices:\n", indices)

输出: