我正在尝试实现批处理硬三元组丢失,如的第3.2节所示https://arxiv.org/pdf/2004.06271.pdf.
我需要导入我的图像,以便每个批次都有特定批次中每个ID的K个实例。因此,每个批次必须是K的倍数。
我有一个图像目录太大,无法放入内存,因此我使用ImageDataGenerator.flow_from_directory()
导入图像,但我看不到该函数的任何参数来实现我需要的功能。
如何使用Keras实现这种批处理行为
您可以尝试以可控的方式将多个数据流合并在一起。
假设你有K个tf.data.Dataset
实例(无论你如何实例化它们(,它们负责提供特定ID的训练实例,你可以将它们连接起来,以便在一个小批量中均匀分布:
ds1 = ... # Training instances with ID == 1
ds2 = ... # Training instances with ID == 2
...
dsK = ... # Training instances with ID == K
train_dataset = tf.data.Dataset.zip((ds1, ds2, ..., dsK)).flat_map(concat_datasets).batch(batch_size=N * K)
其中concat_datasets
是合并函数:
def concat_datasets(*datasets):
ds = tf.data.Dataset.from_tensors(datasets[0])
for i in range(1, len(datasets)):
ds = ds.concatenate(tf.data.Dataset.from_tensors(datasets[i]))
return ds
从Tensorflow 2.4开始,我看不到用ImageDataGenerator
实现这一点的标准方法。
因此,我认为您需要基于tensorflow.keras.utils.Sequence
类编写自己的内容,这样您就可以自由地自己定义批处理内容。
参考文献:
https://www.tensorflow.org/api_docs/python/tf/keras/utils/Sequence
https://stanford.edu/~shervine/blog/keras如何生成飞行中的数据