Skip to content

Commit

Permalink
fix: openai vision
Browse files Browse the repository at this point in the history
  • Loading branch information
xihajun committed Jan 15, 2025
1 parent 7e1f46a commit fcb7e48
Show file tree
Hide file tree
Showing 2 changed files with 80 additions and 22 deletions.
Binary file added example.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
102 changes: 80 additions & 22 deletions llm_dialog_manager/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import json
import os
import uuid
from typing import List, Dict, Optional, Union
from typing import List, Dict, Union, Optional, Any
import logging
from pathlib import Path
import random
Expand Down Expand Up @@ -97,13 +97,30 @@ def completion(model: str, messages: List[Dict[str, Union[str, List[Union[str, I
api_key = os.getenv(f"{service.upper()}_API_KEY")
base_url = os.getenv(f"{service.upper()}_BASE_URL")

def format_messages_for_api(model, messages):
"""Convert ChatHistory messages to the format required by the specific API."""
def format_messages_for_api(
model: str,
messages: List[Dict[str, Union[str, List[Union[str, Image.Image, Dict]]]]]
) -> tuple[Optional[str], List[Dict[str, Any]]]:
"""
Convert ChatHistory messages to the format required by the specific API.
Args:
model: The model name (e.g., "claude", "gemini", "gpt")
messages: List of message dictionaries with role and content
Returns:
tuple: (system_message, formatted_messages)
- system_message is extracted system message for Claude, None for others
- formatted_messages is the list of formatted message dictionaries
"""
if "claude" in model and "openai" not in model:
formatted = []
system_msg = ""

# Extract system message if present
if messages and messages[0]["role"] == "system":
system_msg = messages.pop(0)["content"]

for msg in messages:
content = msg["content"]
if isinstance(content, str):
Expand All @@ -113,9 +130,12 @@ def format_messages_for_api(model, messages):
combined_content = []
for block in content:
if isinstance(block, str):
combined_content.append({"type": "text", "text": block})
combined_content.append({
"type": "text",
"text": block
})
elif isinstance(block, Image.Image):
# For Claude, convert PIL.Image to base64
# Convert PIL.Image to base64
buffered = io.BytesIO()
block.save(buffered, format="PNG")
image_base64 = base64.b64encode(buffered.getvalue()).decode("utf-8")
Expand Down Expand Up @@ -145,9 +165,12 @@ def format_messages_for_api(model, messages):
"data": block["image_base64"]["data"]
}
})
formatted.append({"role": msg["role"], "content": combined_content})
formatted.append({
"role": msg["role"],
"content": combined_content
})
return system_msg, formatted

elif ("gemini" in model or "gpt" in model or "grok" in model) and "openai" not in model:
formatted = []
for msg in messages:
Expand All @@ -160,40 +183,75 @@ def format_messages_for_api(model, messages):
if isinstance(block, str):
parts.append(block)
elif isinstance(block, Image.Image):
# Keep PIL.Image objects as is for Gemini
parts.append(block)
elif isinstance(block, dict):
if block.get("type") == "image_url":
parts.append({"type": "image_url", "image_url": {"url": block["image_url"]["url"]}})
parts.append({
"type": "image_url",
"image_url": {
"url": block["image_url"]["url"]
}
})
elif block.get("type") == "image_base64":
parts.append({"type": "image_base64", "image_base64": {"data": block["image_base64"]["data"], "media_type": block["image_base64"]["media_type"]}})
formatted.append({"role": msg["role"], "parts": parts})
parts.append({
"type": "image_base64",
"image_base64": {
"data": block["image_base64"]["data"],
"media_type": block["image_base64"]["media_type"]
}
})
formatted.append({
"role": msg["role"],
"parts": parts
})
return None, formatted

else: # OpenAI models
formatted = []
for msg in messages:
content = msg["content"]
if isinstance(content, str):
formatted.append({"role": msg["role"], "content": content})
formatted.append({
"role": msg["role"],
"content": content
})
elif isinstance(content, list):
# OpenAI expects 'content' as string; images are not directly supported
# You can convert images to URLs or descriptions if needed
combined_content = ""
formatted_content = []
for block in content:
if isinstance(block, str):
combined_content += block + "\n"
formatted_content.append({
"type": "text",
"text": block
})
elif isinstance(block, Image.Image):
# Convert PIL.Image to base64 or upload and use URL
# Convert PIL.Image to base64
buffered = io.BytesIO()
block.save(buffered, format="PNG")
image_base64 = base64.b64encode(buffered.getvalue()).decode("utf-8")
combined_content += f"[Image Base64: {image_base64[:30]}...]\n"
formatted_content.append({
"type": "image_url",
"image_url": {
"url": f"data:image/jpeg;base64,{image_base64}"
}
})
elif isinstance(block, dict):
if block.get("type") == "image_url":
combined_content += f"[Image: {block['image_url']['url']}]\n"
formatted_content.append({
"type": "image_url",
"image_url": block["image_url"]
})
elif block.get("type") == "image_base64":
combined_content += f"[Image Base64: {block['image_base64']['data'][:30]}...]\n"
formatted.append({"role": msg["role"], "content": combined_content.strip()})
formatted_content.append({
"type": "image_url",
"image_url": {
"url": f"data:image/jpeg;base64,{block['image_base64']['data']}"
}
})
formatted.append({
"role": msg["role"],
"content": formatted_content
})
return None, formatted

system_msg, formatted_messages = format_messages_for_api(model, messages.copy())
Expand Down Expand Up @@ -546,7 +604,7 @@ def add_repo(self, repo_url: Optional[str] = None, username: Optional[str] = Non
if __name__ == "__main__":
# Example Usage
# Create an Agent instance (Gemini model)
agent = Agent("gemini-1.5-flash", "you are Jack101", memory_enabled=True)
agent = Agent("gemini-1.5-flash-openai", "you are Jack101", memory_enabled=True)

# Add an image
agent.add_image(image_path="example.png")
Expand Down

0 comments on commit fcb7e48

Please sign in to comment.