PyTorch Lightning 问题中的 EarlyStopping 回调

EarlyStopping callback in PyTorch Lightning problem

我尝试在 PyTorch Lightning 中训练神经网络模型,但在执行 EarlyStopping 回调的验证步骤中训练失败。

模型的相关部分如下。请特别参阅 validation_step,它必须记录提前停止所需的指标。

class DialogActsLightningModel(pl.LightningModule):

    def __init__(self, config):
        super().__init__()

        self.config = config

        self.model = ContextAwareDAC(
            model_name=self.config['model_name'],
            hidden_size=self.config['hidden_size'],
            num_classes=self.config['num_classes'],
            device=self.config['device']
        )
        self.tokenizer = AutoTokenizer.from_pretrained(config['model_name'])

    def forward(self, batch):
        logits = self.model(batch)
        return logits

    def validation_step(self, batch, batch_idx):
        input_ids, attention_mask, targets = batch['input_ids'], batch['attention_mask'], batch['label'].squeeze()
        logits = self(batch)
        loss = F.cross_entropy(logits, targets)
        acc = accuracy_score(targets.cpu(), logits.argmax(dim=1).cpu())
        f1 = f1_score(targets.cpu(), logits.argmax(dim=1).cpu(), average=self.config['average'])
        precision = precision_score(targets.cpu(), logits.argmax(dim=1).cpu(), average=self.config['average'])
        recall = recall_score(targets.cpu(), logits.argmax(dim=1).cpu(), average=self.config['average'])
        return {"val_loss": loss, "val_accuracy": torch.tensor([acc]), "val_f1": torch.tensor([f1]),
                "val_precision": torch.tensor([precision]), "val_recall": torch.tensor([recall])}

    def validation_epoch_end(self, outputs):
        avg_loss = torch.stack([x['val_loss'] for x in outputs]).mean()
        avg_acc = torch.stack([x['val_accuracy'] for x in outputs]).mean()
        avg_f1 = torch.stack([x['val_f1'] for x in outputs]).mean()
        avg_precision = torch.stack([x['val_precision'] for x in outputs]).mean()
        avg_recall = torch.stack([x['val_recall'] for x in outputs]).mean()
        wandb.log({"val_loss": avg_loss, "val_accuracy": avg_acc, "val_f1": avg_f1, "val_precision": avg_precision,
                   "val_recall": avg_recall})
        return {"val_loss": avg_loss, "val_accuracy": avg_acc, "val_f1": avg_f1, "val_precision": avg_precision,
                "val_recall": avg_recall}

当我运行按以下方式训练时:

from pytorch_lightning.callbacks import EarlyStopping, ModelCheckpoint
from pytorch_lightning.loggers import WandbLogger
import pytorch_lightning as pl
import os

from trainer import DialogActsLightningModel
import wandb

wandb.init()

logger = WandbLogger(
    name="model_name",
    entity='myname',
    save_dir=config["save_dir"],
    project=config["project"],
    log_model=True,
)
early_stopping = EarlyStopping(
    monitor="val_accuracy",
    min_delta=0.001,
    patience=5,
)

model = DialogActsLightningModel(config=config)

trainer = pl.Trainer(
    logger=logger,
    gpus=[0],
    checkpoint_callback=True,
    callbacks=[early_stopping],
    default_root_dir=MODELS_DIRECTORY,
    max_epochs=config["epochs"],
    precision=config["precision"],
    limit_train_batches=10, # run for only 10 batches, debug mode
    limit_test_batches=10,
    limit_val_batches=10
)

trainer.fit(model)

我遇到了一个错误,但模型应该在验证步骤中计算并记录了指标“val_accuracy”。

Epoch 0: 100%
20/20 [00:28<00:00, 1.41s/it, loss=1.95]

/opt/conda/lib/python3.9/site-packages/pytorch_lightning/callbacks/early_stopping.py in _validate_condition_metric(self, logs)
    149         if monitor_val is None:
    150             if self.strict:
--> 151                 raise RuntimeError(error_msg)
    152             if self.verbose > 0:
    153                 rank_zero_warn(error_msg, RuntimeWarning)

RuntimeError: Early stopping conditioned on metric `val_accuracy` which is not available. Pass in or modify your `EarlyStopping` callback to use any of the following: ``

我做错了什么?如何解决?

如果您使用 pytorch-lightning 最新版本,您应该在调用提前停止或类似功能时记录 val_accuracy 或 val_loss。有关更多信息,请查看代码 below.i 认为这对您肯定有帮助...

def validation_step(self, batch, batch_idx):
    input_ids, attention_mask, targets = batch['input_ids'], batch['attention_mask'], batch['label'].squeeze()
    logits = self(batch)
    loss = F.cross_entropy(logits, targets)
    acc = accuracy_score(targets.cpu(), logits.argmax(dim=1).cpu())
    f1 = f1_score(targets.cpu(), logits.argmax(dim=1).cpu(), average=self.config['average'])
    precision = precision_score(targets.cpu(), logits.argmax(dim=1).cpu(), average=self.config['average'])
    recall = recall_score(targets.cpu(), logits.argmax(dim=1).cpu(), average=self.config['average'])

    ##########################################################################
    ##########################################################################
    self.log("val_accuracy", torch.tensor([acc])     # try this line
    ##########################################################################
    ##########################################################################

    return {"val_loss": loss, "val_accuracy": torch.tensor([acc]), "val_f1": torch.tensor([f1]),
            "val_precision": torch.tensor([precision]), "val_recall": torch.tensor([recall])}

如果Post有用请投票