我是TensorFlow的新手。当我阅读张量流保存和恢复变量手册时,我遇到了一个问题。我保存了一个由常量初始化的变量,但无法恢复该变量。代码如下:
a = tf.get_variable("name_a", initializer=[1,2,3])
op1 = a.assign(a+1)
saver = tf.train.Saver()
init = tf.global_variables_initializer()
with tf.Session() as sess:
sess.run(init)
op1.op.run()
print(a.eval())
saver.save(sess,"log1/model.ckpt")
然后我恢复它。
a = tf.get_variable("name_a", shape=[3])
saver = tf.train.Saver()
with tf.Session() as sess:
saver.restore(sess, "log1/model.ckpt")
print(a.eval())
我想像[2,3,4]
一样获得输出,但我得到了[ 2.80259693e-45 4.20389539e-45 5.60519386e-45]
.都是零。
但是,当我将第一个代码片段中的第一行修改为
a = tf.get_variable("name_a", initializer=tf.zeros([3]))
我可以得到正确的恢复变量:[ 1. 1. 1.]
我想知道这种情况的原因。
我不是 100% 确定,但看起来原因是您的两个变量:
-
tf.get_variable("name_a", initializer=[1,2,3])
-
tf.get_variable("name_a", shape=[3])
不等效,不能轻易地用于另一个(更新:dtype
不同,感谢@BlueSun注意到这一点)。
如果您在恢复代码中定义张量,就像在保存中一样,您将获得稳定的输出:a = tf.get_variable("name_a", initializer=[1,2,3])
.但是,更好的方法是直接使用保存的图形:
saver = tf.train.import_meta_graph('log1/model.ckpt.meta')
with tf.Session() as sess:
saver.restore(sess, "log1/model.ckpt")
saved = sess.graph.get_tensor_by_name('name_a:0')
print(sess.run(saved))
无论您如何定义初始值设定项,它都可以正常工作。
您必须使用相同的数据类型定义变量a
。如果未指定它并且没有任何初始值设定项,则默认情况下 dtype 将为 tf.float32,并且加载 tf.int32 将失败。只需将数据类型设置为 int32 即可解决问题:
a = tf.get_variable("name_a", shape=[3], dtype=tf.int32)
saver = tf.train.Saver()
with tf.Session() as sess:
saver.restore(sess, "log1/model.ckpt")
print(a.eval())
使用 a = tf.get_variable("name_a", initializer=tf.zeros([3]))
有效tf.zeros([3])
因为它具有与 [2, 3, 4]
相同的 dtype 。在创建变量时始终设置dtype
更安全。