Tensorflow:导入预训练模型(mobilenet,.pb,.ckpt)



我一直在研究在tensorflow中导入预训练模型的检查点。这样做的目的是让我可以检查其结构,并将其用于图像分类。

具体来说,就是这里找到的移动网络模型。 我找不到任何从各种 *.ckpt.* 文件导入模型的合理方法,以及一些论坛嗅探我发现了Github用户StanislawAntol编写的要点,其中据称将所述文件转换为冻结模型,ProtoBuf(.pb(文件。这要点在这里

运行脚本给了我一堆 .pb 文件,我希望我可以工作跟。 事实上,这个问题似乎回应了我的祈祷。

我一直在尝试以下代码的变体,但无济于事。 任何对象被返回 tf.import_graph_def 似乎属于 None 类型。

import tensorflow as tf
from tensorflow.python.platform import gfile
model_filename = LOCATION_OF_PB_FILE
with gfile.FastGFile(model_filename,'rb') as f:
graph_def = tf.GraphDef()
graph_def.ParseFromString(f.read())
g_in = tf.import_graph_def(graph_def, name='')
print(g_in)

我在这里缺少什么吗? 整个转换为 .pb 的过程是错误的吗?

tf.import_graph_def

返回图形,它会填充作用域中的"默认图形"。有关返回值的详细信息,请参阅tf.import_graph_def文档。

在您的情况下,您可以使用 tf.get_default_graph() 检查图形。例如:

with gfile.FastGFile(model_filename, 'rb') as f:
  graph_def = tf.GraphDef()
  graph_def.ParseFromString(f.read())
tf.import_graph_def(graph_def, name='')
g = tf.get_default_graph()
print(len(g.get_operations()))

有关"默认图形"概念和范围tf.Graph的更多详细信息,请参阅文档。

希望有帮助。