深度复制tf.function修饰的函数



我使用的一些代码正在尝试copy.deepcopy一个@tf.function修饰的函数,但失败了(作为pickle的一部分(。有正确的方法吗?

简化复制

定义任何简单的@tf.function修饰函数:

>>> import tensorflow as tf
>>> @tf.function
... def foo(x): return x * 2
>>> foo
<tensorflow.python.eager.def_function.Function object at 0x7fc98bec0950>

在初始化之前深度复制效果良好:

>>> import copy
>>> copy.deepcopy(foo)
<tensorflow.python.eager.def_function.Function object at 0x7fc98bec8b50>

然而,它在初始化后失败:

>>> foo(tf.constant([3.]))
<tf.Tensor: shape=(1,), dtype=float32, numpy=array([6.], dtype=float32)>
>>> copy.deepcopy(foo)
Traceback (most recent call last):
File "<stdin>", line 1, in <module>
File "/home/user/.pyenv/versions/3.7.13/lib/python3.7/copy.py", line 180, in deepcopy
y = _reconstruct(x, memo, *rv)
....
File "/home/user/.pyenv/versions/3.7.13/lib/python3.7/copy.py", line 169, in deepcopy
rv = reductor(4)
TypeError: can't pickle _thread.RLock objects

评论

  1. 我不介意复制的版本是否未初始化,因此在第一次使用时必须重新编译,但我不想每次复制时都必须重新编译原始函数
  2. CCD_ 4类具有省略CCD_ 6属性的CCD_;不可拾取对象";。然而,虽然这可能允许浅复制,但它不处理深复制,因为类对象还包含以下类型的属性,这些属性本身在其中的某个位置包含锁:
_stateful_fn: <class 'tensorflow.python.eager.function.Function'>
_stateless_fn: <class 'tensorflow.python.eager.function.Function'>
_lifted_initializer_graph: <class 'tensorflow.python.framework.func_graph.FuncGraph'>
_graph_deleter: <class 'tensorflow.python.eager.def_function.FunctionDeleter'>
_concrete_stateful_fn: <class 'tensorflow.python.eager.function.ConcreteFunction'>

环境

在Ubuntu上的Python 3.7.13中测试了TF 2.8和TF 2.4。

发布后就想出了一个可能的解决方案。

将函数放入类中似乎可以复制它:

>>> class Foo:
...    @tf.function
...    def __call__(self, x): return x * 2
>>> foo = Foo()
>>> foo(tf.constant([3.]))
<tf.Tensor: shape=(1,), dtype=float32, numpy=array([6.], dtype=float32)>
>>> foo_copy = copy.deepcopy(foo)
>>> foo_copy(tf.constant([3.]))
<tf.Tensor: shape=(1,), dtype=float32, numpy=array([6.], dtype=float32)>

有趣的是,复制的函数似乎确实没有被初始化(但对我来说没关系(。

>>> foo.__call__._get_tracing_count()
1
>>> copy.deepcopy(foo).__call__._get_tracing_count()
0

最新更新