将Tensorflow 1.12型号转换为Tensorflow Lite(TFLite)



我正在尝试将我在Tensorflow 1.12中创建的模型转换为Tensorflow Lite。

我使用这个代码:

import numpy as np
import tensorflow as tf
# Generate tf.keras model.
model = tf.keras.models.Sequential()
model.add(tf.keras.layers.Dense(2, input_shape=(3,)))
model.add(tf.keras.layers.RepeatVector(3))
model.add(tf.keras.layers.TimeDistributed(tf.keras.layers.Dense(3)))
model.compile(loss=tf.keras.losses.MSE,
optimizer=tf.keras.optimizers.RMSprop(lr=0.0001),
metrics=[tf.keras.metrics.categorical_accuracy],
sample_weight_mode='temporal')
x = np.random.random((1, 3))
y = np.random.random((1, 3, 3))
model.train_on_batch(x, y)
model.predict(x)
# Save tf.keras model in HDF5 format.
keras_file = "keras_model.h5"
tf.keras.models.save_model(model, keras_file)
# Convert to TensorFlow Lite model.
converter = tf.lite.TFLiteConverter.from_keras_model_file(keras_file)
tflite_model = converter.convert()
open("converted_model.tflite", "wb").write(tflite_model)

我从网站上取了这个代码示例https://github.com/tensorflow/tensorflow/blob/master/tensorflow/lite/g3doc/r1/convert/python_api.md#pre_tensorflow_1.12.因为我使用的是Tensorflow 1.12,所以我修改了

converter = tf.lite.TFLiteConverter.from_keras_model_file(keras_file)

converter = tf.contrib.lite.TFLiteConverter.from_keras_model_file(keras_file)

如上面链接中所建议的。当我运行此代码时,我得到了以下信息:

INFO:tensorflow:Froze 4 variables.
INFO:tensorflow:Converted 4 variables to const ops.

之后我得到了这个错误:


