使用optuna在huggingface上进行Hyperparam搜索失败,并出现wandb错误



我使用的是这个简单的脚本,使用的是示例博客文章。但是,由于wandb,它失败了。使wandb脱机也没有用。

from datasets import load_dataset, load_metric
from transformers import (AutoModelForSequenceClassification, AutoTokenizer,
Trainer, TrainingArguments)
import wandb

wandb.init()
tokenizer = AutoTokenizer.from_pretrained('distilbert-base-uncased')
dataset = load_dataset('glue', 'mrpc')
metric = load_metric('glue', 'mrpc')
def encode(examples):
outputs = tokenizer(
examples['sentence1'], examples['sentence2'], truncation=True)
return outputs
encoded_dataset = dataset.map(encode, batched=True)
def model_init():
return AutoModelForSequenceClassification.from_pretrained(
'distilbert-base-uncased', return_dict=True)
def compute_metrics(eval_pred):
predictions, labels = eval_pred
predictions = predictions.argmax(axis=-1)
return metric.compute(predictions=predictions, references=labels)
# Evaluate during training and a bit more often
# than the default to be able to prune bad trials early.
# Disabling tqdm is a matter of preference.
training_args = TrainingArguments(
"test", eval_steps=500, disable_tqdm=True,
evaluation_strategy='steps',)
trainer = Trainer(
args=training_args,
tokenizer=tokenizer,
train_dataset=encoded_dataset["train"],
eval_dataset=encoded_dataset["validation"],
model_init=model_init,
compute_metrics=compute_metrics,
)
def my_hp_space(trial):
return {
"learning_rate": trial.suggest_float("learning_rate", 1e-4, 1e-2, log=True),
"weight_decay": trial.suggest_float("weight_decay", 0.1, 0.3),
"num_train_epochs": trial.suggest_int("num_train_epochs", 5, 10),
"seed": trial.suggest_int("seed", 20, 40),
"per_device_train_batch_size": trial.suggest_categorical("per_device_train_batch_size", [32, 64]),
}

trainer.hyperparameter_search(
direction="maximize",
backend="optuna",
n_trials=10,
hp_space=my_hp_space
)

Trail 0成功完成,但下一个Trail 1崩溃,并出现以下错误:

File "/home/user123/anaconda3/envs/iza/lib/python3.8/site-packages/transformers/integrations.py", line 138, in _objective
trainer.train(resume_from_checkpoint=checkpoint, trial=trial)
File "/home/user123/anaconda3/envs/iza/lib/python3.8/site-packages/transformers/trainer.py", line 1376, in train
self.log(metrics)
File "/home/user123/anaconda3/envs/iza/lib/python3.8/site-packages/transformers/trainer.py", line 1688, in log
self.control = self.callback_handler.on_log(self.args, self.state, self.control, logs)
File "/home/user123/anaconda3/envs/iza/lib/python3.8/site-packages/transformers/trainer_callback.py", line 371, in on_log
return self.call_event("on_log", args, state, control, logs=logs)
File "/home/user123/anaconda3/envs/iza/lib/python3.8/site-packages/transformers/trainer_callback.py", line 378, in call_event
result = getattr(callback, event)(
File "/home/user123/anaconda3/envs/iza/lib/python3.8/site-packages/transformers/integrations.py", line 754, in on_log
self._wandb.log({**logs, "train/global_step": state.global_step})
File "/home/user123/anaconda3/envs/iza/lib/python3.8/site-packages/wandb/sdk/lib/preinit.py", line 38, in preinit_wrapper
raise wandb.Error("You must call wandb.init() before {}()".format(name))
wandb.errors.Error: You must call wandb.init() before wandb.log()

非常感谢您的帮助。

请检查在最新版本的wandb和transformers上运行代码。适用于wandb 0.11.0transformers 4.9.0

相关内容

  • 没有找到相关文章

最新更新