numpy:在不广播的情况下乘以数组



我不知道这个问题的标题是否有意义,但我想不出更好的标题。

请考虑以下情况:我有两个 numpy 数组abb形状(2, 2, 5)a形状(5,)。我想将a的每个元素乘以b中的每个2x2 数组,我不想a中的每个元素乘以b中每个 2x2 数组中的每个元素,这就是如果我简单地这样做会发生什么a * b

以下代码演示了我的问题和我想要的结果:

import numpy as np
np.random.seed(1234)

class MyClass:
def __mul__(self, other):
print(f'{type(self).__name__} * {other}')
return other
def __rmul__(self, other):
print(f'{other} * {type(self).__name__}')
return other

a = np.full((5,), MyClass())
b = np.random.uniform(-1, 1, (2, 2, 5))
a * b
# MyClass * -0.6169610992422154
# MyClass * 0.24421754207966373
# ...
# MyClass * 0.5456532432247481
# MyClass * 0.7652823812722331
# Desired result:
[ai * bi for ai, bi in zip(a, np.moveaxis(b, -1, 0))]
# MyClass * [[-0.6169611  -0.45481479]
#  [-0.28436546  0.12239237]]
# ...
# MyClass * [[ 0.55995162  0.75186527]
#  [-0.25949849  0.76528238]]
# EDIT: Solution suggested by "Guimoute", does the same as a * b but also introduces addition (not a solution).
b @ a
# -0.6169610992422154 * MyClass
# 0.24421754207966373 * MyClass
# ...
# 0.5456532432247481 * MyClass
# 0.7652823812722331 * MyClass
c = np.split(np.moveaxis(b, -1, 0).reshape(-1, 2), 5)
# c is now a list of numpy arrays, with length = 5
a * c
# Raises an exception
# ValueError: operands could not be broadcast together with shapes (5,) (5,2,2)

编辑:我可能应该澄清一下,我要求的是一个numpy解决方案,而不是涉及python循环的解决方案,即:[ai * bi for ai, bi in zip(a, np.moveaxis(b, -1, 0))]

编辑:我添加了"Guimoute"的建议,涉及matmul @运算符,可以看出这显然不是解决方案。

编辑:为了解决难以验证潜在解决方案而不必检查某些打印输出的顺序的投诉,我添加了以下示例,其中包括一个函数来检查函数的作用是否是解决方案:

import numpy as np
np.random.seed(1234)

class MyClass:
def __mul__(self, other):
return self
def __rmul__(self, other):
return self

def solution_involving_python_loop(a, b):
return np.array([ai * bi for ai, bi in zip(a, np.moveaxis(b, -1, 0))])
def is_valid_solution(func):
return func(a, b).ndim == 1

a = np.full((5,), MyClass())
b = np.random.uniform(-1, 1, (2, 2, 5))

print(is_valid_solution(solution_involving_python_loop))
# True

编辑:为了响应"hpaulj"的答案,我添加了以下示例,它的作用不仅仅是乘以对同一对象的引用数组,该数组不执行任何操作。在这里,我再次尝试澄清标准 numpy 乘法和我正在寻找的乘法类型之间的区别

import numpy as np
np.random.seed(1234)

class Symbol:
def __init__(self, name):
self.name = name
def __repr__(self):
return self.name
def __mul__(self, other):
return Mul(self, other)
def __rmul__(self, other):
return Mul(other, self)

class Mul:
def __init__(self, a, b):
self.a = a
self.b = b
def __repr__(self):
return f'{self.a} * {self.b}'

a = np.array([Symbol(chr(i)) for i in range(ord('a'), ord('a') + 5)])
b = np.random.randint(0, 100, (2, 2, 5))
# multiplication with broadcasting
broadcast_result = a * b
print(broadcast_result)
# [[[a * 47 b * 83 c * 38 d * 53 e * 76]
#   [a * 24 b * 15 c * 49 d * 23 e * 26]]
#
#  [[a * 30 b * 43 c * 30 d * 26 e * 58]
#   [a * 92 b * 69 c * 80 d * 73 e * 47]]]
# desired result
desired_result = np.array([ai * bi for ai, bi in zip(a, np.moveaxis(b, -1, 0))])
for x in desired_result:
print(x)
# a * [[47 24]
#  [30 92]]
# b * [[83 15]
#  [43 69]]
# c * [[38 49]
#  [30 80]]
# d * [[53 23]
#  [26 73]]
# e * [[76 26]
#  [58 47]]
import numpy as np
np.random.seed(1234)

class MyClass:
def __mul__(self, other):
print(f'{type(self).__name__} * {other}')
return other
def __rmul__(self, other):
print(f'{other} * {type(self).__name__}')
return other

a = np.full((5,), MyClass())
b = np.random.uniform(-1, 1, (2, 2, 5))
your_sol = np.array([ai * bi for ai, bi in zip(a, np.moveaxis(b, -1, 0))])
your_required_array = np.moveaxis(a*b, -1, 0).astype(float)

