Skip to content

Commit

Permalink
fix label_id issue
Browse files Browse the repository at this point in the history
  • Loading branch information
davidheineman committed Jan 8, 2025
1 parent e485097 commit 23e26cb
Show file tree
Hide file tree
Showing 3 changed files with 6 additions and 3 deletions.
7 changes: 5 additions & 2 deletions olmo/eval/downstream.py
Original file line number Diff line number Diff line change
Expand Up @@ -361,9 +361,11 @@ def collate_fn(self, data):
"cont_byte_len": torch.LongTensor(cont_byte_lens),
"input_ids": torch.stack(queries),
"dc_input_ids": torch.stack(dc_queries),
"label_id": torch.LongTensor(label_ids),
}

if not isinstance(label_ids, str):
batch["label_id"] = torch.LongTensor(label_ids)

return batch

def token_encode(self, string: str) -> List[int]:
Expand Down Expand Up @@ -1538,7 +1540,7 @@ def prep_examples(self):
label_id = request["label"]
cont_id = request["idx"]
if self.metric_type in ["ce_loss", "bpb"]:
if label_id != cont_id:
if label_id != cont_id and not isinstance(label_id, str):
# Skip non-target continuations for ce_loss and bpb
continue
else:
Expand Down Expand Up @@ -1758,6 +1760,7 @@ def doc_to_label(self, doc) -> int:
"csqa_rc_0shot_bpb": (OEEvalTask, {"dataset_path": "csqa", "dataset_name": "rc_0shot", "metric_type": "bpb"}),
"csqa_rc_5shot": (OEEvalTask, {"dataset_path": "csqa", "dataset_name": "rc_5shot", "metric_type": "len_norm"}),
"csqa_rc_5shot_bpb": (OEEvalTask, {"dataset_path": "csqa", "dataset_name": "rc_5shot", "metric_type": "bpb"}),
"gsm8k_gold_bpb_5shot": (OEEvalTask, {"dataset_path": "gsm8k", "dataset_name": "gold_bpb_5shot", "metric_type": "bpb"}),
"hellaswag_mc_5shot": (
OEEvalTask,
{"dataset_path": "hellaswag", "dataset_name": "mc_5shot", "metric_type": "acc"},
Expand Down
2 changes: 1 addition & 1 deletion olmo_data/oe_eval_tasks/gsm8k/gold_bpb_5shot/config.json
Original file line number Diff line number Diff line change
@@ -1 +1 @@
{"task_name": "gsm8k", "task_hash": "c9a8b5bfa866f678c3ea4ef06729f149", "task_config": {"task_name": "gsm8k", "task_core": "gsm8k", "limit": null, "split": "test", "num_shots": 8, "fewshot_seed": 1234, "primary_metric": "logits_per_byte", "random_subsample_seed": 1234, "context_kwargs": {"no_cot": false}, "generation_kwargs": {"max_gen_toks": 512, "do_sample": false, "temperature": 0.0, "stop_sequences": ["Question:", "</s>", "<|im_end|>", "\n\n"], "repeats": 1}, "metric_kwargs": {"regexes_to_ignore": [",", "\\$", "(?s).*#### ", "\\.$"]}, "native_id_field": "id", "fewshot_source": "STD:GSM8k", "dataset_path": "gsm8k", "dataset_name": "main", "use_chat_format": null, "version": 0.1, "revision": null, "compute_gold_bpb": true, "metadata": {"alias": "gsm8k::bpb"}}, "current_date": "2025-01-08 21:03:44 UTC", "num_instances": 1319}
{"task_name": "gsm8k", "task_hash": "c9a8b5bfa866f678c3ea4ef06729f149", "task_config": {"task_name": "gsm8k", "task_core": "gsm8k", "limit": null, "split": "test", "num_shots": 8, "fewshot_seed": 1234, "primary_metric": "logits_per_byte", "random_subsample_seed": 1234, "context_kwargs": {"no_cot": false}, "generation_kwargs": {"max_gen_toks": 512, "do_sample": false, "temperature": 0.0, "stop_sequences": ["Question:", "</s>", "<|im_end|>", "\n\n"], "repeats": 1}, "metric_kwargs": {"regexes_to_ignore": [",", "\\$", "(?s).*#### ", "\\.$"]}, "native_id_field": "id", "fewshot_source": "STD:GSM8k", "dataset_path": "gsm8k", "dataset_name": "main", "use_chat_format": null, "version": 0.1, "revision": null, "compute_gold_bpb": true, "metadata": {"alias": "gsm8k::bpb"}}, "current_date": "2025-01-08 21:30:11 UTC", "num_instances": 1319}
Binary file modified olmo_data/oe_eval_tasks/gsm8k/gold_bpb_5shot/requests.jsonl.gz
Binary file not shown.

0 comments on commit 23e26cb

Please sign in to comment.