📜  张量中的最大值索引 - Python (1)

📅  最后修改于: 2023-12-03 14:54:13.089000             🧑  作者: Mango

张量中的最大值索引 - Python

在Python的NumPy库中,可以使用argmax()方法来获取张量中最大值的索引。该方法可以应用于一维,二维或更高维度的张量。

一维张量

对于一维张量,argmax()方法可以直接作用于该张量,如下所示:

import numpy as np

a = np.array([1, 5, 3, 9, 7])
max_index = np.argmax(a)
print(max_index)

以上代码输出结果为:

3

此时,最大值是9,其在一维张量a中的索引为3。

二维张量

对于二维张量,可以指定行或列方向上求最大值。例如,对于以下二维张量:

import numpy as np

a = np.array([[1, 2, 3], [4, 5, 6], [7, 8, 9]])

若想获取每行最大值的索引,可以使用如下代码:

max_indices_rows = np.argmax(a, axis=1)
print(max_indices_rows)

输出结果为:

[2 2 2]

可见,每行最大值的索引分别为2,即第三列。

若想获取每列最大值的索引,可以使用如下代码:

max_indices_cols = np.argmax(a, axis=0)
print(max_indices_cols)

输出结果为:

[2 2 2]

此时,每列最大值的索引同样为2,即第三行。

更高维度张量

对于更高维度的张量,仍然可以指定某个方向求最大值。例如,对于以下三维张量:

import numpy as np

a = np.array([[[1, 2], [3, 4], [5, 6]], [[7, 8], [9, 10], [11, 12]]])

若想获取每个二维子张量中最大值的索引,可以使用如下代码:

max_indices = np.argmax(a, axis=2)
print(max_indices)

输出结果为:

[[1 1 1]
 [1 1 1]]

此时,每个二维子张量中最大值的索引均为1,即第二列。