InvalidArgumentError:无法将dtype资源的张量转换为NumPy数组



我想使用tf-hub建立一个文本分类模型并导出为tflite模型,但
在转换包括tf hub的tensorflow模型时,我得到了错误。请帮我解决我的问题。

import tensorflow as tf
import tensorflow_hub as hub 
model = tf.keras.Sequential()
model.add(tf.keras.layers.InputLayer(dtype=tf.string, input_shape=()))
model.add(hub.KerasLayer("https://tfhub.dev/google/tf2-preview/nnlm-en-dim50/1"))
converter=tf.lite.TFLiteConverter.from_keras_model(model)
tflite_model = converter.convert()

我尝试了tf-lite python和命令行api。但我得到了InvalidArgumentError


InvalidArgumentError                      Traceback (most recent call last)
<ipython-input-15-5a8dbd778645> in <module>()
5 model.add(hub.KerasLayer("https://tfhub.dev/google/tf2-preview/nnlm-en-dim50/1"))
6 converter = tf.lite.TFLiteConverter.from_keras_model(model)
----> 7 tflite_model = converter.convert()
6 frames
/usr/local/lib/python3.6/dist-packages/tensorflow/lite/python/lite.py in convert(self)
850     frozen_func, graph_def = (
851         _convert_to_constants.convert_variables_to_constants_v2_as_graph(
--> 852             self._funcs[0], lower_control_flow=False))
853 
854     input_tensors = [
/usr/local/lib/python3.6/dist-packages/tensorflow/python/framework/convert_to_constants.py in convert_variables_to_constants_v2_as_graph(func, lower_control_flow, aggressive_inlining)
1103       func=func,
1104       lower_control_flow=lower_control_flow,
-> 1105       aggressive_inlining=aggressive_inlining)
1106 
1107   output_graph_def, converted_input_indices = _replace_variables_by_constants(
/usr/local/lib/python3.6/dist-packages/tensorflow/python/framework/convert_to_constants.py in __init__(self, func, lower_control_flow, aggressive_inlining, variable_names_allowlist, variable_names_denylist)
804         variable_names_allowlist=variable_names_allowlist,
805         variable_names_denylist=variable_names_denylist)
--> 806     self._build_tensor_data()
807 
808   def _build_tensor_data(self):
/usr/local/lib/python3.6/dist-packages/tensorflow/python/framework/convert_to_constants.py in _build_tensor_data(self)
823         data = map_index_to_variable[idx].numpy()
824       else:
--> 825         data = val_tensor.numpy()
826       self._tensor_data[tensor_name] = _TensorData(
827           numpy=data,
/usr/local/lib/python3.6/dist-packages/tensorflow/python/framework/ops.py in numpy(self)
1069     """
1070     # TODO(slebedev): Consider avoiding a copy for non-CPU or remote tensors.
-> 1071     maybe_arr = self._numpy()  # pylint: disable=protected-access
1072     return maybe_arr.copy() if isinstance(maybe_arr, np.ndarray) else maybe_arr
1073 
/usr/local/lib/python3.6/dist-packages/tensorflow/python/framework/ops.py in _numpy(self)
1037       return self._numpy_internal()
1038     except core._NotOkStatusException as e:  # pylint: disable=protected-access
-> 1039       six.raise_from(core._status_to_exception(e.code, e.message), None)  # pylint: disable=protected-access
1040 
1041   @property
/usr/local/lib/python3.6/dist-packages/six.py in raise_from(value, from_value)
InvalidArgumentError: Cannot convert a Tensor of dtype resource to a NumPy array.

上次我检查的时候,TFLite不支持查找表,查找表是TF Hub模型中资源张量的主要来源(除了变量,但那些肯定有效)。

相关内容

  • 没有找到相关文章

最新更新