📜  Python中的 numpy.argpartition()(1)

📅  最后修改于: 2023-12-03 15:19:25.381000             🧑  作者: Mango

Python中的 numpy.argpartition()

numpy.argpartition()是NumPy库中的一个函数,用于对数组进行局部排序并返回原始数组中的索引。该函数提供了一种高效的方法来查找数组中前k个最小(或最大)元素的索引。

语法
numpy.argpartition(arr, kth, axis=-1, kind='introselect', order=None)

参数说明:

  • arr:要排序的数组。
  • kth:指定元素的索引。
  • axis:沿着指定轴进行排序,默认为最后一个轴。
  • kind:排序的算法,默认为 ‘introselect’
  • order:如果数组是结构化类型,则可以指定排序顺序。
功能

numpy.argpartition()函数根据指定的kth值对数组进行局部排序。该函数返回原始数组中的索引,这些索引表示数组中的元素在排序后的数组中的位置。相对于较大的数组,argpartition()可以提供更好的性能,因为它仅对需要的元素进行排序,而不必对整个数组进行全排序。

示例

让我们通过几个示例来了解这个函数的用法。

示例 1

import numpy as np

arr = np.array([7, 2, 1, 3, 6, 5, 4])
indices = np.argpartition(arr, 3)

print("原始数组:", arr)
print("较小元素的索引:", indices[:3])

输出结果:

原始数组: [7 2 1 3 6 5 4]
较小元素的索引: [2 3 1]

在以上示例中,argpartition()函数对数组进行了局部排序,将较小的3个元素移到前面,并返回了它们的索引。

示例 2

import numpy as np

arr = np.array([[1, 5, 4], [3, 2, 6]])
indices = np.argpartition(arr, 2, axis=1)

print("原始数组:")
print(arr)
print("每行较小元素的索引:")
print(indices)

输出结果:

原始数组:
[[1 5 4]
 [3 2 6]]
每行较小元素的索引:
[[0 2 1]
 [1 0 2]]

以上示例展示了如何在二维数组中使用argpartition()函数。函数将每行的较小元素移到前面,并返回它们的索引。

总结

numpy.argpartition()函数是NumPy库中用于对数组进行局部排序并返回原始数组中的索引的强大工具。通过指定kth值,我们可以快速找到数组中前k个最小(或最大)元素的索引,并且对于较大的数组,它相对更高效。这个函数在处理大型数据集时特别有用。