这个问题很简单,但我一直在努力解决这个问题。
import torch
t = torch.tensor([[2,3],[4,6]])
overlap = [2, 6]
f = lambda x: x in overlap
我想要:
torch.tensor([[True,False],[False,True]])
张量和重叠都很大,所以这里希望有效率。
本机方法是使用torch.Tensor.apply_
方法:
t.apply_(f)
然而,根据官方文件,它只适用于CPU上的张量,不鼓励达到高性能。
此外,似乎没有本地的torch函数指示张量的值是否在列表中,唯一的选择应该是在列表overlap
上迭代。看看这里和这里。因此,您可以尝试:
sum(t==i for i in overlap).bool()
我发现第二个函数对大的t
和overlap
更具性能,而第一个函数对小的t
和overlap
更具性能。
我找到了一个简单的方法。由于火炬是通过numpy阵列实现的,因此以下工作是有效的:
import torch
import numpy as np
t = torch.tensor([[2,3],[4,6]])
overlap = [2, 6]
f = lambda x: x in overlap
mask = np.vectorize(f)(t)
在这里找到。