无法使用权重文件脱机加载 keras resnet50 模型



我想离线训练 keras 预训练 resnet50 模型,但我无法加载模型。

当我设置weights='imagenet'时它有效。它会自动下载图像净重文件。

from keras.applications.resnet import ResNet50
base_model = ResNet50(include_top=False, weights='resnet', input_shape=(w,h,3),pooling='avg')

但是当我手动下载相同的权重文件并设置weights=resnet_weights_path时,它会抛出 ValueError。

(w,h) = 224,224
resnet_weights_path = '../input/resnet50/resnet50_weights_tf_dim_ordering_tf_kernels_notop.h5'
base_model = ResNet50(include_top=False, weights=resnet_weights_path, input_shape=(w,h,3),pooling='avg')

ValueError:形状 (1, 1, 256, 512( 和 (512, 128, 1, 1( 不兼容。

完整回溯:

---------------------------------------------------------------------------
ValueError                                Traceback (most recent call last)
<ipython-input-18-7683562fa2b9> in <module>
1 resnet_weights_path = '../input/resnet50/resnet50_weights_tf_dim_ordering_tf_kernels_notop.h5'
2 base_model = ResNet50(include_top=False, weights=resnet_weights_path,
----> 3                       pooling='avg')
4 base_model.summary()
/opt/conda/lib/python3.6/site-packages/keras/applications/__init__.py in wrapper(*args, **kwargs)
18         kwargs['models'] = models
19         kwargs['utils'] = utils
---> 20         return base_fun(*args, **kwargs)
21 
22     return wrapper
/opt/conda/lib/python3.6/site-packages/keras/applications/resnet.py in ResNet50(*args, **kwargs)
12 @keras_modules_injection
13 def ResNet50(*args, **kwargs):
---> 14     return resnet.ResNet50(*args, **kwargs)
15 
16 
/opt/conda/lib/python3.6/site-packages/keras_applications/resnet_common.py in ResNet50(include_top, weights, input_tensor, input_shape, pooling, classes, **kwargs)
433                   input_tensor, input_shape,
434                   pooling, classes,
--> 435                   **kwargs)
436 
437 
/opt/conda/lib/python3.6/site-packages/keras_applications/resnet_common.py in ResNet(stack_fn, preact, use_bias, model_name, include_top, weights, input_tensor, input_shape, pooling, classes, **kwargs)
411         model.load_weights(weights_path)
412     elif weights is not None:
--> 413         model.load_weights(weights)
414 
415     return model
/opt/conda/lib/python3.6/site-packages/keras/engine/saving.py in load_wrapper(*args, **kwargs)
490                 os.remove(tmp_filepath)
491             return res
--> 492         return load_function(*args, **kwargs)
493 
494     return load_wrapper
/opt/conda/lib/python3.6/site-packages/keras/engine/network.py in load_weights(self, filepath, by_name, skip_mismatch, reshape)
1228             else:
1229                 saving.load_weights_from_hdf5_group(
-> 1230                     f, self.layers, reshape=reshape)
1231             if hasattr(f, 'close'):
1232                 f.close()
/opt/conda/lib/python3.6/site-packages/keras/engine/saving.py in load_weights_from_hdf5_group(f, layers, reshape)
1235                              ' elements.')
1236         weight_value_tuples += zip(symbolic_weights, weight_values)
-> 1237     K.batch_set_value(weight_value_tuples)
1238 
1239 
/opt/conda/lib/python3.6/site-packages/keras/backend/tensorflow_backend.py in batch_set_value(tuples)
2958             `value` should be a Numpy array.
2959     """
-> 2960     tf_keras_backend.batch_set_value(tuples)
2961 
2962 
/opt/conda/lib/python3.6/site-packages/tensorflow_core/python/keras/backend.py in batch_set_value(tuples)
3321     with ops.init_scope():
3322       for x, value in tuples:
-> 3323         x.assign(np.asarray(value, dtype=dtype(x)))
3324   else:
3325     with get_graph().as_default():
/opt/conda/lib/python3.6/site-packages/tensorflow_core/python/ops/resource_variable_ops.py in assign(self, value, use_locking, name, read_value)
817     with _handle_graph(self.handle):
818       value_tensor = ops.convert_to_tensor(value, dtype=self.dtype)
--> 819       self._shape.assert_is_compatible_with(value_tensor.shape)
820       assign_op = gen_resource_variable_ops.assign_variable_op(
821           self.handle, value_tensor, name=name)
/opt/conda/lib/python3.6/site-packages/tensorflow_core/python/framework/tensor_shape.py in assert_is_compatible_with(self, other)
1108     """
1109     if not self.is_compatible_with(other):
-> 1110       raise ValueError("Shapes %s and %s are incompatible" % (self, other))
1111 
1112   def most_specific_compatible_shape(self, other):
ValueError: Shapes (1, 1, 256, 512) and (512, 128, 1, 1) are incompatible

该问题可能是由于 keras 版本造成的。我正在使用的当前 keras 版本是2.3.1.
执行以下操作以解决问题:
1. 使用选项weights='imagenet'运行代码。它会自动下载权重文件。
2. 提供下载的权重文件的路径。

存在形状不匹配,除了根据权重更改体系结构外无法解决,因为向量形状不匹配会导致问题。

从此处下载权重,然后重试。 这些是 Keras 它自己给出的权重。

WEIGHTS_PATH = ('https://github.com/fchollet/deep-learning-models/'
'releases/download/v0.2/'
'resnet50_weights_tf_dim_ordering_tf_kernels.h5')
WEIGHTS_PATH_NO_TOP = ('https://github.com/fchollet/deep-learning-models/'
'releases/download/v0.2/'
'resnet50_weights_tf_dim_ordering_tf_kernels_notop.h5')

对于加载 Resnet50 以供离线使用的简单解决方案,您可以尝试通过设置参数来自动加载权重weights ='imagenet'

from keras.applications.resnet import ResNet50
base_model = ResNet50(include_top=False, weights='imagenet', input_shape=(w,h,3), pooling='avg')

使用 保存模型

base_model.save("model_name.h5")

然后可以使用

from tensorflow.keras.models import load_model
resnet = load_model('model_name.h5')

请将参数"by_name=True"添加到"model.load_weights(("中。这是在线或离线模式下问题的正确解决方案。我采用离线模式,因为我的桌面上有重量。

# Build model.
model = Model(inputs, x, name='resnet50')
# load weights
if weights == 'imagenet':
if include_top:
weights_path = WEIGHTS_PATH
else:
weights_path = WEIGHTS_PATH_NO_TOP
# -model.load_weights(weights_path)
model.load_weights(weights_path, by_name=True)

最新更新