tf.function 中的 for 循环返回 Tensor( "while/Placeholder:0" , shape=(), dtype=int32)



我想在tf.function中的张量上使用for循环,如下所示:

@tf.function
def test(x):
for i in range(tf.shape(x)[0]):
print(i)

我定义:

S = tf.random.uniform([2,2],0,1)

然后

test(S)

给出

Tensor("while/Placeholder:0", shape=(), dtype=int32)

for i in range(tf.shape(S)[0]):
print(i)

返回

0
1

为什么我不能在tf.函数中的张量长度上循环?

使用tf.print:

import tensorflow as tf
@tf.function
def test(x):
for i in range(tf.shape(x)[0]):
tf.print(i)
S = tf.random.uniform([2,2],0,1)
test(S)
# 0
# 1

请在此处检查在tf.function中使用python操作的副作用。

相关内容

最新更新