根据条件沿列轴获取最大值的索引



我有一个形状NxM的numpy数组,其值在0到1之间。只有当值大于0.9,否则大于-1时,我才希望沿着列轴获得最大值的索引。

,

import numpy as np
arr = np.array([[0.6,0.9,1],[0.3,0.5,0.7]])

所以我需要上面数组array([2, -1])的最大索引输出。

我试过使用np.where

arr_filtered = np.where(arr>0.9,arr,-1)
max_index = np.argmax(arr_filtered,axis=1)

以上代码片段的输出是array([2, 0])。这和我的预期输出不匹配。有更简单的方法吗?

你可以试试:

  1. 查找每一行的max索引
  2. 检查每一行的最大值>0.9
  3. 合并以上两步的结果
  4. 如果值为零,用-1替换
arr = np.array([[0.6,0.9,1],[0.3,0.5,0.7]])
a = np.argmax(arr, axis=1)
# 1 -> array([2, 2])
b = np.max(arr,axis=1) > 0.9
# 2 -> array([ True, False])
c = a*b
# 3 -> array([2, 0])
c[c==0] = -1
print(c)
# 4 -> array([ 2, -1])

最新更新