两个cnn堆叠的适配模块设计



我正在尝试堆栈两个不同的cnn 使用适配模块来连接它们,但是我很难正确地确定适配模块的层超参数。

更精确地说,我想训练自适应模块来桥接两个卷积层:

  1. 输出形状为:(29,29,256)的图层A
  2. 输入形状为(8,8,384)的图层B

所以,在A层之后,我依次添加适配模块,我选择:

  • Conv2D具有384个滤波器的层,内核大小:(3,3)/输出形状:(29,29,384)
  • MaxPool2D池大小:(2,2),步幅:(4,4),填充:same"/输出形状:(8,8,384)

最后,我尝试将层B添加到模型中,但我得到以下错误从tensorflow:

InvalidArgumentError: Dimensions must be equal, but are 384 and 288 for '{{node batch_normalization_159/FusedBatchNormV3}} = FusedBatchNormV3[T=DT_FLOAT, U=DT_FLOAT, data_format="NHWC", epsilon=0.001, exponential_avg_factor=1, is_training=false](Placeholder, batch_normalization_159/scale, batch_normalization_159/ReadVariableOp, batch_normalization_159/FusedBatchNormV3/ReadVariableOp, batch_normalization_159/FusedBatchNormV3/ReadVariableOp_1)' with input shapes: [?,8,8,384], [288], [288], [288], [288].

有一个最小的可复制的例子:

from keras.applications.inception_resnet_v2 import InceptionResNetV2
from keras.applications.mobilenet import MobileNet
from keras.layers import Conv2D, MaxPool2D
from keras.models import Sequential
mobile_model = MobileNet(weights='imagenet')
server_model = InceptionResNetV2(weights='imagenet')
hybrid = Sequential()
for i, layer in enumerate(mobile_model.layers):
if i <= 36:
layer.trainable = False
hybrid.add(layer)
hybrid.add(Conv2D(384, kernel_size=(3,3), padding='same'))
hybrid.add(MaxPool2D(pool_size=(2,2), strides=(4,4), padding='same'))
for i, layer in enumerate(server_model.layers):
if i >= 610:
layer.trainable = False
hybrid.add(layer)

顺序模型只支持层像链表一样排列的模型——每一层只接受一层的输出,每一层的输出只提供给单个层。你的两个基本模型都有残差块,这打破了上面的假设,并将模型体系结构转变为有向无环图(DAG)。

要做你想做的事,你需要使用Functional API。使用Functional API,您可以显式地控制中间激活,即KerasTensors。

对于第一个模型,您可以跳过额外的工作,只需从现有图的子集创建一个新模型,如下所示

sub_mobile = keras.models.Model(mobile_model.inputs, mobile_model.layers[36].output)

连接第二个模型的一些层要困难得多。切掉keras模型的末尾很容易——切掉开始部分要困难得多,因为需要一个tf.keras.Input占位符。要成功地做到这一点,您需要编写一个遍历层的模型遍历算法,跟踪输出KerasTensor,然后用新的输入调用每个层以创建新的输出KerasTensor。

你可以通过简单地为一个InceptionResNet找到一些源代码并通过Python添加层来避免所有的工作,而不是自省一个现有的模型。这里有一个可能符合要求。

https://github.com/yuyang-huang/keras-inception-resnet-v2/blob/master/inception_resnet_v2.py

最新更新