我有一个子类模型,有一些自定义属性,像这样:
class MyModel(tf.keras.Model):
def __init__(self, *args, my_var, **kwargs):
super().__init__(*args, **kwargs)
self.my_var = my_var
def my_func(self):
pass
def get_config(self):
config = super().get_config()
config.update(
{
"my_var": self.my_var
}
)
return config
现在我定义模型并克隆它与clone_model
x_in = layers.Input(shape=(100, 100, 3))
x_out = layers.Conv2D(filters=16, kernel_size=3, activation="relu")(x_in)
model = MyModel(inputs=x_in, outputs=x_out, my_var="my_var")
cloned = tf.keras.models.clone_model(model)
print(cloned.my_var)
模型被复制了,但是没有my_var
是否有任何方法来复制这种类型的模型正确与所有属性(my_var和my_func)?
你需要添加
cloned = model.__class__.from_config(model.get_config())
如doc https://www.tensorflow.org/api_docs/python/tf/keras/models/clone_model#example
所示