如何在TF 2.0中计算这个梯度?



我提供了一个我想要解决的最小示例。我已经定义了一个类,其中有一些变量在不同的函数中定义。我想知道如何在函数中跟踪这些变量来得到梯度。我想我必须使用tf.GradientTape,但我已经尝试了一些变体,但没有成功。

class A():
def __init__(self):
self.alpha = tf.Variable(2.0)
def f1(self):
wt = self.alpha * 5.0
return wt
def f2(self):
wt_f1 = f1()
with tf.GradientTape() as tape:
wt_f2 = wt_f1 * 10.0
print(tape.gradient(wt_f2, self.alpha))
a = A()
print(a.f2())

最后一行返回None。显然wt_f2alpha的导数是50.0。然而,我得到了None。任何想法?我尝试在__init__函数中初始化一个持久的梯度磁带,并使用它来观察wtself.alpha等变量,但这没有帮助。任何想法?

更新1:

wt_f1呼叫置于tape下无效

class A():
def __init__(self):
self.alpha = tf.Variable(2.0)
def f1(self):
wt = self.alpha * 5.0
return wt
def f2(self):
with tf.GradientTape() as tape:
wt_f1 = f1()
wt_f2 = wt_f1 * 10.0
print(tape.gradient(wt_f2, self.alpha))

这也返回None

您正在打印无。因为f2()什么也不返回,所以你得到None。删除打印:

a = A()
a.f2()

此外,一些编辑可能对您编写的代码有好处。

  1. 您在f1()函数之前错过了self,这是因为您在其他地方定义了f1函数。无论如何添加self.f1().
  2. print语句移出tape作用域。因为最好在录制完成的地方得到梯度。
  3. 添加tape.watch()以确保被磁带跟踪。
class A():
def __init__(self):
self.alpha = tf.Variable(2.0)
def f1(self):
wt = self.alpha * 5.0
return wt
def f2(self):
with tf.GradientTape() as tape:
tape.watch(self.alpha)
wt_f1 = self.f1()
wt_f2 = wt_f1 * 10.0
print(tape.gradient(wt_f2, self.alpha))

相关内容

  • 没有找到相关文章

最新更新