我的tgt
张量呈[12, 32, 1]
的形状,sequence_length, batch_size, token_idx
.
创建掩码的最佳方法是什么,该掩码按顺序包含<eos>
和之前的条目,之后为零?
目前我正在像这样计算我的面具,它只是将零放在<blank>
的位置,否则。
mask = torch.zeros_like(tgt).masked_scatter_((tgt != tgt_padding), torch.ones_like(tgt))
但问题是,我的tgt
也可以包含<blank>
(在<eos>
之前(,在这种情况下,我不想掩盖它。
我的临时解决方案:
mask = torch.ones_like(tgt)
for eos_token in (tgt == tgt_eos).nonzero():
mask[eos_token[0]+1:,eos_token[1]] = 0
我猜您正在尝试为 PAD 令牌创建掩码。有几种方法。其中之一如下。
# tensor is of shape [seq_len, batch_size, 1]
tensor = tensor.mul(tensor.ne(PAD).float())
在这里,PAD
代表PAD_TOKEN
的索引。tensor.ne(PAD)
将创建一个字节张量,其中在PAD_TOKEN
位置,将分配 0,在其他地方分配 1。
如果您有类似的例子,"<s> I think <pad> so </s> <pad> <pad>"
.然后,我建议使用不同的PAD令牌,用于</s>
之前和之后。
或者,如果您有每个句子的长度信息(在上面的示例中,句子长度为 6(,则可以使用以下函数创建掩码。
def sequence_mask(lengths, max_len=None):
"""
Creates a boolean mask from sequence lengths.
:param lengths: 1d tensor [batch_size]
:param max_len: int
"""
batch_size = lengths.numel()
max_len = max_len or lengths.max()
return (torch.arange(0, max_len, device=lengths.device) # (0 for pad positions)
.type_as(lengths)
.repeat(batch_size, 1)
.lt(lengths.unsqueeze(1)))