tensorflow 2.x代码中的条件报告错误



迁移到tensorflow 2.x.Win 10,tf版本为2.3.1。基本上,

import tensorflow as tf
def do_nothing(x, y):
m, n = x.shape
if m==n:
return x, y
else:
raise Exception('should never arrive here')
xys = [[tf.eye(2), tf.eye(3)], 
[tf.eye(4), tf.eye(5)],]
@tf.function
def foo():
return [do_nothing(x, y) for (x, y) in xys]

ans = foo()

工作。然后我只是将条件m==m更改为tf.equal(m,n)作为

import tensorflow as tf
def do_nothing(x, y):
m, n = x.shape
if tf.equal(m, n):
return x, y
else:
raise Exception('should never arrive here')
xys = [[tf.eye(2), tf.eye(3)], 
[tf.eye(4), tf.eye(5)],]
@tf.function
def foo():
return [do_nothing(x, y) for (x, y) in xys]

ans = foo()

编码器不再工作。真的很困惑。虫子还是什么?

我尝试了更多的实验来用更少的代码重现这个问题。看起来,如果使用tf.equaltf.greater之类的东西,那么ifelse子句必须返回相同类型和大小的张量。请参阅下面的代码。

import tensorflow as tf
#this piece works
@tf.function  
def foo1(x):
if tf.greater(len(x), 0):
return True
else:
return False
print(foo1(tf.zeros([1])))
print(foo1(tf.zeros([0])))
#this piece works too
@tf.function
def foo2(x):
if len(x)>0: 
return True
else:
raise Exception()
print(foo2(tf.zeros([1])))
#this piece no long works
@tf.function
def foo3(x):
if tf.greater(len(x), 0):
return True
else:
raise Exception()
print(foo3(tf.zeros([1])))

我认为原因是tf返回布尔类型的张量,而不是简单的布尔。http://tensorflow.biotecan.com/python/Python_1.8/tensorflow.google.cn/api_docs/python/tf/equal.html

参考我在谷歌实验室做的测试:

https://colab.research.google.com/drive/1sR99ScE-IDsWz0rNCH6VsWVclw1wz5oE#scrollTo=FsMqxbpnJ-Xg&线=1&uniqifier=1

import tensorflow as tf
def do_nothing(x, y):
m, n = x.shape
print(x)
print(y)
print(m,n)
print(m==m)
print(n==n)
print(m==n)
print(tf.equal(m,m))
print(tf.equal(n,n))
print(tf.equal(m,n))
if tf.equal(m, n):
return x, y
else:
raise Exception('should never arrive here')
xys = [[tf.eye(2), tf.eye(3)], 
[tf.eye(4), tf.eye(5)],]
@tf.function
def foo():
return [do_nothing(x, y) for (x, y) in xys]

ans = foo()

tf.Tensor(
[[1. 0.]
[0. 1.]], shape=(2, 2), dtype=float32)
tf.Tensor(
[[1. 0. 0.]
[0. 1. 0.]
[0. 0. 1.]], shape=(3, 3), dtype=float32)
2 2
True
True
True
Tensor("Equal:0", shape=(), dtype=bool)
Tensor("Equal_1:0", shape=(), dtype=bool)
Tensor("Equal_2:0", shape=(), dtype=bool)
---------------------------------------------------------------------------
Exception                                 Traceback (most recent call last)
<ipython-input-12-24121e0806b4> in <module>()
24     return [do_nothing(x, y) for (x, y) in xys]
25 
---> 26 ans = foo()
8 frames
/usr/local/lib/python3.6/dist-packages/tensorflow/python/framework/func_graph.py in wrapper(*args, **kwargs)
975           except Exception as e:  # pylint:disable=broad-except
976             if hasattr(e, "ag_error_metadata"):
--> 977               raise e.ag_error_metadata.to_exception(e)
978             else:
979               raise

相关内容

最新更新