我仍然对两种乘法之间的区别感到困惑。 随手它们听起来是一样的。 愿a是对象 dtype 的事实很重要。 但是,你知道你的a到底是什么吗?

让我们为您的类添加一个 repr:

In [122]: class MyClass:
...:     def __mul__(self, other):
...:         print(f'{type(self).__name__} * {other}')
...:         return other
...: 
...:     def __rmul__(self, other):
...:         print(f'{other} * {type(self).__name__}')
...:         return other
...:     def __repr__(self):
...:         return f'<{id(self)}>'
...: 
In [123]: MyClass()
Out[123]: <140080214874240>
In [124]: MyClass()
Out[124]: <140080207167392>
In [125]: a = np.full((5,), MyClass())
In [126]: a
Out[126]: 
array([<140080207226960>, <140080207226960>, <140080207226960>,
<140080207226960>, <140080207226960>], dtype=object)

请注意,a有 5 个对 SAME 对象的引用。 包含 5 个不同对象的数组:

In [127]: a1 = np.array([MyClass() for _ in range(5)])
In [128]: a1
Out[128]: 
array([<140080340066944>, <140080340068816>, <140080340066368>,
<140080205271872>, <140080205271488>], dtype=object)

让我们进一步更改它以在乘法中显示 id:

In [135]: class MyClass:
...:     def __mul__(self, other):
...:         print(f'{self} * {other}')
...:         return other
...: 
...:     def __rmul__(self, other):
...:         print(f'{other} * {self}')
...:         return other
...:     def __repr__(self):
...:         return f'<{id(self)}>'
In [138]: a*b
<140080213575568> * 0
<140080213575568> * 1
<140080213575568> * 2
<140080213575568> * 3
...
<140080213575568> * 19
Out[138]: 
array([[[0, 1, 2, 3, 4],
[5, 6, 7, 8, 9]],
[[10, 11, 12, 13, 14],
[15, 16, 17, 18, 19]]], dtype=object)

当物体不同时,我们看到重复的模式

In [139]: a1*b
<140080187967328> * 0        # 1
<140080187966320> * 1
<140080187967712> * 2
<140080187966368> * 3
<140080187968960> * 4
<140080187967328> * 5        # 2
<140080187966320> * 6
<140080187967712> * 7
<140080187966368> * 8
<140080187968960> * 9
<140080187967328> * 10       # 3
<140080187966320> * 11
<140080187967712> * 12
<140080187966368> * 13
<140080187968960> * 14
<140080187967328> * 15       # 4
<140080187966320> * 16
<140080187967712> * 17
<140080187966368> * 18
<140080187968960> * 19
Out[139]: 
array([[[0, 1, 2, 3, 4],
[5, 6, 7, 8, 9]],
[[10, 11, 12, 13, 14],
[15, 16, 17, 18, 19]]], dtype=object)

编辑

所以你想要的更像是这个列表理解:

In [142]: [a1[i]*b[:,:,i] for i in range(5)]
<140080187967328> * [[ 0  5]
[10 15]]
<140080187966320> * [[ 1  6]
[11 16]]
<140080187967712> * [[ 2  7]
[12 17]]
<140080187966368> * [[ 3  8]
[13 18]]
<140080187968960> * [[ 4  9]
[14 19]]

在 [139] 中,<140080187967328>也乘以 [0,5,10,15]。

如果我们从b创建一个 (5,) 对象数组,我们得到相同的配对:

In [143]: b1 = np.empty(5,object); b1[:] = [b[:,:,i] for i in range(5)]
In [144]: a1*b1
<140080187967328> * [[ 0  5]
[10 15]]
<140080187966320> * [[ 1  6]
[11 16]]
<140080187967712> * [[ 2  7]
[12 17]]
<140080187966368> * [[ 3  8]
[13 18]]
<140080187968960> * [[ 4  9]
[14 19]]
Out[144]:                         # same as b1
array([array([[ 0,  5],
[10, 15]]), array([[ 1,  6],
[11, 16]]), array([[ 2,  7],
[12, 17]]),
array([[ 3,  8],
[13, 18]]), array([[ 4,  9],
[14, 19]])], dtype=object)

对象 dtype 数组的数学运算发生在列表理解速度上。

解决方案:将b拆分为包含所有 2x2 数组的数组(使用 dtype=object)

c = np.empty((5,), dtype=object)
c[:] = np.split(np.moveaxis(b, -1, 0).reshape(-1, 2), 5)
# c is now an object array of numpy arrays, with length = 5
a * c
# MyClass * [[-0.6169611  -0.45481479]
#  [-0.28436546  0.12239237]]
# ...
# MyClass * [[ 0.55995162  0.75186527]
#  [-0.25949849  0.76528238]]

不是我喜欢的解决方案,但似乎没有人有更好的解决方案。

最新更新