在 TensorFlow .ckpt 文件中使用了预处理模型



我有一个 ckpt 文件。我只想得到CNN的权重我已经从 ckpt 检查点文件进行了训练。?inception_resnet_v2_2016_08_30

import tensorflow as tf
saver = tf.train.Saver()
sess = tf.Session()
saver.restore(sess, "inception_resnet_v2_2016_08_30.ckpt")

from tensorflow.core.framework import graph_pb2
from tensorflow.core.protobuf import saver_pb2
from tensorflow.python import pywrap_tensorflow
from tensorflow.python.client import session
from tensorflow.python.framework import graph_util
from tensorflow.python.framework import importer
from tensorflow.python.platform import app
from tensorflow.python.platform import gfile
from tensorflow.python.training import saver as saver_lib
with session.Session() as sess:
var_list = {}
reader =pywrap_tensorflow.NewCheckpointReader("./inception_resnet_v2_2016_08_30.ckpt")
var_to_shape_map = reader.get_variable_to_shape_map()
for key in var_to_shape_map:
    try:
       tensor = sess.graph.get_tensor_by_name(key + ":0")
    except KeyError:
            continue
    var_list[key] = tensor
    saver = saver_lib.Saver(var_list=var_list)
    saver.restore(sess, input_checkpoint)
    if initializer_nodes:
       sess.run(initializer_nodes)

仅当已构建检查点将还原到的图形结构(包括一组tf.Variable对象(时,tf.train.Saver.restore() 方法才有效。您(至少(有两种解决此问题的选项:

  1. 使用 tf.train.NewCheckpointReader("inception_resnet_v2_2016_08_30.ckpt") 打开检查点文件。可以调用返回对象上的 get_tensor() 方法以按名称查找保存的变量,或调用 get_variable_to_shape_map() 方法来获取可用变量的列表。

  2. 如果有,请为检查点模式加载一个 MetaGraph,其中包括图形结构以及从该图形结构到检查点中变量的映射。