Skip to content

Commit

Permalink
Allow running recipes from the Unitxt catalog (#3267)
Browse files Browse the repository at this point in the history
  • Loading branch information
yifanmai authored Jan 14, 2025
1 parent 5f5c17e commit 1bc7bd0
Show file tree
Hide file tree
Showing 3 changed files with 16 additions and 6 deletions.
5 changes: 4 additions & 1 deletion src/helm/benchmark/metrics/unitxt_metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,10 @@ class UnitxtMetric(MetricInterface):

def __init__(self, **kwargs):
super().__init__()
dataset_name = ",".join(f"{key}={value}" for key, value in kwargs.items())
if len(kwargs) == 1 and "recipe" in kwargs:
dataset_name = kwargs["recipe"]
else:
dataset_name = ",".join(f"{key}={value}" for key, value in kwargs.items())
self.dataset = load_dataset("unitxt/data", dataset_name, trust_remote_code=True)

def evaluate(
Expand Down
7 changes: 4 additions & 3 deletions src/helm/benchmark/run_specs/unitxt_run_specs.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,9 @@
@run_spec_function("unitxt")
def get_unitxt_spec(**kwargs) -> RunSpec:
card = kwargs.get("card")
if not card:
raise Exception("Unitxt card must be specified")
recipe = kwargs.get("recipe")
if not card and not recipe:
raise Exception("Unitxt card or recipe must be specified")
if os.environ.get("HELM_UNITXT_SHORTEN_RUN_SPEC_NAMES", "").lower() == "true":
name_suffix = ",".join(
[f"{key}={value}" for key, value in kwargs.items() if key not in ["template_card_index", "loader_limit"]]
Expand Down Expand Up @@ -46,5 +47,5 @@ def get_unitxt_spec(**kwargs) -> RunSpec:
MetricSpec(class_name="helm.benchmark.metrics.unitxt_metrics.UnitxtMetric", args=kwargs),
]
+ get_basic_metric_specs([]),
groups=[f"unitxt_{card}"],
groups=[f"unitxt_{card or recipe}"],
)
10 changes: 8 additions & 2 deletions src/helm/benchmark/scenarios/unitxt_scenario.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,13 +32,19 @@ def __init__(self, **kwargs):
self.kwargs = kwargs

def get_instances(self, output_path: str) -> List[Instance]:
dataset_name = ",".join(f"{key}={value}" for key, value in self.kwargs.items())
if len(self.kwargs) == 1 and "recipe" in self.kwargs:
dataset_name = self.kwargs["recipe"]
else:
dataset_name = ",".join(f"{key}={value}" for key, value in self.kwargs.items())
dataset = load_dataset("unitxt/data", dataset_name, trust_remote_code=True)

instances: List[Instance] = []

for unitxt_split_name, helm_split_name in UnitxtScenario.UNITXT_SPLIT_NAME_TO_HELM_SPLIT_NAME.items():
for index, row in enumerate(dataset[unitxt_split_name]):
dataset_split = dataset.get(unitxt_split_name)
if dataset_split is None:
continue
for index, row in enumerate(dataset_split):
references = [
Reference(
output=Output(text=reference_text),
Expand Down

0 comments on commit 1bc7bd0

Please sign in to comment.