假设张量A定义为:
1 2 3 4
5 6 7 8
9 10 11 12
13 14 15 16
我试图通过使用另一个张量作为索引,从这个矩阵中提取一个平面阵列。例如,如果第二个张量定义为:
0
1
2
3
我希望索引的结果是一维张量的内容:
1
6
11
16
它的行为似乎不像NumPy;我试过A[:, B]
,但它只是因为无法分配大量内存而抛出一个错误,我不知道为什么!
第一种方法:使用torch.gather
torch.gather(A, 1, B.unsqueeze_(dim=1))
如果你想要一维矢量,你可以在末尾添加挤压:
torch.gather(A, 1, B.unsqueeze_(dim=1)).squeeze_()
第二种方法:使用列表综合
您可以使用列表综合来选择特定索引处的项,然后可以使用torch.stack
将它们连接起来。这里一个重要的点是,你不应该使用torch.tensor
从列表中创建新的张量,如果你这样做,你会破坏链(你无法计算通过该节点的梯度(:
torch.stack([A[i, B[i]] for i in range(A.size()[0])])
您可以将张量转换为NumPy数组。如果您正在使用Cuda,请不要忘记将其传递给cpu。如果没有,就没有必要将其传递给cpu。示例代码如下:
val.data.cpu().numpy()[:,B]
如果它解决了你的问题,请告诉我
PyTorch实现了torch.take,相当于numpy.take