尝试理解AutoGraph和tf.function:tf.function中的打印损失


def train_one_step():
with tf.GradientTape() as tape:
a = tf.random.normal([1, 3, 1])
b = tf.random.normal([1, 3, 1])
loss = mse(a, b)
tf.print('inner tf print', loss)
print("inner py print", loss)
return loss

@tf.function
def train():
loss = train_one_step()
tf.print('outer tf print', loss)
print('outer py print', loss)
return loss
loss = train()
tf.print('outest tf print', loss)
print("outest py print", loss)

我试图更多地理解tf.function。我用不同的方法在四个地方打印了损失。它产生这样的结果

inner py print Tensor("mean_absolute_error/weighted_loss/value:0", shape=(), dtype=float32)
outer py print Tensor("mean_absolute_error/weighted_loss/value:0", shape=(), dtype=float32)
inner tf print 1.82858419
outer tf print 1.82858419
outest tf print 1.82858419
outest py print tf.Tensor(1.8285842, shape=(), dtype=float32)
  1. TF.print和Python print有什么区别?
  2. 看起来 Python 打印将在图形评估期间执行,但 TF 打印仅在执行时执行?
  3. 以上仅适用于有 tf.function 装饰器的情况? 除此之外,tf.print 在 Python Print 之前运行?

我在一篇由三部分组成的文章中涵盖并回答了您的所有问题:"分析 tf.function 以发现 AutoGraph 的优势和微妙之处":第 1 部分、第 2 部分、第 3 部分。

总结并回答您的3个问题:

  • TF.print和Python print有什么区别?

tf.print是一个 Tensorflow 构造,默认情况下在标准错误上打印,更重要的是,它在评估时会产生一个操作。

当一个操作运行时,在急切执行中,它或多或少地以与Tensorflow 1.x相同的方式生成一个"节点"。

tf.function能够捕获生成的tf.print操作并将其转换为图形节点。

相反,print是一个 Python 构造,默认情况下在标准输出上打印,并且在执行时生成任何操作。因此,tf.function无法将其转换为图形等效项,并且仅在函数跟踪期间执行它。

  • 看起来 Python 打印将在图形评估期间执行,但 TF 打印仅在执行时执行?

我已经在上一点中回答了这个问题,但再一次,print仅在函数跟踪期间执行,而tf.print在跟踪期间和执行其图形表示时执行(在tf.function成功将函数转换为图形之后)。

  • 以上仅适用于有 tf.function 装饰器的情况? 除此之外,tf.print 在 Python Print 之前运行?

是的。tf.printprint之前或之后不会运行。在急切执行中,一旦 Python 解释器找到语句,就会对它们进行评估。预先执行的唯一区别是输出流。

无论如何,我建议您阅读链接的三篇文章,因为它们详细介绍了tf.function的这个和其他特点。

>print是正常的python打印。tf.print是张量流图的一部分。 在渴望模式下,张量流将直接执行图。这就是为什么在你的@tf.function函数之外,python print的输出是一个数字(tensorflow直接执行图形并将数字给出给正常的print函数),这也是tf.print立即打印的原因。

另一方面,在@tf.function函数内部,Tensorflow不会立即执行图。相反,它会将您调用的 tensorflow 函数"堆叠"到一个更大的图中,我们将在@tf.function结束时一次性执行该图。

这就是为什么python打印没有给你@tf.function函数中的数字(图在这一点上还没有执行)。但是在函数结束后,图形将与图形中的tf.print一起执行。因此,tf.print在python打印后打印,并为您提供实际的损失数字。

最新更新