我想在使用 scipy csr_matrix进行矩阵向量乘法时保留 ndarray sublcass 的类型。
我的子类是
import numpy as np
class FlattenedMeshVector(np.ndarray):
__array_priority__ = 15
def __new__(cls, input_array):
obj = np.asarray(input_array).view(cls)
return obj
def __array_finalize__(self, obj):
if obj is None: return
self.nx = getattr(obj, 'nx', None)
def __array_wrap__(self, out_arr, context=None):
return super().__array_wrap__(self, out_arr, context)
由于我使用常规 ndarray(具有__array_priority__ = 0
)设置__array_priority__ = 15
矩阵向量乘法很好地保留了子
>>> a = FlattenedMeshVector([1,1,1,1])
>>> id_mat = np.diag(np.ones(4))
>>> id_mat.dot(a)
FlattenedMeshVector([1., 1., 1., 1.])
但是,当对 scipy 稀疏矩阵执行相同的操作时,子类类型被销毁,即使csr_matrix具有__array_priority__ = 10.1
。 当使用自python 3.5以来首选的@
运算符时,也会发生这种情况。
>>> from scipy.sparse import csr_matrix
>>> a = FlattenedMeshVector([1,1,1,1])
>>> id_mat = csr_matrix(np.diag(np.ones(4)))
>>> id_mat.dot(a)
array([1., 1., 1., 1.])
>>> id_mat @ a
array([1., 1., 1., 1.])
我假设 csr_matrix.dot 在某个时候会向 ndarray 进行一些转换。知道我该如何规避这一点吗?
Signature: M.__mul__(other)
有几个测试用例
if other.__class__ is np.ndarray:
self._mul_vector(other) # or, depending on dimensions
self._mul_multivector(other)
if issparse(other):
self._mul_sparse_matrix(other)
other_a = np.asanyarray(other) # anything else
乘法后,可能会将ndarray
变成np.matrix
:
if isinstance(other, np.matrix):
result = asmatrix(result)
所以基本上这3种情况是:
In [646]: M@M
Out[646]:
<11x11 sparse matrix of type '<class 'numpy.int64'>'
with 31 stored elements in Compressed Sparse Row format>
In [647]: type(M@M.A)
Out[647]: numpy.ndarray
In [648]: type(M@M.todense())
Out[648]: numpy.matrix
看起来在您的情况下,这些是相同的:
In [671]: id_mat@a
Out[671]: array([1., 1., 1., 1.])
In [672]: id_mat._mul_vector(a)
Out[672]: array([1., 1., 1., 1.])
_mul_vector
确实:
result = np.zeros(M, dtype=upcast_char(self.dtype.char,
other.dtype.char))
# csr_matvec or csc_matvec
fn = getattr(_sparsetools, self.format + '_matvec')
fn(M, N, self.indptr, self.indices, self.data, other, result)
sparse._sparsetools.csr_matvec
是"内置的",即编译的,可能来自cython
代码。 在任何情况下,result
都是具有正确形状和 dtype 以及计算值的np.zeros
。
所以从它对np.matrix
的处理中获取线索,我认为你唯一的选择是
In [678]: id_mat@(a)
Out[678]: array([1., 1., 1., 1.])
In [679]: FlattenedMeshVector(id_mat@(a))
Out[679]: FlattenedMeshVector([1., 1., 1., 1.])