为什么TensorFlow无法恢复由常量初始化的变量



我是TensorFlow的新手。当我阅读张量流保存和恢复变量手册时,我遇到了一个问题。我保存了一个由常量初始化的变量,但无法恢复该变量。代码如下:

a = tf.get_variable("name_a", initializer=[1,2,3])
op1 = a.assign(a+1)
saver = tf.train.Saver()
init = tf.global_variables_initializer()
with tf.Session() as sess:
    sess.run(init)
    op1.op.run()
    print(a.eval())
    saver.save(sess,"log1/model.ckpt")

然后我恢复它。

a = tf.get_variable("name_a", shape=[3])
saver = tf.train.Saver()
with tf.Session() as sess:
    saver.restore(sess, "log1/model.ckpt")
    print(a.eval())

我想像[2,3,4]一样获得输出,但我得到了[ 2.80259693e-45 4.20389539e-45 5.60519386e-45].都是零。

但是,当我将第一个代码片段中的第一行修改为

a = tf.get_variable("name_a", initializer=tf.zeros([3]))

我可以得到正确的恢复变量:[ 1. 1. 1.]

我想知道这种情况的原因。

我不是 100% 确定,但看起来原因是您的两个变量:

  • tf.get_variable("name_a", initializer=[1,2,3])

  • tf.get_variable("name_a", shape=[3])

等效,不能轻易地用于另一个(更新dtype不同,感谢@BlueSun注意到这一点)。

如果您在恢复代码中定义张量,就像在保存中一样,您将获得稳定的输出:a = tf.get_variable("name_a", initializer=[1,2,3]) .但是,更好的方法是直接使用保存的图形:

saver = tf.train.import_meta_graph('log1/model.ckpt.meta')
with tf.Session() as sess:
  saver.restore(sess, "log1/model.ckpt")
  saved = sess.graph.get_tensor_by_name('name_a:0')
  print(sess.run(saved))

无论您如何定义初始值设定项,它都可以正常工作。

您必须使用相同的数据类型定义变量a。如果未指定它并且没有任何初始值设定项,则默认情况下 dtype 将为 tf.float32,并且加载 tf.int32 将失败。只需将数据类型设置为 int32 即可解决问题:

a = tf.get_variable("name_a", shape=[3], dtype=tf.int32)
saver = tf.train.Saver()
with tf.Session() as sess:
    saver.restore(sess, "log1/model.ckpt")
    print(a.eval())

使用 a = tf.get_variable("name_a", initializer=tf.zeros([3])) 有效tf.zeros([3])因为它具有与 [2, 3, 4] 相同的 dtype 。在创建变量时始终设置dtype更安全。

相关内容

  • 没有找到相关文章

最新更新