张量流:保存和重新存储会话 - 多个变量



给定以下代码:

import tensorflow as tf
with tf.Session() as sess:
var = tf.Variable(42, name='var')
sess.run(tf.global_variables_initializer())
tf.train.export_meta_graph('file.meta')
with tf.Session() as sess:
saver = tf.train.import_meta_graph('file.meta')
print sess.run(var)

我在行saver = tf.train.import_meta_graph('file.meta')ValueError: At least two variables have the same name: var

时出错。我该如何解决这个问题?导入元图时是否有覆盖计算图的方法?

编辑:

我已经到达了以下代码:

import tensorflow as tf
file_name = "./file"
with tf.Session() as sess:
var = tf.Variable(42, name='my_var')
sess.run(tf.global_variables_initializer())
saver = tf.train.Saver()
saver.save(sess,file_name)
saver.export_meta_graph(file_name + '.meta')
with tf.Session() as sess:
saver = tf.train.import_meta_graph(file_name + '.meta')
saver.restore(sess, file_name)
print(sess.run(var))
# new code that fails:
saver = tf.train.Saver()
saver.save(sess,file_name)
saver.export_meta_graph(file_name + '.meta')

这将打印正确的值var,但是当我第二次保存图形时,我得到相同的原始错误:ValueError: At least two variables have the same name: var

在这种情况下,您将在已定义变量的默认图形中加载变量。因此 在导入之前,您需要重置 TensorFlow 图。

在导入之前,请使用tf.reset_default_graph(). 执行此操作。查看导出和导入 Metagraph 下的"在默认图形中导入"部分。

当然,您必须使用tf.get_variable()重新定义变量var.试试这段代码,

import tensorflow as tf
with tf.Session() as sess:
var = tf.Variable(42, name='var')
sess.run(tf.global_variables_initializer())
tf.train.export_meta_graph('file.meta')
tf.reset_default_graph()
with tf.Session() as sess:
saver = tf.train.import_meta_graph('file.meta')
var = tf.global_variables()[0]
sess.run(tf.initialize_all_variables())
print sess.run(var)

中间代码不起作用的原因是tf.get_variable()正在创建一个随机初始化的新变量。确保先做tf.get_variable_scope().reuse_variables()。 看看 了解tf.get_variable().

不幸的是,使用tf.Variable()创建的变量不能直接与tf.get_variable()重用。看看这个评论和这个评论,确切地知道为什么。因此,如果您希望将来重用变量,则需要使用tf.get_variable()来创建变量。

最新更新