加快简单的多维计数器代码



这是慢代码:

def doCounts(maskA1, maskA2, maskA3, counts, maskB):
counts[0, maskB & maskA1] += 1
counts[1, maskB & maskA2] += 1
counts[2, maskB & maskA3] += 1

有没有办法一次性完成/使其更快?

矢量化可能很困难或不可能。这里的提示是第二维度中的高级索引,例如maskB & maskA1,可以为每行提供任意True值。因此,您无法隔离用于索引的m x n数组。

使用numba的朴素for循环似乎可以通过一个因素提高性能:

# Python 3.6.5, NumPy 1.14.3, Numba 0.38.0
import numpy as np
from numba import njit
@njit
def doCounts(maskA1, maskA2, maskA3, counts, maskB):
mask1, mask2, mask3 = maskB & maskA1, maskB & maskA2, maskB & maskA3
for i in range(counts.shape[0]):
m1, m2, m3 = mask1[i], mask2[i], mask3[i]
for j in range(counts.shape[1]):
if m1:
counts[0, j] += 1
if m2:
counts[1, j] += 1
if m3:
counts[2, j] += 1
return counts
def doCounts_original(maskA1, maskA2, maskA3, counts, maskB):
counts[0, maskB & maskA1] += 1
counts[1, maskB & maskA2] += 1
counts[2, maskB & maskA3] += 1
return counts
n = 100
np.random.seed(0)
m1, m2, m3, mB = (np.random.randint(0, 2, n**3).astype(bool) for _ in range(4))
counts = np.random.randint(0, 100, (3, n**3))
assert np.array_equal(doCounts(m1, m2, m3, counts, mB),
doCounts_original(m1, m2, m3, counts, mB))
%timeit doCounts(m1, m2, m3, counts, mB)           # 5.36 ms
%timeit doCounts_original(m1, m2, m3, counts, mB)  # 40.2 ms

最新更新