NumPy:如何检索多维数组中最大值的索引



使用以下数组:

In [103]: da                                                                                         
Out[103]: 
array([[[ 6, 22,  3],
[ 4,  9, 20],
[21, 16,  0]],
[[ 2, 25, 11],
[ 5, 17, 18],
[23, 13,  7]],
[[10, 14, 26],
[ 8,  1, 19],
[15, 12, 24]]])
In [104]: da.shape                                                                                   
Out[104]: (3, 3, 3)

具有最大值的元素的指数可以通过以下方式确定:

In [114]: np.unravel_index(np.argmax(da), da.shape)                                                  
Out[114]: (2, 0, 2)

并检查:

In [115]: da[2, 0, 2]                                                                                
Out[115]: 26

但是,在不循环/迭代的情况下,是否可以确定包含每组整数da[:, i1, i2]的最大值的9个索引,其中i1i2是0、1或2?

例如,组da[:, 0, 0]是6、2和10。最大值为10,其索引为da[2, 0, 0]

axis参数允许您指定一个操作轴:

i0 = np.argmax(da, axis=0)

这意味着i0是包含每个对应的i1i2的最大值的索引的(3, 3)阵列。在您的示例中,任何i1i2的最大值都是

da[i0[i1, i2], i1, i2]

最新更新