如何改变拥抱面部预训练的较长时间模型的参数



我使用的是hug -face预训练的LongformerModel模型。我用它来提取句子的嵌入。我想改变token length,max sentence length参数,但我不能这样做。代码如下:

model = LongformerModel.from_pretrained('allenai/longformer-base-4096',output_hidden_states = True)
tokenizer = LongformerTokenizer.from_pretrained('allenai/longformer-base-4096')
model.eval()
text=[" I like to play cricket"]
input_ids = torch.tensor(tokenizer.encode(text,max_length=20,padding=True,add_special_tokens=True)).unsqueeze(0)
print(tokenizer.encode(text,max_length=20,padding=True,add_special_tokens=True))
# [0, 38, 101, 7, 310, 5630, 2]

我期望编码器给我一个大小为20的填充列表,因为我已经传递了一个参数max_length=20.,但它只返回大小为7的列表?

attention_mask = torch.ones(input_ids.shape, dtype=torch.long, device=input_ids.device)
attention_mask[:, [0,-1]] = 2
outputs = model(input_ids, attention_mask=attention_mask, return_dict=True)
hidden_states = outputs[2]
print ("Number of layers:", len(hidden_states), "  (initial embeddings + 12 BERT layers)")
layer_i = 0
print ("Number of batches:", len(hidden_states[layer_i]))
batch_i = 0
print ("Number of tokens:", len(hidden_states[layer_i][batch_i]))
token_i = 0
print ("Number of hidden units:", len(hidden_states[layer_i][batch_i][token_i]))

输出:

Number of layers: 13   (initial embeddings + 12 BERT layers)
Number of batches: 1
Number of tokens: 512 # How can I change this parameter to pick up my sentence length during run-time
Number of hidden units: 768

如何将token的数量减少到句子长度而不是512 ?每次我输入一个新句子,它都应该取这个长度。

关于填充的问题

padding=True将您的输入填充到最长序列。padding=max_length将您的输入填充到指定的max_length (documentation):

from transformers import LongformerTokenizer
tokenizer = LongformerTokenizer.from_pretrained('allenai/longformer-base-4096')
text=[" I like to play cricket"]
print(tokenizer.encode(text[0],max_length=20,padding='max_length',add_special_tokens=True))

输出:

[0, 38, 101, 7, 310, 5630, 2, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]

关于隐藏状态令牌数量的问题

Longformer实现对序列应用填充以匹配注意窗口大小。您可以在模型配置中看到注意窗口的大小:

model.config.attention_window

输出:

[512, 512, 512, 512, 512, 512, 512, 512, 512, 512, 512, 512]

对应的代码行:link.

最新更新