基于二维布尔掩码的三维numpy数组的行列掩码



对于类似以下的三维立方体numpy数组:

import numpy as np
a = np.array([[[1,2,3],[4,5,6],[7,8,9]],[[10,11,12],[13,14,15],[16,17,18]],[[19,20,21],[22,23,24],[25,26,27]]])
array([[[ 1,  2,  3],
[ 4,  5,  6],
[ 7,  8,  9]],
[[10, 11, 12],
[13, 14, 15],
[16, 17, 18]],
[[19, 20, 21],
[22, 23, 24],
[25, 26, 27]]])

和一些二维布尔掩码数组,如下所示:

b = np.array([[0,1,1],[1,1,1],[1,1,0]])
array([[0, 1, 1],
[1, 1, 1],
[1, 1, 0]])

我想知道是否有一种方法,使用numpy运算来计算一个结果,这样对于b[i][j] = 0a[i,:,j] = 0a[i,j,:] = 0所在的所有元素。保证了bn x nan x n x n。在上面的例子中,结果看起来像

array([[[ 0,  0,  0],
[ 0,  5,  6],
[ 0,  8,  9]],
[[10, 11, 12],
[13, 14, 15],
[16, 17, 18]],
[[19, 20,  0],
[22, 23,  0],
[ 0,  0,  0]]])
In [111]: b = np.array([[0,1,1],[1,1,1],[1,1,0]])
In [116]: I,J = np.nonzero(b==0)
In [117]: I,J
Out[117]: (array([0, 2]), array([0, 2]))

测试索引:

In [118]: a[I,:,J]
Out[118]: 
array([[ 1,  4,  7],
[21, 24, 27]])
In [119]: a[I,J,:]
Out[119]: 
array([[ 1,  2,  3],
[25, 26, 27]])

应用:

In [120]: a[I,:,J]=0
In [121]: a[I,J,:]=0
In [122]: a
Out[122]: 
array([[[ 0,  0,  0],
[ 0,  5,  6],
[ 0,  8,  9]],
[[10, 11, 12],
[13, 14, 15],
[16, 17, 18]],
[[19, 20,  0],
[22, 23,  0],
[ 0,  0,  0]]])