我留出了一个验证拆分,如下所示:
val_samples = 60
train_imgs = coco_imgs[:-val_samples]
train_masks = coco_masks[:-val_samples]
val_imgs = coco_imgs[-val_samples:]
val_masks = coco_masks[-val_samples:]
我的train_imgs
和val_imgs
显示不同的图片:
fig, ax = plt.subplots(ncols =2, figsize = (10,3), sharex = True, sharey = True)
ax[0].imshow(train_imgs[14])
ax[1].imshow(val_imgs[14])
然后我写数据生成器函数:
class DataGenerator(keras.utils.Sequence):
def __init__(self, input_img, input_mask, image_size,
augmentation, batch_size):
self.image_size = img_size
self.augmentation = augmentation
self.batch_size = batch_size
self.input_img = train_imgs
self.input_mask = train_masks
def __len__(self):
return len(self.input_img) // self.batch_size
def __getitem__(self, index):
data_index_min = int(index*self.batch_size)
data_index_max = int(min((index+1)*self.batch_size, len(self.input_img)))
indexes = self.input_img[data_index_min:data_index_max]
this_batch_size = len(indexes)
X = np.empty((this_batch_size, self.image_size , self.image_size , 3), dtype=np.float32)
y = np.empty((this_batch_size, self.image_size , self.image_size , self.nb_y_features), dtype=np.uint8)
for i, sample_index in enumerate(indexes):
X_sample = self.input_img[index * self.batch_size + i]
y_sample = self.input_mask[index * self.batch_size + i]
if self.augmentation is True:
aug = transform(image = X_sample, mask = y_sample)
img_aug = aug['image']
mask_aug = aug['mask']
X[i, ...] = img_aug/255
y[i, ...] = mask_aug.reshape(self.image_size , self.image_size , self.nb_y_features).astype(np.uint8)
else:
pass
return X, y
这是我的train_generator
和val_generator
。
train_generator = DataGenerator(input_img = train_imgs, input_mask = train_masks, image_size = img_size,
augmentation=True, batch_size = 5)
val_generator = DataGenerator(input_img = val_imgs, input_mask = val_masks, image_size = img_size,
augmentation=True, batch_size = 5)
它们显示了与train_imgs
相同的图片。
for i in range(3):
X_sample_temp, y_sample_temp = train_generator[2]
fig, ax = plt.subplots(ncols=2)
ax[0].imshow(X_sample_temp[4])
ax[1].imshow(y_sample_temp[4,:,:,0])
plt.show()
和
for i in range(3):
X_sample_temp, y_sample_temp = val_generator[2]
fig, ax = plt.subplots(ncols=2)
ax[0].imshow(X_sample_temp[4])
ax[1].imshow(y_sample_temp[4,:,:,0])
plt.show()
我希望val_generator
能产生与val_imgs
相同的图片,但我不知道如何修复它。我感谢任何输入。
在__init__
中,您有硬编码
self.input_img = train_imgs
self.input_mask = train_masks
因此所有生成器都使用相同的train_imgs
train_masks
,但您应该使用参数input_img
和input_mask
self.input_img = input_img
self.input_mask = input_mask
我不确定它是否不是self.image_size = img_size
中的打字错误,因为它应该是self.image_size = image_size