我需要计算一个元素梯度作为神经网络的输入。我决定使用tf.data.Dataset
来存储输入数据。预处理数据和计算梯度是昂贵的,我想批量处理,然后存储数据。
我简化了函数来处理形状(batch_size, x, y)
,我想分别计算每个y
的梯度。
使用tf.GradientTape
,如下所示:
import tensoflow as tf
# @tf.function
def test(inp):
with tf.GradientTape(persistent=True) as tape:
tape.watch(inp)
out = inp**2
out = tf.unstack(out, axis=-1)
grad = []
for x in out:
grad.append(tape.gradient(x, inp))
del tape
return tf.stack(grad, axis=-1)
inp = tf.random.normal((32, 100, 50))
test(inp)
这段代码使用~76 ms
来执行,而使用tf.function
装饰符执行3.1 s
。不幸的是,当它与tf.data.Dataset.map
一起使用时,也会发生同样的减速,我认为它会将其转换为tf.function
我尝试使用tf.batch_jacobian
代替,它不会受到tf.function
的影响,但计算方式更多的梯度,我必须减少它们。执行大约需要15秒。
@tf.function
def test(inp):
with tf.GradientTape() as tape:
tape.watch(inp)
out = inp**2
grad = tape.batch_jacobian(out, inp)
return tf.math.reduce_sum(grad, axis=3)
x = test(inp)
对于更大的数据集和更多的资源繁重的计算,我试图避免这样的减速,但我还没有找到一个解决方案,我也不明白,为什么它计算这么慢。有没有一种方法可以重塑数据并使用雅可比法或其他方法,来克服这个问题?
让我们用IPython的%timeit
做一个快速的实验。我定义了两个函数,一个带有tf.function
修饰符,一个没有:
import tensorflow as tf
def test_no_tracing(inp):
with tf.GradientTape(persistent=True) as tape:
tape.watch(inp)
out = inp**2
out = tf.unstack(out, axis=-1)
grad = []
for x in out:
grad.append(tape.gradient(x, inp))
del tape
return tf.stack(grad, axis=-1)
@tf.function
def test_tracing(inp):
print("Tracing")
with tf.GradientTape(persistent=True) as tape:
tape.watch(inp)
out = inp**2
out = tf.unstack(out, axis=-1)
grad = []
for x in out:
grad.append(tape.gradient(x, inp))
del tape
return tf.stack(grad, axis=-1)
inp = tf.random.normal((32, 100, 50))
让我们看看结果:
与tf.function
装饰符:
In [2]: %timeit test_tracing(inp)
Tracing
2021-01-22 15:22:15.003262: I tensorflow/compiler/mlir/mlir_graph_optimization_pass.cc:116] None of the MLIR optimization passes are enabled (registered 2)
2021-01-22 15:22:15.076448: I tensorflow/core/platform/profile_utils/cpu_utils.cc:112] CPU Frequency: 2599990000 Hz
10.3 ms ± 579 µs per loop (mean ± std. dev. of 7 runs, 1 loop each)
In [3]: %timeit test_no_tracing(inp)
71.7 ms ± 1.1 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)
用tf.function
修饰的函数大约快了7倍。如果只运行一次函数,可能会显得慢一些,因为修饰函数有跟踪的开销,需要将代码转换成图形。一旦跟踪完成,代码就会快得多。
这可以通过只运行一次函数来验证,当它还没有被跟踪时。我们可以通过告诉%timeit
只做一次循环和一次重复来做到这一点:
In [2]: %timeit -r 1 -n 1 test_tracing(inp)
Tracing
2021-01-22 15:29:47.189850: I tensorflow/compiler/mlir/mlir_graph_optimization_pass.cc:116] None of the MLIR optimization passes are enabled (registered 2)
2021-01-22 15:29:47.284413: I tensorflow/core/platform/profile_utils/cpu_utils.cc:112] CPU Frequency: 2599990000 Hz
4.97 s ± 0 ns per loop (mean ± std. dev. of 1 run, 1 loop each)
这里,时间更大,更接近你在问题中报告的时间。但是一旦这样做了,跟踪函数就快多了!让我们再来一次:
In [3]: %timeit -r 1 -n 1 test_tracing(inp)
29.1 ms ± 0 ns per loop (mean ± std. dev. of 1 run, 1 loop each)
您可以在指南中阅读有关如何使用tf.function
获得更好性能的更多信息:使用tf.function提高性能