我正在尝试使用tf.data.Dataset.from_generator()
来生成训练和验证数据。
我有自己的数据生成器,它在飞行中进行功能准备:
def data_iterator(self, input_file_list, ...):
for f in input_file_list:
X, y = get_feature(f)
yield X, y
最初,我将其直接输入tensorflow keras模型,但在第一批之后,我遇到了数据超出范围的错误。然后我决定将其封装在tensorflow数据生成器中:
train_gen = lambda: data_iterator(train_files, ...)
valid_gen = lambda: data_iterator(valid_files, ...)
output_types = (tf.float32, tf.float32)
output_shapes = (tf.TensorShape([499, 13]), tf.TensorShape([2]))
train_dat = tf.data.Dataset.from_generator(train_gen,
output_types=output_types,
output_shapes=output_shapes)
valid_dat = tf.data.Dataset.from_generator(valid_gen,
output_types=output_types,
output_shapes=output_shapes)
train_dat = train_dat.repeat().batch(batch_size=128)
valid_dat = valid_dat.repeat().batch(batch_size=128)
然后适合:
model.fit(x=train_dat,
validation_data=valid_dat,
steps_per_epoch=train_steps,
validation_steps=valid_steps,
epochs=100,
callbacks=callbacks)
然而,尽管生成器中有.repeat()
,我仍然会收到错误:
BaseCollectiveExecutor::StartAbort超出范围:序列结束
我的问题是:
- 为什么
.repeat()
不在这里工作 - 我应该在自己的迭代器中添加一个
while True
来避免这种情况吗?我觉得这可以解决问题,但看起来不是正确的方法
我在自己的生成器中添加了while True,这样它就永远不会用完,也不会再出现错误:
def data_iterator(self, input_file_list, ...):
while True;
for f in input_file_list:
X, y = get_feature(f)
yield X, y
然而,我不知道为什么.repeat()
不适用于.from_generator()