如何保存包含所有权重的Tensorflow 2对象检测模型



我正在Python中使用Tensorflow 2 API进行对象检测。到目前为止效果很好。但是,如果我想保存模型,我使用的是exporter_main_v2.py,它导出一个图(.pb(和一个检查点(checkpointckpt-0.datackpt-0.index(。该图不包含任何权重,我总是必须使用检查点来处理保存的模型。有没有办法将所有权重保存到Protobuf(.pb(文件中?

以下是我尝试过的:

  • 保存冻结模型:TF2显然不再支持冻结图。export_inference_graph.py将冻结包含所有权重的图形,但在TF2下不起作用
  • freeze_graph.py也是如此:仅可使用TF1

您仍然可以使用TF2中TF1的冻结技术,使用compat.v1模块:

在下面的代码段中,我假设您有一个预训练的模型,该模型具有以TF2方式保存的权重tf.saved_model.save

graph = tf.Graph()
with graph.as_default():
sess = tf.compat.v1.Session()
with sess.as_default():
# creating the model/loading it from a TF2 pb file
# (If you have a keras model, you can use 
#`tf.keras.models.load_model` instead). 
model = tf.saved_model.load("/path/to/model")
# the default signature might be different.
sign = model.signatures["serving_default"]
# if using keras, just use model.outputs
tensor_out_names = [out.name.split(":")[0] for out in sign.outputs]

graphdef = tf.compat.v1.graph_util.convert_variables_to_constants(
sess, graph.as_graph_def(), tensor_out_names
)
# the following is optional, use only if no more training is required
graphdef = tf.compat.v1.graph_util.remove_training_nodes(graphdef)
tf.python.framework.graph_io.write_graph(graphdef, "./", "/path/to/frozengraph", as_text=False)

然而,除了与旧工具兼容之外,我不会这么做。compat模块可能有一天会被弃用,据我所知,只有一个文件包含图+权重,而不是拆分它们,这并不是一个很大的值。

相关内容

  • 没有找到相关文章

最新更新