使用数据集时,在OutOfRangeError之后重置Tensorflow图



我正在尝试使用数据集API的from_generator接口将多个"轮次"输入注入到图中。

在第一次尝试时,我使用repeat()函数使生成器连续运行3次。然而,对于批处理大小为而不是batch_join调用,它是每轮迭代次数的偶数倍(批处理大小3的10次迭代(,来自不同"轮"/"时期"的数据最终会在同一批中(取决于张量的处理顺序;图中存在一些并行性(。

在我的第二次尝试中,我尝试在每个epoch完成后重新运行迭代器。但是,一旦抛出tf.errors.OutOfRangeError,批处理调用输出上对sess.run()的所有后续调用都会再次抛出OutOfRangeError,即使在重新运行迭代器的初始化器之后也是如此。

我希望将多轮输入连续注入到图中,而不是像第一个示例那样使它们重叠(例如,在批处理选项上使用allow_smaller_final_batch(。我在自定义Tensorflow fork中实例化的一些内核重新启动非常昂贵,例如mmap,它是一个O(10gb(的文件,所以我想以某种方式充分利用这两个世界。

我认为问题源于将tf.contrib.data.Dataset(支持重新初始化(与tf.train.batch_join()(使用TensorFlow队列和队列运行器,因此不支持重新初始化。(一起使用。

我不完全清楚您的代码在做什么,但我认为您可以将整个管道实现为Dataset。替换以下代码片段:

my_iterator = MyIterator(iterations=iterations)
dataset = ds.Dataset.from_generator(my_iterator, 
output_types=my_iterator.output_types, 
output_shapes=my_iterator.output_shapes)
#dataset = dataset.repeat(count=repetitions)
iterator = dataset.make_initializable_iterator()
next_elem = iterator.get_next()
#change constant to 1 or 2 or something to see that the batching is more predictable
ripple_adds = [(tf.stack((next_elem[0], next_elem[1] + constant)),) 
for constant in ripple_add_coefficients]
batch = tf.train.batch_join(ripple_adds, batch_size=batch_size, 
enqueue_many=False, name="sink_queue")

类似以下内容:

my_iterator = MyIterator(iterations=iterations)
dataset = tf.contrib.data.from_generator(my_iterator,
output_types=my_iterator.output_types,
output_shapes=my_iterator.output_shapes)
def ripple_add_map_func(x, y):
return (tf.contrib.data.Dataset.range(num_ripples)
.map(lambda r: tf.stack([x, y + r])))
dataset = dataset.flat_map(ripple_add_map_func).batch(batch_size)
iterator = dataset.make_initializable_iterator()
batch = iterator.get_next()

相关内容

最新更新