Keras:在图像增强过程中失去轴并带有brightness_range



我正在训练一个基于 U-net 的分割网络,并使用 keras 的ImageDataGenerator对我的灰度图像进行内联增强。除非我在论点中包含brightness_range,否则一切都按预期进行。当这种情况发生时,我的 512,512,1 图像似乎变成了 512,512 图像并搞砸了事情。我该如何解决这个问题?

这是我的扩充代码:

data_gen_args = dict(
rotation_range=15,
shear_range=45,
width_shift_range=0.1,
height_shift_range=0.1,
zoom_range=[0.5,1.5],
#horizontal_flip=True,
#vertical_flip=True,
brightness_range=[0.5,1.5],
fill_mode='nearest'
)
image_datagen_train = ImageDataGenerator(**data_gen_args)

train_image_generator = image_datagen_train.flow_from_directory(
train_ct,
target_size = (512, 512),
color_mode = ("grayscale"),
classes=None,
class_mode=None,
seed = seed,
batch_size = BS)
train_mask_generator = mask_datagen_train.flow_from_directory(
train_mask,
target_size = (512, 512),
color_mode = ("grayscale"),
classes=None,
class_mode=None,
seed = seed,
batch_size = BS)

这是我的错误消息:

ValueError                                Traceback (most recent call last)
<ipython-input-33-8d701f27a3fa> in <module>
7                     verbose=1,
8                     callbacks=cb_check,
----> 9                     use_multiprocessing = False
10                              )
/usr/local/lib/python3.6/dist-packages/keras/legacy/interfaces.py in wrapper(*args, **kwargs)
89                 warnings.warn('Update your `' + object_name + '` call to the ' +
90                               'Keras 2 API: ' + signature, stacklevel=2)
---> 91             return func(*args, **kwargs)
92         wrapper._original_function = func
93         return wrapper
/usr/local/lib/python3.6/dist-packages/keras/engine/training.py in fit_generator(self, generator, steps_per_epoch, epochs, verbose, callbacks, validation_data, validation_steps, validation_freq, class_weight, max_queue_size, workers, use_multiprocessing, shuffle, initial_epoch)
1730             use_multiprocessing=use_multiprocessing,
1731             shuffle=shuffle,
-> 1732             initial_epoch=initial_epoch)
1733 
1734     @interfaces.legacy_generator_methods_support
/usr/local/lib/python3.6/dist-packages/keras/engine/training_generator.py in fit_generator(model, generator, steps_per_epoch, epochs, verbose, callbacks, validation_data, validation_steps, validation_freq, class_weight, max_queue_size, workers, use_multiprocessing, shuffle, initial_epoch)
183             batch_index = 0
184             while steps_done < steps_per_epoch:
--> 185                 generator_output = next(output_generator)
186 
187                 if not hasattr(generator_output, '__len__'):
/usr/local/lib/python3.6/dist-packages/keras/utils/data_utils.py in get(self)
740                     "`use_multiprocessing=False, workers > 1`."
741                     "For more information see issue #1638.")
--> 742             six.reraise(*sys.exc_info())
/usr/local/lib/python3.6/dist-packages/six.py in reraise(tp, value, tb)
691             if value.__traceback__ is not tb:
692                 raise value.with_traceback(tb)
--> 693             raise value
694         finally:
695             value = None
/usr/local/lib/python3.6/dist-packages/keras/utils/data_utils.py in get(self)
709                 try:
710                     future = self.queue.get(block=True)
--> 711                     inputs = future.get(timeout=30)
712                     self.queue.task_done()
713                 except mp.TimeoutError:
/usr/lib/python3.6/multiprocessing/pool.py in get(self, timeout)
642             return self._value
643         else:
--> 644             raise self._value
645 
646     def _set(self, i, obj):
/usr/lib/python3.6/multiprocessing/pool.py in worker(inqueue, outqueue, initializer, initargs, maxtasks, wrap_exception)
117         job, i, func, args, kwds = task
118         try:
--> 119             result = (True, func(*args, **kwds))
120         except Exception as e:
121             if wrap_exception and func is not _helper_reraises_exception:
/usr/local/lib/python3.6/dist-packages/keras/utils/data_utils.py in next_sample(uid)
648         The next value of generator `uid`.
649     """
--> 650     return six.next(_SHARED_SEQUENCES[uid])
651 
652 
/usr/local/lib/python3.6/dist-packages/keras_preprocessing/image/iterator.py in __next__(self, *args, **kwargs)
102 
103     def __next__(self, *args, **kwargs):
--> 104         return self.next(*args, **kwargs)
105 
106     def next(self):
/usr/local/lib/python3.6/dist-packages/keras_preprocessing/image/iterator.py in next(self)
114         # The transformation of images is not under thread lock
115         # so it can be done in parallel
--> 116         return self._get_batches_of_transformed_samples(index_array)
117 
118     def _get_batches_of_transformed_samples(self, index_array):
/usr/local/lib/python3.6/dist-packages/keras_preprocessing/image/iterator.py in _get_batches_of_transformed_samples(self, index_array)
244                 x = self.image_data_generator.apply_transform(x, params)
245                 x = self.image_data_generator.standardize(x)
--> 246             batch_x[i] = x
247         # optionally save augmented images to disk for debugging purposes
248         if self.save_to_dir:
ValueError: could not broadcast input array from shape (512,512) into shape (512,512,1)

这是我在ImageDataGenerator中使用brightness_range的模型,它没有引起任何问题。该模型工作得很好。

import tensorflow as tf
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Dense, Conv2D, Flatten, Dropout, MaxPooling2D
from tensorflow.keras.preprocessing.image import ImageDataGenerator
from tensorflow.keras.optimizers import Adam
import os
import numpy as np
import matplotlib.pyplot as plt
_URL = 'https://storage.googleapis.com/mledu-datasets/cats_and_dogs_filtered.zip'
path_to_zip = tf.keras.utils.get_file('cats_and_dogs.zip', origin=_URL, extract=True)
PATH = os.path.join(os.path.dirname(path_to_zip), 'cats_and_dogs_filtered')
train_dir = os.path.join(PATH, 'train')
validation_dir = os.path.join(PATH, 'validation')
train_cats_dir = os.path.join(train_dir, 'cats')  # directory with our training cat pictures
train_dogs_dir = os.path.join(train_dir, 'dogs')  # directory with our training dog pictures
validation_cats_dir = os.path.join(validation_dir, 'cats')  # directory with our validation cat pictures
validation_dogs_dir = os.path.join(validation_dir, 'dogs')  # directory with our validation dog pictures
num_cats_tr = len(os.listdir(train_cats_dir))
num_dogs_tr = len(os.listdir(train_dogs_dir))
num_cats_val = len(os.listdir(validation_cats_dir))
num_dogs_val = len(os.listdir(validation_dogs_dir))
total_train = num_cats_tr + num_dogs_tr
total_val = num_cats_val + num_dogs_val
batch_size = 128
epochs = 15
IMG_HEIGHT = 150
IMG_WIDTH = 150
train_image_generator = ImageDataGenerator(rescale=1./255,brightness_range=[0.5,1.5]) # Generator for our training data
validation_image_generator = ImageDataGenerator(rescale=1./255,brightness_range=[0.5,1.5]) # Generator for our validation data
train_data_gen = train_image_generator.flow_from_directory(batch_size=batch_size,
directory=train_dir,
shuffle=True,
target_size=(IMG_HEIGHT, IMG_WIDTH),
class_mode='binary')
val_data_gen = validation_image_generator.flow_from_directory(batch_size=batch_size,
directory=validation_dir,
target_size=(IMG_HEIGHT, IMG_WIDTH),
class_mode='binary')
model = Sequential([
Conv2D(16, 3, padding='same', activation='relu', input_shape=(IMG_HEIGHT, IMG_WIDTH ,3)),
MaxPooling2D(),
Conv2D(32, 3, padding='same', activation='relu'),
MaxPooling2D(),
Conv2D(64, 3, padding='same', activation='relu'),
MaxPooling2D(),
Flatten(),
Dense(512, activation='relu'),
Dense(1)
])
model.compile(optimizer="adam", 
loss=tf.keras.losses.BinaryCrossentropy(from_logits=True),
metrics=['accuracy'])
history = model.fit_generator(
train_data_gen,
steps_per_epoch=total_train // batch_size,
epochs=epochs,
validation_data=val_data_gen,
validation_steps=total_val // batch_size)

输出-

Found 2000 images belonging to 2 classes.
Found 1000 images belonging to 2 classes.
WARNING:tensorflow:From <ipython-input-2-2b8537e7d5b3>:74: Model.fit_generator (from tensorflow.python.keras.engine.training) is deprecated and will be removed in a future version.
Instructions for updating:
Please use Model.fit, which supports generators.
Epoch 1/15
15/15 [==============================] - 9s 591ms/step - loss: 1.0527 - accuracy: 0.5010 - val_loss: 0.6918 - val_accuracy: 0.5089
Epoch 2/15
15/15 [==============================] - 9s 609ms/step - loss: 0.6790 - accuracy: 0.5337 - val_loss: 0.6473 - val_accuracy: 0.5647
Epoch 3/15
15/15 [==============================] - 9s 610ms/step - loss: 0.6340 - accuracy: 0.5983 - val_loss: 0.6208 - val_accuracy: 0.6172
Epoch 4/15
15/15 [==============================] - 9s 609ms/step - loss: 0.5899 - accuracy: 0.6464 - val_loss: 0.5938 - val_accuracy: 0.6585
Epoch 5/15
15/15 [==============================] - 9s 599ms/step - loss: 0.5182 - accuracy: 0.7286 - val_loss: 0.6165 - val_accuracy: 0.7042
Epoch 6/15
15/15 [==============================] - 9s 608ms/step - loss: 0.4697 - accuracy: 0.7682 - val_loss: 0.5853 - val_accuracy: 0.7109
Epoch 7/15
15/15 [==============================] - 9s 604ms/step - loss: 0.4393 - accuracy: 0.7746 - val_loss: 0.5826 - val_accuracy: 0.7132
Epoch 8/15
15/15 [==============================] - 9s 608ms/step - loss: 0.4115 - accuracy: 0.7895 - val_loss: 0.6602 - val_accuracy: 0.7042
Epoch 9/15
15/15 [==============================] - 9s 598ms/step - loss: 0.3831 - accuracy: 0.8162 - val_loss: 0.6254 - val_accuracy: 0.7076
Epoch 10/15
15/15 [==============================] - 9s 601ms/step - loss: 0.3151 - accuracy: 0.8531 - val_loss: 0.5924 - val_accuracy: 0.7098
Epoch 11/15
15/15 [==============================] - 9s 611ms/step - loss: 0.2904 - accuracy: 0.8632 - val_loss: 0.6664 - val_accuracy: 0.6964
Epoch 12/15
15/15 [==============================] - 9s 604ms/step - loss: 0.2524 - accuracy: 0.8921 - val_loss: 0.7111 - val_accuracy: 0.6752
Epoch 13/15
15/15 [==============================] - 9s 592ms/step - loss: 0.2143 - accuracy: 0.9081 - val_loss: 0.7246 - val_accuracy: 0.6953
Epoch 14/15
15/15 [==============================] - 9s 599ms/step - loss: 0.1829 - accuracy: 0.9284 - val_loss: 0.7323 - val_accuracy: 0.7221
Epoch 15/15
15/15 [==============================] - 9s 598ms/step - loss: 0.1469 - accuracy: 0.9460 - val_loss: 0.8435 - val_accuracy: 0.6998

最新更新