假设我有以下两个火炬。张量:
x = torch.tensor([0,0,0,1,1,2,2,2,2], dtype=torch.int64)
y = torch.tensor([0,2], dtype=torch.int64)
我想以某种方式过滤x
,以便只保留y
中的值:
x_filtered = torch.tensor([0,0,0,2,2,2,2])
例如,如果为y = torch.tensor([0,1])
,则为x_filtered = torch.tensor([0,0,0,1,1])
。两个x,y
总是1D和int64。y
总是排序的,如果它更简单,我们可以假设x
也总是排序的。
我试着想出各种不使用循环的方法,但都失败了。我不能真正使用循环,因为我的用例涉及数百万的x
和数万的y
。感谢您的帮助。
刚刚意识到我需要的是相当于numpy.in1d
的手电筒
要想在任务中过滤张量,需要使用torch中提供的isin函数。它的使用方式如下:-
import torch
x = torch.tensor([0,0,0,1,1,2,2,2,2,3], dtype=torch.int64)
y = torch.tensor([0,2], dtype=torch.int64)
# torch.isin(x, y)
c=x[torch.isin(x,y)]
print(c)
运行此代码后,您将得到您喜欢的答案。
答案是https://pytorch.org/docs/master/generated/torch.isin.html?highlight=isin#torch.isin:
>>> torch.isin(x,y)
tensor([ True, True, True, False, False, True, True, True, True])