Numpy中数组的数组中的元素明智比较



我有以下shape(5,2,3)数组,它是2 * 3数组的集合。

a = array([[[ 0,  2,  0],
    [ 3,  1,  1]],
   [[ 1,  1,  0],
    [ 2,  2,  1]],
   [[ 0,  1,  0],
    [ 3,  2,  1]],
   [[-1,  2,  0],
    [ 4,  1,  1]],
   [[ 1,  0,  0],
    [ 2,  3,  1]]])

1)我如何检查是否存在一个2 * 3数组在这个数组的数组中至少有一个元素是负的?

#which is this:
[[-1,  2,  0],
[ 4,  1,  1]]

2)之后,我如何从a中删除上述发现的2 * 3数组?

矢量化的实现是非常值得赞赏的,但是循环也很好。

你可以做-

a[~(a<0).any(axis=(1,2))]

或与.all()等价,从而避免inverting -

a[(a>=0).all(axis=(1,2))]

示例运行-

In [35]: a
Out[35]: 
array([[[ 0,  2,  0],
        [ 3,  1,  1]],
       [[ 1,  1,  0],
        [ 2,  2,  1]],
       [[ 0,  1,  0],
        [ 3,  2,  1]],
       [[-1,  2,  0],
        [ 4,  1,  1]],
       [[ 1,  0,  0],
        [ 2,  3,  1]]])
In [36]: a[~(a<0).any(axis=(1,2))]
Out[36]: 
array([[[0, 2, 0],
        [3, 1, 1]],
       [[1, 1, 0],
        [2, 2, 1]],
       [[0, 1, 0],
        [3, 2, 1]],
       [[1, 0, 0],
        [2, 3, 1]]])

使用any:

In [10]: np.any(a<0,axis=-1)
Out[10]: 
array([[False, False],
       [False, False],
       [False, False],
       [ True, False],
       [False, False]], dtype=bool)

或者更完整,如果您想要(2,3)数组的对应索引:

In [22]: np.where(np.any(a<0,axis=-1).any(axis=-1))
Out[22]: (array([3]),)
# Or as mentioned in comment you can pass a tuple to `any` np.where(np.any(a<0,axis=(1, 2)))

你也可以通过一个简单的索引得到数组:

In [27]: a[np.any(a<0, axis=(1, 2))]
Out[27]: 
array([[[-1,  2,  0],
        [ 4,  1,  1]]])

相关内容

  • 没有找到相关文章

最新更新