我使用ImageDataGenerator与flow_from_dataframe加载数据集。
使用flow_from_dataframe
和shuffle=True
会打乱数据集中的图像。
我想洗牌。如果我有12个图像和batch_size=3
,那么我有4批:
batch1 = [image1, image2, image3]
batch2 = [image4, image5, image6]
batch3 = [image7, image8, image9]
batch4 = [image10, image11, image12]
我想洗牌批次而不洗牌每个批次中的图像,因此我得到例如:
batch2 = [image4, image5, image6]
batch1 = [image1, image2, image3]
batch4 = [image10, image11, image12]
batch3 = [image7, image8, image9]
是可能的ImageDataGenerator和flow_from_dataframe?有我可以使用的预处理功能吗?
考虑使用tf.data.Dataset
API。可以在洗牌前执行批处理操作
import tensorflow as tf
file_names = [f'image_{i}' for i in range(1, 10)]
ds = tf.data.Dataset.from_tensor_slices(file_names).batch(3).shuffle(3)
for _ in range(3):
for batch in ds:
print(batch.numpy())
print()
[b'image_4' b'image_5' b'image_6']
[b'image_7' b'image_8' b'image_9']
[b'image_1' b'image_2' b'image_3']
[b'image_1' b'image_2' b'image_3']
[b'image_4' b'image_5' b'image_6']
[b'image_7' b'image_8' b'image_9']
[b'image_1' b'image_2' b'image_3']
[b'image_4' b'image_5' b'image_6']
[b'image_7' b'image_8' b'image_9']
然后,您可以使用映射操作从文件名加载图像:
def read_image(file_name):
image = tf.io.read_file(file_name)
image = tf.image.decode_image(image)
image = tf.image.convert_image_dtype(image, tf.float32)
image = tf.image.resize_with_crop_or_pad(image, target_height=224, target_width=224)
label = tf.strings.split(file_path, os.sep)[0]
label = tf.cast(tf.equal(label, class_categories), tf.int32)
return image, label
ds = ds.map(read_image)