"Gathers values along an axis specified by dim."是什么意思?



我想了解下面的代码中的"沿 dim 指定的轴收集值"是什么意思。如何在我脑海中构建对数据的功能操作。此函数对数据做什么以及如何执行?

请参考此链接 https://pytorch.org/docs/stable/torch.html#torch.gather


torch.gather(input, dim, index, out=None, sparse_grad=False)

Gathers values along an axis specified by dim.
For a 3-D tensor the output is specified by:
out[i][j][k] = input[index[i][j][k]][j][k]  # if dim == 0
out[i][j][k] = input[i][index[i][j][k]][k]  # if dim == 1
out[i][j][k] = input[i][j][index[i][j][k]]  # if dim == 2

是的,它通过张量的给定暗度(维度(,并将提供的索引指定的值收集到一个新的张量中。因此,如果我有一个 1D 张量(允许吗?

MyValues = torch.tensor([0,2,4,6,8])

并做到了

torch.gather(MyValues, 0, torch.tensor([0,1,3])) 

我希望返回一个包含[0,2,6]的一维张量,即位于位置 013 的值。

因此,它只是使用 index 张量作为指向要从input张量中提取的内容位置的指针来挑选内容。

dim是要沿其编制索引的维度。因此,对于2D,您可以选择按行或列进行索引,并且您可以将其推断为任意多个维度。

最新更新