将TF2对象检测API模型转换为冻结图形



我使用Tensorflow对象检测API训练模型ssd_resnet50_v1_fpn_640x640_coco17_tpu-8https://github.com/tensorflow/models/blob/master/research/object_detection/model_main_tf2.py

导出到Save model后:.exporter_main_v2.py --input_type image_tensor --pipeline_config_path .modelsmy_ssd_resnet50_v1_fpnpipeline.config --trained_checkpoint_dir .modelsmy_ssd_resnet50_v1_fpn --output_directory .exported-modelsmodelsBel_model使用https://github.com/tensorflow/models/blob/master/research/object_detection/exporter_main_v2.py

在这一步推理工作良好使用Tensorflow。从保存的模型和检查点。这段代码用于测试推断:https://tensorflow-object-detection-api-tutorial.readthedocs.io/en/latest/_downloads/07fcc19ba03226cd3d83d4e40ec44385/auto_examples_python.zip

在我尝试将保存的模型转换为冻结图形后,使用这种方法在OpenCV中使用它https://github.com/opencv/opencv/issues/16879 issuecomment - 603815872

import tensorflow as tf
print(tf.__version__)
from tensorflow.python.framework.convert_to_constants import convert_variables_to_constants_v2
loaded = tf.saved_model.load('models/mnist_test')
infer = loaded.signatures['serving_default']
f = tf.function(infer).get_concrete_function(flatten_input=tf.TensorSpec(shape=[None, 28, 28, 1], dtype=tf.float32))
f2 = convert_variables_to_constants_v2(f)
graph_def = f2.graph.as_graph_def()
# Export frozen graph
with tf.io.gfile.GFile('frozen_graph.pb', 'wb') as f:
f.write(graph_def.SerializeToString())

不幸的是,在这一步我收到错误:

Traceback (most recent call last):
File ".frozen_graph.py", line 8, in <module>
f = tf.function(infer).get_concrete_function(input_1=tf.TensorSpec(shape=[None, 640, 640, 3], dtype=tf.float32))
File "C:UsersBleachminiconda3envsTFstdlibsite-packagestensorflowpythoneagerdef_function.py", line 1299, in get_concrete_function
concrete = self._get_concrete_function_garbage_collected(*args, **kwargs)
File "C:UsersBleachminiconda3envsTFstdlibsite-packagestensorflowpythoneagerdef_function.py", line 1205, in _get_concrete_function_garbage_collected
self._initialize(args, kwargs, add_initializers_to=initializers)
File "C:UsersBleachminiconda3envsTFstdlibsite-packagestensorflowpythoneagerdef_function.py", line 725, in _initialize
self._stateful_fn._get_concrete_function_internal_garbage_collected(  # pylint: disable=protected-access
File "C:UsersBleachminiconda3envsTFstdlibsite-packagestensorflowpythoneagerfunction.py", line 2969, in _get_concrete_function_internal_garbage_collected
graph_function, _ = self._maybe_define_function(args, kwargs)
File "C:UsersBleachminiconda3envsTFstdlibsite-packagestensorflowpythoneagerfunction.py", line 3361, in _maybe_define_function
graph_function = self._create_graph_function(args, kwargs)
File "C:UsersBleachminiconda3envsTFstdlibsite-packagestensorflowpythoneagerfunction.py", line 3196, in _create_graph_function
func_graph_module.func_graph_from_py_func(
File "C:UsersBleachminiconda3envsTFstdlibsite-packagestensorflowpythonframeworkfunc_graph.py", line 990, in func_graph_from_py_func
func_outputs = python_func(*func_args, **func_kwargs)
File "C:UsersBleachminiconda3envsTFstdlibsite-packagestensorflowpythoneagerdef_function.py", line 634, in wrapped_fn
out = weak_wrapped_fn().__wrapped__(*args, **kwds)
File "C:UsersBleachminiconda3envsTFstdlibsite-packagestensorflowpythonframeworkfunc_graph.py", line 977, in wrapper
raise e.ag_error_metadata.to_exception(e)
TypeError: in user code:
C:UsersBleachminiconda3envsTFstdlibsite-packagestensorflowpythoneagerfunction.py:1669 __call__  *
return self._call_impl(args, kwargs)
C:UsersBleachminiconda3envsTFstdlibsite-packagestensorflowpythoneagerfunction.py:1685 _call_impl  **
raise structured_err
C:UsersBleachminiconda3envsTFstdlibsite-packagestensorflowpythoneagerfunction.py:1678 _call_impl
return self._call_with_structured_signature(args, kwargs,
C:UsersBleachminiconda3envsTFstdlibsite-packagestensorflowpythoneagerfunction.py:1756 _call_with_structured_signature
self._structured_signature_check_missing_args(args, kwargs)
C:UsersBleachminiconda3envsTFstdlibsite-packagestensorflowpythoneagerfunction.py:1775 _structured_signature_check_missing_args
raise TypeError("{} missing required arguments: {}".format(
TypeError: signature_wrapper(*, input_tensor) missing required arguments: input_tensor

请帮我解决这个问题也许你可以给我建议另一种方法来创建一个冻结的图形。是否有可能更容易的解决方案来训练使用Keras的模型?

替换以下3行:

loaded = tf.saved_model.load('models/mnist_test')
infer = loaded.signatures['serving_default']
f = tf.function(infer).get_concrete_function(flatten_input=tf.TensorSpec(shape=[None, 28, 28, 1], dtype=tf.float32))

loaded = keras.models.load_model('models/mnist_test')
f = tf.function(lambda x: loaded(x))
f = f.get_concrete_function(tf.TensorSpec(loaded.inputs[0].shape, loaded.inputs[0].dtype))

最新更新