如何在每个历元保存检查点,并加载一个随机保存的检查点以继续训练



你能帮我写代码吗?每个epoch保存模型(架构和权重(,以及我如何从第五个检查点继续训练我的模型,就像训练epoch一样,从1到25,而不做检查点(我保存的第五个模型(。

classifier = Sequential()
classifier.add(Conv2D(6, (3, 3), input_shape = (30, 30, 3), data_format="channels_last", activation = 'relu'))
classifier.add(MaxPooling2D(pool_size = (2, 2)))
classifier.add(Conv2D(6, (3, 3), activation = 'relu'))
classifier.add(MaxPooling2D(pool_size = (2, 2)))
classifier.add(Flatten())
classifier.add(Dense(units = 128, activation = 'relu'))
classifier.add(Dense(units = 64, activation = 'relu'))
classifier.add(Dense(units = 1, activation = 'sigmoid'))
opt = Adam(learning_rate = 0.001, beta_1 = 0.9, beta_2 = 0.999, epsilon = 1e-08, decay = 0.0)
classifier.compile(optimizer = opt, loss = 'binary_crossentropy', metrics = ['accuracy', precision, recall, fmeasure])
from keras.preprocessing.image import ImageDataGenerator
train_datagen = ImageDataGenerator(rescale = 1./255,
horizontal_flip = True,
vertical_flip = True,
rotation_range = 180)
validation_datagen = ImageDataGenerator(rescale = 1./255)
training_set = train_datagen.flow_from_directory('/home/dataset/training_set',
target_size = (30, 30),
batch_size = 32,
class_mode = 'binary')
validation_set = validation_datagen.flow_from_directory('/home/dataset/validation_set',
target_size = (30, 30),
batch_size = 32,
class_mode = 'binary')
history = classifier.fit_generator(training_set,
steps_per_epoch = 208170,
epochs = 15,
validation_data = validation_set,
validation_steps = 89140)

我假设你的意思是,你想在每个epoch之后保存你的模型和权重,然后在稍后阶段,加载在第五个epoch之后存储的模型和权值。

您可以在TensorFlow中使用SaveModel格式,通常如下所示:

classifier.save()

这将保存架构、权重、有关优化器的信息以及您在compile()中设置的配置

由于您使用的是fit_generator,您只需使用ModelCheckpoint()来保存您的模型,如下所示:

from keras.callbacks import ModelCheckpoint
checkpoint = ModelCheckpoint(path_to_save_to, save_freq = 'epoch', 
save_weights_only = False)
history = classifier.fit_generator(training_set,
steps_per_epoch = 208170,
epochs = 15,
validation_data = validation_set,
validation_steps = 89140,
callbacks = [checkpoint])

您可以格式化路径,使其保存带有epoch/丢失详细信息的模型,如path_name + '-{epoch:02d}-{val_loss:.2f}.h5'

要加载第五个检查点,请执行以下操作:

from keras.models import load_model
classifier = load_model(path_to_fifth_checkpoint)

最新更新