如何将 PyTorch 张量的每一行中的重复值归零?



我想编写一个函数来实现这个问题中描述的行为。

也就是说,我想将 PyTorch 中矩阵的每一行中的重复值归零。例如,给定一个矩阵

torch.Tensor(([1, 2, 3, 4, 3, 3, 4],
[1, 6, 3, 5, 3, 5, 4]])

我想得到

torch.Tensor(([1, 2, 3, 4, 0, 0, 0],
[1, 6, 3, 5, 0, 0, 4]])

torch.Tensor(([1, 2, 3, 4, 0, 0, 0],
[1, 6, 3, 5, 4, 0, 0]])

根据相关的问题,仅靠torch.unique()是不够的。我想知道如何在没有循环的情况下实现此功能。

x = torch.tensor([
[1, 2, 3, 4, 3, 3, 4],
[1, 6, 3, 5, 3, 5, 4]
], dtype=torch.long)
# sorting the rows so that duplicate values appear together
# e.g., first row: [1, 2, 3, 3, 3, 4, 4]
y, indices = x.sort(dim=-1)
# subtracting, so duplicate values will become 0
# e.g., first row: [1, 2, 3, 0, 0, 4, 0]
y[:, 1:] *= ((y[:, 1:] - y[:, :-1]) !=0).long()
# retrieving the original indices of elements
indices = indices.sort(dim=-1)[1]
# re-organizing the rows following original order
# e.g., first row: [1, 2, 3, 4, 0, 0, 0]
result = torch.gather(y, 1, indices)
print(result) # => output

输出

tensor([[1, 2, 3, 4, 0, 0, 0],
[1, 6, 3, 5, 0, 0, 4]])

最新更新