我正在使用卷积层在 2D 图像上训练自动编码器,并希望将完全连接的层放在编码器部分的顶部进行分类。我的自动编码器定义如下(只是一个简单的说明(:
def encoder(input_img):
conv1 = Conv2D(32, (3, 3), activation='relu', padding='same')(input_img)
conv1 = BatchNormalization()(conv1)
pool1 = MaxPooling2D(pool_size=(2, 2))(conv1)
conv2 = Conv2D(64, (3, 3), activation='relu', padding='same')(pool1)
conv2 = BatchNormalization()(conv2)
return conv2
def decoder(conv2):
conv3 = Conv2D(128, (3, 3), activation='relu', padding='same')(conv2)
conv3 = BatchNormalization()(conv3)
up1 = UpSampling2D((2,2))(conv3)
decoded = Conv2D(1, (3, 3), activation='sigmoid', padding='same')(up1)
return decoded
autoencoder = Model(input_img, decoder(encoder(input_img)))
我的输入图像大小为 (64,80,1(。现在,当将完全连接的层堆叠在编码器顶部时,我正在执行以下操作:
def fc(enco):
flat = Flatten()(enco)
den = Dense(128, activation='relu')(flat)
out = Dense(num_classes, activation='softmax')(den)
return out
encode = encoder(input_img)
full_model = Model(input_img,fc(encode))
for l1,l2 in zip(full_model.layers[:19],autoencoder.layers[0:19]):
l1.set_weights(l2.get_weights())
对于只有一个自动编码器,这是有效的,但现在的问题是我有 2 个自动编码器在所有大小的图像集(64、80、1(上训练。
对于每个标签,我输入两个大小的图像(64、80、1(和一个标签(0 或 1(。我需要将图像 1 输入第一个自动编码器,将图像 2 输入第二个自动编码器。但是如何在上面的代码中full_model
中组合两个自动编码器呢?
另一个问题也是fit()
方法的输入。到目前为止,只有一个自动编码器,输入仅由图像的numpy数组组成(例如(1000,64,80,1((,但是使用两个自动编码器,我将有两组图像作为输入。如何将其输入fit()
方法,以便第一个自动编码器使用第一组图像,第二个自动编码器使用第二组图像?
问:如何在full_model
中组合两个自动编码器?
答:您可以在fc
内连接瓶颈层enco_1
和两个自动编码器的enco_2
:
def fc(enco_1, enco_2):
flat_1 = Flatten()(enco_1)
flat_2 = Flatten()(enco_2)
flat = Concatenate()([enco_1, enco_2])
den = Dense(128, activation='relu')(flat)
out = Dense(num_classes, activation='softmax')(den)
return out
encode_1 = encoder_1(input_img_1)
encode_2 = encoder_2(input_img_2)
full_model = Model([input_img_1, input_img_2], fc(encode_1, encode_2))
请注意,手动设置编码器权重的最后一部分是不必要的 - 请参阅 https://keras.io/getting-started/functional-api-guide/#shared-layers
问:如何将其输入fit
方法,以便第一个自动编码器使用第一组图像,第二个自动编码器使用第二组图像?
答:在上面的代码中,请注意,两个编码器馈送了不同的输入(每个图像集一个(。现在,假设模型以这种方式定义,则可以按如下方式调用full_model.fit
:
full_model.fit(x=[images_set_1, images_set_2],
y=label,
...)
注:未测试。