在掩码为true的情况下,用另一个张量填充张量



我需要以一定的概率将张量new的元素插入到张量old中,为了简单起见,假设它是0.8。基本上这就是masked_fill的作用,但它只适用于一维张量。实际上我在做

prob = torch.rand(trgs.shape, dtype=torch.float32).to(trgs.device)
mask = prob < 0.8
dim1, dim2, dim3, dim4 = new.shape
for a in range(dim1):
for b in range(dim2):
for c in range(dim3):
for d in range(dim4):
old[a][b][c][d] = old[a][b][c][d] if mask[a][b][c][d] else new[a][b][c][d]

这太可怕了。我想要之类的

prob = torch.rand(trgs.shape, dtype=torch.float32).to(trgs.device)
mask = prob < 0.8
old = trgs.multidimensional_masked_fill(mask, new)

我不确定你的一些对象是什么,但这应该能让你在短时间内完成你需要的:

old是您现有的数据。

mask是您使用概率p 生成的掩码

new是具有要插入的元素的新张量。

# torch.where
result = old.where(mask, new)

最新更新