如何在 pytorch lightning 的每个时期从记录器中提取损失和准确性?

How to extract loss and accuracy from logger by each epoch in pytorch lightning?

我想提取所有数据来制作绘图,而不是使用 tensorboard。我的理解是所有有损失和准确度的日志都存储在定义的目录中,因为 tensorboard 绘制了折线图。

%reload_ext tensorboard
%tensorboard --logdir lightning_logs/

但是,我想知道如何从pytorch lightning中的记录器中提取所有日志。接下来是训练部分的代码示例。

#model
ssl_classifier = SSLImageClassifier(lr=lr)

#train
logger = pl.loggers.TensorBoardLogger(name=f'ssl-{lr}-{num_epoch}', save_dir='lightning_logs')

trainer = pl.Trainer(progress_bar_refresh_rate=20,
                            gpus=1,
                            max_epochs = max_epoch,
                            logger = logger,
                            )

trainer.fit(ssl_classifier, train_loader, val_loader)

我已经确认 trainer.logger.log_dir 返回了似乎保存日志的目录并且 trainer.logger.log_metrics 返回了 <bound method TensorBoardLogger.log_metrics of <pytorch_lightning.loggers.tensorboard.TensorBoardLogger object at 0x7efcb89a3e50>>

trainer.logged_metrics只返回最后一个epoch的日志,如

{'epoch': 19,
 'train_acc': tensor(1.),
 'train_loss': tensor(0.1038),
 'val_acc': 0.6499999761581421,
 'val_loss': 1.2171183824539185}

你知道怎么解决吗?

Lightning 不会单独存储所有日志。它所做的只是 streams 它们进入 logger 实例,记录器决定要做什么。

检索所有记录指标的最佳方法是使用自定义回调:

class MetricTracker(Callback):

  def __init__(self):
    self.collection = []

  def on_validation_batch_end(trainer, module, outputs, ...):
    vacc = outputs['val_acc'] # you can access them here
    self.collection.append(vacc) # track them

  def on_validation_epoch_end(trainer, module):
    elogs = trainer.logged_metrics # access it here
    self.collection.append(elogs)
    # do whatever is needed

然后您可以从回调实例访问所有记录的内容

cb = MatricTracker()
Trainer(callbacks=[cb])

cb.collection # do you plotting and stuff

接受的答案并没有根本错误,但不遵循 Pytorch-Lightning 的官方(当前)指南。

这里建议:https://pytorch-lightning.readthedocs.io/en/stable/extensions/logging.html#make-a-custom-logger

建议写一个class像:

from pytorch_lightning.utilities import rank_zero_only
from pytorch_lightning.loggers import LightningLoggerBase
from pytorch_lightning.loggers.base import rank_zero_experiment


class MyLogger(LightningLoggerBase):
    @property
    def name(self):
        return "MyLogger"

    @property
    @rank_zero_experiment
    def experiment(self):
        # Return the experiment object associated with this logger.
        pass

    @property
    def version(self):
        # Return the experiment version, int or str.
        return "0.1"

    @rank_zero_only
    def log_hyperparams(self, params):
        # params is an argparse.Namespace
        # your code to record hyperparameters goes here
        pass

    @rank_zero_only
    def log_metrics(self, metrics, step):
        # metrics is a dictionary of metric names and values
        # your code to record metrics goes here
        pass

    @rank_zero_only
    def save(self):
        # Optional. Any code necessary to save logger data goes here
        # If you implement this, remember to call `super().save()`
        # at the start of the method (important for aggregation of metrics)
        super().save()

    @rank_zero_only
    def finalize(self, status):
        # Optional. Any code that needs to be run after training
        # finishes goes here
        pass

通过查看 class LightningLoggerBase,可以看到一些可以覆盖的功能建议。

这是我的一个简约记录器。它高度未优化,但将是一个很好的第一枪。如果我改进它,我会编辑它。

import collections

from pytorch_lightning.loggers import LightningLoggerBase
from pytorch_lightning.loggers.base import rank_zero_experiment
from pytorch_lightning.utilities import rank_zero_only

class History_dict(LightningLoggerBase):
    def __init__(self):
        super().__init__()

        self.history = collections.defaultdict(list) # copy not necessary here  
        # The defaultdict in contrast will simply create any items that you try to access

    @property
    def name(self):
        return "Logger_custom_plot"

    @property
    def version(self):
        return "1.0"

    @property
    @rank_zero_experiment
    def experiment(self):
        # Return the experiment object associated with this logger.
        pass

@rank_zero_only
def log_metrics(self, metrics, step):
    # metrics is a dictionary of metric names and values
    # your code to record metrics goes here
    for metric_name, metric_value in metrics.items():
        if metric_name != 'epoch':
            self.history[metric_name].append(metric_value)
        else: # case epoch. We want to avoid adding multiple times the same. It happens for multiple losses.
            if (not len(self.history['epoch']) or    # len == 0:
                not self.history['epoch'][-1] == metric_value) : # the last values of epochs is not the one we are currently trying to add.
                self.history['epoch'].append(metric_value)
            else:
                pass
    return

    def log_hyperparams(self, params):
        pass