检查tensorflow keras模型中的下一层是什么



我有一个keras模型,它在层之间有快捷方式。对于每一层,我想获得下一个连接层的名称(或索引(,因为简单地迭代所有model.layers不会告诉我该层是否连接到上一个。

一个示例模型可能是:

model = tf.keras.applications.resnet50.ResNet50(
include_top=True, weights='imagenet', input_tensor=None,
input_shape=None, pooling=None, classes=1000)

您可以通过这种方式提取dict格式的信息。。。

首先,定义一个效用函数,并从每个Functional模型(代码参考(中获得model.summary()方法中的相关节点

relevant_nodes = []
for v in model._nodes_by_depth.values():
relevant_nodes += v
def get_layer_summary_with_connections(layer):

info = {}
connections = []
for node in layer._inbound_nodes:
if relevant_nodes and node not in relevant_nodes:
# node is not part of the current network
continue
for inbound_layer, node_index, tensor_index, _ in node.iterate_inbound():
connections.append(inbound_layer.name)

name = layer.name
info['type'] = layer.__class__.__name__
info['parents'] = connections

return info

其次,通过分层迭代提取信息:

results = {}
layers = model.layers
for layer in layers:
info = get_layer_summary_with_connections(layer)
results[layer.name] = info

results是嵌套的dict,格式如下:

{
'layer_name': {'type':'the layer type', 'parents':'list of the parent layers'},
...
'layer_name': {'type':'the layer type', 'parents':'list of the parent layers'}
}

对于ResNet50,结果为:

{
'input_4': {'type': 'InputLayer', 'parents': []},
'conv1_pad': {'type': 'ZeroPadding2D', 'parents': ['input_4']},
'conv1_conv': {'type': 'Conv2D', 'parents': ['conv1_pad']},
'conv1_bn': {'type': 'BatchNormalization', 'parents': ['conv1_conv']},
...
'conv5_block3_out': {'type': 'Activation', 'parents': ['conv5_block3_add']},
'avg_pool': {'type': 'GlobalAveragePooling2D', 'parents' ['conv5_block3_out']},
'predictions': {'type': 'Dense', 'parents': ['avg_pool']}
}

此外,您可以修改get_layer_summary_with_connections以返回您对感兴趣的所有信息

您可以使用keras的模型绘图实用程序查看整个模型及其连接

tf.keras.utils.plot_model(model, to_file='path/to/image', show_shapes=True)

最新更新