这是我第一次实现用于图像聚类的自动编码器,我是CNN的新手。我试图了解它是如何工作的,并通过用图像测试模型来学习。这是我的模型,我只想知道我在这里做的是否有什么错误,或者你是否有任何可以改进这个模型的建议。
input = Input((224, 224,3), name = 'input')
conv1 = Conv2D(64, (3, 3), activation = 'relu', padding = 'same')(input)
conv1 = Conv2D(64, (3, 3), activation = 'relu', padding = 'same')(conv1)
bn1 = BatchNormalization()(conv1)
pool_enc = MaxPooling2D(pool_size = (2,2), strides = (2,2))(bn1)
conv5 = Conv2D(512, (3, 3), activation = 'relu', padding = 'same')(bn1)
bn1 = BatchNormalization()(bn1)
pool1 = MaxPooling2D(pool_size = (2,2), strides = (2,2))(bn1)
#DECODER
pool1 = UpSampling2D((2,2))(pool1)
bn1 = BatchNormalization()(pool1)
bn1 = Conv2DTranspose(512, (3, 3), activation = 'relu', padding = 'same'')(bn1)
opt = Adam(lr=0.01)
model = Model(inputs = inputs, outputs = bn1, name = 'a-coder')
model.compile(optimizer=opt, loss=keras.losses.categorical_crossentropy, metrics=['accuracy'])
单词"改进";是非常模糊的,因为你真的需要具体说明它目前到底出了什么问题;改进";。
但我要对您的代码做的第一件事是重新排列层。
将CNN层组合在一起的最佳方法是
a CNN -> a Batch Normalization -> an activation (like relu)
而你现在所做的就像
a CNN -> a relu -> a CNN -> a relu -> a Batch Normalization
这并不理想。
编码器部分应该看起来像
conv5_enc = Conv2D(512, (3, 3), padding = 'same', name ='conv5_1_enc')(pool4_enc)
conv5_enc = BatchNormalization(name = "bn5_enc_1")(conv5_enc)
conv5_enc = Activation('relu')(conv5_enc)
conv5_enc = Conv2D(512, (3, 3), padding = 'same', name ='conv5_2_enc')(conv5_enc)
conv5_enc = BatchNormalization(name = "bn5_enc_2")(conv5_enc)
conv5_enc = Activation('relu')(conv5_enc)
conv5_enc = Conv2D(512, (3, 3), padding = 'same', name ='conv5_3_enc')(conv5_enc)
conv5_enc = BatchNormalization(name = "bn5_enc_3")(conv5_enc)
conv5_enc = Activation('relu')(conv5_enc)
解码器也是,在cnn 之前使用batchnorm没有意义
pool2_dec = UpSampling2D((2,2), name = 'pool_2')(conv3_dec)
conv2_dec = Conv2DTranspose(128, (3, 3), padding = 'same', name ='conv2_2_dec')(pool2_dec)
conv2_dec = BatchNormalization(name = "bn4_dec_1")(conv2_dec)
conv2_dec = Activation('relu')(conv2_dec)
conv2_dec = Conv2DTranspose(128, (3, 3), padding = 'same', name ='conv2_1_dec')(conv2_dec)
conv2_dec = BatchNormalization(name = "bn4_dec_2")(conv2_dec)
conv2_dec = Activation('relu')(conv2_dec)