如果嵌套数组的最大值高于阈值,则用于获取嵌套数组的 numpy 条件



我有以下数组:

arr = numpy.array([[.5, .5], [.9, .1], [.8, .2]])

我想获取包含最大值大于或等于 .9 的数组的arr索引。因此,对于这种情况,结果将是[1]的,因为索引为 1 [.9, .1] 的数组是唯一最大值为>= 9 的数组。

我试过了:

>>> condition = np.max(arr) >= .9
>>> arr[condition]
array([ 0.5,  0.5])

但是,正如你所看到的,它产生了错误的答案。

我想你想在这里np.where。此函数返回满足特定条件的任何值的索引:

>>> np.where(arr >= 0.9)[0] # here we look at the whole 2D array
array([1])

np.where(arr >= 0.9) 返回索引数组的元组,数组的每个轴对应一个。预期的输出意味着您只需要行索引(轴 0)。

如果要先取每行的最大值,可以使用arr.max(axis=1)

>>> np.where(arr.max(axis=1) >= 0.9)[0] # here we look at the 1D array of row maximums
array([1])
In [18]: arr = numpy.array([[.5, .5], [.9, .1], [.8, .2]])
In [19]: numpy.argwhere(numpy.max(arr, 1) >= 0.9)
Out[19]: array([[1]])

你得到错误答案的原因是np.max(arr)给了你平展数组的最大值。 你想要np.max(arr, axis=1),或者更好的是,arr.max(axis=1).

(arr.max(axis=1)>=.9).nonzero()

沿轴使用 max 获取行最大值,然后where获取最大值的索引:

np.where(arr.max(axis=1)>=0.9)

最新更新