generator()的tensorflow数据集超出范围错误



我正在尝试使用tf.data.Dataset.from_generator()来生成训练和验证数据。

我有自己的数据生成器,它在飞行中进行功能准备:

def data_iterator(self, input_file_list, ...):
for f in input_file_list:
X, y = get_feature(f)
yield X, y

最初,我将其直接输入tensorflow keras模型,但在第一批之后,我遇到了数据超出范围的错误。然后我决定将其封装在tensorflow数据生成器中:

train_gen = lambda: data_iterator(train_files, ...)
valid_gen = lambda: data_iterator(valid_files, ...)
output_types = (tf.float32, tf.float32)
output_shapes = (tf.TensorShape([499, 13]), tf.TensorShape([2]))
train_dat = tf.data.Dataset.from_generator(train_gen,
output_types=output_types,
output_shapes=output_shapes)
valid_dat = tf.data.Dataset.from_generator(valid_gen,
output_types=output_types,
output_shapes=output_shapes)
train_dat = train_dat.repeat().batch(batch_size=128)
valid_dat = valid_dat.repeat().batch(batch_size=128)

然后适合:

model.fit(x=train_dat,
validation_data=valid_dat,
steps_per_epoch=train_steps,
validation_steps=valid_steps,
epochs=100,
callbacks=callbacks)

然而,尽管生成器中有.repeat(),我仍然会收到错误:

BaseCollectiveExecutor::StartAbort超出范围:序列结束

我的问题是:

  • 为什么.repeat()不在这里工作
  • 我应该在自己的迭代器中添加一个while True来避免这种情况吗?我觉得这可以解决问题,但看起来不是正确的方法

我在自己的生成器中添加了while True,这样它就永远不会用完,也不会再出现错误:

def data_iterator(self, input_file_list, ...):
while True;
for f in input_file_list:
X, y = get_feature(f)
yield X, y

然而,我不知道为什么.repeat()不适用于.from_generator()

最新更新