我有一个 ckpt 文件。我只想得到CNN的权重我已经从 ckpt 检查点文件进行了训练。?inception_resnet_v2_2016_08_30
import tensorflow as tf
saver = tf.train.Saver()
sess = tf.Session()
saver.restore(sess, "inception_resnet_v2_2016_08_30.ckpt")
from tensorflow.core.framework import graph_pb2
from tensorflow.core.protobuf import saver_pb2
from tensorflow.python import pywrap_tensorflow
from tensorflow.python.client import session
from tensorflow.python.framework import graph_util
from tensorflow.python.framework import importer
from tensorflow.python.platform import app
from tensorflow.python.platform import gfile
from tensorflow.python.training import saver as saver_lib
with session.Session() as sess:
var_list = {}
reader =pywrap_tensorflow.NewCheckpointReader("./inception_resnet_v2_2016_08_30.ckpt")
var_to_shape_map = reader.get_variable_to_shape_map()
for key in var_to_shape_map:
try:
tensor = sess.graph.get_tensor_by_name(key + ":0")
except KeyError:
continue
var_list[key] = tensor
saver = saver_lib.Saver(var_list=var_list)
saver.restore(sess, input_checkpoint)
if initializer_nodes:
sess.run(initializer_nodes)
仅当已构建检查点将还原到的图形结构(包括一组tf.Variable
对象(时,tf.train.Saver.restore()
方法才有效。您(至少(有两种解决此问题的选项:
-
使用
tf.train.NewCheckpointReader("inception_resnet_v2_2016_08_30.ckpt")
打开检查点文件。可以调用返回对象上的get_tensor()
方法以按名称查找保存的变量,或调用get_variable_to_shape_map()
方法来获取可用变量的列表。 -
如果有,请为检查点模式加载一个 MetaGraph,其中包括图形结构以及从该图形结构到检查点中变量的映射。