Skip to content

Commit

Permalink
add novelty prompt
Browse files Browse the repository at this point in the history
  • Loading branch information
fschlatt committed Jun 7, 2024
1 parent 202cb21 commit 3a4ade4
Showing 1 changed file with 30 additions and 8 deletions.
38 changes: 30 additions & 8 deletions rank_gpt.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,9 +49,9 @@ def wrapper(*args, **kwargs):
tries += 1
return func(*args, **kwargs)
except Exception as e:
pattern = r"Please try again in (\d+m)?(\d+.?\d*)(\w+)\."
pattern = r"Please try again in (\d+m)?(\d+(?:\.\d+)?)(\w+)\."
match = re.search(pattern, str(e))
if match is None or tries > 5:
if match is None or tries > 30:
raise e
minutes = 0
if match.group(1):
Expand All @@ -69,7 +69,7 @@ def wrapper(*args, **kwargs):
range(int(seconds) + 1),
position=1,
leave=False,
desc="Waiting for rate limit",
desc=f"Waiting for rate limit, tries: {tries}",
):
time.sleep(1)

Expand Down Expand Up @@ -199,7 +199,22 @@ def run_retriever(topics, searcher, qrels=None, k=100, qid=None):
return ranks


def get_prefix_prompt(query, num):
def get_prefix_prompt(query, num, novelty_prompt):
if novelty_prompt:
return [
{
"role": "system",
"content": "You are RankGPT, an intelligent assistant that can rank passages based on their relevancy and novelty to the query.",
},
{
"role": "user",
"content": f"I will provide you with {num} passages, each indicated by number identifier []. \nRank the passages based on their relevance and novelty to the query: {query}.",
},
{
"role": "assistant",
"content": "Okay, please provide the passages.",
},
]
return [
{
"role": "system",
Expand All @@ -213,19 +228,21 @@ def get_prefix_prompt(query, num):
]


def get_post_prompt(query, num):
def get_post_prompt(query, num, novelty_prompt):
if novelty_prompt:
return f"Search Query: {query}. \nRank the {num} passages above based on their relevance and novelty to the search query. The passages should be listed in descending order using identifiers. The most relevant and novel passages should be listed first. For near-duplicate passages, put all but one of the passages at the bottom of the ranking. The output format should be [] > [], e.g., [1] > [2]. Only response the ranking results, do not say any word or explain."
return f"Search Query: {query}. \nRank the {num} passages above based on their relevance to the search query. The passages should be listed in descending order using identifiers. The most relevant passages should be listed first. The output format should be [] > [], e.g., [1] > [2]. Only response the ranking results, do not say any word or explain."


def create_permutation_instruction(
item, rank_start=0, rank_end=100, model_name="gpt-3.5-turbo"
item, rank_start=0, rank_end=100, model_name="gpt-3.5-turbo", novelty_prompt=False
):
query = item["query"]
num = len(item["hits"][rank_start:rank_end])

max_length = 300
while True:
messages = get_prefix_prompt(query, num)
messages = get_prefix_prompt(query, num, novelty_prompt)
rank = 0
for hit in item["hits"][rank_start:rank_end]:
rank += 1
Expand All @@ -238,7 +255,7 @@ def create_permutation_instruction(
messages.append(
{"role": "assistant", "content": f"Received passage [{rank}]."}
)
messages.append({"role": "user", "content": get_post_prompt(query, num)})
messages.append({"role": "user", "content": get_post_prompt(query, num, novelty_prompt)})

if (
num_tokens_from_messages(messages, model_name)
Expand Down Expand Up @@ -300,6 +317,7 @@ def permutation_pipeline(
rank_start: int = 0,
rank_end: int = 100,
model_name="gpt-3.5-turbo",
novelty_prompt=False,
):
messages = create_permutation_instruction(
item=item, rank_start=rank_start, rank_end=rank_end, model_name=model_name
Expand All @@ -319,6 +337,7 @@ def sliding_windows(
window_size: int = 20,
step: int = 10,
model_name="gpt-3.5-turbo",
novelty_prompt=False,
):
item = copy.deepcopy(item)
end_pos = rank_end
Expand All @@ -331,6 +350,7 @@ def sliding_windows(
rank_start=start_pos,
rank_end=end_pos,
model_name=model_name,
novelty_prompt=novelty_prompt,
)
end_pos = end_pos - step
start_pos = start_pos - step
Expand Down Expand Up @@ -378,6 +398,7 @@ def main():
parser.add_argument("--step", type=int, default=10)
parser.add_argument("--k", type=int, default=100)
parser.add_argument("--add_positive_passages", action="store_true")
parser.add_argument("--novelty-prompt", action="store_true")

args = parser.parse_args()

Expand Down Expand Up @@ -449,6 +470,7 @@ def main():
step=args.step,
model_name=args.model_name,
api_key=args.api_key,
novelty_prompt=args.novelty_prompt,
)
new_df = to_df(new_item)
reranked_run = pd.concat((reranked_run, new_df))
Expand Down

0 comments on commit 3a4ade4

Please sign in to comment.