如何将函数元素应用于二维张量



这个问题很简单,但我一直在努力解决这个问题。

import torch
t = torch.tensor([[2,3],[4,6]])
overlap = [2, 6]
f = lambda x: x in overlap

我想要:

torch.tensor([[True,False],[False,True]])

张量和重叠都很大,所以这里希望有效率。

本机方法是使用torch.Tensor.apply_方法:

t.apply_(f)

然而,根据官方文件,它只适用于CPU上的张量,不鼓励达到高性能。

此外,似乎没有本地的torch函数指示张量的值是否在列表中,唯一的选择应该是在列表overlap上迭代。看看这里和这里。因此,您可以尝试:

sum(t==i for i in overlap).bool()

我发现第二个函数对大的toverlap更具性能,而第一个函数对小的toverlap更具性能。

我找到了一个简单的方法。由于火炬是通过numpy阵列实现的,因此以下工作是有效的:

import torch
import numpy as np
t = torch.tensor([[2,3],[4,6]])
overlap = [2, 6]
f = lambda x: x in overlap
mask = np.vectorize(f)(t)

在这里找到。

最新更新