如何继承一个典型的Tensorflow模型类来重用大多数的图构建代码



典型的tensorflow模型类如下所示:

class Model:
def __init__(self):
build()
def build(self):
self.x = tf.placeholder()
self.y = f(self.x)
self.z = g(self.y)

如果我们需要稍微修改(即,将self.y=f(self.x)更改为slef.y=h(self.x)(,我们真的希望继承这个Model类并添加一些代码来实现这一点。

然而,一旦调用了build函数,就会构建一个完整的图。重写属性不会更改图形结构。有没有办法把这项工作做得整整齐齐的?

您可以参数化fg(或您拥有的任何东西(,并将它们传递到构造函数中:

class Model:
def __init__(self, f=default_f, g=default_g):
self.f = f
self.g = g
self.build()
def build(self):
self.x = tf.placeholder()
self.y = self.f(self.x)
self.z = self.g(self.y)

或者,您可以使它们成为可重写的类级变量,以避免构造函数的签名变得臃肿,但您不能在构造函数中隐式调用.build()

class Model:
f = default_f
g = default_g
def build(self):
self.x = tf.placeholder()
self.y = self.f(self.x)
self.z = self.g(self.y)
# ...
m = Model()
m.f = some_other_f
m.build()

最新更新