如果通过image_dataset_from_directory获得,则验证集只包含来自一个类的图像



我有以下函数来返回训练和验证数据集:

def load_from_directory(path, shuffle=False):
train_ds = tfk.preprocessing.image_dataset_from_directory(
directory=path,
image_size=IMAGE_SIZE,
validation_split=VALIDATION_SPLIT,
batch_size=BATCH_SIZE,
seed=SEED,
subset='training',
label_mode='binary',
shuffle=shuffle
)
val_ds = tfk.preprocessing.image_dataset_from_directory(
directory=path,
image_size=IMAGE_SIZE,
validation_split=VALIDATION_SPLIT,
batch_size=BATCH_SIZE,
seed=SEED,
subset='validation',
label_mode='binary',
shuffle=False
)
return train_ds, val_ds
train_ds, val_ds = load_from_directory(path=TRAINING_PATH, shuffle=True)

问题是,在一些奇怪的结果(第二次epoch后验证精度100%)之后,我分析了验证集的组成,并得出结论,它只包含一个类的图像。

这很奇怪,但我不知道该怎么处理。我正在使用来自微软的猫和狗数据集,其中包含了每个类的大量示例。

将类分布放在图表中,我正在做以下工作:

import plotly.graph_objects as go
labels = np.concatenate([y for _, y in train_ds], axis=0)
_, counts = np.unique(labels, return_counts=True)
fig = go.Figure(
data=[
go.Pie(
labels=CLASS_NAMES, 
values=counts, 
hole=.5, 
marker_colors=['rgb(205, 152, 36)', 'rgb(129, 180, 179)', 'rgb(177, 180, 34)']
)], 
layout_title_text='Train Class Frequency'
)
fig.update_layout(width=400, height=400)
fig.show()
labels = np.concatenate([y for _, y in val_ds], axis=0)
_, counts = np.unique(labels, return_counts=True)
fig = go.Figure(
data=[
go.Pie(
labels=CLASS_NAMES, 
values=counts, 
hole=.5, 
marker_colors=['rgb(205, 152, 36)', 'rgb(129, 180, 179)', 'rgb(177, 180, 34)']
)], 
layout_title_text='Validation Class Frequency'
)
fig.update_layout(width=400, height=400)
fig.show()

更奇怪的是,suffle=True用于创建数据集,数据集中有两个类,但将该标志设置为True是没有意义的。

结果

我运行了你的代码,没有看到问题。我使用了一个包含2个类的数据集。然后用shuffle=True和shuffle=False来运行它。要测试val_ds是否有正确数量的类,请使用

print(val_ds.class_names)

相关内容

  • 没有找到相关文章

最新更新