在python中用2D数组乘和4D数组的最快方法



这是我的问题。我有两个矩阵AB,它们分别具有维度为(n,n,m,m)(n,n)的复杂条目。

以下是我为获得矩阵C-而执行的操作

C = np.sum(B[:,:,None,None]*A, axis=(0,1))

计算上述内容一次大约需要6-8秒。由于我必须计算许多这样的C,所以需要花费大量时间。有更快的方法吗?(我在多核CPU上使用JAX NumPy进行这些操作;正常的NumPy需要更长的时间(

n=77m=512,如果你想知道的话。我可以在处理集群时进行并行化,但阵列的庞大规模会消耗大量内存。

看起来您想要einsum:

C = np.einsum('ijkl,ij->kl', A, B)

在Colab CPU上使用numpy,我得到了这个:

import numpy as np
x = np.random.rand(50, 50, 500, 500)
y = np.random.rand(50, 50)
def f1(x, y):
return np.sum(y[:,:,None,None]*x, axis=(0,1))
def f2(x, y):
return np.einsum('ijkl,ij->kl', x, y)
np.testing.assert_allclose(f1(x, y), f2(x, y))
%timeit f1(x, y)
# 1 loop, best of 5: 1.52 s per loop
%timeit f2(x, y)
# 1 loop, best of 5: 620 ms per loop

相关内容

  • 没有找到相关文章

最新更新