我正在Python中使用Tensorflow 2 API进行对象检测。到目前为止效果很好。但是,如果我想保存模型,我使用的是exporter_main_v2.py
,它导出一个图(.pb(和一个检查点(checkpoint
、ckpt-0.data
、ckpt-0.index
(。该图不包含任何权重,我总是必须使用检查点来处理保存的模型。有没有办法将所有权重保存到Protobuf(.pb(文件中?
以下是我尝试过的:
- 保存冻结模型:TF2显然不再支持冻结图。
export_inference_graph.py
将冻结包含所有权重的图形,但在TF2下不起作用 freeze_graph.py
也是如此:仅可使用TF1
您仍然可以使用TF2中TF1的冻结技术,使用compat.v1
模块:
在下面的代码段中,我假设您有一个预训练的模型,该模型具有以TF2方式保存的权重tf.saved_model.save
。
graph = tf.Graph()
with graph.as_default():
sess = tf.compat.v1.Session()
with sess.as_default():
# creating the model/loading it from a TF2 pb file
# (If you have a keras model, you can use
#`tf.keras.models.load_model` instead).
model = tf.saved_model.load("/path/to/model")
# the default signature might be different.
sign = model.signatures["serving_default"]
# if using keras, just use model.outputs
tensor_out_names = [out.name.split(":")[0] for out in sign.outputs]
graphdef = tf.compat.v1.graph_util.convert_variables_to_constants(
sess, graph.as_graph_def(), tensor_out_names
)
# the following is optional, use only if no more training is required
graphdef = tf.compat.v1.graph_util.remove_training_nodes(graphdef)
tf.python.framework.graph_io.write_graph(graphdef, "./", "/path/to/frozengraph", as_text=False)
然而,除了与旧工具兼容之外,我不会这么做。compat
模块可能有一天会被弃用,据我所知,只有一个文件包含图+权重,而不是拆分它们,这并不是一个很大的值。