我有一个TensorFlow模型f(x)
,我有时需要它的梯度,有时不需要,这取决于向前传递的结果。为了节省计算时间,我只在需要的时候计算梯度。如果我使用stop_gradient()停止梯度计算,或者不将它们记录在GradientTape上,那么如果不再次计算向前传递,我似乎永远无法获得梯度。我尝试做的一个简化示例如下(在伪代码中):
x = 5
y = f(x)
if y > 0:
compute_gradients(f, x)
是否有可能在TensorFlow中实现这一点,如果是这样,我该怎么做?
是的,您可以使用一个简单的条件跳过渐变更新。
import tensorflow as tf
from tensorflow.python.platform import test as test_lib
# network
x_in = tf.keras.Input([10])
x_out = tf.keras.layers.Dense(1)(x_in)
# optimizer
opt = tf.keras.optimizers.Adam(1e-1)
# forward pass
def train_step(model, X, y, threshold):
with tf.GradientTape() as tape:
y_hat = model(X)
# threshold = tf.math.reduce_mean(y_hat)
loss = tf.math.reduce_mean(tf.keras.losses.MSE(y, y_hat))
if tf.math.greater(threshold, 1.0):
m_vars = model.trainable_variables
m_grads = tape.gradient(loss, m_vars)
opt.apply_gradients(zip(m_grads, m_vars))
return loss
# test cases
class SporaticGradientUpdateTest(test_lib.TestCase):
def setUp(self):
self.model = tf.keras.Model(x_in, x_out)
self.X = tf.random.normal([100, 10])
self.y = tf.random.normal([100])
self.w_before = self.model.get_weights()
def test_weights_dont_change(self):
_ = train_step(self.model, self.X, self.y, 0.99)
# get weights that shouldn't have updated
w_after = self.model.get_weights()
self.assertAllClose(self.w_before, w_after)
def test_weights_change(self):
_ = train_step(self.model, self.X, self.y, 1.01)
# get weights that should updated
w_after = self.model.get_weights()
self.assertNotAllClose(self.w_before, w_after)
if __name__ == "__main__":
test_lib.main()
# [ RUN ] SporaticGradientUpdate.test_weights_change
# [ OK ] SporaticGradientUpdate.test_weights_change
# [ RUN ] SporaticGradientUpdate.test_weights_dont_change
# [ OK ] SporaticGradientUpdate.test_weights_dont_change
根据你的评论,看起来你的用例与这个例子有点不同,但应该适应你想做的任何事情。
在这个例子中,我传入了threshold
作为一个参数,所以我可以测试这两种情况,但通常你会通过对网络的输出做一些事情来创建它(比如注释掉的部分)。