我有以下 numpy 代码。我有一个带有 3d 点的数组 (a
(,另一个带有权重 (b
(。我需要将a
中的每一行乘以相应行中b
的每个权重。我希望使这段代码更容易理解并消除循环。
a = np.array([[1, 2, 3], [4, 5, 6]])
b = np.array([[7, 8, 9, 10], [11, 12, 13, 14]])
c = np.zeros((2, 4, 3))
for i, row in enumerate(b):
for j, col in enumerate(row):
print('Mult:', a[i, :], '*', col)
c[i, j, :] = a[i, :] * col
print(c[0, :, :])
print(c[1, :, :])
这是输出。
Mult: [1 2 3] * 7
Mult: [1 2 3] * 8
Mult: [1 2 3] * 9
Mult: [1 2 3] * 10
Mult: [4 5 6] * 11
Mult: [4 5 6] * 12
Mult: [4 5 6] * 13
Mult: [4 5 6] * 14
[[ 7. 14. 21.]
[ 8. 16. 24.]
[ 9. 18. 27.]
[10. 20. 30.]]
[[44. 55. 66.]
[48. 60. 72.]
[52. 65. 78.]
[56. 70. 84.]]
您可以以不同的方式调整矩阵,然后执行逐元素乘法:
a[:,None,:] * b[:,:,None]
因此,如果a
是一个m×n-矩阵,b
是一个m×p-矩阵,我们得到一个m×p×n-张量。对于给定的示例数据,我们得到:
>>> a[:,None,:] * b[:,:,None]
array([[[ 7, 14, 21],
[ 8, 16, 24],
[ 9, 18, 27],
[10, 20, 30]],
[[44, 55, 66],
[48, 60, 72],
[52, 65, 78],
[56, 70, 84]]])