如何优化连续的numpy dot产品



我拥有的一些代码的瓶颈是:

for _ in range(n):
W = np.dot(A, W)

其中n可以变化,A是固定大小的MxM矩阵,W是Mx1。

有没有一个好的方法来优化它?

Numpy解决方案

由于np.dot只是形状的矩阵乘法,所以可以将您想要的内容写成A^n*W。其中"^是重复矩阵乘法";matrix_ power";和*矩阵乘法。因此,您可以将代码重写为

np.linalg.matrix_power(A,n)@W

线性代数解

你可以用线性代数做得更好。假设WA的特征向量,即A*W=a*W只有一个数字,那么它就跟在A^n*W=a^n*W.后面。现在你可能会认为可以,但如果W不是特征向量呢。由于矩阵乘法是线性的,所以如果W可以写成特征向量的线性组合也是好的,并且在W不能写成特征向量线性组合的情况下,甚至可以推广这个想法。如果你想阅读更多关于谷歌对角化约旦范式的信息。

最新更新