Python中大型对称半正定矩阵的有效矩阵平方根



我有一个大的(150000 x 150000(对称正半定样本协方差矩阵,我希望在Python中有效地计算其矩阵平方根。

如果矩阵是对称的psd,有什么方法可以加快平方根的计算吗?scipy.linal.sqrtm对我来说很慢。

根据您的应用程序,如果找到Bs @ Bs.T ~ B就足够了,则可以使用Cholesky分解。如果没有,你可以得到基于特征值分解的平方根。

import numpy as np;
import scipy.linalg
A = np.random.randn(1500, 1500)
%%time
Bs = scipy.linalg.sqrtm(B)

Wall time: 4.4 s-我们的基线

%%time
Bs = scipy.linalg.cholesky(B)

Wall time: 52 msCholesky的要快得多

D, V = scipy.linalg.eigh(B)
Bs = (V * np.sqrt(D)) @ V.T

Wall time: 1.62 s比快两倍多(它探索对称性(

使用pytorch

Pytorch支持一些线性代数函数,它们可以跨多个CPU进行矢量化

import torch.linalg
B_cpu = torch.tensor(B, device='cpu')

平方根使用eigh(12个逻辑/6个物理CPU(

%%time
D, V = torch.linalg.eigh(B_cpu)
Bs = (V * torch.sqrt(D)) @ V.T

Wall time: 400 ms

或Cholesky分解

Bs = torch.linalg.cholesky(B_cpu)

Wall time: 27 ms

为什么是新答案

另一个答案是Cholesky分解,它只适用于正定矩阵。虽然该限制非常适合样本协方差矩阵,但该问题指定了一个正的半定矩阵,这将导致Scipy/Numbery中的LinAlgErroreigeigh都适用于半正定,但正如另一个答案所示,它们是缓慢的。有一个与Cholesky有关的半定分解,即LDLt分解。如果D是对角的,我们可以简单地取元素的平方根,就像在本征分解中一样。然而,在实践中,scipy使用产生块对角线D.的快速算法

解决方案

因为LDL算法产生的D最多有2x2个块,所以我们可以从eigh升级到eigh_tridiagonal,这要快得多。此外,保证D具有与原始矩阵相同的(半(确定性,保证了非负特征值。

时间

以下代码将eigh与组合ldl+eigh_tridiagonal进行比较

import numpy as np
from scipy.linalg import ldl, eigh, eigh_tridiagonal, cholesky
import time
ldl_time = 0
eig_time = 0
tri_time = 0
n_trials = 20
size = 400
for trial in range(n_trials):
arr = np.random.normal(size=(size,size))
arr = arr @ arr.T
a = time.time()
vals, vecs = eigh(arr)
eig_time += time.time() - a
a = time.time()
l, d, p = ldl(arr)
ldl_time += time.time() - a
a = time.time()
w, v = eigh_tridiagonal(np.diag(d), np.diag(d, 1))
tri_time += time.time() - a
a = time.time()
L = cholesky(arr)
chol_time += time.time() - a
print(f"LDL alone took {ldl_time:.3f} seconds")
print(f"    Additional factorization of D took {tri_time:.3f} seconds")
print(f"On the other hand, eigendecomposition took {eig_time:.3f} seconds")
LDL alone took 0.444 seconds
Additional factorization of D took 0.041 seconds
On the other hand, eigendecomposition took 2.145 seconds

我们可以用重建原始矩阵

np.linalg.norm(l @ (v * w) @ v.T @ l.T - arr)
3.4043016350272527e-12

最新更新