如何激活keras。提前仅当监视值大于阈值时才停止。例如,如何仅在 val 精度> 0.9 时触发earlystop = EarlyStopping(monitor='val_accuracy', min_delta=0.0001, patience=5, verbose=1, mode='auto')
?另外,我应该如何正确导出中间模型,例如每 50 个 epoch?
我没有太多的知识,早期停止的基线论点似乎意味着阈值以外的其他东西。
在指标阈值上停止的最佳方法是使用 Keras 自定义回调。下面是将完成这项工作的自定义回调(SOMT - 在指标阈值处停止)的代码。 SOMT 回调可用于根据训练准确度和/或验证准确度的值结束训练。 使用形式是 callbacks=[SOMT(model, train_thold, valid_thold)] 其中
模型- 是符合模型的名称
- train_thold是浮子。它是模型必须达到的准确度值(以百分比为单位),以便有条件地停止训练
- valid_threshold是浮子。模型必须达到验证准确性的值(以百分比为单位) 为了有条件地停止训练
如果要仅根据训练精度停止训练,请将valid_thold设置为 0.0.
同样,如果您只想停止训练,则仅根据验证精度集 train_thold= 0.0.
注意,如果在同一 epoch 中未达到两个阈值,训练将继续,直到 epoch 的值。如果在同一纪元中达到两个阈值,则训练将停止,并且模型权重将设置为该纪元的权重.
例如,当
训练准确率达到或超过 95% 并且验证准确度已达到至少 85%时,您希望停止训练
则代码将是回调=[SOMT(my_model, .95, .85)]
# the callback uses the time module so
import time
class SOMT(keras.callbacks.Callback):
def __init__(self, model, train_thold, valid_thold):
super(SOMT, self).__init__()
self.model=model
self.train_thold=train_thold
self.valid_thold=valid_thold
def on_train_begin(self, logs=None):
print('Starting Training - training will halt if training accuracy achieves or exceeds ', self.train_thold)
print ('and validation accuracy meets or exceeds ', self.valid_thold)
msg='{0:^8s}{1:^12s}{2:^12s}{3:^12s}{4:^12s}{5:^12s}'.format('Epoch', 'Train Acc', 'Train Loss','Valid Acc','Valid_Loss','Duration')
print (msg)
def on_train_batch_end(self, batch, logs=None):
acc=logs.get('accuracy')* 100 # get training accuracy
loss=logs.get('loss')
msg='{0:1s}processed batch {1:4s} training accuracy= {2:8.3f} loss: {3:8.5f}'.format(' ', str(batch), acc, loss)
print(msg, 'r', end='') # prints over on the same line to show running batch count
def on_epoch_begin(self,epoch, logs=None):
self.now= time.time()
def on_epoch_end(self,epoch, logs=None):
later=time.time()
duration=later-self.now
tacc=logs.get('accuracy')
vacc=logs.get('val_accuracy')
tr_loss=logs.get('loss')
v_loss=logs.get('val_loss')
ep=epoch+1
print(f'{ep:^8.0f} {tacc:^12.2f}{tr_loss:^12.4f}{vacc:^12.2f}{v_loss:^12.4f}{duration:^12.2f}')
if tacc>= self.train_thold and vacc>= self.valid_thold:
print( f'ntraining accuracy and validation accuracy reached the thresholds on epoch {epoch + 1}' )
self.model.stop_training = True # stop training
注意 在编译模型之后和拟合模型之前包括此代码
train_thold= .98
valid_thold=.95
callbacks=[SOMT(model, train_thold, valid_thold)]
# training will halt if train accuracy meets or exceeds train_thold
# AND validation accuracy meets or exceeds valid_thold in the SAME epoch
在 model.fit 中包括回调=回调,详细=0。 在每个纪元结束时,回调会生成一个电子表格,例如表单的打印输出
Epoch Train Acc Train Loss Valid Acc Valid_Loss Duration
1 0.90 4.3578 0.95 2.3982 84.16
2 0.95 1.6816 0.96 1.1039 63.13
3 0.97 0.7794 0.95 0.5765 63.40
training accuracy and validation accuracy reached the thresholds on epoch 3.