我想使用tf-hub建立一个文本分类模型并导出为tflite模型,但
在转换包括tf hub的tensorflow模型时,我得到了错误。请帮我解决我的问题。
import tensorflow as tf
import tensorflow_hub as hub
model = tf.keras.Sequential()
model.add(tf.keras.layers.InputLayer(dtype=tf.string, input_shape=()))
model.add(hub.KerasLayer("https://tfhub.dev/google/tf2-preview/nnlm-en-dim50/1"))
converter=tf.lite.TFLiteConverter.from_keras_model(model)
tflite_model = converter.convert()
我尝试了tf-lite python和命令行api。但我得到了InvalidArgumentError
InvalidArgumentError Traceback (most recent call last)
<ipython-input-15-5a8dbd778645> in <module>()
5 model.add(hub.KerasLayer("https://tfhub.dev/google/tf2-preview/nnlm-en-dim50/1"))
6 converter = tf.lite.TFLiteConverter.from_keras_model(model)
----> 7 tflite_model = converter.convert()
6 frames
/usr/local/lib/python3.6/dist-packages/tensorflow/lite/python/lite.py in convert(self)
850 frozen_func, graph_def = (
851 _convert_to_constants.convert_variables_to_constants_v2_as_graph(
--> 852 self._funcs[0], lower_control_flow=False))
853
854 input_tensors = [
/usr/local/lib/python3.6/dist-packages/tensorflow/python/framework/convert_to_constants.py in convert_variables_to_constants_v2_as_graph(func, lower_control_flow, aggressive_inlining)
1103 func=func,
1104 lower_control_flow=lower_control_flow,
-> 1105 aggressive_inlining=aggressive_inlining)
1106
1107 output_graph_def, converted_input_indices = _replace_variables_by_constants(
/usr/local/lib/python3.6/dist-packages/tensorflow/python/framework/convert_to_constants.py in __init__(self, func, lower_control_flow, aggressive_inlining, variable_names_allowlist, variable_names_denylist)
804 variable_names_allowlist=variable_names_allowlist,
805 variable_names_denylist=variable_names_denylist)
--> 806 self._build_tensor_data()
807
808 def _build_tensor_data(self):
/usr/local/lib/python3.6/dist-packages/tensorflow/python/framework/convert_to_constants.py in _build_tensor_data(self)
823 data = map_index_to_variable[idx].numpy()
824 else:
--> 825 data = val_tensor.numpy()
826 self._tensor_data[tensor_name] = _TensorData(
827 numpy=data,
/usr/local/lib/python3.6/dist-packages/tensorflow/python/framework/ops.py in numpy(self)
1069 """
1070 # TODO(slebedev): Consider avoiding a copy for non-CPU or remote tensors.
-> 1071 maybe_arr = self._numpy() # pylint: disable=protected-access
1072 return maybe_arr.copy() if isinstance(maybe_arr, np.ndarray) else maybe_arr
1073
/usr/local/lib/python3.6/dist-packages/tensorflow/python/framework/ops.py in _numpy(self)
1037 return self._numpy_internal()
1038 except core._NotOkStatusException as e: # pylint: disable=protected-access
-> 1039 six.raise_from(core._status_to_exception(e.code, e.message), None) # pylint: disable=protected-access
1040
1041 @property
/usr/local/lib/python3.6/dist-packages/six.py in raise_from(value, from_value)
InvalidArgumentError: Cannot convert a Tensor of dtype resource to a NumPy array.
上次我检查的时候,TFLite不支持查找表,查找表是TF Hub模型中资源张量的主要来源(除了变量,但那些肯定有效)。