使用另一个张量在二维PyTorch张量中进行索引



假设张量A定义为:

1  2  3  4
5  6  7  8
9 10 11 12
13 14 15 16

我试图通过使用另一个张量作为索引,从这个矩阵中提取一个平面阵列。例如,如果第二个张量定义为:

0
1
2
3

我希望索引的结果是一维张量的内容:

1
6
11
16

它的行为似乎不像NumPy;我试过A[:, B],但它只是因为无法分配大量内存而抛出一个错误,我不知道为什么!

第一种方法:使用torch.gather

torch.gather(A, 1, B.unsqueeze_(dim=1))

如果你想要一维矢量,你可以在末尾添加挤压:

torch.gather(A, 1, B.unsqueeze_(dim=1)).squeeze_()

第二种方法:使用列表综合

您可以使用列表综合来选择特定索引处的项,然后可以使用torch.stack将它们连接起来。这里一个重要的点是,你不应该使用torch.tensor从列表中创建新的张量,如果你这样做,你会破坏链(你无法计算通过该节点的梯度(:

torch.stack([A[i, B[i]] for i in range(A.size()[0])])

您可以将张量转换为NumPy数组。如果您正在使用Cuda,请不要忘记将其传递给cpu。如果没有,就没有必要将其传递给cpu。示例代码如下:

val.data.cpu().numpy()[:,B]

如果它解决了你的问题,请告诉我

PyTorch实现了torch.take,相当于numpy.take

最新更新