Skip to content

Commit

Permalink
add
Browse files Browse the repository at this point in the history
  • Loading branch information
BabyChouSr committed Sep 15, 2024
1 parent 059b654 commit d4b1fd3
Show file tree
Hide file tree
Showing 2 changed files with 54 additions and 3 deletions.
21 changes: 20 additions & 1 deletion fastchat/serve/monitor/classify/category.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@
# - if
# - score
import ast
import base64
import os
import re


Expand All @@ -24,6 +26,8 @@ def create_category(name):
return CategoryIF()
elif name == "math_v0.1":
return CategoryMath()
elif name == "criteria_vision_v0.1":
return CategoryVisionHardPrompt()

raise Exception(f"Category name is incorrect: {name}")

Expand Down Expand Up @@ -138,4 +142,19 @@ def post_process(self, judgment):
class CategoryVisionHardPrompt(CategoryHardPrompt):
def __init__(self):
super().__init__()
self.name_tag = "criteria_vision_v0.1"
self.name_tag = "criteria_vision_v0.1"

def _convert_filepath_to_base64(self, filepath):
with open(filepath, "rb") as image_file:
return base64.b64encode(image_file.read()).decode('utf-8')

def pre_process(self, prompt: str, image_list: list):
# Prompt is a list where the first element is text and the second element is a list of image in base64 format
conv = [{"role": "system", "content": self.sys_prompt}]
single_turn_content_list = []
single_turn_content_list.append({"type": "text", "text": prompt})
for image_url in image_list:
single_turn_content_list.append({"type": "image_url", "image_url": {"url": f"data:image/png;base64,{self._convert_filepath_to_base64(image_url)}"}})

conv.append({"role": "user", "content": single_turn_content_list})
return conv
36 changes: 34 additions & 2 deletions fastchat/serve/monitor/classify/label.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,10 @@ def get_answer(
output_log = {}

for category in categories:
conv = category.pre_process(question["prompt"])
if config["images_dir"]:
conv = category.pre_process(question["prompt"], question["image_list"])
else:
conv = category.pre_process(question["prompt"])
output = chat_completion_openai(
model=model_name,
messages=conv,
Expand Down Expand Up @@ -164,6 +167,30 @@ def find_required_tasks(row):
)
]

def aggregate_entire_conversation(conversation, images_dir):
final_text_content = ""
final_image_list = []

for i in range(0, len(conversation), 2):
content = conversation[i]["content"]
if isinstance(content, str):
final_text_content += "\n" + content
elif isinstance(content, list):
text_content, image_list = content
final_text_content += "\n" + text_content

for image in image_list:
image_url = os.path.join(images_dir, f"{image}.png")
if os.path.exists(image_url):
final_image_list.append(image_url)

return final_text_content, final_image_list

def get_prompt_from_conversation(conversation):
return conversation[0]

def get_image_list_from_conversation(conversation):
return conversation[1]

if __name__ == "__main__":
parser = argparse.ArgumentParser()
Expand Down Expand Up @@ -247,8 +274,13 @@ def find_required_tasks(row):
)

not_labeled["prompt"] = not_labeled.conversation_a.map(
lambda convo: "\n".join([convo[i]["content"] for i in range(0, len(convo), 2)])
lambda convo: aggregate_entire_conversation(convo, config["images_dir"])
)

if config["images_dir"]:
not_labeled["image_list"] = not_labeled.prompt.map(get_image_list_from_conversation)
not_labeled = not_labeled[not_labeled.image_list.map(len) > 0]
not_labeled["prompt"] = not_labeled.prompt.map(get_prompt_from_conversation)
not_labeled["prompt"] = not_labeled.prompt.map(lambda x: x[:12500])

with concurrent.futures.ThreadPoolExecutor(
Expand Down

0 comments on commit d4b1fd3

Please sign in to comment.