我正在为图像分割模型准备数据。我每个像素有 5 个类,累积起来不会覆盖整个图像,所以我想创建一个"空"类作为第 6 个类。现在我有一个独热编码的ndarray和一个解决方案,可以进行一堆我想要优化的Python调用。我现在的草图代码:
arrs.shape
(25, 25, 5)
null_class = np.zeros(arrs.shape[:-1])
for i in range(arrs.shape[0]):
for j in range(arrs.shape[1]):
if not np.any(arrs[i][j] == 1):
null_class[i][j] = 1
理想情况下,我找到一种几行且性能更高的方法来计算空示例 - 我的实际训练数据来自 20K x 20K 图像,我想一次计算和存储所有数据。有什么建议吗?
我相信
你可以用numpy.where
和numpy.all
的组合来做到这一点。使用 all
检查最后一个维度上的所有零将为您提供一个布尔数组,该数组True
null_class
应1
的位置。为了显示,我将使用(2,2,5)
数组。
arr = np.random.randint(0, 2, size=(2,2,5))
null_class = np.zeros(arr.shape[:-1])
arr[0, 0] = [0, 0, 0, 0, 0]
arr
array([[[0, 0, 0, 0, 0],
[1, 1, 1, 1, 1]],
[[0, 0, 1, 0, 0],
[0, 1, 1, 1, 0]]])
np.all(arr[:, :] == 0, axis=2)
array([[ True, False],
[False, False]], dtype=bool)
np.where(np.all(arr[:, :] == 0, axis=2))
(array([0]), array([0]))
null_class[np.where(np.all(arr[:, :] == 0, axis=2)] = 1
null_class
array([[ 1., 0.],
[ 0., 0.]])