Keras 提前停止和监控

Keras Early Stop and Monitor

如何仅在监控值大于阈值时才激活keras.EarlyStopping。例如,如何仅在 val accuracy > 0.9 时触发 earlystop = EarlyStopping(monitor='val_accuracy', min_delta=0.0001, patience=5, verbose=1, mode='auto')?另外,我应该如何正确导出中间模型,例如每 50 个时期?

我没有太多知识,EarlyStopping 的基线参数似乎意味着阈值以外的东西。

在指标阈值处停止的最佳方法是使用 Keras 自定义回调。下面是将完成这项工作的自定义回调代码(SOMT - 停止在指标阈值上)。 SOMT 回调可用于根据训练准确度或验证准确度或两者的值结束训练。 使用形式为callbacks=[SOMT(model, train_thold, valid_thold)] where

  • model 是您编译的模型的名称
  • train_thold 是一个浮点数。为了有条件地停止训练,模型必须达到的准确度值(以百分比为单位)
  • valid_threshold 是一个浮点数。它是模型必须达到的验证准确度值(以百分比表示) 为了有条件地停止训练

注意停止训练 train_thold 和 valid_thold 必须在同一个纪元中超过。
如果您想停止仅基于训练准确度的训练,请将 valid_thold 设置为 0.0.
同样,如果您只想停止对验证精度集的训练 train_thold= 0.0.
请注意,如果在同一纪元中未达到两个阈值,训练将持续到纪元的值。如果在同一时期达到两个阈值,则停止训练并将模型权重设置为该时期的权重。
举个例子,当
时你想停止训练 训练准确率达到或超过95%,验证准确率至少达到85%
那么代码将是 callbacks=[SOMT(my_model, .95, .85)]

# the callback uses the time module so
import time
class SOMT(keras.callbacks.Callback):
    def __init__(self, model,  train_thold, valid_thold):
        super(SOMT, self).__init__()
        self.model=model        
        self.train_thold=train_thold
        self.valid_thold=valid_thold
        
    def on_train_begin(self, logs=None):
        print('Starting Training - training will halt if training accuracy achieves or exceeds ', self.train_thold)
        print ('and validation accuracy meets or exceeds ', self.valid_thold) 
        msg='{0:^8s}{1:^12s}{2:^12s}{3:^12s}{4:^12s}{5:^12s}'.format('Epoch', 'Train Acc', 'Train Loss','Valid Acc','Valid_Loss','Duration')
        print (msg)                                                                                    
            
    def on_train_batch_end(self, batch, logs=None):
        acc=logs.get('accuracy')* 100  # get training accuracy 
        loss=logs.get('loss')
        msg='{0:1s}processed batch {1:4s}  training accuracy= {2:8.3f}  loss: {3:8.5f}'.format(' ', str(batch),  acc, loss)
        print(msg, '\r', end='') # prints over on the same line to show running batch count 
        
    def on_epoch_begin(self,epoch, logs=None):
        self.now= time.time()
        
    def on_epoch_end(self,epoch, logs=None): 
        later=time.time()
        duration=later-self.now 
        tacc=logs.get('accuracy')           
        vacc=logs.get('val_accuracy')
        tr_loss=logs.get('loss')
        v_loss=logs.get('val_loss')
        ep=epoch+1
        print(f'{ep:^8.0f} {tacc:^12.2f}{tr_loss:^12.4f}{vacc:^12.2f}{v_loss:^12.4f}{duration:^12.2f}')
        if tacc>= self.train_thold and vacc>= self.valid_thold:
            print( f'\ntraining accuracy and validation accuracy reached the thresholds on epoch {epoch + 1}' )
            self.model.stop_training = True # stop training

注意在编译模型之后和拟合模型之前包含此代码

train_thold= .98
valid_thold=.95
callbacks=[SOMT(model, train_thold, valid_thold)]
# training will halt if train accuracy meets or exceeds train_thold
# AND validation accuracy meets or exceeds valid_thold in the SAME epoch

在 model.fit 中包括 callbacks=callbacks,verbose=0。 在每个纪元结束时,回调会生成一个电子表格,例如

形式的打印输出
Epoch   Train Acc   Train Loss  Valid Acc   Valid_Loss   Duration  
   1         0.90       4.3578       0.95       2.3982      84.16    
   2         0.95       1.6816       0.96       1.1039      63.13    
   3         0.97       0.7794       0.95       0.5765      63.40 
training accuracy and validation accuracy reached the thresholds on epoch 3.