diff --git a/fastchat/serve/monitor/classify/category.py b/fastchat/serve/monitor/classify/category.py index f1a6dc3df..62a845d9b 100644 --- a/fastchat/serve/monitor/classify/category.py +++ b/fastchat/serve/monitor/classify/category.py @@ -9,6 +9,8 @@ # - if # - score import ast +import base64 +import os import re @@ -24,6 +26,8 @@ def create_category(name): return CategoryIF() elif name == "math_v0.1": return CategoryMath() + elif name == "criteria_vision_v0.1": + return CategoryVisionHardPrompt() raise Exception(f"Category name is incorrect: {name}") @@ -138,4 +142,19 @@ def post_process(self, judgment): class CategoryVisionHardPrompt(CategoryHardPrompt): def __init__(self): super().__init__() - self.name_tag = "criteria_vision_v0.1" \ No newline at end of file + self.name_tag = "criteria_vision_v0.1" + + def _convert_filepath_to_base64(self, filepath): + with open(filepath, "rb") as image_file: + return base64.b64encode(image_file.read()).decode('utf-8') + + def pre_process(self, prompt: str, image_list: list): + # Prompt is a list where the first element is text and the second element is a list of image in base64 format + conv = [{"role": "system", "content": self.sys_prompt}] + single_turn_content_list = [] + single_turn_content_list.append({"type": "text", "text": prompt}) + for image_url in image_list: + single_turn_content_list.append({"type": "image_url", "image_url": {"url": f"data:image/png;base64,{self._convert_filepath_to_base64(image_url)}"}}) + + conv.append({"role": "user", "content": single_turn_content_list}) + return conv \ No newline at end of file diff --git a/fastchat/serve/monitor/classify/label.py b/fastchat/serve/monitor/classify/label.py index 2d0471a1f..b411cf2e2 100644 --- a/fastchat/serve/monitor/classify/label.py +++ b/fastchat/serve/monitor/classify/label.py @@ -107,7 +107,10 @@ def get_answer( output_log = {} for category in categories: - conv = category.pre_process(question["prompt"]) + if config["images_dir"]: + conv = category.pre_process(question["prompt"], question["image_list"]) + else: + conv = category.pre_process(question["prompt"]) output = chat_completion_openai( model=model_name, messages=conv, @@ -164,6 +167,30 @@ def find_required_tasks(row): ) ] +def aggregate_entire_conversation(conversation, images_dir): + final_text_content = "" + final_image_list = [] + + for i in range(0, len(conversation), 2): + content = conversation[i]["content"] + if isinstance(content, str): + final_text_content += "\n" + content + elif isinstance(content, list): + text_content, image_list = content + final_text_content += "\n" + text_content + + for image in image_list: + image_url = os.path.join(images_dir, f"{image}.png") + if os.path.exists(image_url): + final_image_list.append(image_url) + + return final_text_content, final_image_list + +def get_prompt_from_conversation(conversation): + return conversation[0] + +def get_image_list_from_conversation(conversation): + return conversation[1] if __name__ == "__main__": parser = argparse.ArgumentParser() @@ -247,8 +274,13 @@ def find_required_tasks(row): ) not_labeled["prompt"] = not_labeled.conversation_a.map( - lambda convo: "\n".join([convo[i]["content"] for i in range(0, len(convo), 2)]) + lambda convo: aggregate_entire_conversation(convo, config["images_dir"]) ) + + if config["images_dir"]: + not_labeled["image_list"] = not_labeled.prompt.map(get_image_list_from_conversation) + not_labeled = not_labeled[not_labeled.image_list.map(len) > 0] + not_labeled["prompt"] = not_labeled.prompt.map(get_prompt_from_conversation) not_labeled["prompt"] = not_labeled.prompt.map(lambda x: x[:12500]) with concurrent.futures.ThreadPoolExecutor(