我已经训练了一个tensorflow模型来预测输入文本的下一个单词。保存为.h5文件。
我可以在另一个python代码中使用该模型来预测word,如下所示:
import numpy as np
from tensorflow.keras.preprocessing.text import Tokenizer
from tensorflow.keras.preprocessing.sequence import pad_sequences
from keras.models import load_model
model = load_model('model.h5')
model.compile(
loss = "categorical_crossentropy",
optimizer = "adam",
metrics = ["accuracy"]
)
data = open("dataset.txt").read()
corpus = data.lower().split("n")
tokenizer = Tokenizer()
tokenizer.fit_on_texts(corpus)
seed_text = input()
sequence_text = tokenizer.texts_to_sequences([seed_text])[0]
padded_sequence = np.array(pad_sequences([sequence_text], maxlen = 11 -1))
predicted = np.argmax(model.predict(padded_sequence))
是否有一种方法可以直接使用该模型扑动,在那里我可以从TextField()和按下的输入按钮,显示预测的单词??
步骤
- 将模型转换为
.tflite
模型
# https://www.tensorflow.org/lite/convert/#convert_a_savedmodel_recommended_
import tensorflow as tf
# Convert the model
converter = tf.lite.TFLiteConverter.from_saved_model(saved_model_dir) # path to the SavedModel directory
tflite_model = converter.convert()
# Save the model.
with open('model.tflite', 'wb') as f:
f.write(tflite_model)
- 将tflite Model添加到App目录。我通常将模型添加到
assets/
目录中。
android/
assets/
model.tflite
ios/
lib/
- 添加tflite作为
pubspec.yaml
的依赖项
dependencies:
flutter:
sdk: flutter
tflite: ^1.0.5
.
.
- 在dart脚本中运行Inference。例如,下面的代码片段是关于如何在图像上运行Inference的示例脚本,其中
labels.txt
是包含类的文本文件:
import 'package:tflite/tflite.dart';
.
.
.
class _MyAppState extends State<MyApp> {
. . .
@override
void initState() {
super.initState();
_loading = true;
loadModel().then((value) {
setState(() {
_loading = false;
});
});
}
classifyImage(File image) async {
var output = await Tflite.runModelOnImage(
path: image.path,
numResults: 2,
threshold: 0.5,
imageMean: 127.5,
imageStd: 127.5,
);
setState(() {
_loading = false;
_outputs = output;
});
}
loadModel() async {
await Tflite.loadModel(
model: "assets/model_unquant.tflite",
labels: "assets/labels.txt",
);
}
@override
void dispose() {
Tflite.close();
super.dispose();
}
. . .
}
旁注
tflite插件不支持文本分类AFAIK,如果你想专门做文本分类我建议使用tflite_flutter
插件。下面是使用文本分类插件的文章链接。
使用TensorFlow Lite Plugin for Flutter进行文本分类
不能使用.h5文件直接在Flutter。您需要将其转换为.tflite文件并使用它或创建一个REST API。
将其转换为.tflite文件是最简单的。您可以查看以下文章了解更多详情:https://medium.com/analytics-vidhya/run-cnn-model-in-flutter-10c944cadcba
如果你想创建一个REST API,查看这篇文章:https://medium.com/analytics-vidhya/deploy-ml-models-using-flask-as-rest-api-and-access-via-flutter-app-7ce63d5c1f3b