我可以围绕二维示例进行思考,并且我对"重复"维度导致乘法以及显式输出中省略维度导致跨维度求和的直觉。但是我正在尝试解析一个算法实现,该算法实现基本上只由三维数组上的 einsum 组成,我需要一本手册来解压缩公式。例如,我有一个形状[F][I][D]
数组x
数组和形状[F][I]
数组y
数组,然后einsum('fid,fi->fd', x, y)
生成一个形状[F][D]
数组,但我无法弄清楚乘法和求和的执行顺序和方向。
x = (np.arange(5*3*2)).reshape(5,3,2)
x
array([[[ 0, 1],
[ 2, 3],
[ 4, 5]],
[[ 6, 7],
[ 8, 9],
[10, 11]],
[[12, 13],
[14, 15],
[16, 17]],
[[18, 19],
[20, 21],
[22, 23]],
[[24, 25],
[26, 27],
[28, 29]]])
y = (np.arange(5*3)).reshape(5,3)
y
array([[ 0, 1, 2],
[ 3, 4, 5],
[ 6, 7, 8],
[ 9, 10, 11],
[12, 13, 14]])
np.einsum("fid,fi->fd",x,y)
array([[ 10, 13],
[ 100, 112],
[ 298, 319],
[ 604, 634],
[1018, 1057]])
有没有关于如何解压缩 einsum 字符串的手册,以便我得到"普通人"求和和乘法公式?
据我推断,结果是
out[f][d] = ∑_i x[f][i][d] * y[f][i]
所以再看'fid,fi->fd'
,因为i
不在 RHS 中,所以i
求和,你只需为元素乘法应用重复维度。
In [15]: x = (np.arange(5*3*2)).reshape(5,3,2)
In [16]: y = (np.arange(5*3)).reshape(5,3)
In [17]: np.einsum('fid,fi->fd',x,y)
Out[17]:
array([[ 10, 13],
[ 100, 112],
[ 298, 319],
[ 604, 634],
[1018, 1057]])
一些替代方案,使用广播和总和:
In [18]: (x*y[:,:,None]).sum(axis=1)
Out[18]:
array([[ 10, 13],
[ 100, 112],
[ 298, 319],
[ 604, 634],
[1018, 1057]])
和批量dot
:
In [19]: np.array([np.dot(b,a) for a,b in zip(x,y)])
Out[19]:
array([[ 10, 13],
[ 100, 112],
[ 298, 319],
[ 604, 634],
[1018, 1057]])
求和的i
维度是矩阵乘法的"乘积总和"维度。f
是批次维度,以所有术语出现在同一位置。
matmul/@
也进行批处理矩阵乘法,但它的应用并不那么直观:
In [21]: y[:,None,:]@x
Out[21]:
array([[[ 10, 13]],
[[ 100, 112]],
[[ 298, 319]],
[[ 604, 634]],
[[1018, 1057]]])
这是一个(5,1,2(,必须squeezed
才能摆脱中间维度。