diff --git a/src/helm/benchmark/metrics/unitxt_metrics.py b/src/helm/benchmark/metrics/unitxt_metrics.py index 2fc684b9c24..95ea4325bca 100644 --- a/src/helm/benchmark/metrics/unitxt_metrics.py +++ b/src/helm/benchmark/metrics/unitxt_metrics.py @@ -18,7 +18,10 @@ class UnitxtMetric(MetricInterface): def __init__(self, **kwargs): super().__init__() - dataset_name = ",".join(f"{key}={value}" for key, value in kwargs.items()) + if len(kwargs) == 1 and "recipe" in kwargs: + dataset_name = kwargs["recipe"] + else: + dataset_name = ",".join(f"{key}={value}" for key, value in kwargs.items()) self.dataset = load_dataset("unitxt/data", dataset_name, trust_remote_code=True) def evaluate( diff --git a/src/helm/benchmark/run_specs/unitxt_run_specs.py b/src/helm/benchmark/run_specs/unitxt_run_specs.py index 3b2ee2e201a..0a4b51a9d59 100644 --- a/src/helm/benchmark/run_specs/unitxt_run_specs.py +++ b/src/helm/benchmark/run_specs/unitxt_run_specs.py @@ -10,8 +10,9 @@ @run_spec_function("unitxt") def get_unitxt_spec(**kwargs) -> RunSpec: card = kwargs.get("card") - if not card: - raise Exception("Unitxt card must be specified") + recipe = kwargs.get("recipe") + if not card and not recipe: + raise Exception("Unitxt card or recipe must be specified") if os.environ.get("HELM_UNITXT_SHORTEN_RUN_SPEC_NAMES", "").lower() == "true": name_suffix = ",".join( [f"{key}={value}" for key, value in kwargs.items() if key not in ["template_card_index", "loader_limit"]] @@ -46,5 +47,5 @@ def get_unitxt_spec(**kwargs) -> RunSpec: MetricSpec(class_name="helm.benchmark.metrics.unitxt_metrics.UnitxtMetric", args=kwargs), ] + get_basic_metric_specs([]), - groups=[f"unitxt_{card}"], + groups=[f"unitxt_{card or recipe}"], ) diff --git a/src/helm/benchmark/scenarios/unitxt_scenario.py b/src/helm/benchmark/scenarios/unitxt_scenario.py index 95e0d125464..e321c066a47 100644 --- a/src/helm/benchmark/scenarios/unitxt_scenario.py +++ b/src/helm/benchmark/scenarios/unitxt_scenario.py @@ -32,13 +32,19 @@ def __init__(self, **kwargs): self.kwargs = kwargs def get_instances(self, output_path: str) -> List[Instance]: - dataset_name = ",".join(f"{key}={value}" for key, value in self.kwargs.items()) + if len(self.kwargs) == 1 and "recipe" in self.kwargs: + dataset_name = self.kwargs["recipe"] + else: + dataset_name = ",".join(f"{key}={value}" for key, value in self.kwargs.items()) dataset = load_dataset("unitxt/data", dataset_name, trust_remote_code=True) instances: List[Instance] = [] for unitxt_split_name, helm_split_name in UnitxtScenario.UNITXT_SPLIT_NAME_TO_HELM_SPLIT_NAME.items(): - for index, row in enumerate(dataset[unitxt_split_name]): + dataset_split = dataset.get(unitxt_split_name) + if dataset_split is None: + continue + for index, row in enumerate(dataset_split): references = [ Reference( output=Output(text=reference_text),