如何使用BERT模型和TensorflowLite预测(分类)用户句子



我正在尝试用TFLite model Maker训练MobileBERT模型;训练部分还可以,测试也可以(我可以使用mb_model.evaluate(mb_test_data)(。

但我完全不知道如何用Python用字符串句子预测结果。。。

这是一个训练示例脚本:

import os
import tensorflow as tf
assert tf.__version__.startswith('2')
from tflite_model_maker import configs
from tflite_model_maker import ExportFormat
from tflite_model_maker import model_spec
from tflite_model_maker import text_classifier
from tflite_model_maker.text_classifier import DataLoader
mb_spec = model_spec.get('mobilebert_classifier')
mb_train_data = DataLoader.from_csv(
filename=os.path.join(os.path.join(data_dir, 'nlu_train.tsv')),
text_column='sentence',
label_column='label',
model_spec=mb_spec,
delimiter='t',
is_training=True)
mb_test_data = DataLoader.from_csv(
filename=os.path.join(os.path.join(data_dir, 'nlu_test.tsv')),
text_column='sentence',
label_column='label',
model_spec=mb_spec,
delimiter='t',
is_training=False)
mb_model = text_classifier.create(mb_train_data, model_spec=mb_spec, epochs=30, batch_size=8)
config = configs.QuantizationConfig.for_float16()
config._experimental_new_quantizer = True
mb_model.export(export_dir='/')

导出/model.tflite

我可以用现有的句子来测试:

import numpy as np
import tensorflow as tf
interpreter = tf.lite.Interpreter(model_path="nlu (6).tflite")
interpreter.allocate_tensors()
input_details = interpreter.get_input_details()
output_details = interpreter.get_output_details()
input_shape = input_details[0]['shape']
input_data = np.array(np.random.random_sample(input_shape), dtype=np.int32)
interpreter.set_tensor(input_details[0]['index'], input_data)
interpreter.invoke()
output_data = interpreter.get_tensor(output_details[0]['index'])
print(output_data)

但我想使用一个自定义句子,而不是input_data = np.array(np.random.random_sample(input_shape), dtype=np.int32),比如:

input_data = "My user sentence"
output_data = interpreter.predict(input_data)

有人知道怎么做吗?我找不到任何文档,TFLite Model Maker(以及official.nlp.datarepository上的BERT(的反向源很难。。。

我没有发现在字符串和标记化过程中使用的完整预处理,以获得替换原始句子的int32列表:/

谢谢!

您可以使用BertNLClassifier进行推理。它将处理预处理和后处理部分。

最新更新