用另一个多维张量索引多维火炬张量



我在pytorch中有一个张量x,假设形状(5,3,2,6(和另一个张量idx的形状(5,3,2,1(,其中包含第一个张量中每个元素的索引。我想用第二个张量的索引对第一个张量进行切片。我尝试了 x= x[idx],但当我真的希望它的形状为 (5,3,2( 或 (5,3,2,1( 时,我得到了一个奇怪的维度。

我将尝试举一个更简单的例子: 比方说

x=torch.Tensor([[10,20,30],
[8,4,43]])
idx = torch.Tensor([[0],
[2]])

我想要类似的东西

y = x[idx]

这样"y"输出[[10],[43]]或类似的东西。

索引表示所需元素在最后一个维度上的位置。 对于上面的示例,其中 x.shape = (2,3( 最后一个维度是列,然后 'idx' 中的索引是列。我想要这个,但超过 2 个维度

根据我从评论中的理解,您需要idx最后一个维度中的索引,并且idx中的每个索引对应于x中的类似索引(最后一个维度除外(。在这种情况下(这是numpy版本,您可以将其转换为火炬(:

ind = np.indices(idx.shape)
ind[-1] = idx
x[tuple(ind)]

输出:

[[10]
[43]]

您可以使用range; 和squeeze来获得正确的idx维度,例如

x[range(x.size(0)), idx.squeeze()]
tensor([10., 43.])
# or
x[range(x.size(0)), idx.squeeze()].unsqueeze(1)
tensor([[10.],
[43.]])

这是使用gather在PyTorch中工作的一个。idx需要采用以下行将确保的torch.int64格式(请注意tensor中"t"的小写(。

idx = torch.tensor([[0],
[2]])
torch.gather(x, 1, idx) # 1 is the axis to index here
tensor([[10.],
[43.]])

最新更新