BERT在TF Keras中的附加预训练



我目前正在开发一个涉及序列多标签分类的项目。由于我使用的是一个技术性很强的数据集,我认为在为分类部分微调BERT之前对其进行额外的预训练是有益的。但我找不到任何指南来使用Huiggingface transformers和Keras一起对模型进行预训练。我的想法是在我的数据集上预训练模型,然后保存并再次加载,以微调分类器。我发现的每一个come都是为PyTorch准备的,但我必须使用TensorFlow。到目前为止,我已经编写了以下代码:

from transformers import TFDistilBertForMaskedLM, AutoTokenizer, AutoConfig
from sklearn.datasets import fetch_20newsgroups
categories = ['alt.atheism', 'soc.religion.christian','comp.graphics', 'sci.med']
twenty_train = fetch_20newsgroups(subset='train',categories=categories, shuffle=True, random_state=42)
tokenizer = AutoTokenizer.from_pretrained("distilbert-base-cased")
model = TFDistilBertForMaskedLM.from_pretrained("distilbert-base-cased")
model.compile(optimizer="adam")
data = tokenizer(
twenty_train.data[:10], 
return_tensors="tf", 
padding=True, 
truncation=True, 
max_length=tokenizer.model_max_length
)

我该从哪里将我的数据放入BERT?我知道我也应该为模型提供屏蔽输入,但我不明白在哪里/如何

您可以使用BERT模型在自定义数据集上进行预训练。

示例工作代码

import os
import tensorflow as tf
import tensorflow_hub as hub
bert_preprocess = hub.KerasLayer("https://tfhub.dev/tensorflow/bert_en_uncased_preprocess/3")
bert_encoder = hub.KerasLayer("https://tfhub.dev/tensorflow/bert_en_uncased_L-12_H-768_A-12/4",trainable=True)
#get sentence embeddings
def get_sentence_embeding(sentences):
preprocessed_text = bert_preprocess(sentences)
return bert_encoder(preprocessed_text)['pooled_output']
get_sentence_embeding([
"How to find which version of TensorFlow is", 
"TensorFlow not found using pip"]
)
def build_classifier_model(num_classes):
class Classifier(tf.keras.Model):
def __init__(self, num_classes):
super(Classifier, self).__init__(name="prediction")
self.encoder = hub.KerasLayer(bert_encoder, trainable=True)
self.dropout = tf.keras.layers.Dropout(0.1)
self.dense = tf.keras.layers.Dense(num_classes)
def call(self, preprocessed_text):
encoder_outputs = self.encoder(preprocessed_text)
pooled_output = encoder_outputs["pooled_output"]
x = self.dropout(pooled_output)
x = self.dense(x)
return x
model = Classifier(num_classes)
return model
test_classifier_model = build_classifier_model(2)
bert_raw_result = test_classifier_model(text_preprocessed)
print(tf.sigmoid(bert_raw_result))

相关内容

最新更新