我想使用transpose转置我的数据,但我遇到了这样的错误。我的数据和相关流程上传到github。
https://github.com/nurkbts/error/blob/main/error.ipynb
使用torch.bmm
(批处理矩阵乘法(时,两个张量都必须具有三维(第一个是批处理(。有关详细信息,请阅读文档。
由于您尝试使用bmm
,因此应该只使用@
运算符(相当于应用torch.matmul
(。另外,不要忘记换位。这将给你一个形状(64, 64)
。
_scores = queries@keys.T / np.sqrt(64)