fast.ai:如何在验证期间获得每批损失

fast.ai: How to get per-batch losses during validation

我正在使用由 fast.ai 实现的 AWD-LSTM 模型。现在我能够得到验证损失所有批次的平均值

from fastai.text import *  
data_lm = (TextList.from_csv("data/penn", "concatenated.csv", cols='text')
    .split_from_df("is_valid")
    .label_for_lm()
    .databunch())  
learner = language_model_learner(data_lm, AWD_LSTM, pretrained=False)  
learner.fit_one_cycle(10, 1e-2)  
learner.export("exported.pkl")

itemlist = TextList.from_csv("data/penn", "concatenated.csv", cols='text')  
newlearner = load_learner(path="data/penn", test=itemlist, file="exported.pkl")  
loss, acc = newlearner.validate(newlearner.data.test_dl)

但是如何获得每批次的验证损失

我尝试过的事情包括:
1.尝试附上一个Recorder。但似乎 Recorder 不监控验证,learner.losses 只存储每批次的火车损失。
2. 使用fastai.basic_train.loss_batch(learner.model, xb, yb, learner.loss_func),其中xbyb就是torch.Tensors。但是这种方法给出了以下 AttributeError

---------------------------------------------------------------------------
AttributeError                            Traceback (most recent call last)
<ipython-input-14-5fa44a2d640f> in <module>
      5 xb = torch.ones((64, 20)).cuda().long()
      6 yb = torch.ones((64, 20)).cuda().long()
----> 7 loss_batch(newlearner.model, xb, yb, newlearner.loss_func)

~/anaconda3/envs/pytorch12/lib/python3.7/site-packages/fastai/basic_train.py in loss_batch(model, xb, yb, loss_func, opt, cb_handler)
     27     out = cb_handler.on_loss_begin(out)
     28     if not loss_func: return to_detach(out), to_detach(yb[0])
---> 29     loss = loss_func(out, *yb)
     30 
     31     if opt is not None:

~/anaconda3/envs/pytorch12/lib/python3.7/site-packages/fastai/layers.py in __call__(self, input, target, **kwargs)
    237 
    238     def __call__(self, input:Tensor, target:Tensor, **kwargs)->Rank0Tensor:
--> 239         input = input.transpose(self.axis,-1).contiguous()
    240         target = target.transpose(self.axis,-1).contiguous()
    241         if self.floatify: target = target.float()

AttributeError: 'tuple' object has no attribute 'transpose'

我现在找到了解决方案。

cb_handler = CallbackHandler(newlearner.callbacks + [], None)
losses, acc = fastai.basic_train.validate(
    newlearner.model, 
    newlearner.data.test_dl, 
    newlearner.loss_func, 
    cb_handler,  # This is necessary
    average=False)