我正试图将我的TF模型转换为TF Lite,因此我可以在微控制器中使用它。为了能够转换它,我需要一个代表性数据集,我真的不知道如何正确地创建它。
这是我的TF模型:model = Sequential()
model.add(Dense(1, activation="relu", input_dim=5))
model.summary()
model.compile(loss="mean_squared_error",optimizer="adam", metrics=["accuracy","mean_squared_error"])
model.fit(X_train, Y_train, batch_size=16, epochs=40, verbose=1)
转换模型的代码:
converter = tf.lite.TFLiteConverter.from_keras_model(model)
converter.optimizations = [tf.lite.Optimize.DEFAULT]
def rep_data_gen():
yield from X_test.iterrows()
converter.representative_dataset = rep_data_gen
converter.target_spec.supported_ops = [tf.lite.OpsSet.TFLITE_BUILTINS_INT8]
tflite_quant_model = converter.convert()
open("converted_model.tflite", "wb").write(tflite_quant_model)
我已经尝试了其他方法来创建rep_data_gen,但由于我不完全理解,我没有成功。我没有运气找到一个教程来帮助我更好地理解它。
这是我的数据快照从X_test:
Gas_ohm Humidity Pressure Temperature temp_prof
0 -0.244977 0.876587 -0.726149 -0.550705 1.423705
1 2.914520 0.884880 -0.840744 -0.550705 -1.404848
2 0.506850 0.880733 -0.840744 -0.550705 -0.697710
3 -0.020688 0.884880 -0.840744 -0.550705 0.009429
4 -0.238303 0.880733 -0.840744 -0.550705 0.716567
5 -0.246105 0.880733 -0.840744 -0.550705 1.423705
6 2.928245 0.872440 -0.840744 -0.550705 -1.404848
7 0.485888 0.897321 -0.840744 -0.588416 -0.697710
8 -0.038831 0.913908 -0.840744 -0.626127 0.009429
9 -0.248267 0.918055 -0.840744 -0.626127 0.716567
10 -0.253719 0.909761 -0.955339 -0.588416 1.423705
11 2.843737 0.909761 -0.955339 -0.588416 -1.404848
12 0.470190 0.901468 -0.840744 -0.588416 -0.697710
13 -0.040429 0.901468 -0.840744 -0.588416 0.009429
14 -0.249395 0.901468 -0.840744 -0.588416 0.716567
15 -0.252121 0.901468 -0.840744 -0.550705 1.423705
16 2.856991 0.884880 -0.840744 -0.550705 -1.404848
17 0.477992 0.884880 -0.840744 -0.588416 -0.697710
18 -0.034037 0.897321 -0.840744 -0.550705 0.009429
19 -0.246105 0.897321 -0.840744 -0.550705 0.716567
20 -0.247797 0.893174 -0.840744 -0.550705 1.423705
有人能帮我创建/理解代表性数据集吗?
以下是如何创建representative_dataset()
的3个好例子:
-
: 9.1。 https://github.com/EdjeElectronics/TensorFlow-Lite-Object-Detection-on-Android-and-Raspberry-Pi/blob/master/Train_TFLite2_Object_Detction_Model.ipynb
-
Point: Convert to TFLite https://github.com/google-coral/tutorials/blob/master/retrain_classification_ptq_tf2.ipynb
-
最简单的完整示例,没有
representative_dataset()
点:使用TFLite的转换器转换https://github.com/https-deeplearning-ai/tensorflow-2-public/blob/main/C2_Device-based-TF-lite/W2/assignment_optional/C2_W2_Assignment_Solution.ipynb
记住有4种类型的优化在TFlite,并取决于哪一个你想要你必须应用representative_dataset()
或不应用它。检查类型:https://www.tensorflow.org/lite/performance/model_optimization?hl=en-419#quantization
简而言之,对于您的问题, a"训练后整数量化";类型优化,您需要此representative_dataset()
(检查TODO注释)
#CODE FROM Point: 9.1. Quantize model
#https://github.com/EdjeElectronics/TensorFlow-Lite-Object-Detection-on-Android-and-Raspberry-Pi/blob/master/Train_TFLite2_Object_Detction_Model.ipynb
# Read oficial documentation https://www.tensorflow.org/lite/performance/model_optimization?hl=en-419#quantization
import glob
import random
# Get list of all images in train directory
image_path = 'Images_folder'
jpg_file_list = glob.glob(image_path + '/*.jpg')
JPG_file_list = glob.glob(image_path + '/*.JPG')
png_file_list = glob.glob(image_path + '/*.png')
bmp_file_list = glob.glob(image_path + '/*.bmp')
quant_image_list = jpg_file_list + JPG_file_list + png_file_list + bmp_file_list
# A generator that provides a representative dataset
# Code modified from https://colab.research.google.com/github/google-coral/tutorials/blob/master/retrain_classification_ptq_tf2.ipynb
import tensorflow as tf
# First, get input details for model so we know how to preprocess images
# SAVED_MODEL_PATH_LITE = "model_mobil_v2_C/frozen/model_simple_sigNo_Mdata.tflite"
# interpreter =tf.lite.Interpreter(model_path=SAVED_MODEL_PATH_LITE)
# # interpreter = tf.lite.Interpreter(model_path=PATH_TO_MODEL) # PATH_TO_MODEL is defined in Step 7 above
# interpreter.allocate_tensors()
# input_details = interpreter.get_input_details()
# output_details = interpreter.get_output_details()
height = 640 ## TODO:change to your model # input_details[0]['shape'][1]
width = 640 ## TODO:change to your model #input_details[0]['shape'][2]
def representative_data_gen():
dataset_list = quant_image_list
quant_num = 200 # TODO: Replace 200s with an automatic way of reading network input size
for i in range(quant_num):
pick_me = random.choice(dataset_list)
image = tf.io.read_file(pick_me)
if pick_me.endswith('.jpg') or pick_me.endswith('.JPG'):
image = tf.io.decode_jpeg(image, channels=3)
elif pick_me.endswith('.png'):
image = tf.io.decode_png(image, channels=3)
elif pick_me.endswith('.bmp'):
image = tf.io.decode_bmp(image, channels=3)
image = tf.image.resize(image, [width, height])
image = tf.cast(image / 255., tf.float32)
image = tf.expand_dims(image, 0)
yield [image]
# Finally, we'll initialize the TFLiteConverter module, point it at the TFLite graph we generated in Step 6, and provide it with the representative dataset
# generator function we created in the previous code block. We'll configure the converter to quantize the model's weight values to INT8 format.
# Initialize converter module
converter = tf.lite.TFLiteConverter.from_saved_model('model_101_C/frozen/saved_model')
# This enables quantization
converter.optimizations = [tf.lite.Optimize.DEFAULT]
# This sets the representative dataset for quantization
converter.representative_dataset = representative_data_gen
# This ensures that if any ops can't be quantized, the converter throws an error
converter.target_spec.supported_ops = [tf.lite.OpsSet.TFLITE_BUILTINS, tf.lite.OpsSet.TFLITE_BUILTINS_INT8]
# For full integer quantization, though supported types defaults to int8 only, we explicitly declare it for clarity.
converter.target_spec.supported_types = [tf.int8]
# These set the input tensors to uint8 and output tensors to float32
converter.inference_input_type = tf.uint8
converter.inference_output_type = tf.float32
tflite_model = converter.convert()
TF_LITE_PATH_DATA = "model_101_C/frozen/model_DATA_GEN.tflite"
with open(TF_LITE_PATH_DATA, 'wb') as f:
f.write(tflite_model)
print("Optimized TFlite saved in: "+TF_LITE_PATH_DATA)