如何达到满足提前停止标准的纪元号

how to reach the epoch number in which the early stopping criteria is met

如果满足某些条件,我会使用回调来停止训练过程。我想知道如何访问由于回调而停止训练的纪元号。

import numpy as np
import random
import tensorflow as tf
from tensorflow import keras 


class stopAtLossValue(tf.keras.callbacks.Callback):
        def on_batch_end(self, batch, logs={}):
            eps = 0.01 
           
            if logs.get('loss') <= eps:
                 self.model.stop_training = True
                    
training_input=  np.random.random ([30,10])
training_output = np.random.random ([30,1])



model = tf.keras.Sequential([  
    tf.keras.layers.Flatten(input_shape=(10,)),
    tf.keras.layers.Dense(15,activation=tf.keras.activations.linear),
    tf.keras.layers.Dense(15, activation='relu'),
    tf.keras.layers.Dense(1) 
])   
                               
model.compile(loss="mse",optimizer = tf.keras.optimizers.Adam(learning_rate=0.01))

hist = model.fit(training_input, training_output, epochs=100, batch_size=100,  verbose=1, callbacks=[stopAtLossValue()])
    

对于这个例子,我的训练在第 66 个 epoch 完成,因为损失小于 0.01。

Epoch 66/100
1/1 [==============================] - 0s 5ms/step - loss: 0.0099
----------------------------------------------------------------- 

简单的方法是获取 history.history 对象的长度:

len(model.history.history['loss'])

更复杂的方法是从优化器获取迭代次数:

model.optimizer._iterations

如果你想在回调中获取纪元号,你应该使用on_epoch_end函数而不是on_batch_end。回调函数见下面代码

def on_epoch_end(self, epoch, logs={}):
    eps = 0.01 
    print(epoch) # This will print the number of epoch
    if logs.get('loss') <= eps:
          self.model.stop_training = True