我在玩MultiHeadAttention的pytorch实现。
在文档中,它指出查询维度是[N,L,E]
(假设batch_first=True
(,其中N
是批维度,L
是目标序列长度,E
是嵌入维度。
然后声明键和值维度是[N,S,E]
,其中S
是源序列长度。假设这意味着S
和L
不需要相等,这是有道理的。
但是,如果运行以下程序:
import torch
import torch.nn as nn
input_size = 10
batch_size = 3
window_size = 2
attention = nn.MultiheadAttention(input_size, num_heads=1)
q = torch.empty(batch_size, 1, input_size)
k = v = torch.empty(batch_size, window_size, input_size)
y = attention(q, k, v, need_weights=False)
产生以下错误:
.../lib/python3.8/site-packages/torch/nn/functional.py:5044, in multi_head_attention_forward(query, key, value, embed_dim_to_check, num_heads, in_proj_weight, in_proj_bias, bias_k, bias_v, add_zero_attn, dropout_p, out_proj_weight, out_proj_bias, training, key_padding_mask, need_weights, attn_mask, use_separate_proj_weight, q_proj_weight, k_proj_weight, v_proj_weight, static_k, static_v)
5042 q = q.contiguous().view(tgt_len, bsz * num_heads, head_dim).transpose(0, 1)
5043 if static_k is None:
-> 5044 k = k.contiguous().view(k.shape[0], bsz * num_heads, head_dim).transpose(0, 1)
5045 else:
5046 # TODO finish disentangling control flow so we don't do in-projections when statics are passed
5047 assert static_k.size(0) == bsz * num_heads,
5048 f"expecting static_k.size(0) of {bsz * num_heads}, but got {static_k.size(0)}"
RuntimeError: shape '[3, 1, 10]' is invalid for input of size 60
我错过什么了吗?
我使用的是torch v1.10.2。
我的错,我会发布这个,以防其他人遇到错误。
batch_first
的默认值为False
,将其设置为True
可修复此问题。
attention = nn.MultiheadAttention(input_size, num_heads=1, batch_first=True)