If,else语句在tf.function中同时返回两者



我想制作一个函数,该函数可以使用Python中的Tensorflow处理浮点和向量作为输入。我定义了以下功能:

def g(t):
if tf.rank(t) == 0:
print('Rank=0')
return tf.math.reduce_sum(tf.math.exp(t))
else:
print('Rank=higher')
return tf.math.reduce_sum(tf.math.exp(t),1)

然而,我想调用另一个tf.function中的函数。作为测试,我做了以下函数:

@tf.function
def Test(t):
return g(t)

调用g(0.5(给出

Rank=0
Out[218]: <tf.Tensor: shape=(), dtype=float32, numpy=2.7182817>

调用测试(0.5(给出:

rank=0
rank=higher
Traceback (most recent call last):
Input In [219] in <cell line: 1>
Test(0.5)
File ~Anaconda3libsite-packagestensorflowpythonutiltraceback_utils.py:153 in error_handler
raise e.with_traceback(filtered_tb) from None
File ~AppDataLocalTemp__autograph_generated_filegb02ol08.py:12 in tf__Test
retval_ = ag__.converted_call(ag__.ld(gn), (ag__.ld(t),), None, fscope)
File ~AppDataLocalTemp__autograph_generated_filegnzfdu42.py:37 in tf__gn
ag__.if_stmt(ag__.converted_call(ag__.ld(int), (ag__.converted_call(ag__.ld(tf).rank, (ag__.ld(t),), None, fscope),), None, fscope) == 0, if_body, else_body, get_state, set_state, ('do_return', 'retval_'), 2)
File ~AppDataLocalTemp__autograph_generated_filegnzfdu42.py:33 in else_body
retval_ = ag__.ld(V0) + ag__.ld(labda) * ag__.ld(theta) * ag__.converted_call(ag__.ld(tf).math.reduce_sum, (ag__.ld(c) / ag__.ld(gamma) * (1 - ag__.converted_call(ag__.ld(tf).math.exp, (-ag__.ld(gamma) * ag__.ld(t),), None, fscope)), 1), None, fscope)
ValueError: in user code:
File "C:UsersjgrouAppDataLocalTempipykernel_118723135092574.py", line 11, in Test  *
return gn(t)
File "C:UsersjgrouAppDataLocalTempipykernel_118723135092574.py", line 7, in gn  *
return V0 + labda * theta * tf.math.reduce_sum(c / gamma * (1 - tf.math.exp(-gamma * t)),1)
ValueError: Invalid reduction dimension 1 for input with 1 dimensions. for '{{node cond/Sum}} = Sum[T=DT_FLOAT, Tidx=DT_INT32, keep_dims=false](cond/mul_1, cond/Sum/reduction_indices)' with input shapes: [1], [] and with computed input tensors: input[1] = <1>.

为什么if-else语句的两个参数都在tf.function中被调用?如何使函数g在tf.function中工作?

看起来有人在最近的Github问题中提出了这种行为。在结束问题之前,强调Tensorflow开发人员之一的回应:

The cause of this problem is due to the behavior of condition tracing in TensorFlow: the same input is applied to both true and false sides for graph tracing, when the condition is based on a non-static value (i.e. tf.rank(v) == 2).

有两种可行的解决方案。

使用常数值

如果使用tf.get_static_value(此处详细说明(返回tf.rank返回的0-D张量的常数值,它将阻止条件跟踪,因为它会评估张量(根据形状和类型将其转换为int、float、numpy数组等(。

def g(t):
if tf.get_static_value(tf.rank(t)) == 0:
print('Rank=0')
return tf.math.reduce_sum(tf.math.exp(t))
else:
print('Rank=higher')
return tf.math.reduce_sum(tf.math.exp(t), 1)

这将返回预期结果:

Rank=0
tf.Tensor(1.6487212, shape=(), dtype=float32)
Rank=0
tf.Tensor(1.6487212, shape=(), dtype=float32)

直接形状评估

与其使用tf.rank,不如直接评估形状,这也需要将任何非张量输入转换为张量:

def g(t):
if not isinstance(t, tf.Tensor):
t = tf.convert_to_tensor(t)
if t.shape.ndims == 0:
print('Rank=0')
return tf.math.reduce_sum(tf.math.exp(t))
else:
print('Rank=higher')
return tf.math.reduce_sum(tf.math.exp(t), 1)

这种实现也产生了预期的结果:

Rank=0
tf.Tensor(1.6487212, shape=(), dtype=float32)
Rank=0
tf.Tensor(1.6487212, shape=(), dtype=float32)

最新更新