一个 numpy.where 错误



我正在使用numpy.where来查找某些值的索引。但是,numpy.where 会产生错误的索引,如下所示。其他人可以解释为什么我得到如此错误的索引吗?

谢谢。

In [1]: d = np.random.rand(3,4)
In [2]: d
Out[2]: 
array([[ 0.11694612,  0.95137658,  0.70099781,  0.06730629],
       [ 0.59989836,  0.52586768,  0.45387929,  0.76093495],
       [ 0.036541  ,  0.91714289,  0.2246452 ,  0.40785078]])
In [3]: np.where(d>0.9)
Out[3]: (array([0, 2]), array([1, 1]))

然而

In [4]: d[0,2]
Out[4]: 0.70099781000000005
In[5]: d[1,1]
Out[5]: 0.52586767999999995

问题是np.where返回一个数组元组,其中索引位于条件所在的给定轴上。所以,也许这让它更清楚:

>>> import numpy as np
>>> d = np.array([[ 0.11694612,  0.95137658,  0.70099781,  0.06730629],
...        [ 0.59989836,  0.52586768,  0.45387929,  0.76093495],
...        [ 0.036541  ,  0.91714289,  0.2246452 ,  0.40785078]])
>>> x, y = np.where(d > 0.9)
>>> d[x[0],y[0]]
0.95137658000000003
>>> d[x[1],y[1]]
0.91714289000000004

请注意,这适用于索引numpy的工作方式:

>>> d[x,y]
array([ 0.95137658,  0.91714289])

请注意,这适用于任何维度:

>>> d.reshape(3,2,2)
array([[[ 0.11694612,  0.95137658],
        [ 0.70099781,  0.06730629]],
       [[ 0.59989836,  0.52586768],
        [ 0.45387929,  0.76093495]],
       [[ 0.036541  ,  0.91714289],
        [ 0.2246452 ,  0.40785078]]])
>>> d = d.reshape(3,2,2)
>>> x, y, z = np.where(d > 0.9)
>>> x
array([0, 2])
>>> y
array([0, 0])
>>> z
array([1, 1])
>>> d[x,y,z]
array([ 0.95137658,  0.91714289])

最新更新