当TensorFlow中有一个图形对象时,为什么保存程序在脚本中的位置很重要



我正在训练一些模型,我注意到当我显式定义一个图变量时,在哪里创建我的保护程序对象很重要。首先,我的代码看起来是这样的:

import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_data
mnist = input_data.read_data_sets("tmp_MNIST_data/", one_hot=True)
x = tf.placeholder(tf.float32, [None, 784])
W = tf.Variable(tf.truncated_normal([784, 10], mean=0.0, stddev=0.1),name='w')
b = tf.Variable(tf.constant(0.1, shape=[10]),name='b')
y = tf.nn.softmax(tf.matmul(x, W) + b)
y_ = tf.placeholder(tf.float32, [None, 10])
cross_entropy = tf.reduce_mean(-tf.reduce_sum(y_ * tf.log(y), reduction_indices=[1]))
correct_prediction = tf.equal(tf.argmax(y,1), tf.argmax(y_,1)) # list of booleans indicating correct predictions
accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))
train_step = tf.train.GradientDescentOptimizer(0.2).minimize(cross_entropy)
saver = tf.train.Saver()
with tf.Session() as sess:
sess.run(tf.global_variables_initializer())
for i in range(1001):
batch_xs, batch_ys = mnist.train.next_batch(100)
sess.run(fetches=train_step, feed_dict={x: batch_xs, y_: batch_ys})
if i % 100 == 0:
saver.save(sess=sess,save_path='./tmp/mdl_ckpt')
print(sess.run(fetches=accuracy, feed_dict={x: mnist.test.images, y_: mnist.test.labels}))

然后我决定把它改成这样,我定义变量的地方和定义saver的地方似乎很敏感。例如,如果它们不是在图形变量创建之后才定义的,则会出现错误。类似地,我注意到,saver必须正好在一个变量之后定义(注意,在图的定义之后是还不够),以便所有变量都能被saver一起捕获(这对我来说没有意义,要求它在的定义后面是所有变量,而不是一个变量)。

这就是代码现在的样子(注释显示了我定义saver的位置):

import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_data
mnist = input_data.read_data_sets("tmp_MNIST_data/", one_hot=True)
graph = tf.Graph()
with tf.Session(graph=graph) as sess:
#saver = tf.train.Saver()
x = tf.placeholder(tf.float32, [None, 784])
saver = tf.train.Saver()
y_ = tf.placeholder(tf.float32, [None, 10])
#saver = tf.train.Saver()
W = tf.Variable(tf.truncated_normal([784, 10], mean=0.0, stddev=0.1),name='w')
#saver = tf.train.Saver()
b = tf.Variable(tf.constant(0.1, shape=[10]),name='b')
y = tf.nn.softmax(tf.matmul(x, W) + b)
#saver = tf.train.Saver()
cross_entropy = tf.reduce_mean(-tf.reduce_sum(y_ * tf.log(y), reduction_indices=[1]))
correct_prediction = tf.equal(tf.argmax(y,1), tf.argmax(y_,1)) # list of booleans indicating correct predictions
accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))
train_step = tf.train.GradientDescentOptimizer(0.5).minimize(cross_entropy)
#saver = tf.train.Saver()
step = tf.Variable(0, name='step')
#saver = tf.train.Saver()
sess.run(tf.global_variables_initializer())
#saver = tf.train.Saver()
for i in range(1001):
batch_xs, batch_ys = mnist.train.next_batch(100)
sess.run(fetches=train_step, feed_dict={x: batch_xs, y_: batch_ys})
if i % 100 == 0:
step_assign = step.assign(i)
sess.run(step_assign)
saver.save(sess=sess,save_path='./tmp/mdl_ckpt')
print(step.eval())
print( [ op.name for op in tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES)] )
print(sess.run(fetches=accuracy, feed_dict={x: mnist.test.images, y_: mnist.test.labels}))

上面的代码应该可以工作,但我很难理解为什么它会这样,或者为什么会发生这种情况。有人知道什么是正确的做法吗?

我不完全确定这里发生了什么,但我怀疑这个问题与变量没有进入错误的图有关,或者会话中的图版本过时。您创建了一个图,但没有将其设置为默认值,然后使用该图创建一个会话。。。但是,当您创建变量时,您没有指定它们应该进入哪个图。也许会话的创建将指定的图设置为默认值,但tensorflow并不是设计使用的方式,所以如果它没有在这个机制中进行彻底测试,我也不会感到惊讶。

虽然我没有解释或发生了什么,但我可以提出一个简单的解决方案:在运行会话的情况下单独构建图。

graph = tf.Graph()
with graph.as_default():
build_graph()
saver = tf.train.Saver()
with tf.Session(graph=graph) as sess:
do_stuff_with(sess)
saver.save(sess, path)

最新更新