如何使用数据生成器来训练自编码器?



我正在尝试训练一个自动编码器,我想使用Keras的数据生成器API来提供数据。代码片段如下所示。我已经尝试了fit()和fit_generator(),它们都不工作。

seed=24
batch_size= 8
image_height,image_width =32,32
img_data_gen_args = dict(rescale = 1/255.)
image_data_generator = ImageDataGenerator(**img_data_gen_args)
image_generator = image_data_generator.flow_from_directory(path, 
seed=seed, 
batch_size=batch_size,
class_mode='input')
encoder_input = keras.Input(shape=(image_height,image_width,1))
# encoder
x = Conv2D(32, (3, 3), activation='relu', padding='same')(encoder_input)
x = MaxPooling2D((2, 2), padding='same')(x)
encoded = Conv2D(1, (3, 3), activation='relu', padding='same')(x)
encoder = Model(encoder_input, encoded)
# decoder
decoder_input= Input(shape=(8, 8, 1))
decoder = Conv2D(32, (3, 3), activation='relu', padding='same')(decoder_input)
x = UpSampling2D((2, 2))(decoder)
decoded = Conv2D(1, (3, 3), activation='sigmoid', padding='same')(x)
decoder = Model(decoder_input, decoded)
# auto encoder
auto_input = Input(shape=(image_height,image_width, 1))
encoded = encoder(auto_input)
decoded = decoder(encoded)
autoencoder = Model(auto_input, decoded)
autoencoder.compile(optimizer="adam", loss="binary_crossentropy", metrics=['accuracy'])     
num_train_imgs = len(os.listdir(path))
steps_per_epoch = num_train_imgs //batch_size
history = autoencoder.fit_generator(generator=image_generator,
validation_data=image_generator,
epochs=20)
# history = model.fit_generator(image_generator, validation_data=image_generator, 
#                     steps_per_epoch=steps_per_epoch, 
#                     validation_steps=steps_per_epoch, epochs=50)      
                                                                  
#history = autoencoder.fit(
#    image_generator, 
#    epochs=10,
#    validation_data=image_generator,
#)

我得到以下错误

---------------------------------------------------------------------------
UnimplementedError                        Traceback (most recent call last)
Input In [62], in <cell line: 1>()
----> 1 history = autoencoder.fit_generator(generator=image_generator,
2                                     validation_data=image_generator,
3                                     epochs=20)

您的解码器的输入形状需要是(16, 16, 1)才能与编码器的输出一起工作。另外,尝试在flow_from_directory中设置color_mode='grayscale'。下面是一个使用虚拟数据的工作示例:

# Create dummy data
import numpy
from PIL import Image
imarray = numpy.random.rand(32, 32, 3) * 255
im = Image.fromarray(imarray.astype('uint8')).convert('L')
im.save('/content/images/class1/image.png')
im.save('/content/images/class1/image1.png')
im.save('/content/images/class1/image2.png')
im.save('/content/images/class2/image3.png')
im.save('/content/images/class2/image4.png')
im.save('/content/images/class2/image5.png')

模型:

import tensorflow as tf
seed=24
batch_size= 8
image_height,image_width =32,32
img_data_gen_args = dict(rescale = 1/255.)

image_data_generator = tf.keras.preprocessing.image.ImageDataGenerator(**img_data_gen_args)
image_generator = image_data_generator.flow_from_directory('/content/images', 
seed=seed, target_size=(image_height, image_width),
batch_size=batch_size,
class_mode='input', color_mode='grayscale')
encoder_input = tf.keras.layers.Input(shape=(image_height,image_width,1))
# encoder
x = tf.keras.layers.Conv2D(32, (3, 3), activation='relu', padding='same')(encoder_input)
x = tf.keras.layers.MaxPooling2D((2, 2), padding='same')(x)
encoded = tf.keras.layers.Conv2D(1, (3, 3), activation='relu', padding='same')(x)
encoder = tf.keras.Model(encoder_input, encoded)
# decoder
decoder_input= tf.keras.layers.Input(shape=(16, 16, 1))
decoder = tf.keras.layers.Conv2D(32, (3, 3), activation='relu', padding='same')(decoder_input)
x = tf.keras.layers.UpSampling2D((2, 2))(decoder)
decoded = tf.keras.layers.Conv2D(1, (3, 3), activation='sigmoid', padding='same')(x)
decoder = tf.keras.Model(decoder_input, decoded)
# auto encoder
auto_input = tf.keras.layers.Input(shape=(image_height,image_width, 1))
encoded = encoder(auto_input)
decoded = decoder(encoded)
autoencoder = tf.keras.Model(auto_input, decoded)
autoencoder.compile(optimizer="adam", loss="binary_crossentropy", metrics=['accuracy'])     
history = autoencoder.fit(image_generator, epochs=20)
Found 6 images belonging to 2 classes.
Epoch 1/20
1/1 [==============================] - 1s 544ms/step - loss: 0.6932 - accuracy: 0.0000e+00
Epoch 2/20
1/1 [==============================] - 0s 19ms/step - loss: 0.6931 - accuracy: 0.0000e+00
Epoch 3/20
...
...

最新更新