Tensorflow Dataset API:使用parallel_interleave并行化tf.data.Datas



在生产环境中,我有来自N个生产者的数据,这些数据必须通过网络。我发现了这个关于并行化的评论tf.data.Dataset.from_generator它真正描述了我想要的东西。

def generator(n):
# returns n-th generator function
def dataset(n):
return tf.data.Dataset.from_generator(generator(n))
ds = tf.data.Dataset.range(N).apply(tf.contrib.data.parallel_interleave(dataset, cycle_lenght=N))
# where N is the number of generators you use

但是生成器(n(函数应该是什么样子的。因为当我运行此示例时

def generator(n):
"""Returns the n-th generator function (for consumer n)
"""
consumer = self.consumers[n]
def gen():
for item in consumer:
yield item
return gen

使用self.consumer一个Python列表,那么我将得到错误:

类型错误:列表索引必须是整数或切片,而不是张量

实现几乎是正确的,但你得到一个错误,因为dataset(n)中的n参数是一个"符号"tf.Tensor,而不是可用于查找self.consumers消费者的实际值。

幸运的是,有一个解决方法,它涉及通过可选的args参数传递ntf.data.Dataset.from_generator()

def dataset(n):
return tf.data.Dataset.from_generator(generator, args=(n,))

在幕后,from_generator()插入一些代码,以便在每次调用generator之前将n转换为 Python 整数。

相关内容

  • 没有找到相关文章

最新更新