获取numpy数组中所有值的相等掩码



假设我创建了以下numpy数组:

a = np.array([[3, 0, 3],
[1, 3, 3],
[1, 1, 3]])
b = np.array([0, 1, 2, 3])

如果我做a == 0,我得到:

array([[False,  True, False],
[False, False, False],
[False, False, False]])

如果我做a == 1,我得到:

array([[False, False, False],
[ True, False, False],
[ True,  True, False]])

以此类推。但是,如果我想得到一个数组包含所有掩码相对于所有条件a == nn属于b,我该如何进行?

np.array([a == n for n in b])做我想要的,但似乎不是很numpythonic。我还尝试了a == b,它只是返回False

只要a == b[:,None,None]和广播就可以了:

>>> a == b[:,None,None]
array([[[False,  True, False],
[False, False, False],
[False, False, False]],
[[False, False, False],
[ True, False, False],
[ True,  True, False]],
[[False, False, False],
[False, False, False],
[False, False, False]],
[[ True, False,  True],
[False,  True,  True],
[False, False,  True]]])

最新更新