如何将 @tf.function 与 Keras 顺序 API 一起使用?



以下是我的代码片段,我想在 Keras API 中使用@tf.function装饰器,但它给了我一个错误:

@tf.function
def convnet(filters, strides, size, norm_type='instancenorm', apply_norm=True, relu = 'relu', apply_relu=True):
initializer = tf.random_normal_initializer(0., 0.02)
result = tf.keras.Sequential()
result.add(tf.keras.layers.Conv3D(filters, size, strides, padding='same',
kernel_initializer=initializer, use_bias=False, input_shape=(None, None, None, 3)))
if apply_norm:
if norm_type.lower() == 'batchnorm':
result.add(tf.keras.layers.BatchNormalization())
elif norm_type.lower() == 'instancenorm':
result.add(InstanceNormalization())
if apply_relu:
if relu == 'relu':
result.add(tf.keras.layers.ReLU())
elif relu == 'leakyrelu':
result.add(tf.keras.layers.LeakyReLU(alpha=0.2))
return result

我在执行它时收到以下错误:

---------------------------------------------------------------------------
TypeError                                 Traceback (most recent call last)
/usr/local/lib/python3.7/site-packages/tensorflow_core/python/framework/tensor_util.py in make_tensor_proto(values, dtype, shape, verify_shape, allow_broadcast)
540     try:
--> 541       str_values = [compat.as_bytes(x) for x in proto_values]
542     except TypeError:
/usr/local/lib/python3.7/site-packages/tensorflow_core/python/framework/tensor_util.py in <listcomp>(.0)
540     try:
--> 541       str_values = [compat.as_bytes(x) for x in proto_values]
542     except TypeError:
/usr/local/lib/python3.7/site-packages/tensorflow_core/python/util/compat.py in as_bytes(bytes_or_text, encoding)
70     raise TypeError('Expected binary or unicode string, got %r' %
---> 71                     (bytes_or_text,))
72 
TypeError: Expected binary or unicode string, got <tensorflow.python.keras.engine.sequential.Sequential object at 0x7fa65de7e198>
During handling of the above exception, another exception occurred:
TypeError                                 Traceback (most recent call last)
/usr/local/lib/python3.7/site-packages/tensorflow_core/python/framework/func_graph.py in convert(x)
875         try:
--> 876           x = ops.convert_to_tensor_or_composite(x)
877         except (ValueError, TypeError):
/usr/local/lib/python3.7/site-packages/tensorflow_core/python/framework/ops.py in convert_to_tensor_or_composite(value, dtype, name)
1419   return internal_convert_to_tensor_or_composite(
-> 1420       value=value, dtype=dtype, name=name, as_ref=False)
1421 
/usr/local/lib/python3.7/site-packages/tensorflow_core/python/framework/ops.py in internal_convert_to_tensor_or_composite(value, dtype, name, as_ref)
1458         as_ref=as_ref,
-> 1459         accept_composite_tensors=True)
1460 
/usr/local/lib/python3.7/site-packages/tensorflow_core/python/framework/ops.py in internal_convert_to_tensor(value, dtype, name, as_ref, preferred_dtype, ctx, accept_composite_tensors)
1295     if ret is None:
-> 1296       ret = conversion_func(value, dtype=dtype, name=name, as_ref=as_ref)
1297 
/usr/local/lib/python3.7/site-packages/tensorflow_core/python/framework/constant_op.py in _constant_tensor_conversion_function(v, dtype, name, as_ref)
285   _ = as_ref
--> 286   return constant(v, dtype=dtype, name=name)
287 
/usr/local/lib/python3.7/site-packages/tensorflow_core/python/framework/constant_op.py in constant(value, dtype, shape, name)
226   return _constant_impl(value, dtype, shape, name, verify_shape=False,
--> 227                         allow_broadcast=True)
228 
/usr/local/lib/python3.7/site-packages/tensorflow_core/python/framework/constant_op.py in _constant_impl(value, dtype, shape, name, verify_shape, allow_broadcast)
264           value, dtype=dtype, shape=shape, verify_shape=verify_shape,
--> 265           allow_broadcast=allow_broadcast))
266   dtype_value = attr_value_pb2.AttrValue(type=tensor_value.tensor.dtype)
/usr/local/lib/python3.7/site-packages/tensorflow_core/python/framework/tensor_util.py in make_tensor_proto(values, dtype, shape, verify_shape, allow_broadcast)
544                       "Contents: %s. Consider casting elements to a "
--> 545                       "supported type." % (type(values), values))
546     tensor_proto.string_val.extend(str_values)
TypeError: Failed to convert object of type <class 'tensorflow.python.keras.engine.sequential.Sequential'> to Tensor. Contents: <tensorflow.python.keras.engine.sequential.Sequential object at 0x7fa65de7e198>. Consider casting elements to a supported type.
During handling of the above exception, another exception occurred:
TypeError                                 Traceback (most recent call last)
<ipython-input-12-c3a2ba712dc3> in <module>
1 OUTPUT_CHANNELS = 3
2 
----> 3 generator = pix2pix_new.generator(OUTPUT_CHANNELS, norm_type='instancenorm')
4 
5 discriminator_seq = pix2pix_new.discriminator_seq(norm_type='instancenorm', target=False)
/gpfs-volume/GANs_Work/Scripts/pix2pix_new.py in generator(output_channels, norm_type)
186   """
187 
--> 188   convnets = [first_convnet(128, (1, 1, 1), (7, 7, 4), norm_type, apply_norm=False, relu='relu', apply_relu=False), # (bs, 128, 128, 64)
189                 convnet(128, 2, (3, 3, 2), norm_type),  # (bs, 64, 64, 128)
190                 convnet(256, 2, (3, 3, 1), norm_type),  # (bs, 32, 32, 256)
/usr/local/lib/python3.7/site-packages/tensorflow_core/python/eager/def_function.py in __call__(self, *args, **kwds)
455 
456     tracing_count = self._get_tracing_count()
--> 457     result = self._call(*args, **kwds)
458     if tracing_count == self._get_tracing_count():
459       self._call_counter.called_without_tracing()
/usr/local/lib/python3.7/site-packages/tensorflow_core/python/eager/def_function.py in _call(self, *args, **kwds)
501       # This is the first call of __call__, so we have to initialize.
502       initializer_map = object_identity.ObjectIdentityDictionary()
--> 503       self._initialize(args, kwds, add_initializers_to=initializer_map)
504     finally:
505       # At this point we know that the initialization is complete (or less
/usr/local/lib/python3.7/site-packages/tensorflow_core/python/eager/def_function.py in _initialize(self, args, kwds, add_initializers_to)
406     self._concrete_stateful_fn = (
407         self._stateful_fn._get_concrete_function_internal_garbage_collected(  # pylint: disable=protected-access
--> 408             *args, **kwds))
409 
410     def invalid_creator_scope(*unused_args, **unused_kwds):
/usr/local/lib/python3.7/site-packages/tensorflow_core/python/eager/function.py in _get_concrete_function_internal_garbage_collected(self, *args, **kwargs)
1846     if self.input_signature:
1847       args, kwargs = None, None
-> 1848     graph_function, _, _ = self._maybe_define_function(args, kwargs)
1849     return graph_function
1850 
/usr/local/lib/python3.7/site-packages/tensorflow_core/python/eager/function.py in _maybe_define_function(self, args, kwargs)
2148         graph_function = self._function_cache.primary.get(cache_key, None)
2149         if graph_function is None:
-> 2150           graph_function = self._create_graph_function(args, kwargs)
2151           self._function_cache.primary[cache_key] = graph_function
2152         return graph_function, args, kwargs
/usr/local/lib/python3.7/site-packages/tensorflow_core/python/eager/function.py in _create_graph_function(self, args, kwargs, override_flat_arg_shapes)
2039             arg_names=arg_names,
2040             override_flat_arg_shapes=override_flat_arg_shapes,
-> 2041             capture_by_value=self._capture_by_value),
2042         self._function_attributes,
2043         # Tell the ConcreteFunction to clean up its graph once it goes out of
/usr/local/lib/python3.7/site-packages/tensorflow_core/python/framework/func_graph.py in func_graph_from_py_func(name, python_func, args, kwargs, signature, func_graph, autograph, autograph_options, add_control_dependencies, arg_names, op_return_value, collections, capture_by_value, override_flat_arg_shapes)
918       # TensorArrays and `None`s.
919       func_outputs = nest.map_structure(convert, func_outputs,
--> 920                                         expand_composites=True)
921 
922       check_mutation(func_args_before, func_args)
/usr/local/lib/python3.7/site-packages/tensorflow_core/python/util/nest.py in map_structure(func, *structure, **kwargs)
533 
534   return pack_sequence_as(
--> 535       structure[0], [func(*x) for x in entries],
536       expand_composites=expand_composites)
537 
/usr/local/lib/python3.7/site-packages/tensorflow_core/python/util/nest.py in <listcomp>(.0)
533 
534   return pack_sequence_as(
--> 535       structure[0], [func(*x) for x in entries],
536       expand_composites=expand_composites)
537 
/usr/local/lib/python3.7/site-packages/tensorflow_core/python/framework/func_graph.py in convert(x)
880               "must return zero or more Tensors; in compilation of %s, found "
881               "return value of type %s, which is not a Tensor." %
--> 882               (str(python_func), type(x)))
883       if add_control_dependencies:
884         x = a.mark_as_return(x)
TypeError: To be compatible with tf.contrib.eager.defun, Python functions must return zero or more Tensors; in compilation of <function first_convnet at 0x7fa6ade9bd90>, found return value of type <class 'tensorflow.python.keras.engine.sequential.Sequential'>, which is not a Tensor.

错误:

TypeError:为了兼容tf.contrib.eager.defun,Python函数必须返回零个或多个张量;在编译时,找到返回类型为<class 'tensorflow.python.keras.engine.sequential.Sequential'>的值,这不是Tensor

当我尝试执行以下内容时,我收到类似的错误:

return tf.keras.Model(inputs=inputs, outputs=x)

这有什么解决方法?由于我是tf 2.0的新手,我想用它来加快训练过程。

通常,在外部构建模型,然后将其作为参数传递给 tf.function:

@tf.function
def use_model(model, ...):
...
outputs = model(...)
...
# Create the model
model = convnet(...)
# It's a good idea to initialize it too
model(<dummy input>)  # or model.build(...)
use_model(model, ...)

最新更新