如何获取np.maximum的索引



我知道np.maximum计算元素最大值,例如

>>> b = np.array([3, 6, 1])
>>> c = np.array([4, 2, 9])
>>> np.maximum(b, c)
array([4, 6, 9])

但是有什么方法可以得到这个指数吗?就像上面的例子一样,我也想要这样的东西,其中每个元组表示(哪个数组、索引(,它可以是元组、字典或其他东西。如果它能在3d数组上工作,就像输入的两个数组是3d数组一样,那就太好了。

array([(1, 0), (0, 1), (1, 2)])

您可以将两个1d数组堆叠以获得一个2d数组,并使用argmax:

arr = np.vstack((b, c))
indices = np.argmax(arr, axis=0)

这将为您提供一个整数列表,而不是元组列表,但正如您所知,您对每列进行比较,每个元组的最后元素无论如何都是不必要的。它们只是从0开始的递增整数。如果你真的需要它们,你可以添加

indices = list(zip(indices, range(len(b)))

最新更新