"tf.case"和"tf.cond"执行TensorFlow中的所有函数



我正在尝试执行一些与条件相关的函数,例如,每个函数都需要根据张量的形状来不同地收缩张量。然而,我意识到tf.condtf.case正在执行所有功能,而与条件无关。准备了以下代码作为示例;

def a(): 
print("a")
return tf.constant(2)
def b(): 
print("b")
return tf.constant(3)
def c(): 
print("c")
return tf.constant(4)
def d(): 
print("default")
return tf.constant(1)
x = tf.constant(1)
@tf.function
def f():
return tf.case([
(tf.equal(x,1), a),
(tf.equal(x,2), b),
(tf.equal(x,2), c)
], default=d, exclusive=True)
@tf.function
def f1():
def cond3():
return tf.cond(tf.equal(x,2), c, d)
def cond2():
return tf.cond(tf.equal(x,2), b, cond3)

return tf.cond(tf.equal(x,1), a,  cond2)
print(f())
print(f1())
# Output:
# a
# b
# c
# default
# tf.Tensor(2, shape=(), dtype=int32)
# a
# b
# c
# default
# tf.Tensor(2, shape=(), dtype=int32)

正如您所看到的,对于这两种情况,结果都是预期的,但每个函数都是在得出结论的同时执行的。因此,在我的特殊情况下,由于我根据张量的形状进行不同的计算,所以我会得到很多错误。我见过很多这样的错误报告,但还没有找到解决方案。有没有其他方法可以执行条件执行,但我不知道根据条件可以在哪里执行不同的函数?注意,我尝试过简单地使用if tf.equal(x,2): ...,但在这种情况下,我得到了一个错误,说张量输出不能用作python布尔值。注意,这个例子是我问题的简化版本,我的条件是基于张量形状,如tf.equal(tf.size(tensor), N),所以我真的需要一种方法来针对不同的情况执行不同的事情。


在@LaplaceRicky的回答之后,我意识到我提供的代码不够有代表性,所以我提供了一个更好的例子来展示我需要做什么;

x = tf.ones((3,2,1))
y = tf.ones((1,2,3))
z = tf.ones((4,3,5))
k = tf.ones((3,5,5))
def a(t): 
def exe():
return tf.einsum("ijk,lmi", t, y)
return exe
def b(t): 
def exe():
return tf.einsum("ijk,ljm", t, z)
return exe
def d(t): 
def exe():
return tf.einsum("ijk,klm", t, z)
return exe
c = tf.constant(1)
@tf.function
def f(t):
y = tf.case([
(tf.equal(tf.shape(t)[0], 3), a(t)),
(tf.equal(tf.shape(t)[1], 3), b(t)),
], default=d, exclusive=True)
return y

print(f(x))

此函数将在没有tf.function装饰器导致的情况下正常执行

tf.Tensor(
[[[[3. 3.]]]
[[[3. 3.]]]], shape=(2, 1, 1, 2), dtype=float32

然而,当包含decorator时,我得到了一个ValueError,它显示所有的案例都被执行了。

系统信息

  • TensorFlow版本:2.4.1
  • Python版本:3.8.2

简短回答:使用tf.print而不是print来检查特定分支是否真的在tensorflow图模式中执行。

说明:print不工作,不会在图形模式下打印,但会在跟踪过程中打印。打印的消息实际上意味着所有分支都已添加到tensorflow图中,但这并不意味着所有的分支都将始终以图形模式执行。应使用tf.print进行调试。

有关详细信息:https://www.tensorflow.org/guide/function#conditionals

演示:

def a():
tf.print('a')
return tf.constant(10)
def b():
tf.print('b')
return tf.constant(11)
def c():
tf.print('c')
return tf.constant(12)

@tf.function
def cond_fn(x):
return tf.switch_case(x, {0:a,1:b}, default=c)
print(cond_fn(tf.constant(0)))
print(cond_fn(tf.constant(1)))
print(cond_fn(tf.constant(2)))

预期输出:

a
tf.Tensor(10, shape=(), dtype=int32)
b
tf.Tensor(11, shape=(), dtype=int32)
c
tf.Tensor(12, shape=(), dtype=int32)

ValueError错误消息是因为tensorflow图不太支持这种功能,至少tf.einsum不支持。一种解决方法是使用tf.function(f).get_concrete_function(tf.TensorSpec(shape=[None,None,None]))生成一个支持可变形状输入的图。

此外,tf.einsum在该过程中存在问题,必须用tf.transposetf.tensordot代替。

示例代码:

x = tf.random.normal((3,2,1))
y = tf.random.normal((1,2,3))
z = tf.random.normal((4,3,5))
k = tf.random.normal((3,5,5))
#for checking the values
def f2(t):
p = tf.case([
(tf.equal(tf.shape(t)[0], 3), lambda:tf.einsum("ijk,lmi", t, y)),
(tf.equal(tf.shape(t)[1], 3), lambda:tf.einsum("ijk,ljm", t, z)),
], default=lambda:tf.einsum("ijk,klm", t, k), exclusive=True)
return p
#work around
def f(t):
if tf.shape(t)[0] == 3:
tf.print('branch a executed')
return tf.tensordot(tf.transpose(t,[1,2,0]), tf.transpose(y,[2,0,1]),1)
elif tf.shape(t)[1] == 3:
tf.print('branch b executed')
return tf.tensordot(tf.transpose(t,[0,2,1]), tf.transpose(z,[1,0,2]),1)
else:
tf.print('branch c executed')
return tf.tensordot(t, k,1)
graph_f=tf.function(f).get_concrete_function(tf.TensorSpec(shape=[None,None,None]))
print(np.allclose(graph_f(x),f2(x)))
print(np.allclose(graph_f(y),f2(y)))
print(np.allclose(graph_f(z),f2(z)))

预期输出:

branch a executed
True
branch c executed
True
branch b executed
True

相关内容

  • 没有找到相关文章

最新更新