python中"torch.Tensor in torch.Tensor"的机制是什么,为什么会有如此混乱的现象?



环境:

谷歌colab

Python 3.7.11

火炬1.9.0+cu102

编码和输出

import torch
b = torch.tensor([[1,1,1],[4,5,6]])
print(b.T)
print(torch.tensor([1,4]) in b.T) # 
print(torch.tensor([2,1]) in b.T) #
print(torch.tensor([1,2]) in b.T) # Not as expected
print(torch.tensor([2,5]) in b.T) # Not as expected
----------------------------------------------------------
tensor([[1, 4],
[1, 5],
[1, 6]])
True
False
True
True

问题

我想判断一个张量是否在另一个张量中。但上述结果令人困惑。

in的作用机制是什么?我应该如何使用它来避免上述意外输出?

这是torch.tensor.T的问题吗?(当未使用.T和初始b = torch.tensor([[1,4],[1,5],[1,6]])时,也可能没有预期输出)

源代码如下(编辑:@Rune找到的源代码片段):

def __contains__(self, element):
r"""Check if `element` is present in tensor
Args:
element (Tensor or scalar): element to be checked
for presence in current tensor"
"""
if has_torch_function_unary(self):
return handle_torch_function(Tensor.__contains__, (self,), self, element)
if isinstance(element, (torch.Tensor, Number)):
# type hint doesn't understand the __contains__ result array
return (element == self).any().item()  # type: ignore[union-attr]
raise RuntimeError(
"Tensor.__contains__ only supports Tensor or scalar, but you passed in a %s." %
type(element)
)

__contains__(由"in"语法使用:x in b)运算符等效于在x == b布尔条件上应用torch.any

>>> b = tensor([[1, 1, 1],
[4, 5, 6]])
>>> check_in = lambda x: torch.any(x == b.T)

然后

>>> check_in(torch.tensor([1,4]))
tensor(True)
>>> check_in(torch.tensor([2,1]))
tensor(False)
>>> check_in(torch.tensor([1,2]))
tensor(True)
>>> check_in(torch.tensor([2,5]))
tensor(True)

重要的是元素在列中的位置,而不是整列的精确匹配。


.T颠倒维度的顺序:相当于b.permute(1, 0),对结果没有影响。使用in时唯一的约束是x的大小需要与b[1]的形状相匹配如果您使用b.T,那么它将是b[0]

>>> check_in = lambda x: torch.any(x == b)
>>> check_in(torch.tensor([1,1,1]))
tensor(True)
>>> check_in(torch.tensor([5,4,2]))
tensor(False)

最新更新