在 Python 中对矩阵对角线位置的行条目求和的快速方法



嗨,我正在尝试解决以下等式,其中A是一个稀疏矩阵,ptotal是一个数字数组。我必须在对角线位置对所有条目进行汇总。

A[ptotal, ptotal] = -sum(A[ptotal, :])

代码似乎给出了正确的答案,但由于我的 ptotal 数组几乎很长(100000 个条目),因此计算效率低下。有没有快速的方法来解决这个问题。

首先是密集数组版本:

In [87]: A = np.arange(36).reshape(6,6)
In [88]: ptotal = np.arange(6)

假设ptotal是所有行索引,则可以将其替换为sum方法调用:

In [89]: sum(A[ptotal,:])
Out[89]: array([ 90,  96, 102, 108, 114, 120])
In [90]: A.sum(axis=0)
Out[90]: array([ 90,  96, 102, 108, 114, 120])

我们可以在对角线上使用这些值创建一个数组:

In [92]: np.diagflat(A.sum(axis=0))
Out[92]: 
array([[ 90,   0,   0,   0,   0,   0],
       [  0,  96,   0,   0,   0,   0],
       [  0,   0, 102,   0,   0,   0],
       [  0,   0,   0, 108,   0,   0],
       [  0,   0,   0,   0, 114,   0],
       [  0,   0,   0,   0,   0, 120]])

将其添加到原始数组中 - 结果是一个"零和"数组:

In [93]: A -= np.diagflat(A.sum(axis=0))
In [94]: A
Out[94]: 
array([[-90,   1,   2,   3,   4,   5],
       [  6, -89,   8,   9,  10,  11],
       [ 12,  13, -88,  15,  16,  17],
       [ 18,  19,  20, -87,  22,  23],
       [ 24,  25,  26,  27, -86,  29],
       [ 30,  31,  32,  33,  34, -85]])
In [95]: A.sum(axis=0)
Out[95]: array([0, 0, 0, 0, 0, 0])

我们可以对稀疏做同样的事情

In [99]: M = sparse.csr_matrix(np.arange(36).reshape(6,6))
In [100]: M
Out[100]: 
<6x6 sparse matrix of type '<class 'numpy.int32'>'
    with 35 stored elements in Compressed Sparse Row format>
In [101]: M.sum(axis=0)
Out[101]: matrix([[ 90,  96, 102, 108, 114, 120]], dtype=int32)

稀疏对角矩阵:

In [104]: sparse.dia_matrix((M.sum(axis=0),0),M.shape)
Out[104]: 
<6x6 sparse matrix of type '<class 'numpy.int32'>'
    with 6 stored elements (1 diagonals) in DIAgonal format>
In [105]: _.A
Out[105]: 
array([[ 90,   0,   0,   0,   0,   0],
       [  0,  96,   0,   0,   0,   0],
       [  0,   0, 102,   0,   0,   0],
       [  0,   0,   0, 108,   0,   0],
       [  0,   0,   0,   0, 114,   0],
       [  0,   0,   0,   0,   0, 120]], dtype=int32)

取差价,得到一个新的矩阵:

In [106]: M-sparse.dia_matrix((M.sum(axis=0),0),M.shape)
Out[106]: 
<6x6 sparse matrix of type '<class 'numpy.int32'>'
    with 36 stored elements in Compressed Sparse Row format>
In [107]: _.A
Out[107]: 
array([[-90,   1,   2,   3,   4,   5],
       [  6, -89,   8,   9,  10,  11],
       [ 12,  13, -88,  15,  16,  17],
       [ 18,  19,  20, -87,  22,  23],
       [ 24,  25,  26,  27, -86,  29],
       [ 30,  31,  32,  33,  34, -85]], dtype=int32)

还有一种setdiag方法

In [117]: M.setdiag(-M.sum(axis=0).A1)
/usr/local/lib/python3.5/dist-packages/scipy/sparse/compressed.py:774: SparseEfficiencyWarning: Changing the sparsity structure of a csr_matrix is expensive. lil_matrix is more efficient.
  SparseEfficiencyWarning)
In [118]: M.A
Out[118]: 
array([[ -90,    1,    2,    3,    4,    5],
       [   6,  -96,    8,    9,   10,   11],
       [  12,   13, -102,   15,   16,   17],
       [  18,   19,   20, -108,   22,   23],
       [  24,   25,   26,   27, -114,   29],
       [  30,   31,   32,   33,   34, -120]], dtype=int32)

Out[101]是二维矩阵; .A1将其转换为setdiag可以使用的一维数组。

稀疏效率警告更多地针对迭代使用,而不是像这样的一次性应用程序。 不过,查看setdiag代码,我怀疑第一种方法更快。 但我们真的需要做时间测试。

最新更新