嗨,我正在尝试解决以下等式,其中A
是一个稀疏矩阵,ptotal
是一个数字数组。我必须在对角线位置对所有条目进行汇总。
A[ptotal, ptotal] = -sum(A[ptotal, :])
代码似乎给出了正确的答案,但由于我的 ptotal 数组几乎很长(100000 个条目),因此计算效率低下。有没有快速的方法来解决这个问题。
首先是密集数组版本:
In [87]: A = np.arange(36).reshape(6,6)
In [88]: ptotal = np.arange(6)
假设ptotal
是所有行索引,则可以将其替换为sum
方法调用:
In [89]: sum(A[ptotal,:])
Out[89]: array([ 90, 96, 102, 108, 114, 120])
In [90]: A.sum(axis=0)
Out[90]: array([ 90, 96, 102, 108, 114, 120])
我们可以在对角线上使用这些值创建一个数组:
In [92]: np.diagflat(A.sum(axis=0))
Out[92]:
array([[ 90, 0, 0, 0, 0, 0],
[ 0, 96, 0, 0, 0, 0],
[ 0, 0, 102, 0, 0, 0],
[ 0, 0, 0, 108, 0, 0],
[ 0, 0, 0, 0, 114, 0],
[ 0, 0, 0, 0, 0, 120]])
将其添加到原始数组中 - 结果是一个"零和"数组:
In [93]: A -= np.diagflat(A.sum(axis=0))
In [94]: A
Out[94]:
array([[-90, 1, 2, 3, 4, 5],
[ 6, -89, 8, 9, 10, 11],
[ 12, 13, -88, 15, 16, 17],
[ 18, 19, 20, -87, 22, 23],
[ 24, 25, 26, 27, -86, 29],
[ 30, 31, 32, 33, 34, -85]])
In [95]: A.sum(axis=0)
Out[95]: array([0, 0, 0, 0, 0, 0])
我们可以对稀疏做同样的事情
In [99]: M = sparse.csr_matrix(np.arange(36).reshape(6,6))
In [100]: M
Out[100]:
<6x6 sparse matrix of type '<class 'numpy.int32'>'
with 35 stored elements in Compressed Sparse Row format>
In [101]: M.sum(axis=0)
Out[101]: matrix([[ 90, 96, 102, 108, 114, 120]], dtype=int32)
稀疏对角矩阵:
In [104]: sparse.dia_matrix((M.sum(axis=0),0),M.shape)
Out[104]:
<6x6 sparse matrix of type '<class 'numpy.int32'>'
with 6 stored elements (1 diagonals) in DIAgonal format>
In [105]: _.A
Out[105]:
array([[ 90, 0, 0, 0, 0, 0],
[ 0, 96, 0, 0, 0, 0],
[ 0, 0, 102, 0, 0, 0],
[ 0, 0, 0, 108, 0, 0],
[ 0, 0, 0, 0, 114, 0],
[ 0, 0, 0, 0, 0, 120]], dtype=int32)
取差价,得到一个新的矩阵:
In [106]: M-sparse.dia_matrix((M.sum(axis=0),0),M.shape)
Out[106]:
<6x6 sparse matrix of type '<class 'numpy.int32'>'
with 36 stored elements in Compressed Sparse Row format>
In [107]: _.A
Out[107]:
array([[-90, 1, 2, 3, 4, 5],
[ 6, -89, 8, 9, 10, 11],
[ 12, 13, -88, 15, 16, 17],
[ 18, 19, 20, -87, 22, 23],
[ 24, 25, 26, 27, -86, 29],
[ 30, 31, 32, 33, 34, -85]], dtype=int32)
还有一种setdiag
方法
In [117]: M.setdiag(-M.sum(axis=0).A1)
/usr/local/lib/python3.5/dist-packages/scipy/sparse/compressed.py:774: SparseEfficiencyWarning: Changing the sparsity structure of a csr_matrix is expensive. lil_matrix is more efficient.
SparseEfficiencyWarning)
In [118]: M.A
Out[118]:
array([[ -90, 1, 2, 3, 4, 5],
[ 6, -96, 8, 9, 10, 11],
[ 12, 13, -102, 15, 16, 17],
[ 18, 19, 20, -108, 22, 23],
[ 24, 25, 26, 27, -114, 29],
[ 30, 31, 32, 33, 34, -120]], dtype=int32)
Out[101]
是二维矩阵; .A1
将其转换为setdiag
可以使用的一维数组。
稀疏效率警告更多地针对迭代使用,而不是像这样的一次性应用程序。 不过,查看setdiag
代码,我怀疑第一种方法更快。 但我们真的需要做时间测试。