在keras回调文件名中包含epoch和型号



我正在for循环中用各种超参数迭代训练模型,我想使用keras回调将多个模型保存在一个文件夹中。我已经能够在每个模型中保存模型编号,但现在我也想包括诸如历元编号之类的变量(并每5个历元保存一次模型(。在下面的代码中,每次运行for循环时,我都会在计数器上加1来表示型号。

filepath = root_path + "/saved_models/model_number_{}.h5".format(counter)
history = final_model.fit(x_train, y_train,
batch_size=batch_size,
epochs=epochs,
validation_data=(x_train, y_train),
shuffle=True,
callbacks= tf.keras.callbacks.ModelCheckpoint(filepath=filepath, monitor='val_accuracy', verbose=0, save_weights_only=True, mode='auto', save_freq='epoch'),
)

我也可以制作这个filepath来保存文件名中的epoch编号和准确性,但我不能将其与我的模型连接起来。有办法做到这一点吗?

filepath = s3_root_path + "/saved_models/weights.{epoch:02d}-{val_loss:.2f}.h5"

在每5个时期将模型保存到同一文件夹时,您需要做一些细微的更改:

  1. 根路径名称应与保存所有模型的位置相同(请参阅上面代码中的名称差异Root_paths3_Root_path(
  2. 保存模型时的文件名格式应正确

请检查以下固定代码:

#Created root 'SAVING/saved_models' folder for saving the entire checkpoints.
!mkdir -p SAVING/saved_models
#'SAVING' folder directory
!ls SAVING/   # Output :saved_models

通过更改save_freq=5*batch_size每5个时期保存一次模型

root_path="SAVING/"
checkpoint_path = root_path + "saved_models/weights.{epoch:02d}.h5"  #-{val_loss:.2f}
checkpoint_dir = os.path.dirname(checkpoint_path)
batch_size = 32
cp_callback= tf.keras.callbacks.ModelCheckpoint(filepath=checkpoint_path, 
monitor='val_sparse_categorical_accuracy', 
verbose=1, 
save_weights_only=True,  
save_freq=5*batch_size)

history = model.fit(train_images, train_labels, 
batch_size=batch_size,
epochs=50,
validation_data=(test_images, test_labels),
shuffle=True,
callbacks=[cp_callback],verbose=0)

输出:

Epoch 5: saving model to SAVING/saved_models/weights.05.h5
Epoch 10: saving model to SAVING/saved_models/weights.10.h5
Epoch 15: saving model to SAVING/saved_models/weights.15.h5
Epoch 20: saving model to SAVING/saved_models/weights.20.h5
Epoch 25: saving model to SAVING/saved_models/weights.25.h5
Epoch 30: saving model to SAVING/saved_models/weights.30.h5
Epoch 35: saving model to SAVING/saved_models/weights.35.h5
Epoch 40: saving model to SAVING/saved_models/weights.40.h5
Epoch 45: saving model to SAVING/saved_models/weights.45.h5
Epoch 50: saving model to SAVING/saved_models/weights.50.h5

检查根文件夹目录中保存的检查点:

!ls SAVING/saved_models

输出:

weights.05.h5  weights.15.h5  weights.25.h5  weights.35.h5  weights.45.h5
weights.10.h5  weights.20.h5  weights.30.h5  weights.40.h5  weights.50.h5

最新更新