在PyTorch张量中查找等于零的元素的索引



我的问题与这个问题几乎相同,只是在PyTorch中的显著差异。我宁愿不使用Numpy解决方案,因为这将涉及到将数据移回CPU。我看到,和Numpy一样,PyTorch有一个非零函数,但它的where函数(我链接的Numpy线程中的解决方案(的行为与Numpy的不同。我想要的行为是is_zero()函数,如下所示:

>>> arr.nonzero()
tensor([[0, 1],
[1, 0]])  
>>> arr.is_zero()
tensor([[0, 0],
[1, 1]])

您可以制作一个布尔掩码,然后调用nonzero():

(arr == 0).nonzero()

例如:

arr = torch.randint(high=2, size=(3, 3))
tensor([[1, 1, 0],  # (0, 2)
[1, 1, 0],  # (1, 2)
[1, 0, 0]]) # (2, 1) and (2, 2)
(arr == 0).nonzero()
tensor([[0, 2],
[1, 2],
[2, 1],
[2, 2]])

最新更新