括号引起的矩阵乘法执行时间差



给定两个 1Dnumpy数组ab

N = 100000
a = np.randn(N)
b = np.randn(N)

为什么以下两个表达式之间存在相当大的执行时间差异:

# expression 1
c = a @ a * b @ b
# expression 2
c = (a @ a) * (b @ b)

使用Jupyter Notebook的%timeit魔力,我得到了以下结果:

%timeit a @ a * b @ b

每个循环 223 μs ± 6.97 μs(7 次运行的平均标准±,每次 1000 次循环(

%timeit (a @ a( * (b @ b(

每个环路 17.4 μs ± 27.3 ns(平均 ± 标准偏差 7 次运行,每次 100000 次循环(

在这两个版本中,您都执行长度 N 向量的两个点积。但是,此外,第一个解决方案执行 N 次乘法,而第二个解决方案只需要一个。

a @ a * b @ b相当于((a @ a) * b) @ b

aa = a @ a  # N multiplications and additions -> scalar
aab = aa * b  # N multiplications -> vector
aabb = aab @ b  # N multiplications and additions -> scalar

(a @ a) * (b @ b)相当于

aa = a @ a  # N multiplications and additions -> scalar
bb = b @ b  # N multiplications and additions -> scalar
aabb = aa * bb  # 1 multiplication -> scalar

矩阵乘法性能可能取决于如何设置括号这一事实是众所周知的。存在通过利用这一事实来优化矩阵链乘法的算法。

更新:正如我刚刚了解到的,numpy 具有优化多个矩阵乘法的功能:numpy.linalg.multidot

最新更新