检查元素是否包含在张量中(带有"@tf.function"的Python Tensorflow2)



我要检查张量中是否包含元素,但我遇到了问题。

例如1个

def foo(a):
if 5 in tf.constant([5, 7, 9]):
tf.print(a)
foo(2)
# you'll get '2', and no erros

例如2

@tf.function
def foo(a):
if 5 in tf.constant([5, 7, 9]):
tf.print(a)
foo(2)
# you'll get erros like "TypeError: argument of type 'Tensor' is not iterable"

显然,添加@tf.function后情况有所不同。如果你能帮我解决这个问题,我将不胜感激!:(

使用@tf.function装饰函数时,它将以图形模式运行。在图形模式中,您不能对tf.Tensor进行迭代(这就是您在if语句中所做的(。

最新更新