与 scipy csr_matrix 相乘时保留 numpy ndarray 子类



我想在使用 scipy csr_matrix进行矩阵向量乘法时保留 ndarray sublcass 的类型。

我的子类是

import numpy as np
class FlattenedMeshVector(np.ndarray):
__array_priority__ = 15

def __new__(cls, input_array):
obj = np.asarray(input_array).view(cls)
return obj

def __array_finalize__(self, obj):
if obj is None: return
self.nx = getattr(obj, 'nx', None)

def __array_wrap__(self, out_arr, context=None):
return super().__array_wrap__(self, out_arr, context)

由于我使用常规 ndarray(具有__array_priority__ = 0)设置__array_priority__ = 15矩阵向量乘法很好地保留了子

>>> a = FlattenedMeshVector([1,1,1,1])
>>> id_mat = np.diag(np.ones(4))
>>> id_mat.dot(a)
FlattenedMeshVector([1., 1., 1., 1.])

但是,当对 scipy 稀疏矩阵执行相同的操作时,子类类型被销毁,即使csr_matrix具有__array_priority__ = 10.1。 当使用自python 3.5以来首选的@运算符时,也会发生这种情况。

>>> from scipy.sparse import csr_matrix
>>> a = FlattenedMeshVector([1,1,1,1])
>>> id_mat = csr_matrix(np.diag(np.ones(4)))
>>> id_mat.dot(a)
array([1., 1., 1., 1.])
>>> id_mat @ a
array([1., 1., 1., 1.])

我假设 csr_matrix.dot 在某个时候会向 ndarray 进行一些转换。知道我该如何规避这一点吗?

Signature: M.__mul__(other)

有几个测试用例

if other.__class__ is np.ndarray:
self._mul_vector(other)     # or, depending on dimensions
self._mul_multivector(other)
if issparse(other):
self._mul_sparse_matrix(other)
other_a = np.asanyarray(other)      # anything else

乘法后,可能会将ndarray变成np.matrix

if isinstance(other, np.matrix):
result = asmatrix(result)

所以基本上这3种情况是:

In [646]: M@M
Out[646]: 
<11x11 sparse matrix of type '<class 'numpy.int64'>'
with 31 stored elements in Compressed Sparse Row format>
In [647]: type(M@M.A)
Out[647]: numpy.ndarray
In [648]: type(M@M.todense())
Out[648]: numpy.matrix

看起来在您的情况下,这些是相同的:

In [671]: id_mat@a
Out[671]: array([1., 1., 1., 1.])
In [672]: id_mat._mul_vector(a)
Out[672]: array([1., 1., 1., 1.])

_mul_vector确实:

result = np.zeros(M, dtype=upcast_char(self.dtype.char,
other.dtype.char))
# csr_matvec or csc_matvec
fn = getattr(_sparsetools, self.format + '_matvec')
fn(M, N, self.indptr, self.indices, self.data, other, result)

sparse._sparsetools.csr_matvec是"内置的",即编译的,可能来自cython代码。 在任何情况下,result都是具有正确形状和 dtype 以及计算值的np.zeros

所以从它对np.matrix的处理中获取线索,我认为你唯一的选择是

In [678]: id_mat@(a)
Out[678]: array([1., 1., 1., 1.])
In [679]: FlattenedMeshVector(id_mat@(a))
Out[679]: FlattenedMeshVector([1., 1., 1., 1.])

最新更新