PyTorch中的批处理index_fill



我有一个大小为(2, 3)的索引张量:

>>> index = torch.empty(6).random_(0,8).view(2,3)
tensor([[6., 3., 2.],
[3., 4., 7.]])

和大小为(2, 8)的值张量:

>>> value = torch.zeros(2,8)
tensor([[0., 0., 0., 0., 0., 0., 0., 0.],
[0., 0., 0., 0., 0., 0., 0., 0.]])

我想通过dim=-1的索引将value中的元素设置为1.**输出应该是这样的:

>>> output
tensor([[0., 0., 1., 1., 0., 0., 1., 0.],
[0., 0., 0., 1., 1., 0., 0., 1.]])

我尝试了value[range(2), index] = 1,但它触发了一个错误。我也尝试了torch.index_fill,但它不接受批量索引。torch.scatter需要创建一个额外的张量,大小为2*8,占用1,这会消耗不必要的内存和时间。

实际上可以通过设置value(int)选项来代替src选项(Tensor)来使用torch.Tensor.scatter_

>>> value.scatter_(dim=-1, index=index.long(), value=1)
>>> value
tensor([[0., 0., 1., 1., 0., 0., 1., 0.],
[0., 0., 0., 1., 1., 0., 0., 1.]])

确保index的类型为int64

最新更新