我正在实现一个基于3D CNN的自编码器。
我想修复错误
conv3D_encoder = keras.models.Sequential([
keras.layers.Conv3D(filters = 16, kernel_size = (3,3,3), padding = "same", input_shape=[8, 160, 160, 3], activation='relu'),
keras.layers.MaxPooling3D(pool_size = (2,2,2)),
keras.layers.Conv3D(filters = 32, kernel_size = (3,3,3), padding = "same", activation='relu'),
keras.layers.MaxPooling3D(pool_size = (2,2,2)),
keras.layers.Conv3D(filters = 64, kernel_size = (3,3,3), padding = "same", activation='relu'),
keras.layers.MaxPooling3D(pool_size = (2,2,2))])
conv3D_decoder = keras.models.Sequential([
keras.layers.Conv3DTranspose(filters = 32, kernel_size=(3,3,3), strides=2, padding="valid", activation="relu", input_shape=[None,3,3,64]),
keras.layers.Conv3DTranspose(filters = 16, kernel_size=(3,3,3), strides=2, padding="same", activation="relu"),
keras.layers.Conv3DTranspose(filters = 1, kernel_size=(3,3,3), strides=2, padding="same", activation="sigmoid"), ])
conv3D_ae = keras.models.Sequential([conv3D_encoder,conv3D_decoder])
conv3D_ae.compile(loss="binary_crossentropy", optimizer=keras.optimizers.SGD(learning_rate=1.5))
history = conv3D_ae.fit(X_train, X_train, epochs=10, validation_data=(X_valid,X_valid))
错误信息:
WARNING:tensorflow:Model was constructed with shape (None, None, 3, 3, 64) for input Tensor("conv3d_transpose_63_input:0", shape=(None, None, 3, 3, 64), dtype=float32), but it was called on an input with incompatible shape (None, 1, 20, 20, 64).
ValueError: logits and labels must have the same shape ((None, 12, 164, 164, 1) vs (None, 8, 160, 160, 3))
在自动编码器中,编码器的输入和解码器的输出应该相同。解码器的输入将是编码器的输出。改变解码器的架构会有所帮助。
conv3D_decoder = keras.models.Sequential([
keras.layers.Conv3DTranspose(filters = 32, kernel_size=(3,3,3), strides=2, padding="valid", activation="relu", input_shape=[1, 20, 20, 64]),
keras.layers.Conv3DTranspose(filters = 16, kernel_size=(3,3,3), strides=2, padding="same", activation="relu"),
keras.layers.Conv3DTranspose(filters = 16, kernel_size=(3,3,3), strides=2, padding="same", activation="sigmoid"),
keras.layers.Conv3D(filters = 3, kernel_size = (5,5,5), padding = "valid", activation='relu'),])
请看一下这个要点,我可以用随机数据重复这个错误。