tqdm 在上次迭代后没有更新新的 set_postfix

tqdm not updating new set_postfix after last iteration

我想为 Pytorch 训练创建一个类似于 tensorflow.keras 的 tqdm 进度条。 这是我的要求:

  1. 对于每个训练步骤,它都会显示进度和训练损失
  2. 在最后一次迭代时,它将提供验证损失的附加信息

我正在学习本教程 https://towardsdatascience.com/training-models-with-a-progress-a-bar-2b664de3e13e 并且我设法满足了第一个要求。

唯一缺少的功能是在每次训练后给出验证损失。
这是我的代码:

for epoch in range(EPOCH):
    with tqdm(train_dataloader, unit=" batch") as tepoch:
        train_loss = 0
        val_loss = 0
        
        # Training part
        for idx,batch in enumerate(tepoch) :
            tepoch.set_description(f"Epoch {epoch}")
            <do training stuff>
            train_loss += loss.item()
            tepoch.set_postfix({'Train loss': loss.item()})
         train_loss /= (idx+1)

         # Evaluation part
         with torch.no_grad():
            for idx,batch in enumerate(val_dataloader) :
            <do inference stuff>
            val_loss += loss.item()
         val_loss /= (idx+1)

    tepoch.set_postfix({'Train loss': train_loss,"Val loss":val_loss})

此代码与给此:

Epoch 0: 100%|██████████| 188/188 [01:22<00:00,  2.28 batch/s, Train loss=0.511]
Epoch 1: 100%|██████████| 188/188 [01:22<00:00,  2.28 batch/s, Train loss=0.298]

但我想要的是:

Epoch 0: 100%|██████████| 188/188 [01:22<00:00,  2.28 batch/s, Train loss=0.511, Val loss={number}]
Epoch 1: 100%|██████████| 188/188 [01:22<00:00,  2.28 batch/s, Train loss=0.298, Val loss={number}]

我看过这个 SO tqdm update after last iteration 但对我来说这似乎不可行,因为验证损失是在所有训练完成后计算的。

工作示例

import random
import time
EPOCH = 100
BATCH_SIZE = 10
for epoch in range(EPOCH):
  with tqdm(total=BATCH_SIZE, unit=" batch") as tepoch:
        tepoch.set_description(f"Epoch {epoch+1}")
        train_loss = 0
        val_loss = 0
        
        # Training part
        for idx,batch in enumerate(range(BATCH_SIZE)) :
            tepoch.update(1)
            # do training stuff
            time.sleep(0.5)
            loss = random.choice(range(10))
            train_loss += loss
            tepoch.set_postfix({'Batch': idx+1, 'Train loss (in progress)': loss})

        train_loss /= (idx+1)

        # Evaluation part
        time.sleep(0.5)
        val_loss += random.choice(range(10))

        val_loss /= (idx+1)

        tepoch.set_postfix({'Train loss (final)': train_loss, 'Val loss': val_loss})
        tepoch.close()

输出

Epoch 1: 100% 10/10 [00:11<00:00, 1.18s/ batch, Train loss (final)=4.4, Val loss=0.5]
Epoch 2: 100% 10/10 [00:06<00:00, 1.62 batch/s, Train loss (final)=4.7, Val loss=0.3]
Epoch 3: 80% 8/10 [00:03<00:00, 2.07 batch/s, Batch=7, Train loss (in progress)=9]