我想了解下面的代码中的"沿 dim 指定的轴收集值"是什么意思。如何在我脑海中构建对数据的功能操作。此函数对数据做什么以及如何执行?
请参考此链接 https://pytorch.org/docs/stable/torch.html#torch.gather
torch.gather(input, dim, index, out=None, sparse_grad=False)
Gathers values along an axis specified by dim.
For a 3-D tensor the output is specified by:
out[i][j][k] = input[index[i][j][k]][j][k] # if dim == 0
out[i][j][k] = input[i][index[i][j][k]][k] # if dim == 1
out[i][j][k] = input[i][j][index[i][j][k]] # if dim == 2
是的,它通过张量的给定暗度(维度(,并将提供的索引指定的值收集到一个新的张量中。因此,如果我有一个 1D 张量(允许吗?
MyValues = torch.tensor([0,2,4,6,8])
并做到了
torch.gather(MyValues, 0, torch.tensor([0,1,3]))
我希望返回一个包含[0,2,6]
的一维张量,即位于位置 0
、1
和 3
的值。
因此,它只是使用 index
张量作为指向要从input
张量中提取的内容位置的指针来挑选内容。
dim
是要沿其编制索引的维度。因此,对于2D,您可以选择按行或列进行索引,并且您可以将其推断为任意多个维度。