类型错误:("关键字参数未理解:", "pool1")尝试使用自定义层加载模型时



我正在尝试加载一个带有自定义层的模型,代码如下:

model = load_model(BEST_MODEL_PATH, compile=False, custom_objects={'Localization': Localization})

但我得到了错误:

---------------------------------------------------------------------------
TypeError                                 Traceback (most recent call last)
<ipython-input-73-668917c4ab3c> in <module>()
3 BEST_MODEL_PATH = '/content/data/00002-test-train/model-01-0.903.hdf5'
4 
----> 5 model = load_model(BEST_MODEL_PATH, compile=False, custom_objects={'Localization': Localization})
6 
7 predictions = model.predict(X_test)
14 frames
/usr/local/lib/python3.7/dist-packages/keras/saving/save.py in load_model(filepath, custom_objects, compile, options)
199             (isinstance(filepath, h5py.File) or h5py.is_hdf5(filepath))):
200           return hdf5_format.load_model_from_hdf5(filepath, custom_objects,
--> 201                                                   compile)
202 
203         filepath = path_to_string(filepath)
/usr/local/lib/python3.7/dist-packages/keras/saving/hdf5_format.py in load_model_from_hdf5(filepath, custom_objects, compile)
179     model_config = json_utils.decode(model_config)
180     model = model_config_lib.model_from_config(model_config,
--> 181                                                custom_objects=custom_objects)
182 
183     # set weights
/usr/local/lib/python3.7/dist-packages/keras/saving/model_config.py in model_from_config(config, custom_objects)
50                     '`Sequential.from_config(config)`?')
51   from keras.layers import deserialize  # pylint: disable=g-import-not-at-top
---> 52   return deserialize(config, custom_objects=custom_objects)
53 
54 
/usr/local/lib/python3.7/dist-packages/keras/layers/serialization.py in deserialize(config, custom_objects)
210       module_objects=LOCAL.ALL_OBJECTS,
211       custom_objects=custom_objects,
--> 212       printable_module_name='layer')
/usr/local/lib/python3.7/dist-packages/keras/utils/generic_utils.py in deserialize_keras_object(identifier, module_objects, custom_objects, printable_module_name)
676             custom_objects=dict(
677                 list(_GLOBAL_CUSTOM_OBJECTS.items()) +
--> 678                 list(custom_objects.items())))
679       else:
680         with CustomObjectScope(custom_objects):
/usr/local/lib/python3.7/dist-packages/keras/engine/functional.py in from_config(cls, config, custom_objects)
661     with generic_utils.SharedObjectLoadingScope():
662       input_tensors, output_tensors, created_layers = reconstruct_from_config(
--> 663           config, custom_objects)
664       model = cls(inputs=input_tensors, outputs=output_tensors,
665                   name=config.get('name'))
/usr/local/lib/python3.7/dist-packages/keras/engine/functional.py in reconstruct_from_config(config, custom_objects, created_layers)
1271   # First, we create all layers and enqueue nodes to be processed
1272   for layer_data in config['layers']:
-> 1273     process_layer(layer_data)
1274   # Then we process nodes in order of layer depth.
1275   # Nodes that cannot yet be processed (if the inbound node
/usr/local/lib/python3.7/dist-packages/keras/engine/functional.py in process_layer(layer_data)
1253       from keras.layers import deserialize as deserialize_layer  # pylint: disable=g-import-not-at-top
1254 
-> 1255       layer = deserialize_layer(layer_data, custom_objects=custom_objects)
1256       created_layers[layer_name] = layer
1257 
/usr/local/lib/python3.7/dist-packages/keras/layers/serialization.py in deserialize(config, custom_objects)
210       module_objects=LOCAL.ALL_OBJECTS,
211       custom_objects=custom_objects,
--> 212       printable_module_name='layer')
/usr/local/lib/python3.7/dist-packages/keras/utils/generic_utils.py in deserialize_keras_object(identifier, module_objects, custom_objects, printable_module_name)
679       else:
680         with CustomObjectScope(custom_objects):
--> 681           deserialized_obj = cls.from_config(cls_config)
682     else:
683       # Then `cls` may be a function returning a class.
/usr/local/lib/python3.7/dist-packages/keras/engine/base_layer.py in from_config(cls, config)
746         A layer instance.
747     """
--> 748     return cls(**config)
749 
750   def compute_output_shape(self, input_shape):
<ipython-input-55-f9b9fb9036d5> in __init__(self, filters_1, filters_2, fc_units, kernel_size, pool_size, **kwargs)
14         self.fc1 = Dense(fc_units, activation='relu')
15         self.fc2 = Dense(6, activation=None, bias_initializer=tf.keras.initializers.constant([1.0, 0.0, 0.0, 0.0, 1.0, 0.0]), kernel_initializer='zeros')
---> 16         super(Localization, self).__init__(**kwargs)
17 
18     def build(self, input_shape):
/usr/local/lib/python3.7/dist-packages/tensorflow/python/training/tracking/base.py in _method_wrapper(self, *args, **kwargs)
528     self._self_setattr_tracking = False  # pylint: disable=protected-access
529     try:
--> 530       result = method(self, *args, **kwargs)
531     finally:
532       self._self_setattr_tracking = previous_value  # pylint: disable=protected-access
/usr/local/lib/python3.7/dist-packages/keras/engine/base_layer.py in __init__(self, trainable, name, dtype, dynamic, **kwargs)
321     }
322     # Validate optional keyword arguments.
--> 323     generic_utils.validate_kwargs(kwargs, allowed_kwargs)
324 
325     # Mutable properties
/usr/local/lib/python3.7/dist-packages/keras/utils/generic_utils.py in validate_kwargs(kwargs, allowed_kwargs, error_message)
1141   for kwarg in kwargs:
1142     if kwarg not in allowed_kwargs:
-> 1143       raise TypeError(error_message, kwarg)
1144 
1145 
TypeError: ('Keyword argument not understood:', 'pool1')

我的自定义图层:

class Localization(tf.keras.layers.Layer):
def __init__(self, filters_1, filters_2, fc_units, kernel_size=(5,5), 
pool_size=(2,2), **kwargs):
self.filters_1 = filters_1
self.filters_2 = filters_2
self.fc_units = fc_units
self.kernel_size = kernel_size
self.pool_size = pool_size
self.pool1 = MaxPooling2D(pool_size=pool_size)
self.conv1 = Conv2D(filters=filters_1, kernel_size=kernel_size, padding='same', strides=1, activation='relu')
self.pool2 = MaxPooling2D(pool_size=pool_size)
self.conv2 = Conv2D(filters=filters_2, kernel_size=kernel_size, padding='same', strides=1, activation='relu')
self.pool3 = MaxPooling2D(pool_size=pool_size)
self.flatten = Flatten()
self.fc1 = Dense(fc_units, activation='relu')
self.fc2 = Dense(6, activation=None, bias_initializer=tf.keras.initializers.constant([1.0, 0.0, 0.0, 0.0, 1.0, 0.0]), kernel_initializer='zeros')
super(Localization, self).__init__(**kwargs)
def build(self, input_shape):
print("Building Localization Network with input shape:", input_shape)
def compute_output_shape(self, input_shape):
return [None, 6]
def call(self, inputs):
x = self.pool1(inputs)
x = self.conv1(x)
x = self.pool2(x)
x = self.conv2(x)
x = self.pool3(x)
x = self.flatten(x)
x = self.fc1(x)
theta = self.fc2(x)
theta = tf.keras.layers.Reshape((2, 3))(theta)
return theta
def get_config(self):
config = super(Localization, self).get_config()
config.update({
'filters_1': self.filters_1,
'filters_2': self.filters_2,
'fc_units': self.fc_units,
'kernel_size': self.kernel_size,
'pool_size': self.pool_size,
'pool1': self.pool1,
'conv1': self.conv1,
'pool2': self.pool2,
'conv2': self.conv2,
'pool3': self.pool3,
'flatten': self.flatten,
'fc1': self.fc1,
'fc2': self.fc2,
})
return config

我正在使用谷歌colab,它升级到tensorflow 2.6。在tensorflow 2.5中,我没有遇到这样的问题。

我被迫对图层进行了许多更改,试图让它发挥作用。在2.5中,我不需要在__init__中分配filters_1filters_2和其他args(因为我在其他地方不使用它们),传递**kwargs并编写get_config函数。

我甚至尝试再次安装tensorflow 2.5和keras,但在训练时会出错。我尝试了很多东西,搜索了文档,阅读了几乎所有类似的问题,但都找不到什么。

正如史努比博士在他们的评论中所说,我从get_config中删除了层,现在它工作得很好。

最新更新