存储恢复的检查点的张量流


_ = importer.import_graph_def(input_graph_def, name='')
with session.Session() as sess:
if input_saver_def:
saver = saver_lib.Saver(saver_def=input_saver_def)
saver.restore(sess, input_checkpoint)
output_graph_def = graph_util.convert_variables_to_constants(
sess,
input_graph_def,
output_node_names.split(','),
variable_names_blacklist=variable_names_blacklist)

在上面的代码中,导入器用于将 graphDef 导入到当前默认图形,保护程序加载之前训练的值。问题是这些训练值存储在哪里?在会话中,在input_graph_def中,在当前图形结构(tf.get_default_graph(((中还是在保护程序中?

我检查方法的代码convert_variables_to_constants. https://github.com/tensorflow/tensorflow/blob/235192d47cfb375c0cc93c1deefb9e440715bf35/tensorflow/python/framework/graph_util_impl.py

它使用 sess.run(变量名称(来获取加载的值。这sess.run从哪里获取值?

当我们定义保护程序时,我们应该将注释传递给它(默认情况下它是全局变量(。

In [2]: import tensorflow as tf
In [3]: a = tf.get_variable("a", [])
In [4]: b = tf.get_variable("b", [])
In [5]: saver_a = tf.train.Saver({"my_a_in_ckpt": a}) # here "my_a_in_ckpt" can be any apt name you like, it is the variable name stored only in the checkpoint (1)
In [6]: init = tf.global_variables_initializer()
In [7]: sess = tf.Session()
In [8]: sess.run(init)
In [9]: sess.run(a)
Out[9]: 0.43891537
In [10]: sess.run(b)
Out[10]: 1.5962805
In [11]: saver_a.save(sess, "./temp_model")

在这里,我们首先初始化所有变量并将 a 保存到 "./temp_model"。要恢复变量:

In [2]: import tensorflow as tf
In [3]: a = tf.get_variable("a", []) 
In [5]: saver_a = tf.train.Saver({"my_a_in_ckpt": a})  # here "my_a_in_ckpt" should match that as you defined in step (1)
In [7]: sess = tf.Session()
In [9]: saver_a.restore(sess, tf.train.latest_checkpoint("./temp_model"))
INFO:tensorflow:Restoring parameters from ./temp_model/temp
In [10]: sess.run(a)
Out[10]: 0.43891537
In [11]: sess.run(b)
Out[11]: 1.5962805

我们可以将 a 和 b 保存到不同的地方:

In [12]: saver_b = tf.train.Saver({"b": b})
In [13]: saver_b.save(sess, "./temp_model_b/temp")
Out[13]: './temp_model_b/temp'

并将它们还原为图形:

In [3]: a = tf.get_variable("a", [])
In [4]: b = tf.get_variable("b", [])
In [5]: saver_b = tf.train.Saver({"b": b})
In [6]: saver_a = tf.train.Saver({"my_a_in_ckpt": a})
In [7]: saver_a.restore(sess, tf.train.latest_checkpoint("./temp_model"))                              
INFO:tensorflow:Restoring parameters from ./temp_model/temp
In [8]: saver_b.restore(sess, tf.train.latest_checkpoint("./temp_model_b"))
INFO:tensorflow:Restoring parameters from ./temp_model_b/temp
In [9]: sess.run(a)
Out[9]: 0.43891537
In [10]: sess.run(b)
Out[10]: 1.5962805

最新更新