PyTorch Lightning 问题中的 EarlyStopping 回调
EarlyStopping callback in PyTorch Lightning problem
我尝试在 PyTorch Lightning 中训练神经网络模型,但在执行 EarlyStopping 回调的验证步骤中训练失败。
模型的相关部分如下。请特别参阅 validation_step
,它必须记录提前停止所需的指标。
class DialogActsLightningModel(pl.LightningModule):
def __init__(self, config):
super().__init__()
self.config = config
self.model = ContextAwareDAC(
model_name=self.config['model_name'],
hidden_size=self.config['hidden_size'],
num_classes=self.config['num_classes'],
device=self.config['device']
)
self.tokenizer = AutoTokenizer.from_pretrained(config['model_name'])
def forward(self, batch):
logits = self.model(batch)
return logits
def validation_step(self, batch, batch_idx):
input_ids, attention_mask, targets = batch['input_ids'], batch['attention_mask'], batch['label'].squeeze()
logits = self(batch)
loss = F.cross_entropy(logits, targets)
acc = accuracy_score(targets.cpu(), logits.argmax(dim=1).cpu())
f1 = f1_score(targets.cpu(), logits.argmax(dim=1).cpu(), average=self.config['average'])
precision = precision_score(targets.cpu(), logits.argmax(dim=1).cpu(), average=self.config['average'])
recall = recall_score(targets.cpu(), logits.argmax(dim=1).cpu(), average=self.config['average'])
return {"val_loss": loss, "val_accuracy": torch.tensor([acc]), "val_f1": torch.tensor([f1]),
"val_precision": torch.tensor([precision]), "val_recall": torch.tensor([recall])}
def validation_epoch_end(self, outputs):
avg_loss = torch.stack([x['val_loss'] for x in outputs]).mean()
avg_acc = torch.stack([x['val_accuracy'] for x in outputs]).mean()
avg_f1 = torch.stack([x['val_f1'] for x in outputs]).mean()
avg_precision = torch.stack([x['val_precision'] for x in outputs]).mean()
avg_recall = torch.stack([x['val_recall'] for x in outputs]).mean()
wandb.log({"val_loss": avg_loss, "val_accuracy": avg_acc, "val_f1": avg_f1, "val_precision": avg_precision,
"val_recall": avg_recall})
return {"val_loss": avg_loss, "val_accuracy": avg_acc, "val_f1": avg_f1, "val_precision": avg_precision,
"val_recall": avg_recall}
当我运行按以下方式训练时:
from pytorch_lightning.callbacks import EarlyStopping, ModelCheckpoint
from pytorch_lightning.loggers import WandbLogger
import pytorch_lightning as pl
import os
from trainer import DialogActsLightningModel
import wandb
wandb.init()
logger = WandbLogger(
name="model_name",
entity='myname',
save_dir=config["save_dir"],
project=config["project"],
log_model=True,
)
early_stopping = EarlyStopping(
monitor="val_accuracy",
min_delta=0.001,
patience=5,
)
model = DialogActsLightningModel(config=config)
trainer = pl.Trainer(
logger=logger,
gpus=[0],
checkpoint_callback=True,
callbacks=[early_stopping],
default_root_dir=MODELS_DIRECTORY,
max_epochs=config["epochs"],
precision=config["precision"],
limit_train_batches=10, # run for only 10 batches, debug mode
limit_test_batches=10,
limit_val_batches=10
)
trainer.fit(model)
我遇到了一个错误,但模型应该在验证步骤中计算并记录了指标“val_accuracy”。
Epoch 0: 100%
20/20 [00:28<00:00, 1.41s/it, loss=1.95]
/opt/conda/lib/python3.9/site-packages/pytorch_lightning/callbacks/early_stopping.py in _validate_condition_metric(self, logs)
149 if monitor_val is None:
150 if self.strict:
--> 151 raise RuntimeError(error_msg)
152 if self.verbose > 0:
153 rank_zero_warn(error_msg, RuntimeWarning)
RuntimeError: Early stopping conditioned on metric `val_accuracy` which is not available. Pass in or modify your `EarlyStopping` callback to use any of the following: ``
我做错了什么?如何解决?
如果您使用 pytorch-lightning 最新版本,您应该在调用提前停止或类似功能时记录 val_accuracy 或 val_loss。有关更多信息,请查看代码 below.i 认为这对您肯定有帮助...
def validation_step(self, batch, batch_idx):
input_ids, attention_mask, targets = batch['input_ids'], batch['attention_mask'], batch['label'].squeeze()
logits = self(batch)
loss = F.cross_entropy(logits, targets)
acc = accuracy_score(targets.cpu(), logits.argmax(dim=1).cpu())
f1 = f1_score(targets.cpu(), logits.argmax(dim=1).cpu(), average=self.config['average'])
precision = precision_score(targets.cpu(), logits.argmax(dim=1).cpu(), average=self.config['average'])
recall = recall_score(targets.cpu(), logits.argmax(dim=1).cpu(), average=self.config['average'])
##########################################################################
##########################################################################
self.log("val_accuracy", torch.tensor([acc]) # try this line
##########################################################################
##########################################################################
return {"val_loss": loss, "val_accuracy": torch.tensor([acc]), "val_f1": torch.tensor([f1]),
"val_precision": torch.tensor([precision]), "val_recall": torch.tensor([recall])}
如果Post有用请投票
我尝试在 PyTorch Lightning 中训练神经网络模型,但在执行 EarlyStopping 回调的验证步骤中训练失败。
模型的相关部分如下。请特别参阅 validation_step
,它必须记录提前停止所需的指标。
class DialogActsLightningModel(pl.LightningModule):
def __init__(self, config):
super().__init__()
self.config = config
self.model = ContextAwareDAC(
model_name=self.config['model_name'],
hidden_size=self.config['hidden_size'],
num_classes=self.config['num_classes'],
device=self.config['device']
)
self.tokenizer = AutoTokenizer.from_pretrained(config['model_name'])
def forward(self, batch):
logits = self.model(batch)
return logits
def validation_step(self, batch, batch_idx):
input_ids, attention_mask, targets = batch['input_ids'], batch['attention_mask'], batch['label'].squeeze()
logits = self(batch)
loss = F.cross_entropy(logits, targets)
acc = accuracy_score(targets.cpu(), logits.argmax(dim=1).cpu())
f1 = f1_score(targets.cpu(), logits.argmax(dim=1).cpu(), average=self.config['average'])
precision = precision_score(targets.cpu(), logits.argmax(dim=1).cpu(), average=self.config['average'])
recall = recall_score(targets.cpu(), logits.argmax(dim=1).cpu(), average=self.config['average'])
return {"val_loss": loss, "val_accuracy": torch.tensor([acc]), "val_f1": torch.tensor([f1]),
"val_precision": torch.tensor([precision]), "val_recall": torch.tensor([recall])}
def validation_epoch_end(self, outputs):
avg_loss = torch.stack([x['val_loss'] for x in outputs]).mean()
avg_acc = torch.stack([x['val_accuracy'] for x in outputs]).mean()
avg_f1 = torch.stack([x['val_f1'] for x in outputs]).mean()
avg_precision = torch.stack([x['val_precision'] for x in outputs]).mean()
avg_recall = torch.stack([x['val_recall'] for x in outputs]).mean()
wandb.log({"val_loss": avg_loss, "val_accuracy": avg_acc, "val_f1": avg_f1, "val_precision": avg_precision,
"val_recall": avg_recall})
return {"val_loss": avg_loss, "val_accuracy": avg_acc, "val_f1": avg_f1, "val_precision": avg_precision,
"val_recall": avg_recall}
当我运行按以下方式训练时:
from pytorch_lightning.callbacks import EarlyStopping, ModelCheckpoint
from pytorch_lightning.loggers import WandbLogger
import pytorch_lightning as pl
import os
from trainer import DialogActsLightningModel
import wandb
wandb.init()
logger = WandbLogger(
name="model_name",
entity='myname',
save_dir=config["save_dir"],
project=config["project"],
log_model=True,
)
early_stopping = EarlyStopping(
monitor="val_accuracy",
min_delta=0.001,
patience=5,
)
model = DialogActsLightningModel(config=config)
trainer = pl.Trainer(
logger=logger,
gpus=[0],
checkpoint_callback=True,
callbacks=[early_stopping],
default_root_dir=MODELS_DIRECTORY,
max_epochs=config["epochs"],
precision=config["precision"],
limit_train_batches=10, # run for only 10 batches, debug mode
limit_test_batches=10,
limit_val_batches=10
)
trainer.fit(model)
我遇到了一个错误,但模型应该在验证步骤中计算并记录了指标“val_accuracy”。
Epoch 0: 100%
20/20 [00:28<00:00, 1.41s/it, loss=1.95]
/opt/conda/lib/python3.9/site-packages/pytorch_lightning/callbacks/early_stopping.py in _validate_condition_metric(self, logs)
149 if monitor_val is None:
150 if self.strict:
--> 151 raise RuntimeError(error_msg)
152 if self.verbose > 0:
153 rank_zero_warn(error_msg, RuntimeWarning)
RuntimeError: Early stopping conditioned on metric `val_accuracy` which is not available. Pass in or modify your `EarlyStopping` callback to use any of the following: ``
我做错了什么?如何解决?
如果您使用 pytorch-lightning 最新版本,您应该在调用提前停止或类似功能时记录 val_accuracy 或 val_loss。有关更多信息,请查看代码 below.i 认为这对您肯定有帮助...
def validation_step(self, batch, batch_idx):
input_ids, attention_mask, targets = batch['input_ids'], batch['attention_mask'], batch['label'].squeeze()
logits = self(batch)
loss = F.cross_entropy(logits, targets)
acc = accuracy_score(targets.cpu(), logits.argmax(dim=1).cpu())
f1 = f1_score(targets.cpu(), logits.argmax(dim=1).cpu(), average=self.config['average'])
precision = precision_score(targets.cpu(), logits.argmax(dim=1).cpu(), average=self.config['average'])
recall = recall_score(targets.cpu(), logits.argmax(dim=1).cpu(), average=self.config['average'])
##########################################################################
##########################################################################
self.log("val_accuracy", torch.tensor([acc]) # try this line
##########################################################################
##########################################################################
return {"val_loss": loss, "val_accuracy": torch.tensor([acc]), "val_f1": torch.tensor([f1]),
"val_precision": torch.tensor([precision]), "val_recall": torch.tensor([recall])}
如果Post有用请投票