如何使用火炬张量运算符去除"for loop iteration"?



我想去掉"for循环迭代"在我的代码中使用Pytorch函数。但是公式太复杂了,我找不到任何线索。for循环是否可以迭代?在下面用火炬行动代替?

B=10
L=20
H=5
mat_A=torch.randn(B,L,L,H)
mat_B=torch.randn(L,B,B,H)
tmp_B=torch.zeros_like(mat_B)
for x in range(L):
for y in range(B):
for z in range(B):
tmp_B[:,y,z,:]+=mat_B[x,y,z,:]*mat_A[z,x,:,:]

这看起来像是应用torch.einsum的一个很好的设置。然而,我们首先需要通过定义每个单独的累积项来显式:占位符。

为了这样做,考虑中间张量结果的形状。第一个mat_B[x,y,z](H,)型,第二个mat_A[z,x,](L, H)型。

在伪代码中,初始操作如下:
for x, y, z, l, h in LxBxBxLxH:
tmp_B[:,y,z,:] += mat_B[x,y,z,:]*mat_A[z,x,:,:]
知道了这一点,我们可以在伪代码中重新表述你的初始循环:
for x, y, z, l, h in LxBxBxLxH:
tmp_B[l,y,z,h] += mat_B[x,y,z,h]*mat_A[z,x,l,h]
因此,我们可以通过使用与上面相同的符号来应用torch.einsum:
>>> torch.einsum('xyzh,zxlh->lyzh', mat_B, mat_A)