Keras功能性替代中间层



我想在内置keras模型中用GroupNorm替换BatchNorm层,例如ResNet50。我正在尝试将节点的层重置到我的新层,但是当我查询model.summary((.时,没有任何变化

import tensorflow as tf
import tensorflow_addons as tfa
from tensorflow.keras import layers
model = tf.keras.applications.resnet.ResNet50(include_top=False, weights=None)
channels = 3
for i,layer in enumerate(model.layers[:]):
if 'bn' in layer.name:
inbound_nodes = layer.inbound_nodes
outbound_nodes = layer.outbound_nodes

new_name = layer.name.replace('bn','gn')
new_layer =  tfa.layers.GroupNormalization(channels)
new_layer._name = new_name 

for j in range(len(inbound_nodes)):
inbound_nodes[j].layer = new_layer #set end of node to this layer

for k in range(len(outbound_nodes)):
new_layer.outbound_nodes.append(outbound_nodes[k])

layer = new_layer

我创建了以下代码,对这个答案进行了一些更改,以使其适用于您的情况:

import tensorflow as tf
import tensorflow_addons as tfa
from tensorflow.keras import layers, Model 
model = tf.keras.applications.resnet.ResNet50(include_top=False, weights=None)
print(model.summary())
channels = 64
from keras.models import Model
def insert_layer_nonseq(model, layer_regex, insert_layer_factory,
insert_layer_name=None, position='after'):
# Auxiliary dictionary to describe the network graph
network_dict = {'input_layers_of': {}, 'new_output_tensor_of': {}}
# Set the input layers of each layer
for layer in model.layers:
for node in layer._outbound_nodes:
layer_name = node.outbound_layer.name
if layer_name not in network_dict['input_layers_of']:
network_dict['input_layers_of'].update(
{layer_name: [layer.name]})
else:
network_dict['input_layers_of'][layer_name].append(layer.name)
# Set the output tensor of the input layer
network_dict['new_output_tensor_of'].update(
{model.layers[0].name: model.input})
# Iterate over all layers after the input
model_outputs = []
for layer in model.layers[1:]:
# Determine input tensors
layer_input = [network_dict['new_output_tensor_of'][layer_aux] 
for layer_aux in network_dict['input_layers_of'][layer.name]]
if len(layer_input) == 1:
layer_input = layer_input[0]
# Insert layer if name matches
if (layer.name).endswith(layer_regex):
if position == 'replace':
x = layer_input
else:
raise ValueError('position must be: replace')
new_layer = insert_layer_factory()
new_layer._name = '{}_{}'.format(layer.name, new_layer.name)
x = new_layer(x)
# print('New layer: {} Old layer: {} Type: {}'.format(new_layer.name, layer.name, position))

else:
x = layer(layer_input)
# Set new output tensor (the original one, or the one of the inserted
# layer)
network_dict['new_output_tensor_of'].update({layer.name: x})
# Save tensor in output list if it is output in initial model
if layer_name in model.output_names:
model_outputs.append(x)
return Model(inputs=model.inputs, outputs=model_outputs)
def replace_layer():
return tfa.layers.GroupNormalization(channels)
model = insert_layer_nonseq(model, 'bn', replace_layer, position="replace")

注意:由于以下原因,我已将您的channels变量从3更改为64。

根据论据group:的文件

Integer,用于组规范化的组数。可以在范围[1,N],其中N是输入维度。输入维度必须可被组的数目整除。默认值为32。

您应该选择最合适的一个。

最新更新