我正在使用tf.contrib.learn.ReadBatchFeatures
(https://www.tensorflow.org/versions/master/api_docs/python/contrib.learn/input_processing#read_batch_features)来读取Example
原型,作为我的输入函数的一部分,该函数返回Tensor
对象的字典。训练模型后,对我的Estimator
调用 predict
会以数组的形式返回一批预测,我想将其与已知值进行比较。
我尝试通过调用 tf.Session().run(labels)
来获取已知值,其中 labels
是从输入函数返回的已知值的Tensor
。但是,此时,我的程序挂起。我怀疑它陷入了从磁盘读取标签的无限循环中,而不是像我想要的那样只读取一批。
这是获取labels
Tensor
中一批值的正确方法吗?
编辑:我尝试启动队列运行器,以下内容是否正确?
_, labels = eval_input_fn()
with tf.Session().as_default():
tf.local_variables_initializer()
tf.train.start_queue_runners()
label_values = labels.eval()
print(label_values)
您需要的整个设置是:
_, labels = eval_input_fn()
with tf.Session() as sess:
sess.run([
tf.local_variables_initializer(),
tf.global_variables_initializer()
])
coord = tf.train.Coordinator()
threads = tf.train.start_queue_runners(sess=sess, coord=coord)
try:
while not coord.should_stop():
print(sess.run(label))
except tf.errors.OutOfRangeError as error:
coord.request_stop(error)
finally:
coord.request_stop()
coord.join(threads)