如何获得使用ImageDataGenerator为双输入CNN模型构建的数据集的标签



当validation_set获得一对图像作为输入并使用ImageDataGenerator提供图像批次时,有人能帮我获得标签吗?如下所示:

GEN = ImageDataGenerator(rescale = 1./255)
def two_inputs(generator, X1, X2, batch_size, img_height, img_width):
U = generator.flow_from_directory(X1,
target_size=(img_height, img_width),
batch_size=batch_size,
shuffle= False,
class_mode='binary',
seed=1221)
V = generator.flow_from_directory(X2,
target_size=(img_height, img_width),
batch_size=batch_size,
shuffle= False,
class_mode='binary',
seed=1221)
while True:
X1i = U.next()
X2i = V.next()
yield [X1i[0], X2i[0]], X2i[1]   # Yield both images and their mutual label

在以下场景中,我可以通过preds = base_model.predict_generator(val_flow)获得预测,其中val_flow是:

val_flow = two_inputs(generator= GEN,
X1 = val_05_dirs,
X2 = val_06_dirs,
batch_size = batch_size,
img_height=img_height,
img_width=img_width
)

我需要使用fpr, tpr, _ = metrics.roc_curve(LABELS, preds)获得fprtpr

因此,我试图获得访问two val_05_dirsval_06_dirs文件夹的完整val_flowLABELS

提前感谢

我创建了一个简单的代码示例。您可以调整此示例以适合您的用例
代码:

GEN = tf.keras.preprocessing.image.ImageDataGenerator(rescale = 1./255)
folder_path = r'C:UsersAniket.kerasdatasetsflower_photos'
def two_inputs(generator, X1, X2, batch_size, img_height, img_width):
U = generator.flow_from_directory(X1,
target_size=(img_height, img_width),
batch_size=batch_size,
shuffle= False,
class_mode='binary',
seed=1221)
V = generator.flow_from_directory(X2,
target_size=(img_height, img_width),
batch_size=batch_size,
shuffle= False,
class_mode='binary',
seed=1221)
while True:
X1i = U.next()
X2i = V.next()
yield [X1i[0], X2i[0]], X2i[1]   # Yield both images and their mutual label

custom_gen = two_inputs(GEN, folder_path, folder_path, 1000, 256, 256)

这里,我的flower_photos目录包含5个子目录,子目录的名称作为图像的标签。

输出:

Found 3670 images belonging to 5 classes.

现在对生成器进行迭代
代码:

val_labels = []
for image, labels in custom_gen:
val_labels += list(labels.astype('int32'))
break

注意:循环将无限运行,因为这个生成器可以无限地从您的数据中生成增强图像

如果你不想这样做,让循环只运行:

no_of_times = total_samples / batch_size

请确保您的批量大小可被样本总数整除,否则您将在列表末尾添加重复标签

你得到的标签是整数。如果你想要映射,你可以使用:

mapping = U.class_indices
mapping

输出:

{'daisy': 0, 'dandelion': 1, 'roses': 2, 'sunflowers': 3, 'tulips': 4}

最新更新