Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add rerank, update demo and tests #112

Merged
merged 2 commits into from
Apr 25, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
64 changes: 40 additions & 24 deletions src/rank_llm/2cr/msmarco.py
Original file line number Diff line number Diff line change
Expand Up @@ -273,30 +273,46 @@ def generate_report(args):
row_cnt=row_cnt,
condition_name=table_keys[name],
row=row_ids[name],
s1=f'{table[name]["dl19"]["MULT"]:.0f}'
if table[name]["dl19"]["MULT"] != 0
else "-",
s2=f'{table[name]["dl20"]["MAP"]:.4f}'
if table[name]["dl20"]["MAP"] != 0
else "SPLADE++ EnsembleDistil",
s3=f'{table[name]["dl19"]["R@1K"]:.4f}'
if table[name]["dl19"]["R@1K"] != 0
else "100",
s4=f'{table[name]["dl19"]["nDCG@10"]:.4f}'
if table[name]["dl19"]["nDCG@10"] != 0
else "-",
s5=f'{table[name]["dl20"]["nDCG@10"]:.4f}'
if table[name]["dl20"]["nDCG@10"] != 0
else "-",
s6=f'{table[name]["dl20"]["R@1K"]:.4f}'
if table[name]["dl20"]["R@1K"] != 0
else "",
s7=f'{table[name]["dev"]["MRR@10"]:.4f}'
if table[name]["dev"]["MRR@10"] != 0
else "",
s8=f'{table[name]["dev"]["R@1K"]:.4f}'
if table[name]["dev"]["R@1K"] != 0
else "",
s1=(
f'{table[name]["dl19"]["MULT"]:.0f}'
if table[name]["dl19"]["MULT"] != 0
else "-"
),
s2=(
f'{table[name]["dl20"]["MAP"]:.4f}'
if table[name]["dl20"]["MAP"] != 0
else "SPLADE++ EnsembleDistil"
),
s3=(
f'{table[name]["dl19"]["R@1K"]:.4f}'
if table[name]["dl19"]["R@1K"] != 0
else "100"
),
s4=(
f'{table[name]["dl19"]["nDCG@10"]:.4f}'
if table[name]["dl19"]["nDCG@10"] != 0
else "-"
),
s5=(
f'{table[name]["dl20"]["nDCG@10"]:.4f}'
if table[name]["dl20"]["nDCG@10"] != 0
else "-"
),
s6=(
f'{table[name]["dl20"]["R@1K"]:.4f}'
if table[name]["dl20"]["R@1K"] != 0
else ""
),
s7=(
f'{table[name]["dev"]["MRR@10"]:.4f}'
if table[name]["dev"]["MRR@10"] != 0
else ""
),
s8=(
f'{table[name]["dev"]["R@1K"]:.4f}'
if table[name]["dev"]["R@1K"] != 0
else ""
),
cmd1=format_command(commands[name]["dl19"]),
cmd2=format_command(commands[name]["dl20"]),
cmd3=format_command(commands[name]["dev"]),
Expand Down
11 changes: 9 additions & 2 deletions src/rank_llm/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,8 +62,15 @@ def read_requests_from_file(file_path: str) -> List[Request]:


class DataWriter:
def __init__(self, data: Union[List[Result] | List[Request]], append: bool = False):
self._data = data
def __init__(
self,
data: Union[Request | Result | List[Result] | List[Request]],
append: bool = False,
):
if isinstance(data, list):
self._data = data
else:
self._data = [data]
self._append = append

def write_ranking_exec_summary(self, filename: str):
Expand Down
7 changes: 5 additions & 2 deletions src/rank_llm/demo/rerank_inline_hits.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
sys.path.append(parent)

from rank_llm.data import Request, DataWriter
from rank_llm.rerank.vicuna_reranker import VicunaReranker
from rank_llm.rerank.zephyr_reranker import ZephyrReranker

request_dict = {
Expand Down Expand Up @@ -74,9 +75,11 @@
],
}

