Keras 中的 .fit() 方法触发了多少次损失函数

How many times the loss function is triggered from .fit() method in Keras

我正在尝试在自定义损失函数中进行一些自定义计算。但是当我记录自定义损失函数的语句时,自定义损失函数似乎只被调用一次(在 .fit() 方法的开头)。

损失函数示例:

def loss(y_true, y_pred):
    print("--- Starting of the loss function ---")
    print(y_true)
    loss = tf.keras.losses.mean_squared_error(y_true, y_pred)
    print("--- Ending of the loss function ---")
    return loss

使用回调检查批处理何时开始和结束:

class monitor(Callback):
    def on_batch_begin(self, batch, logs=None):
        print("\n >> Starting a new batch (batch index) :: ", batch)

    def on_batch_end(self, batch, logs=None):
        print(">> Ending a batch (batch index) :: ", batch)

.fit() 方法用作:

history = model.fit(
    x=[inputs],
    y=[outputs],
    shuffle=False,
    batch_size=BATCH_SIZE,
    epochs=NUM_EPOCH,
    verbose=1,
    callbacks=[monitor()]
)

使用的参数:

BATCH_SIZE = 128
NUM_EPOCH = 3
inputs.shape = (512, 8)
outputs.shape = (512, 2)

输出:

Epoch 1/3
 >> Starting a new batch (batch index) ::  0
--- Starting of the loss function ---
Tensor("IteratorGetNext:5", shape=(128, 2), dtype=float32)
--- Ending of the loss function ---
--- Starting of the loss function ---
Tensor("IteratorGetNext:5", shape=(128, 2), dtype=float32)
--- Ending of the loss function ---
1/4 [======>.......................] - ETA: 0s - loss: 0.5551
 >> Ending a batch (batch index) ::  0

 >> Starting a new batch (batch index) ::  1
 >> Ending a batch (batch index) ::  1

 >> Starting a new batch (batch index) ::  2
 >> Ending a batch (batch index) ::  2

 >> Starting a new batch (batch index) ::  3
 >> Ending a batch (batch index) ::  3
4/4 [==============================] - 0s 5ms/step - loss: 0.5307

Epoch 2/3
 >> Starting a new batch (batch index) ::  0
1/4 [======>.......................] - ETA: 0s - loss: 0.5443
 >> Ending a batch (batch index) ::  0

 >> Starting a new batch (batch index) ::  1
 >> Ending a batch (batch index) ::  1

 >> Starting a new batch (batch index) ::  2
 >> Ending a batch (batch index) ::  2

 >> Starting a new batch (batch index) ::  3
 >> Ending a batch (batch index) ::  3
4/4 [==============================] - 0s 5ms/step - loss: 0.5246

Epoch 3/3
 >> Starting a new batch (batch index) ::  0
1/4 [======>.......................] - ETA: 0s - loss: 0.5433
 >> Ending a batch (batch index) ::  0

 >> Starting a new batch (batch index) ::  1
 >> Ending a batch (batch index) ::  1

 >> Starting a new batch (batch index) ::  2
 >> Ending a batch (batch index) ::  2

 >> Starting a new batch (batch index) ::  3
 >> Ending a batch (batch index) ::  3
4/4 [==============================] - 0s 4ms/step - loss: 0.5219

为什么自定义损失函数只在开始调用而不是每批计算都调用?而且我也想知道什么时候损失函数是called/triggered?

The loss function debug messages were printed only at the beginning of the training.

这是因为为了性能,您的损失函数在内部被转换为张量流图,并且 python 打印函数仅在您的函数被跟踪时才起作用。即它仅在训练开始时打印,这意味着当时正在跟踪您的损失函数。请参阅以下页面以获取更多信息:https://www.tensorflow.org/guide/function

简答:要正确打印,请使用 tf.print() 而不是 print()

And I would also like to know when the loss function is called/triggered?

使用 tf.print() 后,调试消息将正确打印。你会看到你的损失函数每一步至少被调用一次以获得损失值和梯度。