我想在tf.function中的张量上使用for循环,如下所示:
@tf.function
def test(x):
for i in range(tf.shape(x)[0]):
print(i)
我定义:
S = tf.random.uniform([2,2],0,1)
然后
test(S)
给出
Tensor("while/Placeholder:0", shape=(), dtype=int32)
而
for i in range(tf.shape(S)[0]):
print(i)
返回
0
1
为什么我不能在tf.函数中的张量长度上循环?
使用tf.print
:
import tensorflow as tf
@tf.function
def test(x):
for i in range(tf.shape(x)[0]):
tf.print(i)
S = tf.random.uniform([2,2],0,1)
test(S)
# 0
# 1
请在此处检查在tf.function
中使用python操作的副作用。