numpy 元素乘法,具有多列

  • 本文关键字:元素 numpy python numpy
  • 更新时间 :
  • 英文 :


我有以下 numpy 代码。我有一个带有 3d 点的数组 (a(,另一个带有权重 (b(。我需要将a中的每一行乘以相应行中b的每个权重。我希望使这段代码更容易理解并消除循环。

a = np.array([[1, 2, 3], [4, 5, 6]])
b = np.array([[7, 8, 9, 10], [11, 12, 13, 14]])
c = np.zeros((2, 4, 3))
for i, row in enumerate(b):
for j, col in enumerate(row):
print('Mult:', a[i, :], '*', col)
c[i, j, :] = a[i, :] * col
print(c[0, :, :])
print(c[1, :, :])

这是输出。

Mult: [1 2 3] * 7
Mult: [1 2 3] * 8
Mult: [1 2 3] * 9
Mult: [1 2 3] * 10
Mult: [4 5 6] * 11
Mult: [4 5 6] * 12
Mult: [4 5 6] * 13
Mult: [4 5 6] * 14
[[ 7. 14. 21.]
[ 8. 16. 24.]
[ 9. 18. 27.]
[10. 20. 30.]]
[[44. 55. 66.]
[48. 60. 72.]
[52. 65. 78.]
[56. 70. 84.]]

您可以以不同的方式调整矩阵,然后执行逐元素乘法:

a[:,None,:] * b[:,:,None]

因此,如果a是一个m×n-矩阵,b是一个m×p-矩阵,我们得到一个m×p×n-张量。对于给定的示例数据,我们得到:

>>> a[:,None,:] * b[:,:,None]
array([[[ 7, 14, 21],
[ 8, 16, 24],
[ 9, 18, 27],
[10, 20, 30]],
[[44, 55, 66],
[48, 60, 72],
[52, 65, 78],
[56, 70, 84]]])

最新更新