Keras使用自定义文件名在训练回调中保存protobuffer和weights



我正试图为我的模型拟合过程编写一个回调,如果Epoch结束后,模型得到了改进,则为我节省了权重模型作为原型缓冲区。优选地类似于./tmp/weights.hdf5./tmp/model.pb。为此,我使用了两个回调(下面的mnist示例,我使用的是TF2.6.0(:

import tensorflow as tf
import numpy as np
mnist = tf.keras.datasets.mnist
(x_train, y_train),(x_test, y_test) = mnist.load_data()
x_train, x_test = x_train / 255.0, x_test / 255.0
def create_model():
return tf.keras.models.Sequential([
tf.keras.layers.Flatten(input_shape=(28, 28)),
tf.keras.layers.Dense(512, activation='relu'),
tf.keras.layers.Dropout(0.2),
tf.keras.layers.Dense(10, activation='softmax')
])
model = create_model()
model.compile(optimizer='adam',
loss='sparse_categorical_crossentropy',
metrics=['accuracy'],
steps_per_execution=10)
callback_weights = tf.keras.callbacks.ModelCheckpoint(
filepath="./tmp/weights.hdf5", save_weights_only=True, save_best_only=True)
callback_model = tf.keras.callbacks.ModelCheckpoint(
filepath="./tmp/", save_weights_only=False, save_best_only=True)
model.fit(x=x_train,
y=y_train,
epochs=2,
validation_data=(x_test, y_test),
callbacks=[callback_weights, callback_model])

ls tmp给出:

assets keras_metadata.pb saved_model.pb variables weights.hdf5

我的问题:

  • 我必须使用两个回调吗?还是可以在一个回调中完成
  • 如何控制protobuffer文件的名称?当save_weights_only=True工作时指定文件名,但当我使用save_weights_only=False时,它会根据filepath参数创建一个目录

要回答您的上述具体问题,

(1( 。我们不必为此使用两个回调(节省型号或重量(。只需一次回调即可完成。(2( 。当我们设置save_weights_only=False时,这意味着程序将保存整个模型及其当前状态(或权重(。为了更清楚地了解,请参阅下面的

# it'll save only weight 
callback_weights = tf.keras.callbacks.ModelCheckpoint(
filepath="weights.h5", 
save_weights_only=True,  
save_best_only=True)
# it'll save model config + weight = entire trained model 
callback_weights_model = tf.keras.callbacks.ModelCheckpoint(
filepath="model.h5", 
save_weights_only=False, # entire model (config + weight)
save_best_only=True)
# it'll also save model config + weight = entire trained model 
callback_model = tf.keras.callbacks.ModelCheckpoint(
filepath="./tmp/", 
save_weights_only=False, # entire model (config + weight)
save_best_only=True)

这里,callback_weights_modelcallback_model将使用不同的输出格式执行相同的工作。但callback_weights将只保存当前状态或训练的权重文件。因此,如果我们在训练时间使用它们,如下所示:

model.fit(...
callbacks=[callback_weights, 
callback_model, 
callback_weights_model])

那么我们将有以下文件。

tmp/ [asset, variable, .pb]
model.h5
weight.h5

让我们检查

loaded_model = tf.keras.models.load_model('./tmp/')
assert np.allclose(model.predict(x_test), loaded_model.predict(x_test))
loaded_model = tf.keras.models.load_model('/content/model.h5')
assert np.allclose(model.predict(x_test), loaded_model.predict(x_test))
#  ValueError: No model config found in the file at /content/weights.h5.
# loaded_model = tf.keras.models.load_model('/content/weights.h5') 
loaded_model = create_model()
loaded_model.load_weights('/content/weights.h5')
assert np.allclose(model.predict(x_test), loaded_model.predict(x_test))

因此,如果我们需要保存整个模型,我们可以选择上面的callback_weights_modelcallback_model回调。否则,如果我们只需要保存权重文件,我们可以使用callback_weights

相关内容

  • 没有找到相关文章

最新更新