如何在 pytorch lightning 的每个时期从记录器中提取损失和准确性?
How to extract loss and accuracy from logger by each epoch in pytorch lightning?
我想提取所有数据来制作绘图,而不是使用 tensorboard。我的理解是所有有损失和准确度的日志都存储在定义的目录中,因为 tensorboard 绘制了折线图。
%reload_ext tensorboard
%tensorboard --logdir lightning_logs/
但是,我想知道如何从pytorch lightning中的记录器中提取所有日志。接下来是训练部分的代码示例。
#model
ssl_classifier = SSLImageClassifier(lr=lr)
#train
logger = pl.loggers.TensorBoardLogger(name=f'ssl-{lr}-{num_epoch}', save_dir='lightning_logs')
trainer = pl.Trainer(progress_bar_refresh_rate=20,
gpus=1,
max_epochs = max_epoch,
logger = logger,
)
trainer.fit(ssl_classifier, train_loader, val_loader)
我已经确认 trainer.logger.log_dir
返回了似乎保存日志的目录并且 trainer.logger.log_metrics
返回了 <bound method TensorBoardLogger.log_metrics of <pytorch_lightning.loggers.tensorboard.TensorBoardLogger object at 0x7efcb89a3e50>>
。
trainer.logged_metrics
只返回最后一个epoch的日志,如
{'epoch': 19,
'train_acc': tensor(1.),
'train_loss': tensor(0.1038),
'val_acc': 0.6499999761581421,
'val_loss': 1.2171183824539185}
你知道怎么解决吗?
Lightning 不会单独存储所有日志。它所做的只是 streams 它们进入 logger
实例,记录器决定要做什么。
检索所有记录指标的最佳方法是使用自定义回调:
class MetricTracker(Callback):
def __init__(self):
self.collection = []
def on_validation_batch_end(trainer, module, outputs, ...):
vacc = outputs['val_acc'] # you can access them here
self.collection.append(vacc) # track them
def on_validation_epoch_end(trainer, module):
elogs = trainer.logged_metrics # access it here
self.collection.append(elogs)
# do whatever is needed
然后您可以从回调实例访问所有记录的内容
cb = MatricTracker()
Trainer(callbacks=[cb])
cb.collection # do you plotting and stuff
接受的答案并没有根本错误,但不遵循 Pytorch-Lightning 的官方(当前)指南。
这里建议:https://pytorch-lightning.readthedocs.io/en/stable/extensions/logging.html#make-a-custom-logger
建议写一个class像:
from pytorch_lightning.utilities import rank_zero_only
from pytorch_lightning.loggers import LightningLoggerBase
from pytorch_lightning.loggers.base import rank_zero_experiment
class MyLogger(LightningLoggerBase):
@property
def name(self):
return "MyLogger"
@property
@rank_zero_experiment
def experiment(self):
# Return the experiment object associated with this logger.
pass
@property
def version(self):
# Return the experiment version, int or str.
return "0.1"
@rank_zero_only
def log_hyperparams(self, params):
# params is an argparse.Namespace
# your code to record hyperparameters goes here
pass
@rank_zero_only
def log_metrics(self, metrics, step):
# metrics is a dictionary of metric names and values
# your code to record metrics goes here
pass
@rank_zero_only
def save(self):
# Optional. Any code necessary to save logger data goes here
# If you implement this, remember to call `super().save()`
# at the start of the method (important for aggregation of metrics)
super().save()
@rank_zero_only
def finalize(self, status):
# Optional. Any code that needs to be run after training
# finishes goes here
pass
通过查看 class LightningLoggerBase
,可以看到一些可以覆盖的功能建议。
这是我的一个简约记录器。它高度未优化,但将是一个很好的第一枪。如果我改进它,我会编辑它。
import collections
from pytorch_lightning.loggers import LightningLoggerBase
from pytorch_lightning.loggers.base import rank_zero_experiment
from pytorch_lightning.utilities import rank_zero_only
class History_dict(LightningLoggerBase):
def __init__(self):
super().__init__()
self.history = collections.defaultdict(list) # copy not necessary here
# The defaultdict in contrast will simply create any items that you try to access
@property
def name(self):
return "Logger_custom_plot"
@property
def version(self):
return "1.0"
@property
@rank_zero_experiment
def experiment(self):
# Return the experiment object associated with this logger.
pass
@rank_zero_only
def log_metrics(self, metrics, step):
# metrics is a dictionary of metric names and values
# your code to record metrics goes here
for metric_name, metric_value in metrics.items():
if metric_name != 'epoch':
self.history[metric_name].append(metric_value)
else: # case epoch. We want to avoid adding multiple times the same. It happens for multiple losses.
if (not len(self.history['epoch']) or # len == 0:
not self.history['epoch'][-1] == metric_value) : # the last values of epochs is not the one we are currently trying to add.
self.history['epoch'].append(metric_value)
else:
pass
return
def log_hyperparams(self, params):
pass
我想提取所有数据来制作绘图,而不是使用 tensorboard。我的理解是所有有损失和准确度的日志都存储在定义的目录中,因为 tensorboard 绘制了折线图。
%reload_ext tensorboard
%tensorboard --logdir lightning_logs/
但是,我想知道如何从pytorch lightning中的记录器中提取所有日志。接下来是训练部分的代码示例。
#model
ssl_classifier = SSLImageClassifier(lr=lr)
#train
logger = pl.loggers.TensorBoardLogger(name=f'ssl-{lr}-{num_epoch}', save_dir='lightning_logs')
trainer = pl.Trainer(progress_bar_refresh_rate=20,
gpus=1,
max_epochs = max_epoch,
logger = logger,
)
trainer.fit(ssl_classifier, train_loader, val_loader)
我已经确认 trainer.logger.log_dir
返回了似乎保存日志的目录并且 trainer.logger.log_metrics
返回了 <bound method TensorBoardLogger.log_metrics of <pytorch_lightning.loggers.tensorboard.TensorBoardLogger object at 0x7efcb89a3e50>>
。
trainer.logged_metrics
只返回最后一个epoch的日志,如
{'epoch': 19,
'train_acc': tensor(1.),
'train_loss': tensor(0.1038),
'val_acc': 0.6499999761581421,
'val_loss': 1.2171183824539185}
你知道怎么解决吗?
Lightning 不会单独存储所有日志。它所做的只是 streams 它们进入 logger
实例,记录器决定要做什么。
检索所有记录指标的最佳方法是使用自定义回调:
class MetricTracker(Callback):
def __init__(self):
self.collection = []
def on_validation_batch_end(trainer, module, outputs, ...):
vacc = outputs['val_acc'] # you can access them here
self.collection.append(vacc) # track them
def on_validation_epoch_end(trainer, module):
elogs = trainer.logged_metrics # access it here
self.collection.append(elogs)
# do whatever is needed
然后您可以从回调实例访问所有记录的内容
cb = MatricTracker()
Trainer(callbacks=[cb])
cb.collection # do you plotting and stuff
接受的答案并没有根本错误,但不遵循 Pytorch-Lightning 的官方(当前)指南。
这里建议:https://pytorch-lightning.readthedocs.io/en/stable/extensions/logging.html#make-a-custom-logger
建议写一个class像:
from pytorch_lightning.utilities import rank_zero_only
from pytorch_lightning.loggers import LightningLoggerBase
from pytorch_lightning.loggers.base import rank_zero_experiment
class MyLogger(LightningLoggerBase):
@property
def name(self):
return "MyLogger"
@property
@rank_zero_experiment
def experiment(self):
# Return the experiment object associated with this logger.
pass
@property
def version(self):
# Return the experiment version, int or str.
return "0.1"
@rank_zero_only
def log_hyperparams(self, params):
# params is an argparse.Namespace
# your code to record hyperparameters goes here
pass
@rank_zero_only
def log_metrics(self, metrics, step):
# metrics is a dictionary of metric names and values
# your code to record metrics goes here
pass
@rank_zero_only
def save(self):
# Optional. Any code necessary to save logger data goes here
# If you implement this, remember to call `super().save()`
# at the start of the method (important for aggregation of metrics)
super().save()
@rank_zero_only
def finalize(self, status):
# Optional. Any code that needs to be run after training
# finishes goes here
pass
通过查看 class LightningLoggerBase
,可以看到一些可以覆盖的功能建议。
这是我的一个简约记录器。它高度未优化,但将是一个很好的第一枪。如果我改进它,我会编辑它。
import collections
from pytorch_lightning.loggers import LightningLoggerBase
from pytorch_lightning.loggers.base import rank_zero_experiment
from pytorch_lightning.utilities import rank_zero_only
class History_dict(LightningLoggerBase):
def __init__(self):
super().__init__()
self.history = collections.defaultdict(list) # copy not necessary here
# The defaultdict in contrast will simply create any items that you try to access
@property
def name(self):
return "Logger_custom_plot"
@property
def version(self):
return "1.0"
@property
@rank_zero_experiment
def experiment(self):
# Return the experiment object associated with this logger.
pass
@rank_zero_only
def log_metrics(self, metrics, step):
# metrics is a dictionary of metric names and values
# your code to record metrics goes here
for metric_name, metric_value in metrics.items():
if metric_name != 'epoch':
self.history[metric_name].append(metric_value)
else: # case epoch. We want to avoid adding multiple times the same. It happens for multiple losses.
if (not len(self.history['epoch']) or # len == 0:
not self.history['epoch'][-1] == metric_value) : # the last values of epochs is not the one we are currently trying to add.
self.history['epoch'].append(metric_value)
else:
pass
return
def log_hyperparams(self, params):
pass