Pytorch MultiHeadAttention错误,查询序列维度与键/值维度不同



我在玩MultiHeadAttention的pytorch实现。

在文档中,它指出查询维度是[N,L,E](假设batch_first=True(,其中N是批维度,L是目标序列长度,E是嵌入维度。

然后声明键和值维度是[N,S,E],其中S是源序列长度。假设这意味着SL不需要相等,这是有道理的。

但是,如果运行以下程序:

import torch
import torch.nn as nn
input_size = 10
batch_size = 3
window_size = 2
attention = nn.MultiheadAttention(input_size, num_heads=1)
q = torch.empty(batch_size, 1, input_size)
k = v = torch.empty(batch_size, window_size, input_size)
y = attention(q, k, v, need_weights=False)

产生以下错误:

.../lib/python3.8/site-packages/torch/nn/functional.py:5044, in multi_head_attention_forward(query, key, value, embed_dim_to_check, num_heads, in_proj_weight, in_proj_bias, bias_k, bias_v, add_zero_attn, dropout_p, out_proj_weight, out_proj_bias, training, key_padding_mask, need_weights, attn_mask, use_separate_proj_weight, q_proj_weight, k_proj_weight, v_proj_weight, static_k, static_v)
5042 q = q.contiguous().view(tgt_len, bsz * num_heads, head_dim).transpose(0, 1)
5043 if static_k is None:
-> 5044     k = k.contiguous().view(k.shape[0], bsz * num_heads, head_dim).transpose(0, 1)
5045 else:
5046     # TODO finish disentangling control flow so we don't do in-projections when statics are passed
5047     assert static_k.size(0) == bsz * num_heads, 
5048         f"expecting static_k.size(0) of {bsz * num_heads}, but got {static_k.size(0)}"
RuntimeError: shape '[3, 1, 10]' is invalid for input of size 60

我错过什么了吗?

我使用的是torch v1.10.2。

我的错,我会发布这个,以防其他人遇到错误。

batch_first的默认值为False,将其设置为True可修复此问题。

attention = nn.MultiheadAttention(input_size, num_heads=1, batch_first=True)

相关内容

  • 没有找到相关文章

最新更新