基于另一个没有循环的张量的滤波器火炬张量



假设我有以下两个火炬。张量:

x = torch.tensor([0,0,0,1,1,2,2,2,2], dtype=torch.int64)
y = torch.tensor([0,2], dtype=torch.int64)

我想以某种方式过滤x,以便只保留y中的值:

x_filtered = torch.tensor([0,0,0,2,2,2,2])

例如,如果为y = torch.tensor([0,1]),则为x_filtered = torch.tensor([0,0,0,1,1])。两个x,y总是1D和int64。y总是排序的,如果它更简单,我们可以假设x也总是排序的。

我试着想出各种不使用循环的方法,但都失败了。我不能真正使用循环,因为我的用例涉及数百万的x和数万的y。感谢您的帮助。


刚刚意识到我需要的是相当于numpy.in1d的手电筒

要想在任务中过滤张量,需要使用torch中提供的isin函数。它的使用方式如下:-

import torch
x = torch.tensor([0,0,0,1,1,2,2,2,2,3], dtype=torch.int64)
y = torch.tensor([0,2], dtype=torch.int64)
# torch.isin(x, y)
c=x[torch.isin(x,y)]
print(c)

运行此代码后,您将得到您喜欢的答案。

答案是https://pytorch.org/docs/master/generated/torch.isin.html?highlight=isin#torch.isin:

>>> torch.isin(x,y)
tensor([ True,  True,  True, False, False,  True,  True,  True,  True])

最新更新