系统信息
- 我是否编写了自定义代码(而不是使用TensorFlow中提供的股票示例脚本(:是
- 操作系统平台和发行版(例如,Linux Ubuntu 16.04(:Ubuntu 20.04 LTS
- TensorFlow安装自(源或二进制(:二进制
- TensorFlow版本(使用下面的命令(:v2.5.0-rc3-213-ga4dfb8d1a71 2.5.0
- Tensorflow_propobility版本:"0.13.0">
- Python版本:3.8.10
- CUDA/cuDNN版本:CUDA_111.2.r11.2/compiler.29373293_0
- GPU型号和内存:12Gb TitanXP
描述当前行为具有Tensorflow_propobility层的Tensorflow模型在保存时会产生错误
**使用以下代码创建模型**
model = Sequential([
Conv2D(8, 5, activation='relu', padding='valid', input_shape=input_shape),
MaxPooling2D(6),
Flatten(),
Dense(10),
tfpl.OneHotCategorical(10)
])
model.compile(optimizer=optimizer, loss=loss, metrics=metrics)
probabilistic_model = get_probabilistic_model(
input_shape=(28, 28, 1),
loss=nll,
optimizer=RMSprop(),
metrics=['accuracy']
probabilistic_model.fit(x_train, y_train_oh, epochs=5)
用于保存模型
probabilistic_model.save('/tmp/model/probabilistic_model')
保存步骤会产生如下所示的错误。
OperatorNotAllowedInGraphError Traceback (most recent call last)
/tmp/ipykernel_11377/1109926494.py in <module>
----> 1 probabilistic_model.save('/tmp/model/probabilistic_model')
~/tf2/lib/python3.8/site-packages/tensorflow/python/keras/engine/training.py in save(self, filepath, overwrite, include_optimizer, save_format, signatures, options, save_traces)
2109 """
2110 # pylint: enable=line-too-long
-> 2111 save.save_model(self, filepath, overwrite, include_optimizer, save_format,
2112 signatures, options, save_traces)
2113
~/tf2/lib/python3.8/site-packages/tensorflow/python/keras/saving/save.py in save_model(model, filepath, overwrite, include_optimizer, save_format, signatures, options, save_traces)
148 else:
149 with generic_utils.SharedObjectSavingScope():
--> 150 saved_model_save.save(model, filepath, overwrite, include_optimizer,
151 signatures, options, save_traces)
152
~/tf2/lib/python3.8/site-packages/tensorflow/python/keras/saving/saved_model/save.py in save(model, filepath, overwrite, include_optimizer, signatures, options, save_traces)
87 with K.deprecated_internal_learning_phase_scope(0):
88 with utils.keras_option_scope(save_traces):
---> 89 saved_nodes, node_paths = save_lib.save_and_return_nodes(
90 model, filepath, signatures, options)
91
~/tf2/lib/python3.8/site-packages/tensorflow/python/saved_model/save.py in save_and_return_nodes(obj, export_dir, signatures, options, raise_metadata_warning, experimental_skip_checkpoint)
1101
1102 _, exported_graph, object_saver, asset_info, saved_nodes, node_paths = (
-> 1103 _build_meta_graph(obj, signatures, options, meta_graph_def,
1104 raise_metadata_warning))
1105 saved_model.saved_model_schema_version = constants.SAVED_MODEL_SCHEMA_VERSION
~/tf2/lib/python3.8/site-packages/tensorflow/python/saved_model/save.py in _build_meta_graph(obj, signatures, options, meta_graph_def, raise_metadata_warning)
1288
1289 with save_context.save_context(options):
-> 1290 return _build_meta_graph_impl(obj, signatures, options, meta_graph_def,
1291 raise_metadata_warning)
~/tf2/lib/python3.8/site-packages/tensorflow/python/saved_model/save.py in _build_meta_graph_impl(obj, signatures, options, meta_graph_def, raise_metadata_warning)
1205 checkpoint_graph_view = _AugmentedGraphView(obj)
1206 if signatures is None:
-> 1207 signatures = signature_serialization.find_function_to_export(
1208 checkpoint_graph_view)
1209
~/tf2/lib/python3.8/site-packages/tensorflow/python/saved_model/signature_serialization.py in find_function_to_export(saveable_view)
97 # If the user did not specify signatures, check the root object for a function
98 # that can be made into a signature.
---> 99 functions = saveable_view.list_functions(saveable_view.root)
100 signature = functions.get(DEFAULT_SIGNATURE_ATTR, None)
101 if signature is not None:
~/tf2/lib/python3.8/site-packages/tensorflow/python/saved_model/save.py in list_functions(self, obj)
152 obj_functions = self._functions.get(obj, None)
153 if obj_functions is None:
--> 154 obj_functions = obj._list_functions_for_serialization( # pylint: disable=protected-access
155 self._serialization_cache)
156 self._functions[obj] = obj_functions
~/tf2/lib/python3.8/site-packages/tensorflow/python/keras/engine/training.py in _list_functions_for_serialization(self, serialization_cache)
2711 self.test_function = None
2712 self.predict_function = None
-> 2713 functions = super(
2714 Model, self)._list_functions_for_serialization(serialization_cache)
2715 self.train_function = train_function
~/tf2/lib/python3.8/site-packages/tensorflow/python/keras/engine/base_layer.py in _list_functions_for_serialization(self, serialization_cache)
3014
3015 def _list_functions_for_serialization(self, serialization_cache):
-> 3016 return (self._trackable_saved_model_saver
3017 .list_functions_for_serialization(serialization_cache))
3018
~/tf2/lib/python3.8/site-packages/tensorflow/python/keras/saving/saved_model/base_serialization.py in list_functions_for_serialization(self, serialization_cache)
90 return {}
91
---> 92 fns = self.functions_to_serialize(serialization_cache)
93
94 # The parent AutoTrackable class saves all user-defined tf.functions, and
~/tf2/lib/python3.8/site-packages/tensorflow/python/keras/saving/saved_model/layer_serialization.py in functions_to_serialize(self, serialization_cache)
71
72 def functions_to_serialize(self, serialization_cache):
---> 73 return (self._get_serialized_attributes(
74 serialization_cache).functions_to_serialize)
75
~/tf2/lib/python3.8/site-packages/tensorflow/python/keras/saving/saved_model/layer_serialization.py in _get_serialized_attributes(self, serialization_cache)
87 return serialized_attr
88
---> 89 object_dict, function_dict = self._get_serialized_attributes_internal(
90 serialization_cache)
91
~/tf2/lib/python3.8/site-packages/tensorflow/python/keras/saving/saved_model/model_serialization.py in _get_serialized_attributes_internal(self, serialization_cache)
51 # the ones serialized by Layer.
52 objects, functions = (
---> 53 super(ModelSavedModelSaver, self)._get_serialized_attributes_internal(
54 serialization_cache))
55 functions['_default_save_signature'] = default_signature
~/tf2/lib/python3.8/site-packages/tensorflow/python/keras/saving/saved_model/layer_serialization.py in _get_serialized_attributes_internal(self, serialization_cache)
97 """Returns dictionary of serialized attributes."""
98 objects = save_impl.wrap_layer_objects(self.obj, serialization_cache)
---> 99 functions = save_impl.wrap_layer_functions(self.obj, serialization_cache)
100 # Attribute validator requires that the default save signature is added to
101 # function dict, even if the value is None.
~/tf2/lib/python3.8/site-packages/tensorflow/python/keras/saving/saved_model/save_impl.py in wrap_layer_functions(layer, serialization_cache)
202 if isinstance(fn, LayerCall):
203 fn = fn.wrapped_call
--> 204 fn.get_concrete_function()
205
206 # Restore overwritten functions and losses
/usr/lib/python3.8/contextlib.py in __exit__(self, type, value, traceback)
118 if type is None:
119 try:
--> 120 next(self.gen)
121 except StopIteration:
122 return False
~/tf2/lib/python3.8/site-packages/tensorflow/python/keras/saving/saved_model/save_impl.py in tracing_scope()
365 if training is not None:
366 with K.deprecated_internal_learning_phase_scope(training):
--> 367 fn.get_concrete_function(*args, **kwargs)
368 else:
369 fn.get_concrete_function(*args, **kwargs)
~/tf2/lib/python3.8/site-packages/tensorflow/python/eager/def_function.py in get_concrete_function(self, *args, **kwargs)
1365 ValueError: if this object has not yet been called on concrete values.
1366 """
-> 1367 concrete = self._get_concrete_function_garbage_collected(*args, **kwargs)
1368 concrete._garbage_collector.release() # pylint: disable=protected-access
1369 return concrete
~/tf2/lib/python3.8/site-packages/tensorflow/python/eager/def_function.py in _get_concrete_function_garbage_collected(self, *args, **kwargs)
1282 # In this case we have not created variables on the first call. So we can
1283 # run the first trace but we should fail if variables are created.
-> 1284 concrete = self._stateful_fn._get_concrete_function_garbage_collected( # pylint: disable=protected-access
1285 *args, **kwargs)
1286 if self._created_variables:
~/tf2/lib/python3.8/site-packages/tensorflow/python/eager/function.py in _get_concrete_function_garbage_collected(self, *args, **kwargs)
3098 args, kwargs = None, None
3099 with self._lock:
-> 3100 graph_function, _ = self._maybe_define_function(args, kwargs)
3101 seen_names = set()
3102 captured = object_identity.ObjectIdentitySet(
~/tf2/lib/python3.8/site-packages/tensorflow/python/eager/function.py in _maybe_define_function(self, args, kwargs)
3442
3443 self._function_cache.missed.add(call_context_key)
-> 3444 graph_function = self._create_graph_function(args, kwargs)
3445 self._function_cache.primary[cache_key] = graph_function
3446
~/tf2/lib/python3.8/site-packages/tensorflow/python/eager/function.py in _create_graph_function(self, args, kwargs, override_flat_arg_shapes)
3277 arg_names = base_arg_names + missing_arg_names
3278 graph_function = ConcreteFunction(
-> 3279 func_graph_module.func_graph_from_py_func(
3280 self._name,
3281 self._python_function,
~/tf2/lib/python3.8/site-packages/tensorflow/python/framework/func_graph.py in func_graph_from_py_func(name, python_func, args, kwargs, signature, func_graph, autograph, autograph_options, add_control_dependencies, arg_names, op_return_value, collections, capture_by_value, override_flat_arg_shapes)
997 _, original_func = tf_decorator.unwrap(python_func)
998
--> 999 func_outputs = python_func(*func_args, **func_kwargs)
1000
1001 # invariant: `func_outputs` contains only Tensors, CompositeTensors,
~/tf2/lib/python3.8/site-packages/tensorflow/python/eager/def_function.py in wrapped_fn(*args, **kwds)
670 # the function a weak reference to itself to avoid a reference cycle.
671 with OptionalXlaContext(compile_with_xla):
--> 672 out = weak_wrapped_fn().__wrapped__(*args, **kwds)
673 return out
674
~/tf2/lib/python3.8/site-packages/tensorflow/python/keras/saving/saved_model/save_impl.py in wrapper(*args, **kwargs)
597 with autocast_variable.enable_auto_cast_variables(
598 layer._compute_dtype_object): # pylint: disable=protected-access
--> 599 ret = method(*args, **kwargs)
600 _restore_layer_losses(original_losses)
601 return ret
~/tf2/lib/python3.8/site-packages/tensorflow/python/keras/saving/saved_model/utils.py in wrap_with_training_arg(*args, **kwargs)
163 return wrapped_call(*args, **kwargs)
164
--> 165 return control_flow_util.smart_cond(
166 training, lambda: replace_training_and_call(True),
167 lambda: replace_training_and_call(False))
~/tf2/lib/python3.8/site-packages/tensorflow/python/keras/utils/control_flow_util.py in smart_cond(pred, true_fn, false_fn, name)
107 return control_flow_ops.cond(
108 pred, true_fn=true_fn, false_fn=false_fn, name=name)
--> 109 return smart_module.smart_cond(
110 pred, true_fn=true_fn, false_fn=false_fn, name=name)
111
~/tf2/lib/python3.8/site-packages/tensorflow/python/framework/smart_cond.py in smart_cond(pred, true_fn, false_fn, name)
52 if pred_value is not None:
53 if pred_value:
---> 54 return true_fn()
55 else:
56 return false_fn()
~/tf2/lib/python3.8/site-packages/tensorflow/python/keras/saving/saved_model/utils.py in <lambda>()
164
165 return control_flow_util.smart_cond(
--> 166 training, lambda: replace_training_and_call(True),
167 lambda: replace_training_and_call(False))
168
~/tf2/lib/python3.8/site-packages/tensorflow/python/keras/saving/saved_model/utils.py in replace_training_and_call(training)
161 def replace_training_and_call(training):
162 set_training_arg(training, training_arg_index, args, kwargs)
--> 163 return wrapped_call(*args, **kwargs)
164
165 return control_flow_util.smart_cond(
~/tf2/lib/python3.8/site-packages/tensorflow/python/keras/saving/saved_model/save_impl.py in call(inputs, *args, **kwargs)
679 return layer.keras_api.__call__ # pylint: disable=protected-access
680 def call(inputs, *args, **kwargs):
--> 681 return call_and_return_conditional_losses(inputs, *args, **kwargs)[0]
682 return _create_call_fn_decorator(layer, call)
683
~/tf2/lib/python3.8/site-packages/tensorflow/python/keras/saving/saved_model/save_impl.py in __call__(self, *args, **kwargs)
637 def __call__(self, *args, **kwargs):
638 self._maybe_trace(args, kwargs)
--> 639 return self.wrapped_call(*args, **kwargs)
640
641 def get_concrete_function(self, *args, **kwargs):
~/tf2/lib/python3.8/site-packages/tensorflow/python/eager/def_function.py in __call__(self, *args, **kwds)
887
888 with OptionalXlaContext(self._jit_compile):
--> 889 result = self._call(*args, **kwds)
890
891 new_tracing_count = self.experimental_get_tracing_count()
~/tf2/lib/python3.8/site-packages/tensorflow/python/eager/def_function.py in _call(self, *args, **kwds)
922 # In this case we have not created variables on the first call. So we can
923 # run the first trace but we should fail if variables are created.
--> 924 results = self._stateful_fn(*args, **kwds)
925 if self._created_variables:
926 raise ValueError("Creating variables on a non-first call to a function"
~/tf2/lib/python3.8/site-packages/tensorflow/python/eager/function.py in __call__(self, *args, **kwargs)
3020 with self._lock:
3021 (graph_function,
-> 3022 filtered_flat_args) = self._maybe_define_function(args, kwargs)
3023 return graph_function._call_flat(
3024 filtered_flat_args, captured_inputs=graph_function.captured_inputs) # pylint: disable=protected-access
~/tf2/lib/python3.8/site-packages/tensorflow/python/eager/function.py in _maybe_define_function(self, args, kwargs)
3442
3443 self._function_cache.missed.add(call_context_key)
-> 3444 graph_function = self._create_graph_function(args, kwargs)
3445 self._function_cache.primary[cache_key] = graph_function
3446
~/tf2/lib/python3.8/site-packages/tensorflow/python/eager/function.py in _create_graph_function(self, args, kwargs, override_flat_arg_shapes)
3277 arg_names = base_arg_names + missing_arg_names
3278 graph_function = ConcreteFunction(
-> 3279 func_graph_module.func_graph_from_py_func(
3280 self._name,
3281 self._python_function,
~/tf2/lib/python3.8/site-packages/tensorflow/python/framework/func_graph.py in func_graph_from_py_func(name, python_func, args, kwargs, signature, func_graph, autograph, autograph_options, add_control_dependencies, arg_names, op_return_value, collections, capture_by_value, override_flat_arg_shapes)
997 _, original_func = tf_decorator.unwrap(python_func)
998
--> 999 func_outputs = python_func(*func_args, **func_kwargs)
1000
1001 # invariant: `func_outputs` contains only Tensors, CompositeTensors,
~/tf2/lib/python3.8/site-packages/tensorflow/python/eager/def_function.py in wrapped_fn(*args, **kwds)
670 # the function a weak reference to itself to avoid a reference cycle.
671 with OptionalXlaContext(compile_with_xla):
--> 672 out = weak_wrapped_fn().__wrapped__(*args, **kwds)
673 return out
674
~/tf2/lib/python3.8/site-packages/tensorflow/python/keras/saving/saved_model/save_impl.py in wrapper(*args, **kwargs)
597 with autocast_variable.enable_auto_cast_variables(
598 layer._compute_dtype_object): # pylint: disable=protected-access
--> 599 ret = method(*args, **kwargs)
600 _restore_layer_losses(original_losses)
601 return ret
~/tf2/lib/python3.8/site-packages/tensorflow/python/keras/saving/saved_model/utils.py in wrap_with_training_arg(*args, **kwargs)
163 return wrapped_call(*args, **kwargs)
164
--> 165 return control_flow_util.smart_cond(
166 training, lambda: replace_training_and_call(True),
167 lambda: replace_training_and_call(False))
~/tf2/lib/python3.8/site-packages/tensorflow/python/keras/utils/control_flow_util.py in smart_cond(pred, true_fn, false_fn, name)
107 return control_flow_ops.cond(
108 pred, true_fn=true_fn, false_fn=false_fn, name=name)
--> 109 return smart_module.smart_cond(
110 pred, true_fn=true_fn, false_fn=false_fn, name=name)
111
~/tf2/lib/python3.8/site-packages/tensorflow/python/framework/smart_cond.py in smart_cond(pred, true_fn, false_fn, name)
52 if pred_value is not None:
53 if pred_value:
---> 54 return true_fn()
55 else:
56 return false_fn()
~/tf2/lib/python3.8/site-packages/tensorflow/python/keras/saving/saved_model/utils.py in <lambda>()
164
165 return control_flow_util.smart_cond(
--> 166 training, lambda: replace_training_and_call(True),
167 lambda: replace_training_and_call(False))
168
~/tf2/lib/python3.8/site-packages/tensorflow/python/keras/saving/saved_model/utils.py in replace_training_and_call(training)
161 def replace_training_and_call(training):
162 set_training_arg(training, training_arg_index, args, kwargs)
--> 163 return wrapped_call(*args, **kwargs)
164
165 return control_flow_util.smart_cond(
~/tf2/lib/python3.8/site-packages/tensorflow/python/keras/saving/saved_model/save_impl.py in call_and_return_conditional_losses(*args, **kwargs)
661 def call_and_return_conditional_losses(*args, **kwargs):
662 """Returns layer (call_output, conditional losses) tuple."""
--> 663 call_output = layer_call(*args, **kwargs)
664 if version_utils.is_v1_layer_or_model(layer):
665 conditional_losses = layer.get_losses_for(
~/tf2/lib/python3.8/site-packages/tensorflow/python/keras/engine/sequential.py in call(self, inputs, training, mask)
378 if not self.built:
379 self._init_graph_network(self.inputs, self.outputs)
--> 380 return super(Sequential, self).call(inputs, training=training, mask=mask)
381
382 outputs = inputs # handle the corner case where self.layers is empty
~/tf2/lib/python3.8/site-packages/tensorflow/python/keras/engine/functional.py in call(self, inputs, training, mask)
418 a list of tensors if there are more than one outputs.
419 """
--> 420 return self._run_internal_graph(
421 inputs, training=training, mask=mask)
422
~/tf2/lib/python3.8/site-packages/tensorflow/python/keras/engine/functional.py in _run_internal_graph(self, inputs, training, mask)
554
555 args, kwargs = node.map_arguments(tensor_dict)
--> 556 outputs = node.layer(*args, **kwargs)
557
558 # Update tensor_dict.
~/tf2/lib/python3.8/site-packages/tensorflow_probability/python/layers/distribution_layer.py in __call__(self, inputs, *args, **kwargs)
228 def __call__(self, inputs, *args, **kwargs):
229 self._enter_dunder_call = True
--> 230 distribution, _ = super(DistributionLambda, self).__call__(
231 inputs, *args, **kwargs)
232 self._enter_dunder_call = False
~/tf2/lib/python3.8/site-packages/tensorflow/python/framework/ops.py in __iter__(self)
518 def __iter__(self):
519 if not context.executing_eagerly():
--> 520 self._disallow_iteration()
521
522 shape = self._shape_tuple()
~/tf2/lib/python3.8/site-packages/tensorflow/python/framework/ops.py in _disallow_iteration(self)
511 self._disallow_when_autograph_disabled("iterating over `tf.Tensor`")
512 elif ag_ctx.control_status_ctx().status == ag_ctx.Status.ENABLED:
--> 513 self._disallow_when_autograph_enabled("iterating over `tf.Tensor`")
514 else:
515 # Default: V1-style Graph execution.
~/tf2/lib/python3.8/site-packages/tensorflow/python/framework/ops.py in _disallow_when_autograph_enabled(self, task)
487
488 def _disallow_when_autograph_enabled(self, task):
--> 489 raise errors.OperatorNotAllowedInGraphError(
490 "{} is not allowed: AutoGraph did convert this function. This might"
491 " indicate you are trying to use an unsupported feature.".format(task))
OperatorNotAllowedInGraphError: iterating over `tf.Tensor` is not allowed: AutoGraph did convert this function. This might indicate you are trying to use an unsupported feature.*_
解决方案的工作能力有限如中所示https://github.com/tensorflow/probability/issues/325#issuecomment-477213850但这只会节省权重,而不会节省模型的其他细节。
Workaround适用于h5格式h5格式保存有效,但无法加载型号
loaded_model = tf.keras.models.load_model('/tmp/model/probabilistic_model.h5')
使用h5格式保存然后加载模型时出错,如下所示。
---------------------------------------------------------------------------
ValueError Traceback (most recent call last)
/tmp/ipykernel_11377/686337657.py in <module>
----> 1 loaded_model = tf.keras.models.load_model('/tmp/model/probabilistic_model.h5')
~/tf2/lib/python3.8/site-packages/tensorflow/python/keras/saving/save.py in load_model(filepath, custom_objects, compile, options)
199 if (h5py is not None and
200 (isinstance(filepath, h5py.File) or h5py.is_hdf5(filepath))):
--> 201 return hdf5_format.load_model_from_hdf5(filepath, custom_objects,
202 compile)
203
~/tf2/lib/python3.8/site-packages/tensorflow/python/keras/saving/hdf5_format.py in load_model_from_hdf5(filepath, custom_objects, compile)
178 model_config = model_config.decode('utf-8')
179 model_config = json_utils.decode(model_config)
--> 180 model = model_config_lib.model_from_config(model_config,
181 custom_objects=custom_objects)
182
~/tf2/lib/python3.8/site-packages/tensorflow/python/keras/saving/model_config.py in model_from_config(config, custom_objects)
57 '`Sequential.from_config(config)`?')
58 from tensorflow.python.keras.layers import deserialize # pylint: disable=g-import-not-at-top
---> 59 return deserialize(config, custom_objects=custom_objects)
60
61
~/tf2/lib/python3.8/site-packages/tensorflow/python/keras/layers/serialization.py in deserialize(config, custom_objects)
157 """
158 populate_deserializable_objects()
--> 159 return generic_utils.deserialize_keras_object(
160 config,
161 module_objects=LOCAL.ALL_OBJECTS,
~/tf2/lib/python3.8/site-packages/tensorflow/python/keras/utils/generic_utils.py in deserialize_keras_object(identifier, module_objects, custom_objects, printable_module_name)
666
667 if 'custom_objects' in arg_spec.args:
--> 668 deserialized_obj = cls.from_config(
669 cls_config,
670 custom_objects=dict(
~/tf2/lib/python3.8/site-packages/tensorflow/python/keras/engine/sequential.py in from_config(cls, config, custom_objects)
495 model = cls(name=name)
496 for layer_config in layer_configs:
--> 497 layer = layer_module.deserialize(layer_config,
498 custom_objects=custom_objects)
499 model.add(layer)
~/tf2/lib/python3.8/site-packages/tensorflow/python/keras/layers/serialization.py in deserialize(config, custom_objects)
157 """
158 populate_deserializable_objects()
--> 159 return generic_utils.deserialize_keras_object(
160 config,
161 module_objects=LOCAL.ALL_OBJECTS,
~/tf2/lib/python3.8/site-packages/tensorflow/python/keras/utils/generic_utils.py in deserialize_keras_object(identifier, module_objects, custom_objects, printable_module_name)
651 # In this case we are dealing with a Keras config dictionary.
652 config = identifier
--> 653 (cls, cls_config) = class_and_config_for_serialized_keras_object(
654 config, module_objects, custom_objects, printable_module_name)
655
~/tf2/lib/python3.8/site-packages/tensorflow/python/keras/utils/generic_utils.py in class_and_config_for_serialized_keras_object(config, module_objects, custom_objects, printable_module_name)
554 cls = get_registered_object(class_name, custom_objects, module_objects)
555 if cls is None:
--> 556 raise ValueError(enter code here
557 'Unknown {}: {}. Please ensure this object is '
558 'passed to the `custom_objects` argument. See '
ValueError: Unknown layer: OneHotCategorical. Please ensure this object is passed to the `custom_objects` argument. See https://www.tensorflow.org/guide/keras/save_and_serialize#registering_the_custom_object for details.
描述预期行为将概率模型保存为TensorFlow保存模型
custom_objects = {"OneHotCategorical": tfp.layers.OneHotCategorical}
with tf.keras.utils.custom_object_scope(custom_objects):
restored_model = tf.keras.models.load_model(saved_path)