我提供了一个我想要解决的最小示例。我已经定义了一个类,其中有一些变量在不同的函数中定义。我想知道如何在函数中跟踪这些变量来得到梯度。我想我必须使用tf.GradientTape
,但我已经尝试了一些变体,但没有成功。
class A():
def __init__(self):
self.alpha = tf.Variable(2.0)
def f1(self):
wt = self.alpha * 5.0
return wt
def f2(self):
wt_f1 = f1()
with tf.GradientTape() as tape:
wt_f2 = wt_f1 * 10.0
print(tape.gradient(wt_f2, self.alpha))
a = A()
print(a.f2())
最后一行返回None
。显然wt_f2
对alpha
的导数是50.0。然而,我得到了None
。任何想法?我尝试在__init__
函数中初始化一个持久的梯度磁带,并使用它来观察wt
和self.alpha
等变量,但这没有帮助。任何想法?
更新1:
将wt_f1
呼叫置于tape
下无效
class A():
def __init__(self):
self.alpha = tf.Variable(2.0)
def f1(self):
wt = self.alpha * 5.0
return wt
def f2(self):
with tf.GradientTape() as tape:
wt_f1 = f1()
wt_f2 = wt_f1 * 10.0
print(tape.gradient(wt_f2, self.alpha))
这也返回None
。
您正在打印无。因为f2()
什么也不返回,所以你得到None
。删除打印:
a = A()
a.f2()
此外,一些编辑可能对您编写的代码有好处。
- 您在
f1()
函数之前错过了self
,这是因为您在其他地方定义了f1
函数。无论如何添加self.f1()
. - 将
print
语句移出tape
作用域。因为最好在录制完成的地方得到梯度。 - 添加
tape.watch()
以确保被磁带跟踪。
class A():
def __init__(self):
self.alpha = tf.Variable(2.0)
def f1(self):
wt = self.alpha * 5.0
return wt
def f2(self):
with tf.GradientTape() as tape:
tape.watch(self.alpha)
wt_f1 = self.f1()
wt_f2 = wt_f1 * 10.0
print(tape.gradient(wt_f2, self.alpha))