屏蔽:屏蔽指定令牌 (eos) 后的所有内容



我的tgt张量呈[12, 32, 1]的形状,sequence_length, batch_size, token_idx.

创建掩码的最佳方法是什么,该掩码按顺序包含<eos>和之前的条目,之后为零?

目前我正在像这样计算我的面具,它只是将零放在<blank>的位置,否则。

mask = torch.zeros_like(tgt).masked_scatter_((tgt != tgt_padding), torch.ones_like(tgt))

但问题是,我的tgt也可以包含<blank>(在<eos>之前(,在这种情况下,我不想掩盖它。

我的临时解决方案:

mask = torch.ones_like(tgt)
for eos_token in (tgt == tgt_eos).nonzero():
mask[eos_token[0]+1:,eos_token[1]] = 0

我猜您正在尝试为 PAD 令牌创建掩码。有几种方法。其中之一如下。

# tensor is of shape [seq_len, batch_size, 1]
tensor = tensor.mul(tensor.ne(PAD).float())

在这里,PAD代表PAD_TOKEN的索引。tensor.ne(PAD)将创建一个字节张量,其中在PAD_TOKEN位置,将分配 0,在其他地方分配 1。


如果您有类似的例子,"<s> I think <pad> so </s> <pad> <pad>".然后,我建议使用不同的PAD令牌,用于</s>之前和之后。

或者,如果您有每个句子的长度信息(在上面的示例中,句子长度为 6(,则可以使用以下函数创建掩码。

def sequence_mask(lengths, max_len=None):
"""
Creates a boolean mask from sequence lengths.
:param lengths: 1d tensor [batch_size]
:param max_len: int
"""
batch_size = lengths.numel()
max_len = max_len or lengths.max()
return (torch.arange(0, max_len, device=lengths.device)  # (0 for pad positions)
.type_as(lengths)
.repeat(batch_size, 1)
.lt(lengths.unsqueeze(1)))

最新更新