从带索引的二维矩阵构建三维布尔矩阵,单位为numpy



我有一个形状为(3, 4)的2D矩阵,索引范围从0到8:

a = array([[0, 4, 1, 2],
[5, 0, 2, 3],
[8, 6, 0, 5]])

目前,我使用for循环来构建形状为(9, 3, 4)的3D布尔数组,该数组将True存储在每个索引的位置,用于0到8:之间的每一行

b = np.zeros((9, 3, 4), dtype=bool)
for i in range(9):
b[i] = np.where(a == i, True, False)

有没有一种方法可以在不迭代的情况下实现相同的结果,也许可以使用numpy函数?

这是你想要的东西吗?

import numpy as np
a = np.array([[0, 4, 1, 2],
[5, 0, 2, 3],
[8, 6, 0, 5]])
y, x = np.mgrid[0:a.shape[0], 0:a.shape[1]]
data = np.zeros((9,) + a.shape, dtype=bool)
data[a, y, x] = True

利用numpy广播的一个非常短的解决方案:

b = np.array([a]*9) == np.arange(9).reshape(-1,1,1)

输出:

>>> b
array([[[ True, False, False, False],
[False,  True, False, False],
[False, False,  True, False]],
[[False, False,  True, False],
[False, False, False, False],
[False, False, False, False]],
[[False, False, False,  True],
[False, False,  True, False],
[False, False, False, False]],
[[False, False, False, False],
[False, False, False,  True],
[False, False, False, False]],
[[False,  True, False, False],
[False, False, False, False],
[False, False, False, False]],
[[False, False, False, False],
[ True, False, False, False],
[False, False, False,  True]],
[[False, False, False, False],
[False, False, False, False],
[False,  True, False, False]],
[[False, False, False, False],
[False, False, False, False],
[False, False, False, False]],
[[False, False, False, False],
[False, False, False, False],
[False, False, False, False]]])