火炬闪电 epoch_end/validation_epoch_end

pytorch lightning epoch_end/validation_epoch_end

谁能分解代码并向我解释一下?需要帮助的部分用“#This part”表示。我将不胜感激任何帮助谢谢

def validation_epoch_end(self, outputs):
    batch_losses = [x["val_loss"]for x in outputs] #This part
    epoch_loss = torch.stack(batch_losses).mean() 
    batch_accs =  [x["val_acc"]for x in outputs]   #This part
    epoch_acc = torch.stack(batch_accs).mean()   
    return {'val_loss': epoch_loss.item(), 'val_acc': epoch_acc.item()}

def epoch_end(self, epoch, result):
    print("Epoch [{}], val_loss: {:.4f}, val_acc: {:.4f}".format( epoch,result['val_loss'], result['val_acc'])) #This part

在您提供的代码段中,outputs 是一个 list 包含 dicts 元素,这些元素似乎至少包含键 "val_loss""val_acc"。假设它们分别对应于验证损失和验证准确性是公平的。

这两行(用 # This path 注释注释)对应于遍历 outputs 列表中的元素的列表理解。第一个收集输出中每个元素的键 "val_loss" 的值。这次第二个收集 "val_acc" 键的值。

一个最小的例子是:

## before
outputs = [{'val_loss': tensor(a), # element 0
            'val_acc': tensor(b)},

           {'val_loss': tensor(c), # element 1
            'val_acc': tensor(d)}]

## after
batch_losses = [tensor(a), tensor(c)]
batch_acc = [tensor(b), tensor(d)]

根据结构,我假设您使用的是 pytorch_lightning

validation_epoch_end() 将从 validation_step() 收集输出,因此它是 dictlist,其长度为验证数据加载器中的批次数。因此,前两个 #This part 只是从验证集中展开结果。

epoch_end()validation_epoch_end().

捕获结果 {'val_loss': epoch_loss.item(), 'val_acc': epoch_acc.item()}