如何自定义输入深度为4的Autoencoder CNN网络?



我有一个像这样的CNN自动编码器结构。该模型可以很好地处理输入形状(200,800,2)。现在,我想用输入维度形状=(200,800,4)来测试CNN自编码器。然而,由于输入和输出形状不同,我得到了一个错误,因为输出add_4形状(200,800,2)。

ValueError:尺寸必须相等,但对于{{节点为2和4mean_squared_error/SquaredDifference}} =SquaredDifference[T=DT_FLOAT](model/add/add, IteratorGetNext:1)' with输入形状:[?], 200800, 2],[? 200800 4]。

如何定制网络,使用输入维度(200,800,4)来匹配add_4层的输出??

input_img = Input(shape=(200, 800, 2))
## Encoder
x = Conv2D(16, (3, 3), activation='tanh', padding='same')(input_img)
x = MaxPooling2D((2, 2), padding='same')(x)
x = Conv2D(8, (3, 3), activation='tanh', padding='same')(x)
x = MaxPooling2D((2, 2), padding='same')(x)
x = Conv2D(8, (3, 3), activation='tanh', padding='same')(x)
x = MaxPooling2D((2, 2), padding='same')(x)
x = Conv2D(8, (3, 3), activation='tanh', padding='same')(x)
x = MaxPooling2D((2, 2), padding='same')(x)
x = Conv2D(4, (3, 3), activation='tanh', padding='same')(x)
x = MaxPooling2D((2, 2), padding='same')(x)
x = Conv2D(4, (3, 3), activation='tanh', padding='same')(x)
x = MaxPooling2D((2, 2), padding='same')(x)
x = Reshape([4*13*4])(x) 
encoded = Dense(2,activation='tanh')(x)
## Two variables
val1= Lambda(lambda x: x[:,0:1])(encoded)
val2= Lambda(lambda x: x[:,1:2])(encoded)
## Decoder 1
x1 = Dense(4*13*4,activation='tanh')(val1)
x1 = Reshape([4,13,4])(x1)
x1 = UpSampling2D((2,2))(x1)
x1 = Conv2D(4,(3,3),activation='tanh',padding='same')(x1)
x1 = UpSampling2D((2,2))(x1)
x1 = Conv2D(8,(3,3),activation='tanh',padding='same')(x1)
x1 = UpSampling2D((2,2))(x1)
x1 = Conv2D(8,(3,3),activation='tanh',padding='same')(x1)
x1 = UpSampling2D((2,2))(x1)
x1 = Conv2D(8,(3,3),activation='tanh',padding='same')(x1)
x1 = UpSampling2D((2,2))(x1)
x1 = Conv2D(16,(3,3),activation='tanh',padding='same')(x1)
x1 = UpSampling2D((2,2))(x1)
x1d = Conv2D(2,(3,3),activation='linear',padding='same')(x1)
## Decoder 2
x2 = Dense(4*13*4,activation='tanh')(val2)
x2 = Reshape([4,13,4])(x2)
x2 = UpSampling2D((2,2))(x2)
x2 = Conv2D(4,(3,3),activation='tanh',padding='same')(x2)
x2 = UpSampling2D((2,2))(x2)
x2 = Conv2D(8,(3,3),activation='tanh',padding='same')(x2)
x2 = UpSampling2D((2,2))(x2)
x2 = Conv2D(8,(3,3),activation='tanh',padding='same')(x2)
x2 = UpSampling2D((2,2))(x2)
x2 = Conv2D(8,(3,3),activation='tanh',padding='same')(x2)
x2 = UpSampling2D((2,2))(x2)
x2 = Conv2D(16,(3,3),activation='tanh',padding='same')(x2)
x2 = UpSampling2D((2,2))(x2)
x2d = Conv2D(2,(3,3),activation='linear',padding='same')(x2)
decoded = Add()([x1d,x2d])

输出形状:

