我的问题与这个问题几乎相同,只是在PyTorch中的显著差异。我宁愿不使用Numpy解决方案,因为这将涉及到将数据移回CPU。我看到,和Numpy一样,PyTorch有一个非零函数,但它的where函数(我链接的Numpy线程中的解决方案(的行为与Numpy的不同。我想要的行为是is_zero()
函数,如下所示:
>>> arr.nonzero()
tensor([[0, 1],
[1, 0]])
>>> arr.is_zero()
tensor([[0, 0],
[1, 1]])
您可以制作一个布尔掩码,然后调用nonzero()
:
(arr == 0).nonzero()
例如:
arr = torch.randint(high=2, size=(3, 3))
tensor([[1, 1, 0], # (0, 2)
[1, 1, 0], # (1, 2)
[1, 0, 0]]) # (2, 1) and (2, 2)
(arr == 0).nonzero()
tensor([[0, 2],
[1, 2],
[2, 1],
[2, 2]])