pytorch等价于keras Dot,带有axes参数



我需要将这个keras操作转换为pytorch:

user_vec = keras.layers.Dot((1,1))([user_vecs,user_att])

假设user_vecs具有形状(2,5,10),user_att具有形状(2,5)。输出形状为(2,10)。

pytorchinner仅适用于两个输入的最后一个维度-我想知道我是否应该排列我的轴并调用inner然后排列回来,或者是否有更好的方法。

user_vecs = user_vecs.permute(0,2,1)
torch.inner(user_vecs, user_att)

然而,这返回的是一个形状为(2,10,2)的张量。

看起来最简单的方法是使用einsum:

user_att = user_att.unsqueeze(2)
torch.einsum('ijk,ijk->ik', user_vecs, user_att)

相关内容

  • 没有找到相关文章

最新更新