Skip to content

Commit

Permalink
Updated silence
Browse files Browse the repository at this point in the history
  • Loading branch information
XKTZ committed Sep 16, 2024
1 parent cc22c95 commit 26beeed
Show file tree
Hide file tree
Showing 3 changed files with 27 additions and 9 deletions.
9 changes: 5 additions & 4 deletions src/rank_llm/rerank/listwise/listwise_rankllm.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,7 @@ def rerank_batch(
batch_size = 1

reorder_policy = self.reorder_policy
model_functions, consumption = self._get_model_function(batched)
model_functions, consumption = self._get_model_function(batched, **kwargs)

# reranking using vllm
if len(set([len(req.candidates) for req in requests])) != 1:
Expand Down Expand Up @@ -549,7 +549,7 @@ def _permutation_to_rank(self, perm_string: str, selected_indices: List[int]):
return perm

def _get_model_function(
self, batched: bool = False, **kwargs
self, batched: bool = False, silence: bool = False, **kwargs
) -> Tuple[ModelFunction, RerankConsumption]:
# [(Request, SelectIndex)] -> [Prompt]

Expand Down Expand Up @@ -577,7 +577,8 @@ def execute(
return [
self._permutation_to_rank(s, selected_indices)
for (s, _), selected_indices in zip(
self.run_llm_batched(batch, **kwargs), selected_indices_batch
self.run_llm_batched(batch, silence=silence, **kwargs),
selected_indices_batch,
)
]

Expand All @@ -598,7 +599,7 @@ def execute(

return [
self._permutation_to_rank(
self.run_llm(x, **kwargs)[0], selected_indices
self.run_llm(x, silence=silence, **kwargs)[0], selected_indices
)
for x, selected_indices in zip(batch, selected_indices_batch)
]
Expand Down
19 changes: 14 additions & 5 deletions src/rank_llm/rerank/listwise/rank_listwise_os_llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,9 @@ def __init__(
)
elif vllm_batched:
self._llm = LLM(
model, download_dir=os.getenv("HF_HOME"), enforce_eager=False
model,
download_dir=os.getenv("HF_HOME"),
enforce_eager=False,
)
self._tokenizer = self._llm.get_tokenizer()
else:
Expand Down Expand Up @@ -133,26 +135,31 @@ def rerank_batch(
def run_llm_batched(
self,
prompts: List[str | List[Dict[str, str]]],
silence: bool = False,
current_window_size: Optional[int] = None,
**kwargs,
) -> List[Tuple[str, int]]:
if SamplingParams is None:
raise ImportError(
"Please install rank-llm with `pip install rank-llm[vllm]` to use batch inference."
)
logger.info(f"VLLM Generating!")
sampling_params = SamplingParams(
temperature=0.0,
max_tokens=self.num_output_tokens(current_window_size),
min_tokens=self.num_output_tokens(current_window_size),
)
outputs = self._llm.generate(prompts, sampling_params)
outputs = self._llm.generate(prompts, sampling_params, use_tqdm=not silence)
return [
(output.outputs[0].text, len(output.outputs[0].token_ids))
for output in outputs
]

def run_llm(
self, prompt: str, current_window_size: Optional[int] = None
self,
prompt: str,
silence: bool = False,
current_window_size: Optional[int] = None,
**kwargs,
) -> Tuple[str, int]:
if current_window_size is None:
current_window_size = self._window_size
Expand All @@ -163,7 +170,9 @@ def run_llm(
gen_cfg.min_new_tokens = self.num_output_tokens(current_window_size)
# gen_cfg.temperature = 0
gen_cfg.do_sample = False
output_ids = self._llm.generate(**inputs, generation_config=gen_cfg)
output_ids = self._llm.generate(
**inputs, use_tqdm=not silence, generation_config=gen_cfg
)

if self._llm.config.is_encoder_decoder:
output_ids = output_ids[0]
Expand Down
8 changes: 8 additions & 0 deletions src/rank_llm/scripts/run_rank_llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ def main(args):
vllm_batched = args.vllm_batched
batch_size = args.batch_size
reorder_policy = args.reorder_policy
silence = args.silence

_ = retrieve_and_rerank(
model_path=model_path,
Expand All @@ -63,6 +64,7 @@ def main(args):
system_message=system_message,
vllm_batched=vllm_batched,
reorder_policy=reorder_policy,
silence=silence,
)


Expand Down Expand Up @@ -182,5 +184,11 @@ def main(args):
help="policy in reordering. defaultly to be sliding window",
type=str,
)
parser.add_argument(
"--silence",
default=False,
action="store_true",
help="Whether or not omitting some unbeautiful tqdm bars that is unavoidable (not able to set leave=False)",
)
args = parser.parse_args()
main(args)

0 comments on commit 26beeed

Please sign in to comment.