如何执行批处理masked_select
?
给定:
x = torch.tensor([[1., 2., 2., 2., 3.],
[1., 2., 4., 3., 2.]])
所需输出为:
tensor([[1., 3., 1., 1., 1.],
[1., 4., 3., 1., 1.]])
以下是一种可能的方法:
x = torch.tensor([[1., 2., 2., 2., 3.],
[1., 2., 4., 3., 2.]])
ones = torch.tensor([[1., 1., 1., 1., 1.],
[1., 1., 1., 1., 1.]])
masks = torch.tensor([[ True, False, False, False, True],
[ True, False, True, True, False]])
for i in range(x.size(0)):
mask = masks[i]
s = torch.masked_select(x[i], mask)
ones[i][:s.size(0)] = s
是否有其他解决方案?
这类问题的主要问题是中间结果是非同质的:在批处理中,元素将具有不同数量的掩码值。如果我们想应用PyTorch内置程序,这是一个问题。在这里,我提供了两种执行此操作的解决方案。
1-使用列表理解
按适当的量检查批次元素、掩模和衬垫:
>>> pad = lambda v: F.pad(v, [0, len(m)-len(v)], value=1)
>>> torch.stack([pad(r[m]) for r, m in zip(x, masks)])
tensor([[1., 3., 1., 1., 1.],
[1., 4., 3., 1., 1.]])
这相当简单,与您的方法类似。
2-使用torch.scatter
矢量化的替代方案是构造正确的值和索引张量,以便我们可以应用torch.scatter
并获得所需的结果。这里的诀窍是使用展平张量。从x
和masks
,我们首先想要访问定义为:的nz
和idx
nz
:来自x
的非屏蔽值(当然使用masks
),即我们需要找到:tensor([1., 3., 1., 4., 3.])
CCD_ 11:当平坦化时,它们在输出张量中的相应索引。
tensor([ 0, 1, 5, 6, 7])
然后我们可以用类似out = ones.scatter(dim=0, idx, nz)
的东西应用散射,它将有效地执行:out[idx[i]] = nz[i]
。
为了构造nz
,我们可以使用masks
:直接用非零值的masks
索引x
>>> nz = x[masks]
tensor([1., 3., 1., 4., 3.])
对于idx
来说,这将有点棘手。我们可以对掩码本身进行排序,将其展平,并使用torch.Tensor.nonzero
获得非零值。排序时,True
值最终位于每行的开头:
>>> idx = masks.sort(1, True).values.view(-1).nonzero()[:,0]
tensor([ 0, 1, 5, 6, 7])
最后,我们可以应用torch.scatter
并进行整形以获得期望的结果:
>>> torch.ones(x.numel()).scatter(0, idx, nz).view_as(x)
tensor([[1., 3., 1., 1., 1.],
[1., 4., 3., 1., 1.]])
这里torch.scatter
的使用是有限的,因为输入是一维的。等效的方法是简单地:
>>> o = torch.ones(x.numel())
>>> o[idx] = nz
>>> o.view_as(x)
完整方法:
>>> idx = masks.sort(1, True)[0].view(-1).nonzero()[:,0]
>>> torch.ones(x.numel()).scatter(0, idx, x[masks]).view_as(x)
tensor([[1., 3., 1., 1., 1.],
[1., 4., 3., 1., 1.]])