requests = [from_dict(data_class=Request, data=request_dict)]
request = from_dict(data_class=Request, data=request_dict)
reranker = ZephyrReranker()
rerank_results = reranker.rerank_batch(requests)
rerank_results = reranker.rerank(request=request)
reranker = VicunaReranker()
rerank_results = reranker.rerank(request=request)
print(rerank_results)

# write rerank results
Expand Down
2 changes: 1 addition & 1 deletion src/rank_llm/rerank/rank_gpt.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ def __init__(
raise ValueError("Please provide OpenAI Keys.")
if prompt_mode not in [PromptMode.RANK_GPT, PromptMode.LRL]:
raise ValueError(
f"unsupported prompt mode for GPT models: {prompt_mode}, expected RANK_GPT or LRL."
f"unsupported prompt mode for GPT models: {prompt_mode}, expected {PromptMode.RANK_GPT} or {PromptMode.LRL}."
)

self._window_size = window_size
Expand Down
2 changes: 1 addition & 1 deletion src/rank_llm/rerank/rank_listwise_os_llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ def __init__(
assert torch.cuda.is_available()
if prompt_mode != PromptMode.RANK_GPT:
raise ValueError(
f"Unsupported prompt mode: {prompt_mode}. The only prompt mode currently supported is a slight variation of Rank_GPT prompt."
f"Unsupported prompt mode: {prompt_mode}. The only prompt mode currently supported is a slight variation of {PromptMode.RANK_GPT} prompt."
)
# ToDo: Make repetition_penalty configurable
self._llm, self._tokenizer = load_model(model, device=device, num_gpus=num_gpus)
Expand Down
39 changes: 39 additions & 0 deletions src/rank_llm/rerank/reranker.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,45 @@ def rerank_batch(
results.append(result)
return results

def rerank(
self,
request: Request,
rank_start: int = 0,
rank_end: int = 100,
window_size: int = 20,
step: int = 10,
shuffle_candidates: bool = False,
logging: bool = False,
) -> Result:
"""
Reranks a request using the RankLLM agent.

This function applies a sliding window algorithm to rerank the results.
Each window of results is processed by the RankLLM agent to obtain a new ranking.

Args:
request (Request): The reranking request which has a query and a candidates list.
rank_start (int, optional): The starting rank for processing. Defaults to 0.
rank_end (int, optional): The end rank for processing. Defaults to 100.
window_size (int, optional): The size of each sliding window. Defaults to 20.
step (int, optional): The step size for moving the window. Defaults to 10.
shuffle_candidates (bool, optional): Whether to shuffle candidates before reranking. Defaults to False.
logging (bool, optional): Enables logging of the reranking process. Defaults to False.

Returns:
Result: the rerank result which contains the reranked candidates.
"""
results = self.rerank_batch(
requests=[request],
rank_start=rank_start,
rank_end=rank_end,
window_size=window_size,
step=step,
shuffle_candidates=shuffle_candidates,
logging=logging,
)
return results[0]

def write_rerank_results(
self,
retrieval_method_name: str,
Expand Down
40 changes: 39 additions & 1 deletion src/rank_llm/rerank/vicuna_reranker.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ def rerank_batch(
List[Result]: A list containing the reranked results.

Note:
check 'rerank' for implementation details of reranking process.
check 'reranker.rerank_batch' for implementation details of reranking process.
"""
return self._reranker.rerank_batch(
requests=requests,
Expand All @@ -69,3 +69,41 @@ def rerank_batch(
shuffle_candidates=shuffle_candidates,
logging=logging,
)

def rerank(
self,
request: Request,
rank_start: int = 0,
rank_end: int = 100,
window_size: int = 20,
step: int = 10,
shuffle_candidates: bool = False,
logging: bool = False,
) -> Result:
"""
Reranks a request using the Vicuna model.

Args:
request (Request): The reranking request which has a query and a candidates list.
rank_start (int, optional): The starting rank for processing. Defaults to 0.
rank_end (int, optional): The end rank for processing. Defaults to 100.
window_size (int, optional): The size of each sliding window. Defaults to 20.
step (int, optional): The step size for moving the window. Defaults to 10.
shuffle_candidates (bool, optional): Whether to shuffle candidates before reranking. Defaults to False.
logging (bool, optional): Enables logging of the reranking process. Defaults to False.

Returns:
Result: the rerank result which contains the reranked candidates.

Note:
check 'reranker.rerank' for implementation details of reranking process.
"""
return self._reranker.rerank(
request=request,
rank_start=rank_start,
rank_end=rank_end,
window_size=window_size,
step=step,
shuffle_candidates=shuffle_candidates,
logging=logging,
)
40 changes: 39 additions & 1 deletion src/rank_llm/rerank/zephyr_reranker.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ def rerank_batch(
List[Result]: A list containing the reranked results.

Note:
check 'rerank' for implementation details of reranking process.
check 'reranker.rerank_batch' for implementation details of reranking process.
"""
return self._reranker.rerank_batch(
requests=requests,
Expand All @@ -69,3 +69,41 @@ def rerank_batch(
shuffle_candidates=shuffle_candidates,
logging=logging,
)

def rerank(
self,
request: Request,
rank_start: int = 0,
rank_end: int = 100,
window_size: int = 20,
step: int = 10,
shuffle_candidates: bool = False,
logging: bool = False,
) -> Result:
"""
Reranks a request using the Zephyr model.

Args:
request (Request): The reranking request which has a query and a candidates list.
rank_start (int, optional): The starting rank for processing. Defaults to 0.
rank_end (int, optional): The end rank for processing. Defaults to 100.
window_size (int, optional): The size of each sliding window. Defaults to 20.
step (int, optional): The step size for moving the window. Defaults to 10.
shuffle_candidates (bool, optional): Whether to shuffle candidates before reranking. Defaults to False.
logging (bool, optional): Enables logging of the reranking process. Defaults to False.

Returns:
Result: the rerank result which contains the reranked candidates.

Note:
check 'reranker.rerank' for implementation details of reranking process.
"""
return self._reranker.rerank(
request=request,
rank_start=rank_start,
rank_end=rank_end,
window_size=window_size,
step=step,
shuffle_candidates=shuffle_candidates,
logging=logging,
)
6 changes: 3 additions & 3 deletions test/analysis/test_response_analyzer.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import unittest

from src.rank_llm.analysis.response_analysis import ResponseAnalyzer
from src.rank_llm.result import RankingExecInfo, Result
from src.rank_llm.data import RankingExecInfo, Result


class TestResponseAnalyzer(unittest.TestCase):
Expand All @@ -10,7 +10,7 @@ def setUp(self):
self.mock_results = [
Result(
query="Query 1",
hits=[],
candidates=[],
ranking_exec_summary=[
RankingExecInfo(
prompt="I will provide you with 3 passages",
Expand All @@ -28,7 +28,7 @@ def setUp(self):
),
Result(
query="Query 2",
hits=[],
candidates=[],
ranking_exec_summary=[
RankingExecInfo(
prompt="I will provide you with 4 passages",
Expand Down
28 changes: 18 additions & 10 deletions test/evaluation/test_trec_eval.py
Original file line number Diff line number Diff line change
@@ -1,23 +1,31 @@
import unittest
from unittest.mock import patch

from dacite import from_dict

from src.rank_llm.evaluation.trec_eval import EvalFunction
from src.rank_llm.result import Result
from src.rank_llm.data import Result


class TestEvalFunction(unittest.TestCase):
def setUp(self):
self.results = [
Result(
query="Query1",
hits=[
{"qid": "q1", "docid": "D1", "rank": 1, "score": 0.9},
{"qid": "q1", "docid": "D2", "rank": 2, "score": 0.8},
],
from_dict(
data_class=Result,
data={
"query": {"text": "Query1", "qid": "q1"},
"candidates": [
{"qid": "q1", "docid": "D1", "score": 0.9},
{"qid": "q1", "docid": "D2", "score": 0.8},
],
},
),
Result(
query="Query2",
hits=[{"qid": "q2", "docid": "D3", "rank": 1, "score": 0.85}],
from_dict(
data_class=Result,
data={
"query": {"text": "Query2", "qid": "q2"},
"candidates": [{"qid": "q2", "docid": "D3", "score": 0.85}],
},
),
]
self.qrels_path = "path/to/qrels"
Expand Down
Loading