假设mat = torch.rand((5,7))
和我想通过传递索引(比如idxs=[0,4,2,3,6]
)从第1维(这里是7)获得值。我现在能做的就是做mat[[0,1,2,3,4],idxs]
。我本以为mat[:,idxs]
会起作用,但没有。第一种选择是唯一的方法还是有更好的方法?
torch.gather
就是您想要的:
torch.gather(mat, 1, torch.tensor(idxs)[:, None])