diff --git a/examples/run_trainier_seq2seq_huggingface.py b/examples/run_trainier_seq2seq_huggingface.py index e800076..6ebfda8 100644 --- a/examples/run_trainier_seq2seq_huggingface.py +++ b/examples/run_trainier_seq2seq_huggingface.py @@ -30,7 +30,8 @@ import nltk # Here to have a nice missing dependency error message early on import numpy as np -from datasets import load_dataset, load_metric +from datasets import load_dataset +import evaluate from sklearn.metrics import accuracy_score, f1_score, recall_score, precision_score, classification_report, confusion_matrix import transformers @@ -562,7 +563,7 @@ def preprocess_function(examples): metric_name = regex.sub("eval_","",training_args.metric_for_best_model) # metric_name = "f1" #"rouge" if data_args.task.startswith("summarization") else "sacrebleu" print ("[INFO] evlaute using ", metric_name, "score", "task name:", data_args.task) - # metric = load_metric("f1", cache_dir=cache_dir) + # metric = evaluate.Metric("f1", cache_dir=cache_dir) def postprocess_text(preds, labels): preds = [pred.strip() for pred in preds]