diff --git a/src/rank_llm/rerank/listwise/listwise_rankllm.py b/src/rank_llm/rerank/listwise/listwise_rankllm.py index 2541613..ae09232 100644 --- a/src/rank_llm/rerank/listwise/listwise_rankllm.py +++ b/src/rank_llm/rerank/listwise/listwise_rankllm.py @@ -88,7 +88,7 @@ def rerank_batch( batch_size = 1 reorder_policy = self.reorder_policy - model_functions, consumption = self._get_model_function(batched) + model_functions, consumption = self._get_model_function(batched, **kwargs) # reranking using vllm if len(set([len(req.candidates) for req in requests])) != 1: @@ -549,7 +549,7 @@ def _permutation_to_rank(self, perm_string: str, selected_indices: List[int]): return perm def _get_model_function( - self, batched: bool = False, **kwargs + self, batched: bool = False, silence: bool = False, **kwargs ) -> Tuple[ModelFunction, RerankConsumption]: # [(Request, SelectIndex)] -> [Prompt] @@ -577,7 +577,8 @@ def execute( return [ self._permutation_to_rank(s, selected_indices) for (s, _), selected_indices in zip( - self.run_llm_batched(batch, **kwargs), selected_indices_batch + self.run_llm_batched(batch, silence=silence, **kwargs), + selected_indices_batch, ) ] @@ -598,7 +599,7 @@ def execute( return [ self._permutation_to_rank( - self.run_llm(x, **kwargs)[0], selected_indices + self.run_llm(x, silence=silence, **kwargs)[0], selected_indices ) for x, selected_indices in zip(batch, selected_indices_batch) ] diff --git a/src/rank_llm/rerank/listwise/rank_listwise_os_llm.py b/src/rank_llm/rerank/listwise/rank_listwise_os_llm.py index 3e738fb..063e0ee 100644 --- a/src/rank_llm/rerank/listwise/rank_listwise_os_llm.py +++ b/src/rank_llm/rerank/listwise/rank_listwise_os_llm.py @@ -103,7 +103,9 @@ def __init__( ) elif vllm_batched: self._llm = LLM( - model, download_dir=os.getenv("HF_HOME"), enforce_eager=False + model, + download_dir=os.getenv("HF_HOME"), + enforce_eager=False, ) self._tokenizer = self._llm.get_tokenizer() else: @@ -133,26 +135,31 @@ def rerank_batch( def run_llm_batched( self, prompts: List[str | List[Dict[str, str]]], + silence: bool = False, current_window_size: Optional[int] = None, + **kwargs, ) -> List[Tuple[str, int]]: if SamplingParams is None: raise ImportError( "Please install rank-llm with `pip install rank-llm[vllm]` to use batch inference." ) - logger.info(f"VLLM Generating!") sampling_params = SamplingParams( temperature=0.0, max_tokens=self.num_output_tokens(current_window_size), min_tokens=self.num_output_tokens(current_window_size), ) - outputs = self._llm.generate(prompts, sampling_params) + outputs = self._llm.generate(prompts, sampling_params, use_tqdm=not silence) return [ (output.outputs[0].text, len(output.outputs[0].token_ids)) for output in outputs ] def run_llm( - self, prompt: str, current_window_size: Optional[int] = None + self, + prompt: str, + silence: bool = False, + current_window_size: Optional[int] = None, + **kwargs, ) -> Tuple[str, int]: if current_window_size is None: current_window_size = self._window_size @@ -163,7 +170,9 @@ def run_llm( gen_cfg.min_new_tokens = self.num_output_tokens(current_window_size) # gen_cfg.temperature = 0 gen_cfg.do_sample = False - output_ids = self._llm.generate(**inputs, generation_config=gen_cfg) + output_ids = self._llm.generate( + **inputs, use_tqdm=not silence, generation_config=gen_cfg + ) if self._llm.config.is_encoder_decoder: output_ids = output_ids[0] diff --git a/src/rank_llm/scripts/run_rank_llm.py b/src/rank_llm/scripts/run_rank_llm.py index 0320964..d570bf5 100644 --- a/src/rank_llm/scripts/run_rank_llm.py +++ b/src/rank_llm/scripts/run_rank_llm.py @@ -39,6 +39,7 @@ def main(args): vllm_batched = args.vllm_batched batch_size = args.batch_size reorder_policy = args.reorder_policy + silence = args.silence _ = retrieve_and_rerank( model_path=model_path, @@ -63,6 +64,7 @@ def main(args): system_message=system_message, vllm_batched=vllm_batched, reorder_policy=reorder_policy, + silence=silence, ) @@ -182,5 +184,11 @@ def main(args): help="policy in reordering. defaultly to be sliding window", type=str, ) + parser.add_argument( + "--silence", + default=False, + action="store_true", + help="Whether or not omitting some unbeautiful tqdm bars that is unavoidable (not able to set leave=False)", + ) args = parser.parse_args() main(args)