我正在用keras训练神经网络,并希望通过多处理加快预处理/数据扩充速度。原则上,对于fit_generator
中的workers=N
和use_multiprocessing=True
,这似乎很简单,但在我的情况下,要避免从并行生成器中获得类似的数据是很困难的。
我的数据在几个文件中,每个文件都有几百万条记录(直到你到达一个文件的末尾才知道总数(。对于每个文件,生成器逐个记录,通过一些数据扩充将记录处理为网络的正确输入/输出格式。没有唯一的ID,尽管我想我可以在飞行中创建一个。
我想知道让多个生成器并行处理单独的文件列表是否最简单。我实际上并没有在一个批处理中使用所有数据,所以如果一个生成器在其他生成器之前在其文件列表的开头重新启动,那就无关紧要了。如果在生成器中,我可以访问类似工人编号(1到N(的内容,那么这将很容易完成。
我不知道如何实现您的建议。一个更高级的解决方案是实例化一个tf.data.TextLineDataset
,它可以处理多个文本文件。为了用它训练Keras模型,您必须将iterator
的输出与模型的Input
张量联系起来。大致如下:
import tensorflow as tf
# Parsing, augmentation etc
def __parse_record(record):
...
return parsed_record
# Construct a TextLineDataset
ds = tf.data.TextLineDataset(filenames).map(_parse_record)
ds.shuffle().batch(batch_size) # Shuffle and batch
# Turn into an iterator
iterator = tf.data.Iterator.from_structure(ds.output_types, ds.output_shapes)
ds_init = iterator.make_initializer(ds)
# The iterator will yield inputs and labels
x,y = iterator.get_next()
# Tie output of iterator into Input of keras model via the tensor argument
model_input = Input(tensor=x)
# ... model definition
# Upon compiling the model specify target tensors
model.compile(loss, optimizer, target_tensors=[y])
# Now you can use model.fit() instead of fit_generator()
with K.get_session() as sess:
sess.run(ds_init)
model.fit(epochs, steps_per_epoch)
这应该训练得很快,然而,它带来了一些缺点。根据相关Keras示例:
输入张量也有重要的缺点。在里面特别是,输入张量在模型构造时是固定的因为还不支持重新布线网络。因此,更改数据输入源意味着必须保存模型权重并重建模型从头开始连接新的输入数据。验证当前不能作为培训执行进度,并且必须在培训完成后执行。