为什么我的train_generator和val_generator生成相同的图片



我留出了一个验证拆分,如下所示:

val_samples = 60
train_imgs = coco_imgs[:-val_samples]
train_masks = coco_masks[:-val_samples]
val_imgs = coco_imgs[-val_samples:]
val_masks = coco_masks[-val_samples:]

我的train_imgsval_imgs显示不同的图片:

fig, ax = plt.subplots(ncols =2, figsize = (10,3), sharex = True, sharey = True)
ax[0].imshow(train_imgs[14])
ax[1].imshow(val_imgs[14])

然后我写数据生成器函数:

class DataGenerator(keras.utils.Sequence):
def __init__(self, input_img, input_mask, image_size, 
augmentation, batch_size):
self.image_size = img_size
self.augmentation = augmentation
self.batch_size = batch_size
self.input_img = train_imgs
self.input_mask = train_masks
def __len__(self):
return len(self.input_img) // self.batch_size
def __getitem__(self, index):
data_index_min = int(index*self.batch_size)
data_index_max = int(min((index+1)*self.batch_size, len(self.input_img)))
indexes = self.input_img[data_index_min:data_index_max]
this_batch_size = len(indexes)

X = np.empty((this_batch_size, self.image_size , self.image_size , 3), dtype=np.float32)
y = np.empty((this_batch_size, self.image_size , self.image_size , self.nb_y_features), dtype=np.uint8)

for i, sample_index in enumerate(indexes):
X_sample = self.input_img[index * self.batch_size + i]
y_sample = self.input_mask[index * self.batch_size + i]
if self.augmentation is True:
aug = transform(image = X_sample, mask = y_sample)
img_aug = aug['image']
mask_aug = aug['mask']
X[i, ...] = img_aug/255
y[i, ...] = mask_aug.reshape(self.image_size , self.image_size , self.nb_y_features).astype(np.uint8)
else:
pass
return X, y

这是我的train_generatorval_generator

train_generator = DataGenerator(input_img = train_imgs, input_mask = train_masks, image_size = img_size,
augmentation=True, batch_size = 5)
val_generator = DataGenerator(input_img = val_imgs, input_mask = val_masks, image_size = img_size,
augmentation=True, batch_size = 5)

它们显示了与train_imgs相同的图片。

for i in range(3):
X_sample_temp, y_sample_temp = train_generator[2]
fig, ax = plt.subplots(ncols=2)
ax[0].imshow(X_sample_temp[4])
ax[1].imshow(y_sample_temp[4,:,:,0])
plt.show()

for i in range(3):
X_sample_temp, y_sample_temp = val_generator[2]
fig, ax = plt.subplots(ncols=2)
ax[0].imshow(X_sample_temp[4])
ax[1].imshow(y_sample_temp[4,:,:,0])
plt.show()

我希望val_generator能产生与val_imgs相同的图片,但我不知道如何修复它。我感谢任何输入。

__init__中,您有硬编码

self.input_img  = train_imgs
self.input_mask = train_masks

因此所有生成器都使用相同的train_imgstrain_masks
,但您应该使用参数input_imginput_mask

self.input_img  = input_img
self.input_mask = input_mask

我不确定它是否不是self.image_size = img_size
中的打字错误,因为它应该是self.image_size = image_size

相关内容

  • 没有找到相关文章

最新更新