如何获取两个不同的numpy.ndarray子类__matmul__以返回特定的子类?



我有两个np.ndarray子类。Tuple @ Matrix返回一个Tuple,但Matrix @ Tuple返回一个Matrix。 我怎样才能让它返回Tuple

import numpy as np
class Tuple(np.ndarray):
def __new__(cls, input_array, info=None):
return np.asarray(input_array).view(cls)
class Matrix(np.ndarray):
def __new__(cls, input_array, info=None):
return np.asarray(input_array).view(cls)
def scaling(x, y, z):
m = Matrix(np.identity(4))
m[0, 0] = x
m[1, 1] = y
m[2, 2] = z
return m

例:

>>> Tuple([1,2,3,4]) @ scaling(2,2,2)
Tuple([2., 4., 6., 4.])
>>> scaling(2,2,2) @ Tuple([1,2,3,4])
Matrix([2., 4., 6., 4.])   # XXXX I'd like this to be a Tuple

PS:Matrix @ Matrix应该返回Matrix

我在从np.matrix示例中复制时犯了一个错误。

class Tuple(np.ndarray): 
__array_priority__ = 10 
def __new__(cls, input_array, info=None): 
return np.asarray(input_array).view(cls) 
class Matrix(np.ndarray):
__array_priority__ = 5.0 
def __new__(cls, input_array, info=None): 
return np.asarray(input_array).view(cls)
In [2]: def scaling(x, y, z):  
...:      ...:     m = Matrix(np.identity(4))  
...:      ...:     m[0, 0] = x  
...:      ...:     m[1, 1] = y  
...:      ...:     m[2, 2] = z  
...:      ...:     return m  
...:                                                                                                                                  
In [3]: Tuple([1,2,3,4]) @ scaling(2,2,2)                                                                                                
Out[3]: Tuple([2., 4., 6., 4.])
In [4]: scaling(2,2,2) @ Tuple([1,2,3,4])                                                                                                
Out[4]: Tuple([2., 4., 6., 4.])

===

np.matrix定义中获取线索:numpy.matrixlib.defmatrix.py

添加__array_priority__属性:

In [382]: class Tuple(np.ndarray): 
...:     def __new__(cls, input_array, info=None): 
...:         __array_priority = 10 
...:         return np.asarray(input_array).view(cls) 
...:  
...: class Matrix(np.ndarray): 
...:     def __new__(cls, input_array, info=None): 
...:         __array_priority = 5 
...:         return np.asarray(input_array).view(cls) 
...:                                                                                            
In [383]:                                                                                            
In [383]: def scaling(x, y, z): 
...:     m = Matrix(np.identity(4)) 
...:     m[0, 0] = x 
...:     m[1, 1] = y 
...:     m[2, 2] = z 
...:     return m 
...:                                                                                            
In [384]: Tuple([1,2,3,4]) @ scaling(2,2,2)                                                          
Out[384]: Tuple([2., 4., 6., 4.])
In [385]: scaling(2,2,2) @ Tuple([1,2,3,4])                                                          
Out[385]: Matrix([2., 4., 6., 4.])

您可以重载__matmul__方法以返回Tuple- 如果您想成为Tuple,如果任何变量是TupleMatrix否则,我认为这将起作用:

class Matrix(np.ndarray):
def __new__(cls, input_array, info=None):
return np.asarray(input_array).view(cls)
def __matmul__(m1, m2):
return (m2.T @ m1.T).T if isinstance(m2, Tuple) else np.matmul(m1, m2)

解决此问题的一种方法是在Matrix中实现自定义__matmul__,并在Tuple中实现__rmatmul__

import numpy as np
class Tuple(np.ndarray):
def __new__(cls, input_array, info=None):
return np.asarray(input_array).view(cls)
def __rmatmul__(self, other):
return super().__matmul__(other)
class Matrix(np.ndarray):
def __new__(cls, input_array, info=None):
return np.asarray(input_array).view(cls)
def __matmul__(self, other):
if not isinstance(other, Matrix):
return NotImplemented
return super().__matmul__(other)
def scaling(x, y, z):
m = Matrix(np.identity(4))
m[0, 0] = x
m[1, 1] = y
m[2, 2] = z
return m
scaling(2,2,2) @ scaling(2,2,2)
# Matrix([[4., 0., 0., 0.],
#         [0., 4., 0., 0.],
#         [0., 0., 4., 0.],
#         [0., 0., 0., 1.]])
Tuple([1,2,3,4]) @ scaling(2,2,2)
# Tuple([2., 4., 6., 4.])
scaling(2,2,2) @ Tuple([1,2,3,4])
# Tuple([2., 4., 6., 4.])

只需重载Matrix类的__matmul__即可返回元组

class Matrix(np.ndarray):
def __new__(cls, input_array, info=None):
return np.asarray(input_array).view(cls)
def __matmul__(self, other):
return other @ self

最新更新