我很难理解变压器。一切都在一点一点地明朗起来,但有一件事让我头疼src_mask和src_keypadding_mask之间的区别是什么?src_key_padding_mask在编码器层和解码器层都作为前向函数中的参数传递。
https://pytorch.org/docs/master/_modules/torch/nn/modules/transformer.html#Transformer
src_mask和src_key_padding_mask-6ypu之间的差异一般的事情是注意张量_mask
和_key_padding_mask
的使用之间的差异。在变换器内部,当注意力集中时,我们通常会得到一个带有所有比较的平方中间张量大小为[Tx, Tx]
(用于编码器的输入)、[Ty, Ty]
(用于移位输出-解码器的输入之一)和CCD_ 5(对于存储器掩码-编码器/存储器的输出和解码器的输入/移位输出之间的注意)。
所以我们知道这是变压器中每个口罩的用途(注意pytorch文档中的注释如下,其中Tx=S is the source sequence length
(例如输入批次的最大值),CCD_ 7(例如,目标长度的最大值),CCD_ 8,D=E is the feature number
):
src_mask
[Tx, Tx] = [S, S]
–src序列的附加掩码(可选)。这在执行atten_src + src_mask
时适用。我不确定示例输入-请参阅tgt_mask以获取示例但典型的用途是添加CCD_ 12,因此如果需要的话可以以这种方式屏蔽src_。如果提供ByteTensor,则不允许参与非零位置,而零位置将保持不变。如果提供布尔张量,则不允许出现True的位置,而False值将保持不变。如果提供FloatTensor,它将被添加到注意力权重中。tgt_mask
[Ty, Ty] = [T, T]
–tgt序列的附加掩码(可选)。这在执行atten_tgt + tgt_mask
时适用。一个例子是对角线,以避免解码器作弊。因此,tgt被右移,第一个令牌是嵌入SOS/BOS的序列令牌的开始,因此第一个令牌条目为零,其余为。具体示例见附录。如果提供ByteTensor,则不允许参与非零位置,而零位置将保持不变。如果提供布尔张量,则不允许出现True的位置,而False值将保持不变。如果提供FloatTensor,它将被添加到注意力权重中。memory_mask
[Ty, Tx] = [T, S]
–编码器输出的附加掩码(可选)。这在执行atten_memory + memory_mask
时适用。不确定示例用法,但如前所述,添加-inf
会将一些注意力权重设置为零。如果提供ByteTensor,则不允许参与非零位置,而零位置将保持不变。如果提供布尔张量,则不允许出现True的位置,而False值将保持不变。如果提供FloatTensor,它将被添加到注意力权重中。src_key_padding_mask
[B, Tx] = [N, S]
–每批src密钥的ByteTensor掩码(可选)。由于src通常具有不同的长度序列,因此通常会删除填充向量你在末尾附加了。为此,您可以指定批次中每个示例的每个序列的长度。具体示例见附录。如果提供ByteTensor,则不允许参与非零位置,而零位置将保持不变。如果提供布尔张量,则不允许出现True的位置,而False值将保持不变。如果提供FloatTensor,它将被添加到注意力权重中。tgt_key_padding_mask
[B, Ty] = [N, t]
–每批tgt密钥的ByteTensor掩码(可选)。与之前相同。具体示例见附录。如果提供ByteTensor,则不允许参与非零位置,而零位置将保持不变。如果提供布尔张量,则不允许出现True的位置,而False值将保持不变。如果提供FloatTensor,它将被添加到注意力权重中。memory_key_padding_mask
[B, Tx] = [N, S]
–每批内存密钥的ByteTensor掩码(可选)。与之前相同。具体示例见附录。如果提供ByteTensor,则不允许参与非零位置,而零位置将保持不变。如果提供布尔张量,则不允许出现True的位置,而False值将保持不变。如果提供FloatTensor,它将被添加到注意力权重中。
附录
pytorch教程中的示例(https://pytorch.org/tutorials/beginner/translation_transformer.html):
1个src_mask示例
src_mask = torch.zeros((src_seq_len, src_seq_len), device=DEVICE).type(torch.bool)
返回大小为[Tx, Tx]
:的布尔值的张量
tensor([[False, False, False, ..., False, False, False],
...,
[False, False, False, ..., False, False, False]])
2 tgt_mask示例
mask = (torch.triu(torch.ones((sz, sz), device=DEVICE)) == 1)
mask = mask.transpose(0, 1).float()
mask = mask.masked_fill(mask == 0, float('-inf'))
mask = mask.masked_fill(mask == 1, float(0.0))
生成作为解码器的输入的右移输出的对角线。
tensor([[0., -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf,
-inf, -inf, -inf],
[0., 0., -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf,
-inf, -inf, -inf],
[0., 0., 0., -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf,
-inf, -inf, -inf],
...,
[0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
0., 0., -inf],
[0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
0., 0., 0.]])
通常右移输出在开始时有BOS/SOS,教程只是简单地得到了右移通过在前面附加该BOS/SOS,然后用CCD_。
3_padding
填充只是为了掩盖最后的填充。src填充通常与内存填充相同。tgt有自己的序列,因此也有自己的填充。示例:
src_padding_mask = (src == PAD_IDX).transpose(0, 1)
tgt_padding_mask = (tgt == PAD_IDX).transpose(0, 1)
memory_padding_mask = src_padding_mask
输出:
tensor([[False, False, False, ..., True, True, True],
...,
[False, False, False, ..., True, True, True]])
注意,False
意味着那里没有填充令牌(所以是的,在transformer前向传递中使用该值),True
意味着有填充令牌(因此屏蔽了它,所以transformer后向传递不会受到影响)。
答案有点分散,但我发现只有这3个参考文献有用(单独的层文档/东西不是很有用的诚实):
- 长教程:https://pytorch.org/tutorials/beginner/translation_transformer.html
- MHA文件:https://pytorch.org/docs/master/generated/torch.nn.MultiheadAttention.html#torch.nn.MultiheadAttention
- 变压器文档:https://pytorch.org/docs/stable/generated/torch.nn.Transformer.html
我必须说PyTorch实现有点令人困惑,因为它包含了太多的掩码参数。但我可以解释一下你所指的两个掩码参数。src_mask
和src_key_padding_mask
都用于MultiheadAttention
机制。根据MultiheadAttention的文档:
key_padding_mask–如果提供,则关注将忽略键中指定的填充元素。
attn_mask–2D或3D遮罩,可防止注意力集中在特定位置。
正如您从论文中所知,注意力就是您所需要的,MultiheadAttention同时用于编码器和解码器。然而,在解码器中,有两种类型的多头注意。一个被称为CCD_ 28,另一个是规则CCD_。为了适应这两种技术,PyTorch在其MultiheadAttention实现中使用了上述两个参数。
所以,长话短说-
attn_mask
和key_padding_mask
用于编码器的MultiheadAttention
和解码器的Masked MultiheadAttention
- CCD_ 34用于解码器的CCD_ 35机制
研究MultiheadAttention的实现可能会对您有所帮助。
从这里和这里可以看到,首先src_mask
用于阻止特定位置参与,然后key_padding_mask
用于阻止参与填充令牌。
注意根据@michael jungo的评论更新了答案。
举一个小例子,考虑我想建立一个顺序推荐器,即,给定用户在"t+1"预测下一个项目之前已经购买的项目
u1 - [i1, i2, i7]
u2 - [i2, i5]
u3 - [i6, i7, i1, i2]
对于这个任务,我可以使用一个转换器,通过在左边填充0来使序列的长度相等。
u1 - [0, i1, i2, i7]
u2 - [0, 0, i2, i5]
u3 - [i6, i7, i1, i2]
我将使用key_padding_mask告诉PyTorch 0的shd将被忽略。现在,考虑用户u3
,其中给定[i6]
,我想预测[i7]
,稍后给定[i6, i7]
,我想要预测[i1]
,即,我想要因果注意力,这样注意力就不会窥探到未来的元素。为此,我将使用attn_mask。因此,对于用户u3
attn_mask将类似
[[True, False, False, False],
[True, True , False, False],
[True, True , True , False]
[True, True , True , True ]]