为了用一个最小的例子来说明我的问题,假设我想创建一个具有
精神的类class test_a:
def __init__(self, X):
self.X = X
def predict(self, a):
return a * self.X
重要的是,如果我将新的X
分配给test_a
的实例,predict()
函数应该改变。在这个例子中,它工作得很好:
X = tf.ones((1, 1))
a = test_a(X)
y = tf.ones((1, 1))
a.predict(y) # output [[1.]]
# now I want to change the value of a.X
Xnew = 2 * tf.ones((1, 1))
a.X = Xnew
a.predict(y) # output [[2.]], as desired.
现在假设我想使用@tf.function
装饰器来加速predict()
。
class test_b:
def __init__(self, X):
self.X = X
@tf.function
def predict(self, a):
return a * self.X
现在出现了以下不希望出现的行为:
X = tf.ones((1, 1))
b = test_b(X)
y = tf.ones((1, 1))
b.predict(y) # output [[1.]]
# now I want to change the value of b.X
Xnew = 2 * tf.ones((1, 1))
b.X = Xnew
b.predict(y) # output is still [[1.]], but I would like it to be [[2.]]
到目前为止,我唯一的想法是有一个方法_predict(X, a)
,然后我可以装饰,然后在(未装饰)方法predict(self, a)
中调用_predict(self.X, a)
。如果能帮助我们做得更好,我将不胜感激。
为了避免python的副作用- self.X
应该是tf.Variable(trainable=False)
。你必须使用tf.Variable.assign()
来改变它。