PyTorch上的转置:IndexError:维度超出范围(应在[-2,1]的范围内,但得到2)



我想使用transpose转置我的数据,但我遇到了这样的错误。我的数据和相关流程上传到github。

https://github.com/nurkbts/error/blob/main/error.ipynb

使用torch.bmm(批处理矩阵乘法(时,两个张量都必须具有三维(第一个是批处理(。有关详细信息,请阅读文档。

由于您尝试使用bmm,因此应该只使用@运算符(相当于应用torch.matmul(。另外,不要忘记换位。这将给你一个形状(64, 64)

_scores = queries@keys.T / np.sqrt(64)

最新更新