从 python 迭代器填充队列



我想创建一个从迭代器填充的队列。但是,在以下 MWE 中,始终使用相同的值排队:

import tensorflow as tf
import numpy as np
# data
imgs = [np.random.randn(i,i) for i in [2,3,4,5]]
# iterate through data infinitly
def data_iterator():
    while True:
        for img in imgs:
            yield img
it = data_iterator()
# create queue for data
q = tf.FIFOQueue(capacity=5, dtypes=[tf.float64])
# feed next element from iterator
enqueue_op = q.enqueue(list(next(it)))
# setup queue runner
numberOfThreads = 1 
qr = tf.train.QueueRunner(q, [enqueue_op] * numberOfThreads)
tf.train.add_queue_runner(qr) 
# dequeue
dequeue_op  = q.dequeue() 
dequeue_op = tf.Print(dequeue_op, data=[dequeue_op], message="dequeue()")
# We start the session as usual ...
with tf.Session() as sess:
    coord = tf.train.Coordinator()
    threads = tf.train.start_queue_runners(coord=coord)
    for i in range(10):
        data = sess.run(dequeue_op)
        print(data)
.
    coord.request_stop()
    coord.join(threads)

我必须使用feed_dict吗?如果是,我该如何将其与 QueueRunner 结合使用?

运行时

enqueue_op = q.enqueue(list(next(it)))

TensorFlow将只执行List(next(it((一次。此后,它将保存第一个列表,并在每次运行enqueue_op时将其添加到q中。若要避免这种情况,必须使用占位符。喂食占位符与tf.train.QueueRunner不兼容。而是使用这个:

import tensorflow as tf
import numpy as np
import threading
# data
imgs = [np.random.randn(i,i) for i in [2,3,4,5]]
# iterate through data infinitly
def data_iterator():
    while True:
        for img in imgs:
            yield img
it = data_iterator()
# create queue for data
q = tf.FIFOQueue(capacity=5, dtypes=[tf.float64])
# feed next element from iterator
img_p = tf.placeholder(tf.float64, [None, None])
enqueue_op = q.enqueue(img_p)
dequeue_op  = q.dequeue()

with tf.Session() as sess:
    coord = tf.train.Coordinator()
    def enqueue_thread():
        with coord.stop_on_exception():
            while not coord.should_stop():
                sess.run(enqueue_op, feed_dict={img_p: list(next(it))})
    numberOfThreads = 1
    for i in range(numberOfThreads):
      threading.Thread(target=enqueue_thread).start()

    for i in range(3):
        data = sess.run(dequeue_op)
        print(data)

最新更新