tf.keras.models.save_model not saving the probabilistic_mode



系统信息

  • 我是否编写了自定义代码(而不是使用TensorFlow中提供的股票示例脚本(:是
  • 操作系统平台和发行版(例如,Linux Ubuntu 16.04(:Ubuntu 20.04 LTS
  • TensorFlow安装自(源或二进制(:二进制
  • TensorFlow版本(使用下面的命令(:v2.5.0-rc3-213-ga4dfb8d1a71 2.5.0
  • Tensorflow_propobility版本:"0.13.0">
  • Python版本:3.8.10
  • CUDA/cuDNN版本:CUDA_111.2.r11.2/compiler.29373293_0
  • GPU型号和内存:12Gb TitanXP

描述当前行为具有Tensorflow_propobility层的Tensorflow模型在保存时会产生错误

**使用以下代码创建模型**


model = Sequential([
Conv2D(8, 5, activation='relu', padding='valid', input_shape=input_shape),
MaxPooling2D(6),
Flatten(),
Dense(10),
tfpl.OneHotCategorical(10)
])
model.compile(optimizer=optimizer, loss=loss, metrics=metrics)

probabilistic_model = get_probabilistic_model(
input_shape=(28, 28, 1), 
loss=nll, 
optimizer=RMSprop(), 
metrics=['accuracy']

probabilistic_model.fit(x_train, y_train_oh, epochs=5)

用于保存模型


probabilistic_model.save('/tmp/model/probabilistic_model')

保存步骤会产生如下所示的错误。

OperatorNotAllowedInGraphError            Traceback (most recent call last)
/tmp/ipykernel_11377/1109926494.py in <module>
----> 1 probabilistic_model.save('/tmp/model/probabilistic_model')
~/tf2/lib/python3.8/site-packages/tensorflow/python/keras/engine/training.py in save(self, filepath, overwrite, include_optimizer, save_format, signatures, options, save_traces)
2109     """
2110     # pylint: enable=line-too-long
-> 2111     save.save_model(self, filepath, overwrite, include_optimizer, save_format,
2112                     signatures, options, save_traces)
2113 
~/tf2/lib/python3.8/site-packages/tensorflow/python/keras/saving/save.py in save_model(model, filepath, overwrite, include_optimizer, save_format, signatures, options, save_traces)
148   else:
149     with generic_utils.SharedObjectSavingScope():
--> 150       saved_model_save.save(model, filepath, overwrite, include_optimizer,
151                             signatures, options, save_traces)
152 
~/tf2/lib/python3.8/site-packages/tensorflow/python/keras/saving/saved_model/save.py in save(model, filepath, overwrite, include_optimizer, signatures, options, save_traces)
87   with K.deprecated_internal_learning_phase_scope(0):
88     with utils.keras_option_scope(save_traces):
---> 89       saved_nodes, node_paths = save_lib.save_and_return_nodes(
90           model, filepath, signatures, options)
91 
~/tf2/lib/python3.8/site-packages/tensorflow/python/saved_model/save.py in save_and_return_nodes(obj, export_dir, signatures, options, raise_metadata_warning, experimental_skip_checkpoint)
1101 
1102   _, exported_graph, object_saver, asset_info, saved_nodes, node_paths = (
-> 1103       _build_meta_graph(obj, signatures, options, meta_graph_def,
1104                         raise_metadata_warning))
1105   saved_model.saved_model_schema_version = constants.SAVED_MODEL_SCHEMA_VERSION
~/tf2/lib/python3.8/site-packages/tensorflow/python/saved_model/save.py in _build_meta_graph(obj, signatures, options, meta_graph_def, raise_metadata_warning)
1288 
1289   with save_context.save_context(options):
-> 1290     return _build_meta_graph_impl(obj, signatures, options, meta_graph_def,
1291                                   raise_metadata_warning)
~/tf2/lib/python3.8/site-packages/tensorflow/python/saved_model/save.py in _build_meta_graph_impl(obj, signatures, options, meta_graph_def, raise_metadata_warning)
1205   checkpoint_graph_view = _AugmentedGraphView(obj)
1206   if signatures is None:
-> 1207     signatures = signature_serialization.find_function_to_export(
1208         checkpoint_graph_view)
1209 
~/tf2/lib/python3.8/site-packages/tensorflow/python/saved_model/signature_serialization.py in find_function_to_export(saveable_view)
97   # If the user did not specify signatures, check the root object for a function
98   # that can be made into a signature.
---> 99   functions = saveable_view.list_functions(saveable_view.root)
100   signature = functions.get(DEFAULT_SIGNATURE_ATTR, None)
101   if signature is not None:
~/tf2/lib/python3.8/site-packages/tensorflow/python/saved_model/save.py in list_functions(self, obj)
152     obj_functions = self._functions.get(obj, None)
153     if obj_functions is None:
--> 154       obj_functions = obj._list_functions_for_serialization(  # pylint: disable=protected-access
155           self._serialization_cache)
156       self._functions[obj] = obj_functions
~/tf2/lib/python3.8/site-packages/tensorflow/python/keras/engine/training.py in _list_functions_for_serialization(self, serialization_cache)
2711     self.test_function = None
2712     self.predict_function = None
-> 2713     functions = super(
2714         Model, self)._list_functions_for_serialization(serialization_cache)
2715     self.train_function = train_function
~/tf2/lib/python3.8/site-packages/tensorflow/python/keras/engine/base_layer.py in _list_functions_for_serialization(self, serialization_cache)
3014 
3015   def _list_functions_for_serialization(self, serialization_cache):
-> 3016     return (self._trackable_saved_model_saver
3017             .list_functions_for_serialization(serialization_cache))
3018 
~/tf2/lib/python3.8/site-packages/tensorflow/python/keras/saving/saved_model/base_serialization.py in list_functions_for_serialization(self, serialization_cache)
90       return {}
91 
---> 92     fns = self.functions_to_serialize(serialization_cache)
93 
94     # The parent AutoTrackable class saves all user-defined tf.functions, and
~/tf2/lib/python3.8/site-packages/tensorflow/python/keras/saving/saved_model/layer_serialization.py in functions_to_serialize(self, serialization_cache)
71 
72   def functions_to_serialize(self, serialization_cache):
---> 73     return (self._get_serialized_attributes(
74         serialization_cache).functions_to_serialize)
75 
~/tf2/lib/python3.8/site-packages/tensorflow/python/keras/saving/saved_model/layer_serialization.py in _get_serialized_attributes(self, serialization_cache)
87       return serialized_attr
88 
---> 89     object_dict, function_dict = self._get_serialized_attributes_internal(
90         serialization_cache)
91 
~/tf2/lib/python3.8/site-packages/tensorflow/python/keras/saving/saved_model/model_serialization.py in _get_serialized_attributes_internal(self, serialization_cache)
51     # the ones serialized by Layer.
52     objects, functions = (
---> 53         super(ModelSavedModelSaver, self)._get_serialized_attributes_internal(
54             serialization_cache))
55     functions['_default_save_signature'] = default_signature
~/tf2/lib/python3.8/site-packages/tensorflow/python/keras/saving/saved_model/layer_serialization.py in _get_serialized_attributes_internal(self, serialization_cache)
97     """Returns dictionary of serialized attributes."""
98     objects = save_impl.wrap_layer_objects(self.obj, serialization_cache)
---> 99     functions = save_impl.wrap_layer_functions(self.obj, serialization_cache)
100     # Attribute validator requires that the default save signature is added to
101     # function dict, even if the value is None.
~/tf2/lib/python3.8/site-packages/tensorflow/python/keras/saving/saved_model/save_impl.py in wrap_layer_functions(layer, serialization_cache)
202           if isinstance(fn, LayerCall):
203             fn = fn.wrapped_call
--> 204           fn.get_concrete_function()
205 
206   # Restore overwritten functions and losses
/usr/lib/python3.8/contextlib.py in __exit__(self, type, value, traceback)
118         if type is None:
119             try:
--> 120                 next(self.gen)
121             except StopIteration:
122                 return False
~/tf2/lib/python3.8/site-packages/tensorflow/python/keras/saving/saved_model/save_impl.py in tracing_scope()
365       if training is not None:
366         with K.deprecated_internal_learning_phase_scope(training):
--> 367           fn.get_concrete_function(*args, **kwargs)
368       else:
369         fn.get_concrete_function(*args, **kwargs)
~/tf2/lib/python3.8/site-packages/tensorflow/python/eager/def_function.py in get_concrete_function(self, *args, **kwargs)
1365       ValueError: if this object has not yet been called on concrete values.
1366     """
-> 1367     concrete = self._get_concrete_function_garbage_collected(*args, **kwargs)
1368     concrete._garbage_collector.release()  # pylint: disable=protected-access
1369     return concrete
~/tf2/lib/python3.8/site-packages/tensorflow/python/eager/def_function.py in _get_concrete_function_garbage_collected(self, *args, **kwargs)
1282       # In this case we have not created variables on the first call. So we can
1283       # run the first trace but we should fail if variables are created.
-> 1284       concrete = self._stateful_fn._get_concrete_function_garbage_collected(  # pylint: disable=protected-access
1285           *args, **kwargs)
1286       if self._created_variables:
~/tf2/lib/python3.8/site-packages/tensorflow/python/eager/function.py in _get_concrete_function_garbage_collected(self, *args, **kwargs)
3098       args, kwargs = None, None
3099     with self._lock:
-> 3100       graph_function, _ = self._maybe_define_function(args, kwargs)
3101       seen_names = set()
3102       captured = object_identity.ObjectIdentitySet(
~/tf2/lib/python3.8/site-packages/tensorflow/python/eager/function.py in _maybe_define_function(self, args, kwargs)
3442 
3443           self._function_cache.missed.add(call_context_key)
-> 3444           graph_function = self._create_graph_function(args, kwargs)
3445           self._function_cache.primary[cache_key] = graph_function
3446 
~/tf2/lib/python3.8/site-packages/tensorflow/python/eager/function.py in _create_graph_function(self, args, kwargs, override_flat_arg_shapes)
3277     arg_names = base_arg_names + missing_arg_names
3278     graph_function = ConcreteFunction(
-> 3279         func_graph_module.func_graph_from_py_func(
3280             self._name,
3281             self._python_function,
~/tf2/lib/python3.8/site-packages/tensorflow/python/framework/func_graph.py in func_graph_from_py_func(name, python_func, args, kwargs, signature, func_graph, autograph, autograph_options, add_control_dependencies, arg_names, op_return_value, collections, capture_by_value, override_flat_arg_shapes)
997         _, original_func = tf_decorator.unwrap(python_func)
998 
--> 999       func_outputs = python_func(*func_args, **func_kwargs)
1000 
1001       # invariant: `func_outputs` contains only Tensors, CompositeTensors,
~/tf2/lib/python3.8/site-packages/tensorflow/python/eager/def_function.py in wrapped_fn(*args, **kwds)
670         # the function a weak reference to itself to avoid a reference cycle.
671         with OptionalXlaContext(compile_with_xla):
--> 672           out = weak_wrapped_fn().__wrapped__(*args, **kwds)
673         return out
674 
~/tf2/lib/python3.8/site-packages/tensorflow/python/keras/saving/saved_model/save_impl.py in wrapper(*args, **kwargs)
597       with autocast_variable.enable_auto_cast_variables(
598           layer._compute_dtype_object):  # pylint: disable=protected-access
--> 599         ret = method(*args, **kwargs)
600     _restore_layer_losses(original_losses)
601     return ret
~/tf2/lib/python3.8/site-packages/tensorflow/python/keras/saving/saved_model/utils.py in wrap_with_training_arg(*args, **kwargs)
163       return wrapped_call(*args, **kwargs)
164 
--> 165     return control_flow_util.smart_cond(
166         training, lambda: replace_training_and_call(True),
167         lambda: replace_training_and_call(False))
~/tf2/lib/python3.8/site-packages/tensorflow/python/keras/utils/control_flow_util.py in smart_cond(pred, true_fn, false_fn, name)
107     return control_flow_ops.cond(
108         pred, true_fn=true_fn, false_fn=false_fn, name=name)
--> 109   return smart_module.smart_cond(
110       pred, true_fn=true_fn, false_fn=false_fn, name=name)
111 
~/tf2/lib/python3.8/site-packages/tensorflow/python/framework/smart_cond.py in smart_cond(pred, true_fn, false_fn, name)
52   if pred_value is not None:
53     if pred_value:
---> 54       return true_fn()
55     else:
56       return false_fn()
~/tf2/lib/python3.8/site-packages/tensorflow/python/keras/saving/saved_model/utils.py in <lambda>()
164 
165     return control_flow_util.smart_cond(
--> 166         training, lambda: replace_training_and_call(True),
167         lambda: replace_training_and_call(False))
168 
~/tf2/lib/python3.8/site-packages/tensorflow/python/keras/saving/saved_model/utils.py in replace_training_and_call(training)
161     def replace_training_and_call(training):
162       set_training_arg(training, training_arg_index, args, kwargs)
--> 163       return wrapped_call(*args, **kwargs)
164 
165     return control_flow_util.smart_cond(
~/tf2/lib/python3.8/site-packages/tensorflow/python/keras/saving/saved_model/save_impl.py in call(inputs, *args, **kwargs)
679     return layer.keras_api.__call__  # pylint: disable=protected-access
680   def call(inputs, *args, **kwargs):
--> 681     return call_and_return_conditional_losses(inputs, *args, **kwargs)[0]
682   return _create_call_fn_decorator(layer, call)
683 
~/tf2/lib/python3.8/site-packages/tensorflow/python/keras/saving/saved_model/save_impl.py in __call__(self, *args, **kwargs)
637   def __call__(self, *args, **kwargs):
638     self._maybe_trace(args, kwargs)
--> 639     return self.wrapped_call(*args, **kwargs)
640 
641   def get_concrete_function(self, *args, **kwargs):
~/tf2/lib/python3.8/site-packages/tensorflow/python/eager/def_function.py in __call__(self, *args, **kwds)
887 
888       with OptionalXlaContext(self._jit_compile):
--> 889         result = self._call(*args, **kwds)
890 
891       new_tracing_count = self.experimental_get_tracing_count()
~/tf2/lib/python3.8/site-packages/tensorflow/python/eager/def_function.py in _call(self, *args, **kwds)
922       # In this case we have not created variables on the first call. So we can
923       # run the first trace but we should fail if variables are created.
--> 924       results = self._stateful_fn(*args, **kwds)
925       if self._created_variables:
926         raise ValueError("Creating variables on a non-first call to a function"
~/tf2/lib/python3.8/site-packages/tensorflow/python/eager/function.py in __call__(self, *args, **kwargs)
3020     with self._lock:
3021       (graph_function,
-> 3022        filtered_flat_args) = self._maybe_define_function(args, kwargs)
3023     return graph_function._call_flat(
3024         filtered_flat_args, captured_inputs=graph_function.captured_inputs)  # pylint: disable=protected-access
~/tf2/lib/python3.8/site-packages/tensorflow/python/eager/function.py in _maybe_define_function(self, args, kwargs)
3442 
3443           self._function_cache.missed.add(call_context_key)
-> 3444           graph_function = self._create_graph_function(args, kwargs)
3445           self._function_cache.primary[cache_key] = graph_function
3446 
~/tf2/lib/python3.8/site-packages/tensorflow/python/eager/function.py in _create_graph_function(self, args, kwargs, override_flat_arg_shapes)
3277     arg_names = base_arg_names + missing_arg_names
3278     graph_function = ConcreteFunction(
-> 3279         func_graph_module.func_graph_from_py_func(
3280             self._name,
3281             self._python_function,
~/tf2/lib/python3.8/site-packages/tensorflow/python/framework/func_graph.py in func_graph_from_py_func(name, python_func, args, kwargs, signature, func_graph, autograph, autograph_options, add_control_dependencies, arg_names, op_return_value, collections, capture_by_value, override_flat_arg_shapes)
997         _, original_func = tf_decorator.unwrap(python_func)
998 
--> 999       func_outputs = python_func(*func_args, **func_kwargs)
1000 
1001       # invariant: `func_outputs` contains only Tensors, CompositeTensors,
~/tf2/lib/python3.8/site-packages/tensorflow/python/eager/def_function.py in wrapped_fn(*args, **kwds)
670         # the function a weak reference to itself to avoid a reference cycle.
671         with OptionalXlaContext(compile_with_xla):
--> 672           out = weak_wrapped_fn().__wrapped__(*args, **kwds)
673         return out
674 
~/tf2/lib/python3.8/site-packages/tensorflow/python/keras/saving/saved_model/save_impl.py in wrapper(*args, **kwargs)
597       with autocast_variable.enable_auto_cast_variables(
598           layer._compute_dtype_object):  # pylint: disable=protected-access
--> 599         ret = method(*args, **kwargs)
600     _restore_layer_losses(original_losses)
601     return ret
~/tf2/lib/python3.8/site-packages/tensorflow/python/keras/saving/saved_model/utils.py in wrap_with_training_arg(*args, **kwargs)
163       return wrapped_call(*args, **kwargs)
164 
--> 165     return control_flow_util.smart_cond(
166         training, lambda: replace_training_and_call(True),
167         lambda: replace_training_and_call(False))
~/tf2/lib/python3.8/site-packages/tensorflow/python/keras/utils/control_flow_util.py in smart_cond(pred, true_fn, false_fn, name)
107     return control_flow_ops.cond(
108         pred, true_fn=true_fn, false_fn=false_fn, name=name)
--> 109   return smart_module.smart_cond(
110       pred, true_fn=true_fn, false_fn=false_fn, name=name)
111 
~/tf2/lib/python3.8/site-packages/tensorflow/python/framework/smart_cond.py in smart_cond(pred, true_fn, false_fn, name)
52   if pred_value is not None:
53     if pred_value:
---> 54       return true_fn()
55     else:
56       return false_fn()
~/tf2/lib/python3.8/site-packages/tensorflow/python/keras/saving/saved_model/utils.py in <lambda>()
164 
165     return control_flow_util.smart_cond(
--> 166         training, lambda: replace_training_and_call(True),
167         lambda: replace_training_and_call(False))
168 
~/tf2/lib/python3.8/site-packages/tensorflow/python/keras/saving/saved_model/utils.py in replace_training_and_call(training)
161     def replace_training_and_call(training):
162       set_training_arg(training, training_arg_index, args, kwargs)
--> 163       return wrapped_call(*args, **kwargs)
164 
165     return control_flow_util.smart_cond(
~/tf2/lib/python3.8/site-packages/tensorflow/python/keras/saving/saved_model/save_impl.py in call_and_return_conditional_losses(*args, **kwargs)
661   def call_and_return_conditional_losses(*args, **kwargs):
662     """Returns layer (call_output, conditional losses) tuple."""
--> 663     call_output = layer_call(*args, **kwargs)
664     if version_utils.is_v1_layer_or_model(layer):
665       conditional_losses = layer.get_losses_for(
~/tf2/lib/python3.8/site-packages/tensorflow/python/keras/engine/sequential.py in call(self, inputs, training, mask)
378       if not self.built:
379         self._init_graph_network(self.inputs, self.outputs)
--> 380       return super(Sequential, self).call(inputs, training=training, mask=mask)
381 
382     outputs = inputs  # handle the corner case where self.layers is empty
~/tf2/lib/python3.8/site-packages/tensorflow/python/keras/engine/functional.py in call(self, inputs, training, mask)
418         a list of tensors if there are more than one outputs.
419     """
--> 420     return self._run_internal_graph(
421         inputs, training=training, mask=mask)
422 
~/tf2/lib/python3.8/site-packages/tensorflow/python/keras/engine/functional.py in _run_internal_graph(self, inputs, training, mask)
554 
555         args, kwargs = node.map_arguments(tensor_dict)
--> 556         outputs = node.layer(*args, **kwargs)
557 
558         # Update tensor_dict.
~/tf2/lib/python3.8/site-packages/tensorflow_probability/python/layers/distribution_layer.py in __call__(self, inputs, *args, **kwargs)
228   def __call__(self, inputs, *args, **kwargs):
229     self._enter_dunder_call = True
--> 230     distribution, _ = super(DistributionLambda, self).__call__(
231         inputs, *args, **kwargs)
232     self._enter_dunder_call = False
~/tf2/lib/python3.8/site-packages/tensorflow/python/framework/ops.py in __iter__(self)
518   def __iter__(self):
519     if not context.executing_eagerly():
--> 520       self._disallow_iteration()
521 
522     shape = self._shape_tuple()
~/tf2/lib/python3.8/site-packages/tensorflow/python/framework/ops.py in _disallow_iteration(self)
511       self._disallow_when_autograph_disabled("iterating over `tf.Tensor`")
512     elif ag_ctx.control_status_ctx().status == ag_ctx.Status.ENABLED:
--> 513       self._disallow_when_autograph_enabled("iterating over `tf.Tensor`")
514     else:
515       # Default: V1-style Graph execution.
~/tf2/lib/python3.8/site-packages/tensorflow/python/framework/ops.py in _disallow_when_autograph_enabled(self, task)
487 
488   def _disallow_when_autograph_enabled(self, task):
--> 489     raise errors.OperatorNotAllowedInGraphError(
490         "{} is not allowed: AutoGraph did convert this function. This might"
491         " indicate you are trying to use an unsupported feature.".format(task))
OperatorNotAllowedInGraphError: iterating over `tf.Tensor` is not allowed: AutoGraph did convert this function. This might indicate you are trying to use an unsupported feature.*_

解决方案的工作能力有限如中所示https://github.com/tensorflow/probability/issues/325#issuecomment-477213850但这只会节省权重,而不会节省模型的其他细节。

Workaround适用于h5格式h5格式保存有效,但无法加载型号

loaded_model = tf.keras.models.load_model('/tmp/model/probabilistic_model.h5')

使用h5格式保存然后加载模型时出错,如下所示。

---------------------------------------------------------------------------
ValueError                                Traceback (most recent call last)
/tmp/ipykernel_11377/686337657.py in <module>
----> 1 loaded_model = tf.keras.models.load_model('/tmp/model/probabilistic_model.h5')
~/tf2/lib/python3.8/site-packages/tensorflow/python/keras/saving/save.py in load_model(filepath, custom_objects, compile, options)
199         if (h5py is not None and
200             (isinstance(filepath, h5py.File) or h5py.is_hdf5(filepath))):
--> 201           return hdf5_format.load_model_from_hdf5(filepath, custom_objects,
202                                                   compile)
203 
~/tf2/lib/python3.8/site-packages/tensorflow/python/keras/saving/hdf5_format.py in load_model_from_hdf5(filepath, custom_objects, compile)
178       model_config = model_config.decode('utf-8')
179     model_config = json_utils.decode(model_config)
--> 180     model = model_config_lib.model_from_config(model_config,
181                                                custom_objects=custom_objects)
182 
~/tf2/lib/python3.8/site-packages/tensorflow/python/keras/saving/model_config.py in model_from_config(config, custom_objects)
57                     '`Sequential.from_config(config)`?')
58   from tensorflow.python.keras.layers import deserialize  # pylint: disable=g-import-not-at-top
---> 59   return deserialize(config, custom_objects=custom_objects)
60 
61 
~/tf2/lib/python3.8/site-packages/tensorflow/python/keras/layers/serialization.py in deserialize(config, custom_objects)
157   """
158   populate_deserializable_objects()
--> 159   return generic_utils.deserialize_keras_object(
160       config,
161       module_objects=LOCAL.ALL_OBJECTS,
~/tf2/lib/python3.8/site-packages/tensorflow/python/keras/utils/generic_utils.py in deserialize_keras_object(identifier, module_objects, custom_objects, printable_module_name)
666 
667       if 'custom_objects' in arg_spec.args:
--> 668         deserialized_obj = cls.from_config(
669             cls_config,
670             custom_objects=dict(
~/tf2/lib/python3.8/site-packages/tensorflow/python/keras/engine/sequential.py in from_config(cls, config, custom_objects)
495     model = cls(name=name)
496     for layer_config in layer_configs:
--> 497       layer = layer_module.deserialize(layer_config,
498                                        custom_objects=custom_objects)
499       model.add(layer)
~/tf2/lib/python3.8/site-packages/tensorflow/python/keras/layers/serialization.py in deserialize(config, custom_objects)
157   """
158   populate_deserializable_objects()
--> 159   return generic_utils.deserialize_keras_object(
160       config,
161       module_objects=LOCAL.ALL_OBJECTS,
~/tf2/lib/python3.8/site-packages/tensorflow/python/keras/utils/generic_utils.py in deserialize_keras_object(identifier, module_objects, custom_objects, printable_module_name)
651     # In this case we are dealing with a Keras config dictionary.
652     config = identifier
--> 653     (cls, cls_config) = class_and_config_for_serialized_keras_object(
654         config, module_objects, custom_objects, printable_module_name)
655 
~/tf2/lib/python3.8/site-packages/tensorflow/python/keras/utils/generic_utils.py in class_and_config_for_serialized_keras_object(config, module_objects, custom_objects, printable_module_name)
554   cls = get_registered_object(class_name, custom_objects, module_objects)
555   if cls is None:
--> 556     raise ValueError(enter code here
557         'Unknown {}: {}. Please ensure this object is '
558         'passed to the `custom_objects` argument. See '
ValueError: Unknown layer: OneHotCategorical. Please ensure this object is passed to the `custom_objects` argument. See https://www.tensorflow.org/guide/keras/save_and_serialize#registering_the_custom_object for details.

描述预期行为将概率模型保存为TensorFlow保存模型

custom_objects = {"OneHotCategorical": tfp.layers.OneHotCategorical}
with tf.keras.utils.custom_object_scope(custom_objects):
restored_model = tf.keras.models.load_model(saved_path)

相关内容

最新更新