------------------------------
RuntimeError                              Traceback (most recent call last)
<ipython-input-21-81a9e7060f2c> in <module>
23 # Convert to TensorFlow Lite model.
24 converter = tf.contrib.lite.TFLiteConverter.from_keras_model_file(keras_file)
---> 25 tflite_model = converter.convert()
26 open("converted_model.tflite", "wb").write(tflite_model)
~Anaconda3envstensorflow1.12libsite-packagestensorflowcontriblitepythonlite.py in convert(self)
451           input_tensors=self._input_tensors,
452           output_tensors=self._output_tensors,
--> 453           **converter_kwargs)
454     else:
455       # Graphs without valid tensors cannot be loaded into tf.Session since they
~Anaconda3envstensorflow1.12libsite-packagestensorflowcontriblitepythonconvert.py in toco_convert_impl(input_data, input_tensors, output_tensors, *args, **kwargs)
340   data = toco_convert_protos(model_flags.SerializeToString(),
341                              toco_flags.SerializeToString(),
--> 342                              input_data.SerializeToString())
343   return data
344 
~Anaconda3envstensorflow1.12libsite-packagestensorflowcontriblitepythonconvert.py in toco_convert_protos(model_flags_str, toco_flags_str, input_data_str)
133     else:
134       raise RuntimeError("TOCO failed see console for info.n%sn%sn" %
--> 135                          (stdout, stderr))
136 
137 
RuntimeError: TOCO failed see console for info.
b'C:\Users\.\Anaconda3\envs\tensorflow1.12\lib\site-packages\tensorflow\python\framework\dtypes.py:523: FutureWarning: Passing (type, 1) or '1type' as a synonym of type is deprecated; in a future version of numpy, it will be understood as (type, (1,)) / '(1,)type'.rn  _np_qint8 = np.dtype([("qint8", np.int8, 1)])rnC:\Users\.\Anaconda3\envs\tensorflow1.12\lib\site-packages\tensorflow\python\framework\dtypes.py:524: FutureWarning: Passing (type, 1) or '1type' as a synonym of type is deprecated; in a future version of numpy, it will be understood as (type, (1,)) / '(1,)type'.rn  _np_quint8 = np.dtype([("quint8", np.uint8, 1)])rnC:\Users\.\Anaconda3\envs\tensorflow1.12\lib\site-packages\tensorflow\python\framework\dtypes.py:525: FutureWarning: Passing (type, 1) or '1type' as a synonym of type is deprecated; in a future version of numpy, it will be understood as (type, (1,)) / '(1,)type'.rn  _np_qint16 = np.dtype([("qint16", np.int16, 1)])rnC:\Users\.\Anaconda3\envs\tensorflow1.12\lib\site-packages\tensorflow\python\framework\dtypes.py:526: FutureWarning: Passing (type, 1) or '1type' as a synonym of type is deprecated; in a future version of numpy, it will be understood as (type, (1,)) / '(1,)type'.rn  _np_quint16 = np.dtype([("quint16", np.uint16, 1)])rnC:\Users\.\Anaconda3\envs\tensorflow1.12\lib\site-packages\tensorflow\python\framework\dtypes.py:527: FutureWarning: Passing (type, 1) or '1type' as a synonym of type is deprecated; in a future version of numpy, it will be understood as (type, (1,)) / '(1,)type'.rn  _np_qint32 = np.dtype([("qint32", np.int32, 1)])rnC:\Users\.\Anaconda3\envs\tensorflow1.12\lib\site-packages\tensorflow\python\framework\dtypes.py:532: FutureWarning: Passing (type, 1) or '1type' as a synonym of type is deprecated; in a future version of numpy, it will be understood as (type, (1,)) / '(1,)type'.rn  np_resource = np.dtype([("resource", np.ubyte, 1)])rnTraceback (most recent call last):rn  File "C:\Users\.\Anaconda3\envs\tensorflow1.12\lib\site-packages\tensorflow\contrib\lite\toco\python\tensorflow_wrap_toco.py", line 18, in swig_import_helperrn    fp, pathname, description = imp.find_module('_tensorflow_wrap_toco', [dirname(__file__)])rn  File "C:\Users\.\Anaconda3\envs\tensorflow1.12\lib\imp.py", line 297, in find_modulern    raise ImportError(_ERR_MSG.format(name), name=name)rnImportError: No module named '_tensorflow_wrap_toco'rnrnDuring handling of the above exception, another exception occurred:rnrnTraceback (most recent call last):rn  File "C:\Users\.\Anaconda3\envs\tensorflow1.12\Scripts\toco_from_protos-script.py", line 6, in <module>rn    from tensorflow.contrib.lite.toco.python.toco_from_protos import mainrn  File "C:\Users\.\Anaconda3\envs\tensorflow1.12\lib\site-packages\tensorflow\contrib\lite\toco\python\toco_from_protos.py", line 22, in <module>rn    from tensorflow.contrib.lite.toco.python import tensorflow_wrap_tocorn  File "C:\Users\.\Anaconda3\envs\tensorflow1.12\lib\site-packages\tensorflow\contrib\lite\toco\python\tensorflow_wrap_toco.py", line 28, in <module>rn    _tensorflow_wrap_toco = swig_import_helper()rn  File "C:\Users\.\Anaconda3\envs\tensorflow1.12\lib\site-packages\tensorflow\contrib\lite\toco\python\tensorflow_wrap_toco.py", line 20, in swig_import_helperrn    import _tensorflow_wrap_tocornModuleNotFoundError: No module named '_tensorflow_wrap_toco'rn'
None
Could someone help to solve this?

我建议您使用更新的tensorflow并使用它的新转换器(称为MLIR,而不是TOCO(。

用2.4.0尝试了你的代码(但也可以用2.2.x(,并修改了一行:

converter = tf.lite.TFLiteConverter.from_keras_model(model)

得到了*.tflite模型。

根据我的实验,tf2.x对keras更友好,可以让你完美地量化。但是对于tf1.x,您应该切换到导出到QAT或冻结图def以确保量化。对于QAT:您可以在这里查看

最新更新