我有一个形状NxM的numpy数组,其值在0到1之间。只有当值大于0.9
,否则大于-1
时,我才希望沿着列轴获得最大值的索引。
,
import numpy as np
arr = np.array([[0.6,0.9,1],[0.3,0.5,0.7]])
所以我需要上面数组array([2, -1])
的最大索引输出。
我试过使用np.where
arr_filtered = np.where(arr>0.9,arr,-1)
max_index = np.argmax(arr_filtered,axis=1)
以上代码片段的输出是array([2, 0])
。这和我的预期输出不匹配。有更简单的方法吗?
你可以试试:
- 查找每一行的max索引
- 检查每一行的最大值>0.9
- 合并以上两步的结果
- 如果值为零,用-1替换
arr = np.array([[0.6,0.9,1],[0.3,0.5,0.7]])
a = np.argmax(arr, axis=1)
# 1 -> array([2, 2])
b = np.max(arr,axis=1) > 0.9
# 2 -> array([ True, False])
c = a*b
# 3 -> array([2, 0])
c[c==0] = -1
print(c)
# 4 -> array([ 2, -1])