在 TensorFlow 中,为什么tf.train.shuffle_batch永远坚持下去并且不会返回批处理?



我是 tensorflow 的新手,目前正在尝试从我的 csv 格式的数据中生成批处理。

我遵循了张量流(https://www.tensorflow.org/programmers_guide/reading_data(的读取数据教程,但我一定误解了什么,因为我的代码永远存在。 我使用了教程中的read_my_file_format函数,它起作用了。现在我想使用批处理来真正训练我的网络,如下所示:

def input_pipeline(filenames, batch_size, num_epochs=None):
filename_queue = tf.train.string_input_producer(
filenames, num_epochs=num_epochs, shuffle=True)
example, label = read_my_file_format(filename_queue)
print('read_my_file is done')
min_after_dequeue = 10
capacity = min_after_dequeue + 3 * batch_size
example_batch, label_batch = tf.train.shuffle_batch(
[example, label], batch_size=batch_size, capacity=capacity,
min_after_dequeue=min_after_dequeue)
print('all done but the return')
return example_batch, label_batch
with tf.Session() as sess:
batch_size=5
# Start populating the filename queue.
coord = tf.train.Coordinator()
threads = tf.train.start_queue_runners(coord=coord)
batch_data,batch_label=sess.run(input_pipeline(file_name,batch_size=batch_size))
print('return is done')
print(batch_data,batch_label)
coord.request_stop()
coord.join(threads)

为了调试,在上面的代码中,我只是尝试打印生成的批处理,而不是将其馈送到网络中。通过我的印刷品,我能够看到它挂在哪里:就在 返回example_batch,label_batch。

我的神经网络已经准备好了,我的数据已经被处理了,所以这是阻碍我推进项目(超新星分类(的唯一因素。您有什么建议或建议吗?我已经坚持了一段时间了。

另外,如果需要,我的文件名中只有一个输入文件。

谢谢

您需要初始化变量。

with tf.Session() as sess:
...
coord = tf.train.Coordinator()
threads = tf.train.start_queue_runners(sess=sess, coord=coord)
sess.run(tf.global_variables_initializer())
...

最新更新