我正试图为我的模型拟合过程编写一个回调,如果Epoch结束后,模型得到了改进,则为我节省了权重和模型作为原型缓冲区。优选地类似于./tmp/weights.hdf5
和./tmp/model.pb
。为此,我使用了两个回调(下面的mnist示例,我使用的是TF2.6.0(:
import tensorflow as tf
import numpy as np
mnist = tf.keras.datasets.mnist
(x_train, y_train),(x_test, y_test) = mnist.load_data()
x_train, x_test = x_train / 255.0, x_test / 255.0
def create_model():
return tf.keras.models.Sequential([
tf.keras.layers.Flatten(input_shape=(28, 28)),
tf.keras.layers.Dense(512, activation='relu'),
tf.keras.layers.Dropout(0.2),
tf.keras.layers.Dense(10, activation='softmax')
])
model = create_model()
model.compile(optimizer='adam',
loss='sparse_categorical_crossentropy',
metrics=['accuracy'],
steps_per_execution=10)
callback_weights = tf.keras.callbacks.ModelCheckpoint(
filepath="./tmp/weights.hdf5", save_weights_only=True, save_best_only=True)
callback_model = tf.keras.callbacks.ModelCheckpoint(
filepath="./tmp/", save_weights_only=False, save_best_only=True)
model.fit(x=x_train,
y=y_train,
epochs=2,
validation_data=(x_test, y_test),
callbacks=[callback_weights, callback_model])
ls tmp
给出:
assets keras_metadata.pb saved_model.pb variables weights.hdf5
我的问题:
- 我必须使用两个回调吗?还是可以在一个回调中完成
- 如何控制protobuffer文件的名称?当
save_weights_only=True
工作时指定文件名,但当我使用save_weights_only=False
时,它会根据filepath
参数创建一个目录
要回答您的上述具体问题,
(1( 。我们不必为此使用两个回调(节省型号或重量(。只需一次回调即可完成。(2( 。当我们设置save_weights_only=False
时,这意味着程序将保存整个模型及其当前状态(或权重(。为了更清楚地了解,请参阅下面的
# it'll save only weight
callback_weights = tf.keras.callbacks.ModelCheckpoint(
filepath="weights.h5",
save_weights_only=True,
save_best_only=True)
# it'll save model config + weight = entire trained model
callback_weights_model = tf.keras.callbacks.ModelCheckpoint(
filepath="model.h5",
save_weights_only=False, # entire model (config + weight)
save_best_only=True)
# it'll also save model config + weight = entire trained model
callback_model = tf.keras.callbacks.ModelCheckpoint(
filepath="./tmp/",
save_weights_only=False, # entire model (config + weight)
save_best_only=True)
这里,callback_weights_model
和callback_model
将使用不同的输出格式执行相同的工作。但callback_weights
将只保存当前状态或训练的权重文件。因此,如果我们在训练时间使用它们,如下所示:
model.fit(...
callbacks=[callback_weights,
callback_model,
callback_weights_model])
那么我们将有以下文件。
tmp/ [asset, variable, .pb]
model.h5
weight.h5
让我们检查
loaded_model = tf.keras.models.load_model('./tmp/')
assert np.allclose(model.predict(x_test), loaded_model.predict(x_test))
loaded_model = tf.keras.models.load_model('/content/model.h5')
assert np.allclose(model.predict(x_test), loaded_model.predict(x_test))
# ValueError: No model config found in the file at /content/weights.h5.
# loaded_model = tf.keras.models.load_model('/content/weights.h5')
loaded_model = create_model()
loaded_model.load_weights('/content/weights.h5')
assert np.allclose(model.predict(x_test), loaded_model.predict(x_test))
因此,如果我们需要保存整个模型,我们可以选择上面的callback_weights_model
或callback_model
回调。否则,如果我们只需要保存权重文件,我们可以使用callback_weights
。