image_gen = ImageDataGenerator(rescale=1.0/255, horizontal_flip=True)
train_data_gen = image_gen.flow_from_directory(batch_size=BATCH_SIZE, directory=train_dir, target_size=(IMG_SHAPE,IMG_SHAPE), shuffle=True)
augmented_images = [train_data_gen[0][0][0] for i in range(5)]
所以我最近开始研究Tensorflow,并试图创建我自己的CNN,同时研究如何使用ImageDataGenerator,我遇到了这个代码。我想知道。flow_from_directory返回它似乎是图像的可迭代对象。然而,让我困惑的是为什么train_data_gen有三个维度,这些维度意味着什么。
train_data_gen就是一个生成器。为了产生输出,需要执行代码
images, labels=next(train_data_gen)
结果是batch_size数量的图像及其相关标签。图片将有形状(batch_size, IMG_SHAPE, IMG_SHAPE, channels)和标签是形状(batch_size,1)。