使用 optuna 对 huggingface 进行超参数搜索失败并出现 wandb 错误

Hyperparam search on huggingface with optuna fails with wandb error

我正在使用这个简单的脚本,使用示例博客 post。但是,由于 wandb,它失败了。 wandb 离线也没用。

from datasets import load_dataset, load_metric
from transformers import (AutoModelForSequenceClassification, AutoTokenizer,
                          Trainer, TrainingArguments)
import wandb


wandb.init()

tokenizer = AutoTokenizer.from_pretrained('distilbert-base-uncased')
dataset = load_dataset('glue', 'mrpc')
metric = load_metric('glue', 'mrpc')

def encode(examples):
    outputs = tokenizer(
        examples['sentence1'], examples['sentence2'], truncation=True)
    return outputs

encoded_dataset = dataset.map(encode, batched=True)

def model_init():
    return AutoModelForSequenceClassification.from_pretrained(
        'distilbert-base-uncased', return_dict=True)

def compute_metrics(eval_pred):
    predictions, labels = eval_pred
    predictions = predictions.argmax(axis=-1)
    return metric.compute(predictions=predictions, references=labels)

# Evaluate during training and a bit more often
# than the default to be able to prune bad trials early.
# Disabling tqdm is a matter of preference.
training_args = TrainingArguments(
    "test", eval_steps=500, disable_tqdm=True,
    evaluation_strategy='steps',)

trainer = Trainer(
    args=training_args,
    tokenizer=tokenizer,
    train_dataset=encoded_dataset["train"],
    eval_dataset=encoded_dataset["validation"],
    model_init=model_init,
    compute_metrics=compute_metrics,
)

def my_hp_space(trial):
    return {
        "learning_rate": trial.suggest_float("learning_rate", 1e-4, 1e-2, log=True),
        "weight_decay": trial.suggest_float("weight_decay", 0.1, 0.3),
        "num_train_epochs": trial.suggest_int("num_train_epochs", 5, 10),
        "seed": trial.suggest_int("seed", 20, 40),
        "per_device_train_batch_size": trial.suggest_categorical("per_device_train_batch_size", [32, 64]),
    }


trainer.hyperparameter_search(
    direction="maximize",
    backend="optuna",
    n_trials=10,
    hp_space=my_hp_space
)

Trail 0 成功完成,但下一个 Trail 1 崩溃并出现以下错误:

  File "/home/user123/anaconda3/envs/iza/lib/python3.8/site-packages/transformers/integrations.py", line 138, in _objective
    trainer.train(resume_from_checkpoint=checkpoint, trial=trial)
  File "/home/user123/anaconda3/envs/iza/lib/python3.8/site-packages/transformers/trainer.py", line 1376, in train
    self.log(metrics)
  File "/home/user123/anaconda3/envs/iza/lib/python3.8/site-packages/transformers/trainer.py", line 1688, in log
    self.control = self.callback_handler.on_log(self.args, self.state, self.control, logs)
  File "/home/user123/anaconda3/envs/iza/lib/python3.8/site-packages/transformers/trainer_callback.py", line 371, in on_log
    return self.call_event("on_log", args, state, control, logs=logs)
  File "/home/user123/anaconda3/envs/iza/lib/python3.8/site-packages/transformers/trainer_callback.py", line 378, in call_event
    result = getattr(callback, event)(
  File "/home/user123/anaconda3/envs/iza/lib/python3.8/site-packages/transformers/integrations.py", line 754, in on_log
    self._wandb.log({**logs, "train/global_step": state.global_step})
  File "/home/user123/anaconda3/envs/iza/lib/python3.8/site-packages/wandb/sdk/lib/preinit.py", line 38, in preinit_wrapper
    raise wandb.Error("You must call wandb.init() before {}()".format(name))
wandb.errors.Error: You must call wandb.init() before wandb.log()

非常感谢任何帮助。

请检查运行 最新版本的 wandb 和变形金刚的代码。 wandb 0.11.0transformers 4.9.0

对我来说效果很好