典型的tensorflow模型类如下所示:
class Model:
def __init__(self):
build()
def build(self):
self.x = tf.placeholder()
self.y = f(self.x)
self.z = g(self.y)
如果我们需要稍微修改(即,将self.y=f(self.x)
更改为slef.y=h(self.x)
(,我们真的希望继承这个Model
类并添加一些代码来实现这一点。
然而,一旦调用了build
函数,就会构建一个完整的图。重写属性不会更改图形结构。有没有办法把这项工作做得整整齐齐的?
您可以参数化f
和g
(或您拥有的任何东西(,并将它们传递到构造函数中:
class Model:
def __init__(self, f=default_f, g=default_g):
self.f = f
self.g = g
self.build()
def build(self):
self.x = tf.placeholder()
self.y = self.f(self.x)
self.z = self.g(self.y)
或者,您可以使它们成为可重写的类级变量,以避免构造函数的签名变得臃肿,但您不能在构造函数中隐式调用.build()
:
class Model:
f = default_f
g = default_g
def build(self):
self.x = tf.placeholder()
self.y = self.f(self.x)
self.z = self.g(self.y)
# ...
m = Model()
m.f = some_other_f
m.build()