diff --git a/opencompass/models/base.py b/opencompass/models/base.py index ec29a8f33..cc53e11e1 100644 --- a/opencompass/models/base.py +++ b/opencompass/models/base.py @@ -1,4 +1,4 @@ -from abc import abstractclassmethod +from abc import abstractmethod from copy import deepcopy from typing import Dict, List, Optional, Tuple, Union @@ -37,7 +37,7 @@ def __init__(self, if meta_template and 'eos_token_id' in meta_template: self.eos_token_id = meta_template['eos_token_id'] - @abstractclassmethod + @abstractmethod def generate(self, inputs: List[str], max_out_len: int) -> List[str]: """Generate results given a list of inputs. @@ -48,8 +48,11 @@ def generate(self, inputs: List[str], max_out_len: int) -> List[str]: Returns: List[str]: A list of generated strings. """ + raise NotImplementedError(f'{self.__class__.__name__} does not support' + ' gen-based evaluation yet, try ppl-based ' + 'instead.') - @abstractclassmethod + @abstractmethod def get_ppl(self, inputs: List[str], mask_length: Optional[List[int]] = None) -> List[float]: @@ -66,8 +69,11 @@ def get_ppl(self, Returns: List[float]: A list of perplexity scores. """ + raise NotImplementedError(f'{self.__class__.__name__} does not support' + ' ppl-based evaluation yet, try gen-based ' + 'instead.') - @abstractclassmethod + @abstractmethod def get_token_len(self, prompt: str) -> int: """Get lengths of the tokenized strings. @@ -192,7 +198,7 @@ def parse_template(self, prompt_template: PromptType, mode: str) -> str: Returns: str: The final string. """ - assert isinstance(prompt_template, (str, list, PromptList)) + assert isinstance(prompt_template, (str, list, PromptList, tuple)) if not isinstance(prompt_template, (str, PromptList)): return [self.parse_template(p, mode=mode) for p in prompt_template] diff --git a/opencompass/models/base_api.py b/opencompass/models/base_api.py index 8cd750b47..7c8f0b314 100644 --- a/opencompass/models/base_api.py +++ b/opencompass/models/base_api.py @@ -1,7 +1,7 @@ import re import threading import warnings -from abc import abstractclassmethod +from abc import abstractmethod from copy import deepcopy from time import sleep from typing import Dict, List, Optional, Tuple, Union @@ -46,7 +46,7 @@ def __init__(self, self.template_parser = APITemplateParser(meta_template) self.logger = get_logger() - @abstractclassmethod + @abstractmethod def generate(self, inputs: List[PromptType], max_out_len: int) -> List[str]: """Generate results given a list of inputs. @@ -60,8 +60,11 @@ def generate(self, inputs: List[PromptType], Returns: List[str]: A list of generated strings. """ + raise NotImplementedError(f'{self.__class__.__name__} does not support' + ' gen-based evaluation yet, try ppl-based ' + 'instead.') - @abstractclassmethod + @abstractmethod def get_ppl(self, inputs: List[PromptType], mask_length: Optional[List[int]] = None) -> List[float]: @@ -78,6 +81,9 @@ def get_ppl(self, Returns: List[float]: A list of perplexity scores. """ + raise NotImplementedError(f'{self.__class__.__name__} does not support' + ' ppl-based evaluation yet, try gen-based ' + 'instead.') def get_token_len(self, prompt: str) -> int: """Get lengths of the tokenized string. Only English and Chinese @@ -161,7 +167,7 @@ def parse_template(self, prompt_template: PromptType, Returns: List[str or PromptList]: The finalized prompt or a conversation. """ - assert isinstance(prompt_template, (str, list, PromptList)) + assert isinstance(prompt_template, (str, list, PromptList, tuple)) if not isinstance(prompt_template, (str, PromptList)): return [self.parse_template(p, mode=mode) for p in prompt_template] diff --git a/opencompass/openicl/icl_inferencer/icl_attack_inferencer.py b/opencompass/openicl/icl_inferencer/icl_attack_inferencer.py index 07f933552..7a51c96b7 100644 --- a/opencompass/openicl/icl_inferencer/icl_attack_inferencer.py +++ b/opencompass/openicl/icl_inferencer/icl_attack_inferencer.py @@ -108,6 +108,12 @@ def predict(self, adv_prompt) -> List: ice_template=self.ice_template, prompt_template=self.prompt_template) + # 3.1 Fetch and zip prompt & gold answer if output column exists + ds_reader = self.retriever.dataset_reader + if ds_reader.output_column: + gold_ans = ds_reader.dataset['test'][ds_reader.output_column] + prompt_list = list(zip(prompt_list, gold_ans)) + # Create tmp json file for saving intermediate results and future # resuming index = 0 @@ -124,7 +130,12 @@ def predict(self, adv_prompt) -> List: # 5. Inference for prompts in each batch logger.info('Starting inference process...') - for entry in tqdm(dataloader, disable=not self.is_main_process): + for datum in tqdm(dataloader, disable=not self.is_main_process): + if ds_reader.output_column: + entry, golds = list(zip(*datum)) + else: + entry = datum + golds = [None for _ in range(len(entry))] # 5-1. Inference with local model with torch.no_grad(): parsed_entries = self.model.parse_template(entry, mode='gen') @@ -133,8 +144,12 @@ def predict(self, adv_prompt) -> List: generated = results # 5-3. Save current output - for prompt, prediction in zip(parsed_entries, generated): - output_handler.save_results(prompt, prediction, index) + for prompt, prediction, gold in zip(parsed_entries, generated, + golds): + output_handler.save_results(prompt, + prediction, + index, + gold=gold) index = index + 1 # 5-4. Save intermediate results diff --git a/opencompass/openicl/icl_inferencer/icl_base_inferencer.py b/opencompass/openicl/icl_inferencer/icl_base_inferencer.py index 25ba94015..7dc7482af 100644 --- a/opencompass/openicl/icl_inferencer/icl_base_inferencer.py +++ b/opencompass/openicl/icl_inferencer/icl_base_inferencer.py @@ -108,11 +108,13 @@ def write_to_json(self, save_dir: str, filename: str): """Dump the result to a json file.""" dump_results_dict(self.results_dict, Path(save_dir) / filename) - def save_results(self, origin_prompt, prediction, idx): + def save_results(self, origin_prompt, prediction, idx, gold=None): self.results_dict[str(idx)] = { 'origin_prompt': origin_prompt, 'prediction': prediction, } + if gold: + self.results_dict[str(idx)]['gold'] = gold class PPLInferencerOutputHandler: @@ -147,6 +149,12 @@ def save_prompt_and_ppl(self, label, input, prompt, ppl, idx): self.results_dict[str(idx)]['label: ' + str(label)]['prompt'] = prompt self.results_dict[str(idx)]['label: ' + str(label)]['PPL'] = ppl + def save_golds(self, golds): + for idx, gold in enumerate(golds): + if str(idx) not in self.results_dict.keys(): + self.results_dict[str(idx)] = {} + self.results_dict[str(idx)]['gold'] = gold + class CLPInferencerOutputHandler: results_dict = {} @@ -164,7 +172,13 @@ def save_ice(self, ice): self.results_dict[str(idx)] = {} self.results_dict[str(idx)]['in-context examples'] = example - def save_prompt_and_condprob(self, input, prompt, cond_prob, idx, choices): + def save_prompt_and_condprob(self, + input, + prompt, + cond_prob, + idx, + choices, + gold=None): if str(idx) not in self.results_dict.keys(): self.results_dict[str(idx)] = {} # TODO: @@ -177,3 +191,4 @@ def save_prompt_and_condprob(self, input, prompt, cond_prob, idx, choices): self.results_dict[str(idx)]['prediction'] = cond_prob # set pred label in case needed self.results_dict[str(idx)]['pred_label'] = int(np.argmax(cond_prob)) + self.results_dict[str(idx)]['gold'] = gold diff --git a/opencompass/openicl/icl_inferencer/icl_clp_inferencer.py b/opencompass/openicl/icl_inferencer/icl_clp_inferencer.py index c264fb56a..727506484 100644 --- a/opencompass/openicl/icl_inferencer/icl_clp_inferencer.py +++ b/opencompass/openicl/icl_inferencer/icl_clp_inferencer.py @@ -175,22 +175,35 @@ def inference(self, # minus the bos token choice_target_ids.append(prompt_token_num - 1) + # 4.1 Fetch and zip prompt & gold answer if output column exists + ds_reader = retriever.dataset_reader + if ds_reader.output_column: + gold_ans = ds_reader.dataset['test'][ds_reader.output_column] + else: + gold_ans = [None] * len(prompt_list) + logger.info('Calculating conditional log probability for prompts.') for idx in trange(0, len(prompt_list), self.batch_size, disable=not self.is_main_process): sub_prompt_list = prompt_list[idx:idx + self.batch_size] + sub_golds = gold_ans[idx:idx + self.batch_size] sub_choice_target_ids = choice_target_ids[idx:idx + self.batch_size] sub_res = self.__get_cond_prob(sub_prompt_list, sub_choice_target_ids, choice_ids) - for res, prompt in zip(sub_res, sub_prompt_list): - output_handler.save_prompt_and_condprob( - prompt.replace(ice[idx], ''), prompt, res, index, - choices) + for res, prompt, gold in zip(sub_res, sub_prompt_list, + sub_golds): + example_input = prompt.replace(ice[idx], '') + output_handler.save_prompt_and_condprob(example_input, + prompt, + res, + index, + choices, + gold=gold) index = index + 1 # 5. Output diff --git a/opencompass/openicl/icl_inferencer/icl_gen_inferencer.py b/opencompass/openicl/icl_inferencer/icl_gen_inferencer.py index d9aa64a5b..a319b9c95 100644 --- a/opencompass/openicl/icl_inferencer/icl_gen_inferencer.py +++ b/opencompass/openicl/icl_inferencer/icl_gen_inferencer.py @@ -99,6 +99,12 @@ def inference(self, ice_template=ice_template, prompt_template=prompt_template) + # 3.1 Fetch and zip prompt & gold answer if output column exists + ds_reader = retriever.dataset_reader + if ds_reader.output_column: + gold_ans = ds_reader.dataset['test'][ds_reader.output_column] + prompt_list = list(zip(prompt_list, gold_ans)) + # Create tmp json file for saving intermediate results and future # resuming index = 0 @@ -115,7 +121,12 @@ def inference(self, # 5. Inference for prompts in each batch logger.info('Starting inference process...') - for entry in tqdm(dataloader, disable=not self.is_main_process): + for datum in tqdm(dataloader, disable=not self.is_main_process): + if ds_reader.output_column: + entry, golds = list(zip(*datum)) + else: + entry = datum + golds = [None for _ in range(len(entry))] # 5-1. Inference with local model with torch.no_grad(): parsed_entries = self.model.parse_template(entry, mode='gen') @@ -124,8 +135,12 @@ def inference(self, generated = results # 5-3. Save current output - for prompt, prediction in zip(parsed_entries, generated): - output_handler.save_results(prompt, prediction, index) + for prompt, prediction, gold in zip(parsed_entries, generated, + golds): + output_handler.save_results(prompt, + prediction, + index, + gold=gold) index = index + 1 # 5-4. Save intermediate results diff --git a/opencompass/openicl/icl_inferencer/icl_ppl_inferencer.py b/opencompass/openicl/icl_inferencer/icl_ppl_inferencer.py index 8e4734c36..0d8bad9c8 100644 --- a/opencompass/openicl/icl_inferencer/icl_ppl_inferencer.py +++ b/opencompass/openicl/icl_inferencer/icl_ppl_inferencer.py @@ -200,7 +200,13 @@ def inference(self, sub_predictions.append(labels[single_ppl.index(min(single_ppl))]) output_handler.save_predictions(sub_predictions) - # 7. Output + # 7. Fetch gold answers if exist + ds_reader = retriever.dataset_reader + if ds_reader.output_column: + golds = ds_reader.dataset['test'][ds_reader.output_column] + output_handler.save_golds(golds) + + # 8. Output if self.is_main_process: os.makedirs(output_json_filepath, exist_ok=True) output_handler.write_to_json(output_json_filepath, diff --git a/opencompass/openicl/icl_inferencer/icl_sc_inferencer.py b/opencompass/openicl/icl_inferencer/icl_sc_inferencer.py index 5b802d51e..b7e9cffe3 100644 --- a/opencompass/openicl/icl_inferencer/icl_sc_inferencer.py +++ b/opencompass/openicl/icl_inferencer/icl_sc_inferencer.py @@ -105,6 +105,12 @@ def inference(self, ice_template=ice_template, prompt_template=prompt_template) + # 3.1 Fetch and zip prompt & gold answer if output column exists + ds_reader = retriever.dataset_reader + if ds_reader.output_column: + gold_ans = ds_reader.dataset['test'][ds_reader.output_column] + prompt_list = list(zip(prompt_list, gold_ans)) + # Create tmp json file for saving intermediate results and future # resuming index = 0 @@ -121,7 +127,12 @@ def inference(self, # 5. Inference for prompts in each batch logger.info('Starting inference process...') - for entry in tqdm(dataloader, disable=not self.is_main_process): + for datum in tqdm(dataloader, disable=not self.is_main_process): + if ds_reader.output_column: + entry, golds = list(zip(*datum)) + else: + entry = datum + golds = [None for _ in range(len(entry))] # TODO: add more types of CoT method # 5-1. Inference sc_size times with local model with torch.no_grad(): @@ -137,8 +148,12 @@ def inference(self, generated = sc_prediction # 5-3. Save current output - for prompt, prediction in zip(parsed_entries, generated): - output_handler.save_results(prompt, prediction, index) + for prompt, prediction, gold in zip(parsed_entries, generated, + golds): + output_handler.save_results(prompt, + prediction, + index, + gold=gold) index = index + 1 # 5-4. Save intermediate results diff --git a/opencompass/openicl/icl_inferencer/icl_tot_inferencer.py b/opencompass/openicl/icl_inferencer/icl_tot_inferencer.py index ba6eac383..5fd174835 100644 --- a/opencompass/openicl/icl_inferencer/icl_tot_inferencer.py +++ b/opencompass/openicl/icl_inferencer/icl_tot_inferencer.py @@ -333,6 +333,12 @@ def inference(self, ice_template=ice_template, prompt_template=prompt_template) + # 3.1 Fetch and zip prompt & gold answer if output column exists + ds_reader = retriever.dataset_reader + if ds_reader.output_column: + gold_ans = ds_reader.dataset['test'][ds_reader.output_column] + prompt_list = list(zip(prompt_list, gold_ans)) + # Create tmp json file for saving intermediate results and future # resuming index = 0 @@ -349,15 +355,24 @@ def inference(self, # 5. Inference for prompts in each batch logger.info('Starting ToT inference process...') - for entries in tqdm(dataloader, disable=not self.is_main_process): + for datum in tqdm(dataloader, disable=not self.is_main_process): + if ds_reader.output_column: + entries, golds = list(zip(*datum)) + else: + entries = datum + golds = [None for _ in range(len(entries))] # 5-1. Inference with ToT and local model with torch.no_grad(): parsed_entries = self.model.parse_template(entries, mode='gen') generated = [self.tot_solve(entry) for entry in entries] # 5-2. Save current output - for prompt, prediction in zip(parsed_entries, generated): - output_handler.save_results(prompt, prediction, index) + for prompt, prediction, gold in zip(parsed_entries, generated, + golds): + output_handler.save_results(prompt, + prediction, + index, + gold=gold) index = index + 1 # 5-3. Save intermediate results