diff --git a/example.png b/example.png new file mode 100644 index 0000000..e58939c Binary files /dev/null and b/example.png differ diff --git a/llm_dialog_manager/agent.py b/llm_dialog_manager/agent.py index f46c2ae..27b2ee8 100644 --- a/llm_dialog_manager/agent.py +++ b/llm_dialog_manager/agent.py @@ -2,7 +2,7 @@ import json import os import uuid -from typing import List, Dict, Optional, Union +from typing import List, Dict, Union, Optional, Any import logging from pathlib import Path import random @@ -97,13 +97,30 @@ def completion(model: str, messages: List[Dict[str, Union[str, List[Union[str, I api_key = os.getenv(f"{service.upper()}_API_KEY") base_url = os.getenv(f"{service.upper()}_BASE_URL") - def format_messages_for_api(model, messages): - """Convert ChatHistory messages to the format required by the specific API.""" + def format_messages_for_api( + model: str, + messages: List[Dict[str, Union[str, List[Union[str, Image.Image, Dict]]]]] + ) -> tuple[Optional[str], List[Dict[str, Any]]]: + """ + Convert ChatHistory messages to the format required by the specific API. + + Args: + model: The model name (e.g., "claude", "gemini", "gpt") + messages: List of message dictionaries with role and content + + Returns: + tuple: (system_message, formatted_messages) + - system_message is extracted system message for Claude, None for others + - formatted_messages is the list of formatted message dictionaries + """ if "claude" in model and "openai" not in model: formatted = [] system_msg = "" + + # Extract system message if present if messages and messages[0]["role"] == "system": system_msg = messages.pop(0)["content"] + for msg in messages: content = msg["content"] if isinstance(content, str): @@ -113,9 +130,12 @@ def format_messages_for_api(model, messages): combined_content = [] for block in content: if isinstance(block, str): - combined_content.append({"type": "text", "text": block}) + combined_content.append({ + "type": "text", + "text": block + }) elif isinstance(block, Image.Image): - # For Claude, convert PIL.Image to base64 + # Convert PIL.Image to base64 buffered = io.BytesIO() block.save(buffered, format="PNG") image_base64 = base64.b64encode(buffered.getvalue()).decode("utf-8") @@ -145,9 +165,12 @@ def format_messages_for_api(model, messages): "data": block["image_base64"]["data"] } }) - formatted.append({"role": msg["role"], "content": combined_content}) + formatted.append({ + "role": msg["role"], + "content": combined_content + }) return system_msg, formatted - + elif ("gemini" in model or "gpt" in model or "grok" in model) and "openai" not in model: formatted = [] for msg in messages: @@ -160,40 +183,75 @@ def format_messages_for_api(model, messages): if isinstance(block, str): parts.append(block) elif isinstance(block, Image.Image): + # Keep PIL.Image objects as is for Gemini parts.append(block) elif isinstance(block, dict): if block.get("type") == "image_url": - parts.append({"type": "image_url", "image_url": {"url": block["image_url"]["url"]}}) + parts.append({ + "type": "image_url", + "image_url": { + "url": block["image_url"]["url"] + } + }) elif block.get("type") == "image_base64": - parts.append({"type": "image_base64", "image_base64": {"data": block["image_base64"]["data"], "media_type": block["image_base64"]["media_type"]}}) - formatted.append({"role": msg["role"], "parts": parts}) + parts.append({ + "type": "image_base64", + "image_base64": { + "data": block["image_base64"]["data"], + "media_type": block["image_base64"]["media_type"] + } + }) + formatted.append({ + "role": msg["role"], + "parts": parts + }) return None, formatted - + else: # OpenAI models formatted = [] for msg in messages: content = msg["content"] if isinstance(content, str): - formatted.append({"role": msg["role"], "content": content}) + formatted.append({ + "role": msg["role"], + "content": content + }) elif isinstance(content, list): - # OpenAI expects 'content' as string; images are not directly supported - # You can convert images to URLs or descriptions if needed - combined_content = "" + formatted_content = [] for block in content: if isinstance(block, str): - combined_content += block + "\n" + formatted_content.append({ + "type": "text", + "text": block + }) elif isinstance(block, Image.Image): - # Convert PIL.Image to base64 or upload and use URL + # Convert PIL.Image to base64 buffered = io.BytesIO() block.save(buffered, format="PNG") image_base64 = base64.b64encode(buffered.getvalue()).decode("utf-8") - combined_content += f"[Image Base64: {image_base64[:30]}...]\n" + formatted_content.append({ + "type": "image_url", + "image_url": { + "url": f"data:image/jpeg;base64,{image_base64}" + } + }) elif isinstance(block, dict): if block.get("type") == "image_url": - combined_content += f"[Image: {block['image_url']['url']}]\n" + formatted_content.append({ + "type": "image_url", + "image_url": block["image_url"] + }) elif block.get("type") == "image_base64": - combined_content += f"[Image Base64: {block['image_base64']['data'][:30]}...]\n" - formatted.append({"role": msg["role"], "content": combined_content.strip()}) + formatted_content.append({ + "type": "image_url", + "image_url": { + "url": f"data:image/jpeg;base64,{block['image_base64']['data']}" + } + }) + formatted.append({ + "role": msg["role"], + "content": formatted_content + }) return None, formatted system_msg, formatted_messages = format_messages_for_api(model, messages.copy()) @@ -546,7 +604,7 @@ def add_repo(self, repo_url: Optional[str] = None, username: Optional[str] = Non if __name__ == "__main__": # Example Usage # Create an Agent instance (Gemini model) - agent = Agent("gemini-1.5-flash", "you are Jack101", memory_enabled=True) + agent = Agent("gemini-1.5-flash-openai", "you are Jack101", memory_enabled=True) # Add an image agent.add_image(image_path="example.png")