我正在尝试执行一些与条件相关的函数,例如,每个函数都需要根据张量的形状来不同地收缩张量。然而,我意识到tf.cond
和tf.case
正在执行所有功能,而与条件无关。准备了以下代码作为示例;
def a():
print("a")
return tf.constant(2)
def b():
print("b")
return tf.constant(3)
def c():
print("c")
return tf.constant(4)
def d():
print("default")
return tf.constant(1)
x = tf.constant(1)
@tf.function
def f():
return tf.case([
(tf.equal(x,1), a),
(tf.equal(x,2), b),
(tf.equal(x,2), c)
], default=d, exclusive=True)
@tf.function
def f1():
def cond3():
return tf.cond(tf.equal(x,2), c, d)
def cond2():
return tf.cond(tf.equal(x,2), b, cond3)
return tf.cond(tf.equal(x,1), a, cond2)
print(f())
print(f1())
# Output:
# a
# b
# c
# default
# tf.Tensor(2, shape=(), dtype=int32)
# a
# b
# c
# default
# tf.Tensor(2, shape=(), dtype=int32)
正如您所看到的,对于这两种情况,结果都是预期的,但每个函数都是在得出结论的同时执行的。因此,在我的特殊情况下,由于我根据张量的形状进行不同的计算,所以我会得到很多错误。我见过很多这样的错误报告,但还没有找到解决方案。有没有其他方法可以执行条件执行,但我不知道根据条件可以在哪里执行不同的函数?注意,我尝试过简单地使用if tf.equal(x,2): ...
,但在这种情况下,我得到了一个错误,说张量输出不能用作python布尔值。注意,这个例子是我问题的简化版本,我的条件是基于张量形状,如tf.equal(tf.size(tensor), N)
,所以我真的需要一种方法来针对不同的情况执行不同的事情。
在@LaplaceRicky的回答之后,我意识到我提供的代码不够有代表性,所以我提供了一个更好的例子来展示我需要做什么;
x = tf.ones((3,2,1))
y = tf.ones((1,2,3))
z = tf.ones((4,3,5))
k = tf.ones((3,5,5))
def a(t):
def exe():
return tf.einsum("ijk,lmi", t, y)
return exe
def b(t):
def exe():
return tf.einsum("ijk,ljm", t, z)
return exe
def d(t):
def exe():
return tf.einsum("ijk,klm", t, z)
return exe
c = tf.constant(1)
@tf.function
def f(t):
y = tf.case([
(tf.equal(tf.shape(t)[0], 3), a(t)),
(tf.equal(tf.shape(t)[1], 3), b(t)),
], default=d, exclusive=True)
return y
print(f(x))
此函数将在没有tf.function
装饰器导致的情况下正常执行
tf.Tensor(
[[[[3. 3.]]]
[[[3. 3.]]]], shape=(2, 1, 1, 2), dtype=float32
然而,当包含decorator时,我得到了一个ValueError
,它显示所有的案例都被执行了。
系统信息
- TensorFlow版本:2.4.1
- Python版本:3.8.2
简短回答:使用tf.print
而不是print
来检查特定分支是否真的在tensorflow图模式中执行。
说明:print
不工作,不会在图形模式下打印,但会在跟踪过程中打印。打印的消息实际上意味着所有分支都已添加到tensorflow图中,但这并不意味着所有的分支都将始终以图形模式执行。应使用tf.print
进行调试。
有关详细信息:https://www.tensorflow.org/guide/function#conditionals
演示:
def a():
tf.print('a')
return tf.constant(10)
def b():
tf.print('b')
return tf.constant(11)
def c():
tf.print('c')
return tf.constant(12)
@tf.function
def cond_fn(x):
return tf.switch_case(x, {0:a,1:b}, default=c)
print(cond_fn(tf.constant(0)))
print(cond_fn(tf.constant(1)))
print(cond_fn(tf.constant(2)))
预期输出:
a
tf.Tensor(10, shape=(), dtype=int32)
b
tf.Tensor(11, shape=(), dtype=int32)
c
tf.Tensor(12, shape=(), dtype=int32)
ValueError
错误消息是因为tensorflow图不太支持这种功能,至少tf.einsum
不支持。一种解决方法是使用tf.function(f).get_concrete_function(tf.TensorSpec(shape=[None,None,None]))
生成一个支持可变形状输入的图。
此外,tf.einsum
在该过程中存在问题,必须用tf.transpose
和tf.tensordot
代替。
示例代码:
x = tf.random.normal((3,2,1))
y = tf.random.normal((1,2,3))
z = tf.random.normal((4,3,5))
k = tf.random.normal((3,5,5))
#for checking the values
def f2(t):
p = tf.case([
(tf.equal(tf.shape(t)[0], 3), lambda:tf.einsum("ijk,lmi", t, y)),
(tf.equal(tf.shape(t)[1], 3), lambda:tf.einsum("ijk,ljm", t, z)),
], default=lambda:tf.einsum("ijk,klm", t, k), exclusive=True)
return p
#work around
def f(t):
if tf.shape(t)[0] == 3:
tf.print('branch a executed')
return tf.tensordot(tf.transpose(t,[1,2,0]), tf.transpose(y,[2,0,1]),1)
elif tf.shape(t)[1] == 3:
tf.print('branch b executed')
return tf.tensordot(tf.transpose(t,[0,2,1]), tf.transpose(z,[1,0,2]),1)
else:
tf.print('branch c executed')
return tf.tensordot(t, k,1)
graph_f=tf.function(f).get_concrete_function(tf.TensorSpec(shape=[None,None,None]))
print(np.allclose(graph_f(x),f2(x)))
print(np.allclose(graph_f(y),f2(y)))
print(np.allclose(graph_f(z),f2(z)))
预期输出:
branch a executed
True
branch c executed
True
branch b executed
True