我有以下数组:
arr = numpy.array([[.5, .5], [.9, .1], [.8, .2]])
我想获取包含最大值大于或等于 .9 的数组的arr
索引。因此,对于这种情况,结果将是[1]
的,因为索引为 1 [.9, .1]
的数组是唯一最大值为>= 9 的数组。
我试过了:
>>> condition = np.max(arr) >= .9
>>> arr[condition]
array([ 0.5, 0.5])
但是,正如你所看到的,它产生了错误的答案。
我想你想在这里np.where
。此函数返回满足特定条件的任何值的索引:
>>> np.where(arr >= 0.9)[0] # here we look at the whole 2D array
array([1])
( np.where(arr >= 0.9)
返回索引数组的元组,数组的每个轴对应一个。预期的输出意味着您只需要行索引(轴 0)。
如果要先取每行的最大值,可以使用arr.max(axis=1)
:
>>> np.where(arr.max(axis=1) >= 0.9)[0] # here we look at the 1D array of row maximums
array([1])
In [18]: arr = numpy.array([[.5, .5], [.9, .1], [.8, .2]])
In [19]: numpy.argwhere(numpy.max(arr, 1) >= 0.9)
Out[19]: array([[1]])
你得到错误答案的原因是np.max(arr)
给了你平展数组的最大值。 你想要np.max(arr, axis=1)
,或者更好的是,arr.max(axis=1)
.
(arr.max(axis=1)>=.9).nonzero()
沿轴使用 max
获取行最大值,然后where
获取最大值的索引:
np.where(arr.max(axis=1)>=0.9)