如何通过访问pytorch中给定索引的2X2张量的特定值来创建张量



假设mat = torch.rand((5,7))和我想通过传递索引(比如idxs=[0,4,2,3,6])从第1维(这里是7)获得值。我现在能做的就是做mat[[0,1,2,3,4],idxs]。我本以为mat[:,idxs]会起作用,但没有。第一种选择是唯一的方法还是有更好的方法?

torch.gather就是您想要的:

torch.gather(mat, 1, torch.tensor(idxs)[:, None])

最新更新