优化了在单热像素标签中制作空类



我正在为图像分割模型准备数据。我每个像素有 5 个类,累积起来不会覆盖整个图像,所以我想创建一个"空"类作为第 6 个类。现在我有一个独热编码的ndarray和一个解决方案,可以进行一堆我想要优化的Python调用。我现在的草图代码:

arrs.shape
(25, 25, 5)
null_class = np.zeros(arrs.shape[:-1])
for i in range(arrs.shape[0]):
    for j in range(arrs.shape[1]):
        if not np.any(arrs[i][j] == 1):
            null_class[i][j] = 1

理想情况下,我找到一种几行且性能更高的方法来计算空示例 - 我的实际训练数据来自 20K x 20K 图像,我想一次计算和存储所有数据。有什么建议吗?

我相信

你可以用numpy.wherenumpy.all的组合来做到这一点。使用 all 检查最后一个维度上的所有零将为您提供一个布尔数组,该数组True null_class1的位置。为了显示,我将使用(2,2,5)数组。

arr = np.random.randint(0, 2, size=(2,2,5))
null_class = np.zeros(arr.shape[:-1])
arr[0, 0] = [0, 0, 0, 0, 0]
arr
array([[[0, 0, 0, 0, 0],
        [1, 1, 1, 1, 1]],
       [[0, 0, 1, 0, 0],
        [0, 1, 1, 1, 0]]])
np.all(arr[:, :] == 0, axis=2)
array([[ True, False],
       [False, False]], dtype=bool)
np.where(np.all(arr[:, :] == 0, axis=2))
(array([0]), array([0]))
null_class[np.where(np.all(arr[:, :] == 0, axis=2)] = 1
null_class
array([[ 1.,  0.],
       [ 0.,  0.]])

最新更新