在生产环境中,我有来自N个生产者的数据,这些数据必须通过网络。我发现了这个关于并行化的评论tf.data.Dataset.from_generator它真正描述了我想要的东西。
def generator(n):
# returns n-th generator function
def dataset(n):
return tf.data.Dataset.from_generator(generator(n))
ds = tf.data.Dataset.range(N).apply(tf.contrib.data.parallel_interleave(dataset, cycle_lenght=N))
# where N is the number of generators you use
但是生成器(n(函数应该是什么样子的。因为当我运行此示例时
def generator(n):
"""Returns the n-th generator function (for consumer n)
"""
consumer = self.consumers[n]
def gen():
for item in consumer:
yield item
return gen
使用self.consumer一个Python列表,那么我将得到错误:
类型错误:列表索引必须是整数或切片,而不是张量
实现几乎是正确的,但你得到一个错误,因为dataset(n)
中的n
参数是一个"符号"tf.Tensor
,而不是可用于查找self.consumers
消费者的实际值。
幸运的是,有一个解决方法,它涉及通过可选的args
参数传递n
tf.data.Dataset.from_generator()
:
def dataset(n):
return tf.data.Dataset.from_generator(generator, args=(n,))
在幕后,from_generator()
插入一些代码,以便在每次调用generator
之前将n
转换为 Python 整数。