快速组合,无需替换数组 - NumPy / Python



我正在从一维数组生成有效的成对组合。Itertools效率太低,如果n>1000

E.g. [1, 2, 3, 4]
magic code...
Out[2]:
array([[1, 2],
[1, 3],
[1, 4],
[2, 3],
[2, 4],
[3, 4]])

最接近它的东西是这里。

I. 成对组合

一种方法是使用numba来获得内存,从而获得性能效率 -

from numba import njit
@njit
def pairwise_combs_numba(a):
n = len(a)
L = n*(n-1)//2
out = np.empty((L,2),dtype=a.dtype)
iterID = 0
for i in range(n):
for j in range(i+1,n):
out[iterID,0] = a[i]
out[iterID,1] = a[j]
iterID += 1
return out

另一个基于 NumPy 的方法是使用np.broadcast_to来获取网格视图,然后屏蔽 -

def pairwise_combs_mask(a):
n = len(a)
L = n*(n-1)//2
out = np.empty((L,2),dtype=a.dtype)
m = ~np.tri(len(a),dtype=bool)
out[:,0] = np.broadcast_to(a[:,None],(n,n))[m]
out[:,1] = np.broadcast_to(a,(n,n))[m]
return out

二、三重组合

我们将扩展相同的方法来获得三重组合 -

@njit
def triplet_combs_numba(a):
n = len(a)
L = n*(n-1)*(n-2)//6
out = np.empty((L,3),dtype=a.dtype)
iterID = 0
for i in range(n):
for j in range(i+1,n):
for k in range(j+1,n):
out[iterID,0] = a[i]
out[iterID,1] = a[j]
out[iterID,2] = a[k]
iterID += 1
return out
def triplet_combs_mask(a):
n = len(a)
L = n*(n-1)*(n-2)//6
out = np.empty((L,3),dtype=a.dtype)
r = np.arange(n)
m = (r[:,None,None]<r[:,None]) & (r[:,None]<r)
out[:,0] = np.broadcast_to(a[:,None,None],(n,n,n))[m]
out[:,1] = np.broadcast_to(a[None,:,None],(n,n,n))[m]
out[:,2] = np.broadcast_to(a[None,None,:],(n,n,n))[m]
return out

更高阶的组合也将同样扩展。

示例运行 -

In [54]: a = np.array([3,9,4,1,7])
In [55]: pairwise_combs_numba(a)
Out[55]: 
array([[3, 9],
[3, 4],
[3, 1],
[3, 7],
[9, 4],
[9, 1],
[9, 7],
[4, 1],
[4, 7],
[1, 7]])
In [56]: triplet_combs_numba(a)
Out[56]: 
array([[3, 9, 4],
[3, 9, 1],
[3, 9, 7],
[3, 4, 1],
[3, 4, 7],
[3, 1, 7],
[9, 4, 1],
[9, 4, 7],
[9, 1, 7],
[4, 1, 7]])

时序(包括 Python 的内置 -itertools.combinations( -

In [68]: a = np.random.rand(4000)
In [69]: %timeit pairwise_combs_numba(a)
...: %timeit pairwise_combs_mask(a)
...: %timeit list(itertools.combinations(a, 2))
10 loops, best of 3: 52.2 ms per loop
10 loops, best of 3: 146 ms per loop
1 loop, best of 3: 597 ms per loop
In [70]: a = np.random.rand(400)
In [71]: %timeit triplet_combs_numba(a)
...: %timeit triplet_combs_mask(a)
...: %timeit list(itertools.combinations(a, 3))
10 loops, best of 3: 98.5 ms per loop
1 loop, best of 3: 352 ms per loop
1 loop, best of 3: 795 ms per loop

最新更新