在测试bert模型时分配权重



我有一个基本的概念上的疑问。当我在句子上训练bert模型时,说:

Train: "went to get loan from bank" 
Test :"received education loan from bank"

测试句子如何分配每个标记的权重,因为它没有通过测试的确切句子,并且有轻微添加的单词,如&;education&;这会稍微改变上下文

假设在我的模型中没有训练这样的上下文,在我进一步微调之前,如何为我的bert中的每个令牌分配权重

如果我混淆了我的问题,简单地说,我试图理解在测试过程中如何分配权重,如果上下文发生了轻微的变化,而不是训练。

标记的向量表示(请记住,标记!= word)存储在嵌入层中。当我们加载'bert-base-uncase '模型时,我们可以看到它'知道';30522个标记,每个标记的向量表示由768个元素组成:

from transformers import BertModel
bert= BertModel.from_pretrained('bert-base-uncased')
print(bert.embeddings.word_embeddings)

输出:

Embedding(30522, 768, padding_idx=0)

这个嵌入层不知道任何字符串,但知道id。例如,id101的向量表示为:

print(bert.embeddings.word_embeddings.weight[101])

输出:

tensor([ 1.3630e-02, -2.6490e-02, -2.3503e-02, -7.7876e-03,  8.5892e-03,
-7.6645e-03, -9.8808e-03,  6.0184e-03,  4.6921e-03, -3.0984e-02,
1.8883e-02, -6.0093e-03, -1.6652e-02,  1.1684e-02, -3.6245e-02,
8.3482e-03, -1.2112e-03,  1.0322e-02,  1.6692e-02, -3.0354e-02,
...
5.4162e-03, -3.0037e-02,  8.6773e-03, -1.7942e-03,  6.6826e-03,
-1.1929e-02, -1.4076e-02,  1.6709e-02,  1.6860e-03, -3.3842e-03,
8.6805e-03,  7.1340e-03,  1.5147e-02], grad_fn=<SelectBackward>)

所有"已知"之外的内容id不能被BERT处理。要回答您的问题,我们需要查看将字符串映射到id的组件。这个组件称为标记器。有不同的标记化方法。BERT使用WordPiece标记器,这是一种子词算法。这个算法替换所有不能创建的从它的词汇表中取出一个未知的标记,该标记是的一部分(原实现中的[UNK], id: 100)。

请看下面的小示例,其中WordPiece标记器是从头开始训练以确认行为的:

from tokenizers import BertWordPieceTokenizer
path ='file_with_your_trainings_sentence.txt'
tokenizer = BertWordPieceTokenizer()
tokenizer.train(files=path, vocab_size=30000, special_tokens=['[UNK]', '[SEP]', '[PAD]', '[CLS]', '[MASK]'])
otrain = tokenizer.encode("went to get loan from bank")
otest =  tokenizer.encode("received education loan from bank")
print('Vocabulary size: {}'.format(tokenizer.get_vocab_size()))
print('Train tokens: {}'.format(otrain.tokens))
print('Test tokens: {}'.format(otest.tokens))

输出:

Vocabulary size: 27
Train tokens: ['w', '##e', '##n', '##t', 't', '##o', 'g', '##e', '##t', 'l', '##o', '##an', 'f', '##r', '##o', '##m', 'b', '##an', '##k']
Test tokens: ['[UNK]', '[UNK]', 'l', '##o', '##an', 'f', '##r', '##o', '##m', 'b', '##an', '##k']

最新更新