PyTorch批屏蔽选择实现



如何执行批处理masked_select

给定:

x = torch.tensor([[1., 2., 2., 2., 3.],
[1., 2., 4., 3., 2.]])

所需输出为:

tensor([[1., 3., 1., 1., 1.],
[1., 4., 3., 1., 1.]])

以下是一种可能的方法:

x = torch.tensor([[1., 2., 2., 2., 3.],
[1., 2., 4., 3., 2.]])
ones = torch.tensor([[1., 1., 1., 1., 1.],
[1., 1., 1., 1., 1.]])
masks = torch.tensor([[ True, False, False, False,  True],
[ True, False,  True,  True, False]])
for i in range(x.size(0)):
mask = masks[i]
s = torch.masked_select(x[i], mask)
ones[i][:s.size(0)] = s

是否有其他解决方案?

这类问题的主要问题是中间结果是非同质的:在批处理中,元素将具有不同数量的掩码值。如果我们想应用PyTorch内置程序,这是一个问题。在这里,我提供了两种执行此操作的解决方案。


1-使用列表理解

按适当的量检查批次元素、掩模和衬垫:

>>> pad = lambda v: F.pad(v, [0, len(m)-len(v)], value=1)
>>> torch.stack([pad(r[m]) for r, m in zip(x, masks)])
tensor([[1., 3., 1., 1., 1.],
[1., 4., 3., 1., 1.]])

这相当简单,与您的方法类似。


2-使用torch.scatter

矢量化的替代方案是构造正确的值和索引张量,以便我们可以应用torch.scatter并获得所需的结果。这里的诀窍是使用展平张量。从xmasks,我们首先想要访问定义为:的nzidx

  • nz:来自x的非屏蔽值(当然使用masks),我们需要找到:

    tensor([1., 3., 1., 4., 3.]) 
    
  • CCD_ 11:当平坦化时,它们在输出张量中的相应索引。

    tensor([ 0,  1,  5,  6,  7])
    

然后我们可以用类似out = ones.scatter(dim=0, idx, nz)的东西应用散射,它将有效地执行:out[idx[i]] = nz[i]

为了构造nz,我们可以使用masks:直接用非零值的masks索引x

>>> nz = x[masks]
tensor([1., 3., 1., 4., 3.])

对于idx来说,这将有点棘手。我们可以对掩码本身进行排序,将其展平,并使用torch.Tensor.nonzero获得非零值。排序时,True值最终位于每行的开头:

>>> idx = masks.sort(1, True).values.view(-1).nonzero()[:,0]
tensor([ 0,  1,  5,  6,  7])

最后,我们可以应用torch.scatter并进行整形以获得期望的结果:

>>> torch.ones(x.numel()).scatter(0, idx, nz).view_as(x)
tensor([[1., 3., 1., 1., 1.],
[1., 4., 3., 1., 1.]])

这里torch.scatter的使用是有限的,因为输入是一维的。等效的方法是简单地:

>>> o = torch.ones(x.numel())
>>> o[idx] = nz
>>> o.view_as(x)

完整方法:

>>> idx = masks.sort(1, True)[0].view(-1).nonzero()[:,0]
>>> torch.ones(x.numel()).scatter(0, idx, x[masks]).view_as(x)
tensor([[1., 3., 1., 1., 1.],
[1., 4., 3., 1., 1.]])

最新更新