__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to                     
==================================================================================================
input_6 (InputLayer)            [(None, 200, 800, 2) 0                                            
__________________________________________________________________________________________________
conv2d_78 (Conv2D)              (None, 200, 800, 16) 304         input_6[0][0]                    
__________________________________________________________________________________________________
max_pooling2d_30 (MaxPooling2D) (None, 100, 400, 16) 0           conv2d_78[0][0]                  
__________________________________________________________________________________________________
conv2d_79 (Conv2D)              (None, 100, 400, 8)  1160        max_pooling2d_30[0][0]           
__________________________________________________________________________________________________
max_pooling2d_31 (MaxPooling2D) (None, 50, 200, 8)   0           conv2d_79[0][0]                  
__________________________________________________________________________________________________
conv2d_80 (Conv2D)              (None, 50, 200, 8)   584         max_pooling2d_31[0][0]           
__________________________________________________________________________________________________
max_pooling2d_32 (MaxPooling2D) (None, 25, 100, 8)   0           conv2d_80[0][0]                  
__________________________________________________________________________________________________
conv2d_81 (Conv2D)              (None, 25, 100, 8)   584         max_pooling2d_32[0][0]           
__________________________________________________________________________________________________
max_pooling2d_33 (MaxPooling2D) (None, 13, 50, 8)    0           conv2d_81[0][0]                  
__________________________________________________________________________________________________
conv2d_82 (Conv2D)              (None, 13, 50, 4)    292         max_pooling2d_33[0][0]           
__________________________________________________________________________________________________
max_pooling2d_34 (MaxPooling2D) (None, 7, 25, 4)     0           conv2d_82[0][0]                  
__________________________________________________________________________________________________
conv2d_83 (Conv2D)              (None, 7, 25, 4)     148         max_pooling2d_34[0][0]           
__________________________________________________________________________________________________
max_pooling2d_35 (MaxPooling2D) (None, 4, 13, 4)     0           conv2d_83[0][0]                  
__________________________________________________________________________________________________
reshape_13 (Reshape)            (None, 208)          0           max_pooling2d_35[0][0]           
__________________________________________________________________________________________________
dense_12 (Dense)                (None, 2)            418         reshape_13[0][0]                 
__________________________________________________________________________________________________
lambda_8 (Lambda)               (None, 1)            0           dense_12[0][0]                   
__________________________________________________________________________________________________
lambda_9 (Lambda)               (None, 1)            0           dense_12[0][0]                   
__________________________________________________________________________________________________
dense_13 (Dense)                (None, 208)          416         lambda_8[0][0]                   
__________________________________________________________________________________________________
dense_14 (Dense)                (None, 208)          416         lambda_9[0][0]                   
__________________________________________________________________________________________________
reshape_14 (Reshape)            (None, 4, 13, 4)     0           dense_13[0][0]                   
__________________________________________________________________________________________________
reshape_15 (Reshape)            (None, 4, 13, 4)     0           dense_14[0][0]                   
__________________________________________________________________________________________________
up_sampling2d_48 (UpSampling2D) (None, 8, 13, 4)     0           reshape_14[0][0]                 
__________________________________________________________________________________________________
up_sampling2d_54 (UpSampling2D) (None, 8, 13, 4)     0           reshape_15[0][0]                 
__________________________________________________________________________________________________
conv2d_84 (Conv2D)              (None, 8, 13, 4)     148         up_sampling2d_48[0][0]           
__________________________________________________________________________________________________
conv2d_90 (Conv2D)              (None, 8, 13, 4)     148         up_sampling2d_54[0][0]           
__________________________________________________________________________________________________
up_sampling2d_49 (UpSampling2D) (None, 16, 26, 4)    0           conv2d_84[0][0]                  
__________________________________________________________________________________________________
up_sampling2d_55 (UpSampling2D) (None, 16, 26, 4)    0           conv2d_90[0][0]                  
__________________________________________________________________________________________________
conv2d_85 (Conv2D)              (None, 16, 26, 8)    296         up_sampling2d_49[0][0]           
__________________________________________________________________________________________________
conv2d_91 (Conv2D)              (None, 16, 26, 8)    296         up_sampling2d_55[0][0]           
__________________________________________________________________________________________________
up_sampling2d_50 (UpSampling2D) (None, 32, 52, 8)    0           conv2d_85[0][0]                  
__________________________________________________________________________________________________
up_sampling2d_56 (UpSampling2D) (None, 32, 52, 8)    0           conv2d_91[0][0]                  
__________________________________________________________________________________________________
conv2d_86 (Conv2D)              (None, 32, 52, 8)    584         up_sampling2d_50[0][0]           
__________________________________________________________________________________________________
conv2d_92 (Conv2D)              (None, 32, 52, 8)    584         up_sampling2d_56[0][0]           
__________________________________________________________________________________________________
up_sampling2d_51 (UpSampling2D) (None, 64, 104, 8)   0           conv2d_86[0][0]                  
__________________________________________________________________________________________________
up_sampling2d_57 (UpSampling2D) (None, 64, 104, 8)   0           conv2d_92[0][0]                  
__________________________________________________________________________________________________
conv2d_87 (Conv2D)              (None, 64, 104, 8)   584         up_sampling2d_51[0][0]           
__________________________________________________________________________________________________
conv2d_93 (Conv2D)              (None, 64, 104, 8)   584         up_sampling2d_57[0][0]           
__________________________________________________________________________________________________
up_sampling2d_52 (UpSampling2D) (None, 128, 208, 8)  0           conv2d_87[0][0]                  
__________________________________________________________________________________________________
up_sampling2d_58 (UpSampling2D) (None, 128, 208, 8)  0           conv2d_93[0][0]                  
__________________________________________________________________________________________________
conv2d_88 (Conv2D)              (None, 128, 208, 16) 1168        up_sampling2d_52[0][0]           
__________________________________________________________________________________________________
conv2d_94 (Conv2D)              (None, 128, 208, 16) 1168        up_sampling2d_58[0][0]           
__________________________________________________________________________________________________
up_sampling2d_53 (UpSampling2D) (None, 256, 416, 16) 0           conv2d_88[0][0]                  
__________________________________________________________________________________________________
up_sampling2d_59 (UpSampling2D) (None, 256, 416, 16) 0           conv2d_94[0][0]                  
__________________________________________________________________________________________________
conv2d_89 (Conv2D)              (None, 256, 416, 2)  290         up_sampling2d_53[0][0]           
__________________________________________________________________________________________________
conv2d_95 (Conv2D)              (None, 256, 416, 2)  290         up_sampling2d_59[0][0]           
__________________________________________________________________________________________________
add_4 (Add)                     (None, 256, 416, 2)  0           conv2d_89[0][0]                  
conv2d_95[0][0]

将输入形状更改为(200, 800, 4),即

input_img = Input(shape=(200, 800, 4))

最新更新