diff --git a/llm_dialog_manager/agent.py b/llm_dialog_manager/agent.py index 30365f3..05a1307 100644 --- a/llm_dialog_manager/agent.py +++ b/llm_dialog_manager/agent.py @@ -2,13 +2,15 @@ import json import os import uuid -from typing import List, Dict, Optional +from typing import List, Dict, Optional, Union import logging from pathlib import Path import random import requests import zipfile import io +import base64 +from PIL import Image # Third-party imports import anthropic @@ -18,8 +20,8 @@ from dotenv import load_dotenv # Local imports -from llm_dialog_manager.chat_history import ChatHistory -from llm_dialog_manager.key_manager import key_manager +from .chat_history import ChatHistory +from .key_manager import key_manager # Set up logging logging.basicConfig(level=logging.INFO) @@ -28,7 +30,7 @@ # Load environment variables def load_env_vars(): """Load environment variables from .env file""" - env_path = Path(__file__).parent.parent / '.env' + env_path = Path(__file__).parent / '.env' if env_path.exists(): load_dotenv(env_path) else: @@ -36,24 +38,9 @@ def load_env_vars(): load_env_vars() -def create_and_send_message(client, model, max_tokens, temperature, messages, system_msg): - """Function to send a message to the Anthropic API and handle the response.""" - try: - response = client.messages.create( - model=model, - max_tokens=max_tokens, - temperature=temperature, - messages=messages, - system=system_msg - ) - return response - except Exception as e: - logger.error(f"Error sending message: {str(e)}") - raise - -def completion(model: str, messages: List[Dict[str, str]], max_tokens: int = 1000, +def completion(model: str, messages: List[Dict[str, Union[str, List[Union[str, Image.Image, Dict]]]]], max_tokens: int = 1000, temperature: float = 0.5, api_key: Optional[str] = None, - base_url: Optional[str] = None) -> str: + base_url: Optional[str] = None, json_format: bool = False) -> str: """ Generate a completion using the specified model and messages. """ @@ -70,202 +57,393 @@ def completion(model: str, messages: List[Dict[str, str]], max_tokens: int = 100 # Get API key and base URL from key manager if not provided if not api_key: - api_key, base_url = key_manager.get_config(service) + # api_key, base_url = key_manager.get_config(service) + # Placeholder for key_manager + api_key = os.getenv(f"{service.upper()}_API_KEY") + base_url = os.getenv(f"{service.upper()}_BASE_URL") - try: + def format_messages_for_api(model, messages): + """Convert ChatHistory messages to the format required by the specific API.""" if "claude" in model: - # Check for Vertex configuration - vertex_project_id = os.getenv('VERTEX_PROJECT_ID') - vertex_region = os.getenv('VERTEX_REGION') - - if vertex_project_id and vertex_region: - client = AnthropicVertex( - region=vertex_region, - project_id=vertex_project_id - ) + formatted = [] + system_msg = "" + if messages and messages[0]["role"] == "system": + system_msg = messages.pop(0)["content"] + for msg in messages: + content = msg["content"] + if isinstance(content, str): + formatted.append({"role": msg["role"], "content": content}) + elif isinstance(content, list): + # Combine content blocks into a single message + combined_content = [] + for block in content: + if isinstance(block, str): + combined_content.append({"type": "text", "text": block}) + elif isinstance(block, Image.Image): + # For Claude, convert PIL.Image to base64 + buffered = io.BytesIO() + block.save(buffered, format="PNG") + image_base64 = base64.b64encode(buffered.getvalue()).decode("utf-8") + combined_content.append({ + "type": "image", + "source": { + "type": "base64", + "media_type": "image/png", + "data": image_base64 + } + }) + elif isinstance(block, dict): + if block.get("type") == "image_url": + combined_content.append({ + "type": "image", + "source": { + "type": "url", + "url": block["image_url"]["url"] + } + }) + elif block.get("type") == "image_base64": + combined_content.append({ + "type": "image", + "source": { + "type": "base64", + "media_type": block["image_base64"]["media_type"], + "data": block["image_base64"]["data"] + } + }) + formatted.append({"role": msg["role"], "content": combined_content}) + return system_msg, formatted + + elif "gemini" in model or "gpt" in model or "grok" in model: + formatted = [] + for msg in messages: + content = msg["content"] + if isinstance(content, str): + formatted.append({"role": msg["role"], "parts": [content]}) + elif isinstance(content, list): + parts = [] + for block in content: + if isinstance(block, str): + parts.append(block) + elif isinstance(block, Image.Image): + parts.append(block) + elif isinstance(block, dict): + if block.get("type") == "image_url": + parts.append({"type": "image_url", "image_url": {"url": block["image_url"]["url"]}}) + elif block.get("type") == "image_base64": + parts.append({"type": "image_base64", "image_base64": {"data": block["image_base64"]["data"], "media_type": block["image_base64"]["media_type"]}}) + formatted.append({"role": msg["role"], "parts": parts}) + return None, formatted + + else: # OpenAI models + formatted = [] + for msg in messages: + content = msg["content"] + if isinstance(content, str): + formatted.append({"role": msg["role"], "content": content}) + elif isinstance(content, list): + # OpenAI expects 'content' as string; images are not directly supported + # You can convert images to URLs or descriptions if needed + combined_content = "" + for block in content: + if isinstance(block, str): + combined_content += block + "\n" + elif isinstance(block, Image.Image): + # Convert PIL.Image to base64 or upload and use URL + buffered = io.BytesIO() + block.save(buffered, format="PNG") + image_base64 = base64.b64encode(buffered.getvalue()).decode("utf-8") + combined_content += f"[Image Base64: {image_base64[:30]}...]\n" + elif isinstance(block, dict): + if block.get("type") == "image_url": + combined_content += f"[Image: {block['image_url']['url']}]\n" + elif block.get("type") == "image_base64": + combined_content += f"[Image Base64: {block['image_base64']['data'][:30]}...]\n" + formatted.append({"role": msg["role"], "content": combined_content.strip()}) + return None, formatted + + system_msg, formatted_messages = format_messages_for_api(model, messages.copy()) + + if "claude" in model: + # Check for Vertex configuration + vertex_project_id = os.getenv('VERTEX_PROJECT_ID') + vertex_region = os.getenv('VERTEX_REGION') + + if vertex_project_id and vertex_region: + client = AnthropicVertex( + region=vertex_region, + project_id=vertex_project_id + ) + else: + client = anthropic.Anthropic(api_key=api_key, base_url=base_url) + + response = client.messages.create( + model=model, + max_tokens=max_tokens, + temperature=temperature, + messages=formatted_messages, + system=system_msg + ) + + while response.stop_reason == "max_tokens": + if formatted_messages[-1]['role'] == "user": + formatted_messages.append({"role": "assistant", "content": response.completion}) else: - client = anthropic.Anthropic(api_key=api_key, base_url=base_url) + formatted_messages[-1]['content'] += response.completion - system_msg = messages.pop(0)["content"] if messages and messages[0]["role"] == "system" else "" response = client.messages.create( model=model, max_tokens=max_tokens, temperature=temperature, - messages=messages, + messages=formatted_messages, system=system_msg ) - - while response.stop_reason == "max_tokens": - if messages[-1]['role'] == "user": - messages.append({"role": "assistant", "content": response.content[0].text}) - else: - messages[-1]['content'] += response.content[0].text - response = client.messages.create( - model=model, - max_tokens=max_tokens, - temperature=temperature, - messages=messages, - system=system_msg - ) + if formatted_messages[-1]['role'] == "assistant" and response.stop_reason == "end_turn": + formatted_messages[-1]['content'] += response.completion + return formatted_messages[-1]['content'] - if messages[-1]['role'] == "assistant" and response.stop_reason == "end_turn": - messages[-1]['content'] += response.content[0].text - return messages[-1]['content'] - - return response.content[0].text - - elif "gemini" in model: - try: - # First try OpenAI-style API - client = openai.OpenAI( - api_key=api_key, - base_url="https://generativelanguage.googleapis.com/v1beta/" - ) - # Remove any system message from the beginning if present - if messages and messages[0]["role"] == "system": - system_msg = messages.pop(0) + return response.completion + + elif "gemini" in model: + try: + # First try OpenAI-style API + client = openai.OpenAI( + api_key=api_key, + base_url="https://generativelanguage.googleapis.com/v1beta/" + ) + # Set response_format based on json_format + response_format = {"type": "json_object"} if json_format else {"type": "plain_text"} + + response = client.chat.completions.create( + model=model, + messages=formatted_messages, + temperature=temperature, + response_format=response_format # Added response_format + ) + return response.choices[0].message.content + + except Exception as e: + # If OpenAI-style API fails, fall back to Google's genai library + logger.info("Falling back to Google's genai library") + genai.configure(api_key=api_key) + + # Convert messages to Gemini format + gemini_messages = [] + for msg in messages: + if msg["role"] == "system": # Prepend system message to first user message if exists - if messages: - messages[0]["content"] = f"{system_msg['content']}\n\n{messages[0]['content']}" - - response = client.chat.completions.create( - model=model, - messages=messages, - temperature=temperature - ) - - return response.choices[0].message.content - - except Exception as e: - # If OpenAI-style API fails, fall back to Google's genai library - logger.info("Falling back to Google's genai library") - genai.configure(api_key=api_key) - - # Convert messages to Gemini format - gemini_messages = [] - for msg in messages: - if msg["role"] == "system": - # Prepend system message to first user message if exists - if gemini_messages: - gemini_messages[0].parts[0].text = f"{msg['content']}\n\n{gemini_messages[0].parts[0].text}" - else: - gemini_messages.append({"role": msg["role"], "parts": [{"text": msg["content"]}]}) - - # Create Gemini model and generate response - model = genai.GenerativeModel(model_name=model) - response = model.generate_content( - gemini_messages, - generation_config=genai.types.GenerationConfig( - temperature=temperature, - max_output_tokens=max_tokens - ) - ) - - return response.text - - elif "grok" in model: - # Randomly choose between OpenAI and Anthropic SDK - use_anthropic = random.choice([True, False]) - - if use_anthropic: - print("using anthropic") - client = anthropic.Anthropic( - api_key=api_key, - base_url="https://api.x.ai" - ) - - system_msg = messages.pop(0)["content"] if messages and messages[0]["role"] == "system" else "" - response = client.messages.create( - model=model, - max_tokens=max_tokens, + if gemini_messages: + first_msg = gemini_messages[0] + if "parts" in first_msg and len(first_msg["parts"]) > 0: + first_msg["parts"][0] = f"{msg['content']}\n\n{first_msg['parts'][0]}" + else: + gemini_messages.append({"role": msg["role"], "parts": msg["content"]}) + + # Set response_mime_type based on json_format + mime_type = "application/json" if json_format else "text/plain" + + # Create Gemini model and generate response + model_instance = genai.GenerativeModel(model_name=model) + response = model_instance.generate_content( + gemini_messages, + generation_config=genai.types.GenerationConfig( temperature=temperature, - messages=messages, - system=system_msg - ) - return response.content[0].text - else: - print("using openai") - client = openai.OpenAI( - api_key=api_key, - base_url="https://api.x.ai/v1" - ) - response = client.chat.completions.create( - model=model, - messages=messages, - max_tokens=max_tokens, - temperature=temperature + response_mime_type=mime_type, # Modified based on json_format + max_output_tokens=max_tokens ) - return response.choices[0].message.content + ) + + return response.text + + elif "grok" in model: + # Randomly choose between OpenAI and Anthropic SDK + use_anthropic = random.choice([True, False]) + + if use_anthropic: + logger.info("Using Anthropic for Grok model") + client = anthropic.Anthropic( + api_key=api_key, + base_url="https://api.x.ai" + ) + + system_msg = "" + if messages and messages[0]["role"] == "system": + system_msg = messages.pop(0)["content"] + + response = client.messages.create( + model=model, + max_tokens=max_tokens, + temperature=temperature, + messages=formatted_messages, + system=system_msg + ) + return response.completion + else: + logger.info("Using OpenAI for Grok model") + client = openai.OpenAI( + api_key=api_key, + base_url="https://api.x.ai/v1" + ) + # Set response_format based on json_format + response_format = {"type": "json_object"} if json_format else {"type": "plain_text"} - else: # OpenAI models - client = openai.OpenAI(api_key=api_key, base_url=base_url) response = client.chat.completions.create( model=model, - messages=messages, + messages=formatted_messages, max_tokens=max_tokens, temperature=temperature, + response_format=response_format # Added response_format ) return response.choices[0].message.content - # Release the API key after successful use - if not api_key: - key_manager.release_config(service, api_key) + else: # OpenAI models + client = openai.OpenAI(api_key=api_key, base_url=base_url) + # Set response_format based on json_format + response_format = {"type": "json_object"} if json_format else {"type": "plain_text"} - return response + response = client.chat.completions.create( + model=model, + messages=formatted_messages, + max_tokens=max_tokens, + temperature=temperature, + response_format=response_format # Added response_format + ) + return response.choices[0].message.content - except Exception as e: - # Report error to key manager - if not api_key: - key_manager.report_error(service, api_key) - raise + # Release the API key after successful use + if not api_key: + # key_manager.release_config(service, api_key) + pass + + return response except Exception as e: logger.error(f"Error in completion: {str(e)}") raise class Agent: - def __init__(self, model_name: str, messages: Optional[str] = None, + def __init__(self, model_name: str, messages: Optional[Union[str, List[Dict[str, Union[str, List[Union[str, Image.Image, Dict]]]]]]] = None, memory_enabled: bool = False, api_key: Optional[str] = None) -> None: """Initialize an Agent instance.""" - # valid_models = ['gpt-3.5-turbo', 'gpt-4', 'claude-2.1', 'gemini-1.5-pro', 'gemini-1.5-flash', 'grok-beta', 'claude-3-5-sonnet-20241022'] - # if model_name not in valid_models: - # raise ValueError(f"Model {model_name} not supported. Supported models: {valid_models}") - self.id = f"{model_name}-{uuid.uuid4().hex[:8]}" self.model_name = model_name - self.history = ChatHistory(messages) + self.history = ChatHistory(messages) if messages else ChatHistory() self.memory_enabled = memory_enabled self.api_key = api_key self.repo_content = [] - def add_message(self, role, content): - repo_content = "" - while self.repo_content: - repo = self.repo_content.pop() - repo_content += f"\n{repo}\n\n" - - content = repo_content + content + def add_message(self, role: str, content: Union[str, List[Union[str, Image.Image, Dict]]]): + """Add a message to the conversation.""" self.history.add_message(content, role) - def generate_response(self, max_tokens=3585, temperature=0.7): + def add_user_message(self, content: Union[str, List[Union[str, Image.Image, Dict]]]): + """Add a user message.""" + self.history.add_user_message(content) + + def add_assistant_message(self, content: Union[str, List[Union[str, Image.Image, Dict]]]): + """Add an assistant message.""" + self.history.add_assistant_message(content) + + def add_image(self, image_path: Optional[str] = None, image_url: Optional[str] = None, media_type: Optional[str] = "image/jpeg"): + """ + Add an image to the conversation. + Either image_path or image_url must be provided. + """ + if not image_path and not image_url: + raise ValueError("Either image_path or image_url must be provided.") + + if image_path: + if not os.path.exists(image_path): + raise FileNotFoundError(f"Image file {image_path} does not exist.") + if "gemini" in self.model_name: + # For Gemini, load as PIL.Image + image_pil = Image.open(image_path) + image_block = image_pil + else: + # For Claude and others, use base64 encoding + with open(image_path, "rb") as img_file: + image_data = base64.standard_b64encode(img_file.read()).decode("utf-8") + image_block = { + "type": "image_base64", + "image_base64": { + "media_type": media_type, + "data": image_data + } + } + else: + # If image_url is provided + if "gemini" in self.model_name: + # For Gemini, you can pass image URLs directly + image_block = {"type": "image_url", "image_url": {"url": image_url}} + else: + # For Claude and others, use image URLs + image_block = { + "type": "image_url", + "image_url": { + "url": image_url + } + } + + # Add the image block to the last user message or as a new user message + if self.history.last_role == "user": + current_content = self.history.messages[-1]["content"] + if isinstance(current_content, list): + current_content.append(image_block) + else: + self.history.messages[-1]["content"] = [current_content, image_block] + else: + # Start a new user message with the image + self.history.add_message([image_block], "user") + + def generate_response(self, max_tokens=3585, temperature=0.7, json_format: bool = False) -> str: + """Generate a response from the agent. + + Args: + max_tokens (int, optional): Maximum number of tokens. Defaults to 3585. + temperature (float, optional): Sampling temperature. Defaults to 0.7. + json_format (bool, optional): Whether to enable JSON output format. Defaults to False. + + Returns: + str: The generated response. + """ if not self.history.messages: raise ValueError("No messages in history to generate response from") - messages = [{"role": msg["role"], "content": msg["content"]} for msg in self.history.messages] - + messages = self.history.messages + response_text = completion( model=self.model_name, messages=messages, max_tokens=max_tokens, temperature=temperature, - api_key=self.api_key + api_key=self.api_key, + json_format=json_format # Pass json_format to completion ) - if messages[-1]["role"] == "assistant": - self.history.messages[-1]["content"] = response_text - - elif self.memory_enabled: - self.add_message("assistant", response_text) + if self.model_name.startswith("openai"): + # OpenAI does not support images, so responses are simple strings + if self.history.messages[-1]["role"] == "assistant": + self.history.messages[-1]["content"] = response_text + elif self.memory_enabled: + self.add_message("assistant", response_text) + elif "claude" in self.model_name: + if self.history.messages[-1]["role"] == "assistant": + self.history.messages[-1]["content"] = response_text + elif self.memory_enabled: + self.add_message("assistant", response_text) + elif "gemini" in self.model_name or "grok" in self.model_name: + if self.history.messages[-1]["role"] == "assistant": + if isinstance(self.history.messages[-1]["content"], list): + self.history.messages[-1]["content"].append(response_text) + else: + self.history.messages[-1]["content"] = [self.history.messages[-1]["content"], response_text] + elif self.memory_enabled: + self.add_message("assistant", response_text) + else: + # Handle other models similarly + if self.history.messages[-1]["role"] == "assistant": + self.history.messages[-1]["content"] = response_text + elif self.memory_enabled: + self.add_message("assistant", response_text) return response_text @@ -274,20 +452,20 @@ def save_conversation(self): with open(filename, 'w', encoding='utf-8') as file: json.dump(self.history.messages, file, ensure_ascii=False, indent=4) - def load_conversation(self, filename=None): + def load_conversation(self, filename: Optional[str] = None): if filename is None: filename = f"{self.id}.json" with open(filename, 'r', encoding='utf-8') as file: messages = json.load(file) + # Handle deserialization of images if necessary self.history = ChatHistory(messages) - def load_repo(self, repo_name: Optional[str] = None, commit_hash: Optional[str] = None): - """eg: repo_name: xihajun/llm_dialog_manager""" - if repo_name: + def add_repo(self, repo_url: Optional[str] = None, username: Optional[str] = None, repo_name: Optional[str] = None, commit_hash: Optional[str] = None): + if username and repo_name: if commit_hash: - repo_url = f"https://github.com/{repo_name}/archive/{commit_hash}.zip" + repo_url = f"https://github.com/{username}/{repo_name}/archive/{commit_hash}.zip" else: - repo_url = f"https://github.com/{repo_name}/archive/refs/heads/main.zip" + repo_url = f"https://github.com/{username}/{repo_name}/archive/refs/heads/main.zip" if not repo_url: raise ValueError("Either repo_url or both username and repo_name must be provided") @@ -306,28 +484,31 @@ def load_repo(self, repo_name: Optional[str] = None, commit_hash: Optional[str] raise ValueError(f"Failed to download repository from {repo_url}") if __name__ == "__main__": - - # write a test for detect finding agent - text = "I think the answer is 42" - - # from agent.messageloader import information_detector_messages + # Example Usage + # Create an Agent instance (Gemini model) + agent = Agent("gemini-1.5-flash", "you are an assistant", memory_enabled=True) - # # Now you can print or use information_detector_messages as needed - # information_detector_agent = Agent("gemini-1.5-pro", information_detector_messages) - # information_detector_agent.add_message("user", text) - # response = information_detector_agent.generate_response() - # print(response) - agent = Agent("gemini-1.5-pro-002", "you are an assistant", memory_enabled=True) + # Add an image + agent.add_image(image_path="/Users/junfan/Projects/Personal/oneapi/dialog_manager/example.png") - # Format the prompt to check if the section is the last one in the outline - prompt = f"Say: {text}\n" + # Add a user message + agent.add_message("user", "What's in this image?") - # Add the prompt as a message from the user - agent.add_message("user", prompt) - agent.add_message("assistant", "the answer") + # Generate response with JSON format enabled + try: + response = agent.generate_response(json_format=True) # json_format set to True + print("Response:", response) + except Exception as e: + logger.error(f"Failed to generate response: {e}") - print(agent.generate_response()) - print(agent.history[:]) + # Print the entire conversation history + print("Conversation History:") + print(agent.history) + + # Pop the last message last_message = agent.history.pop() - print(last_message) - print(agent.history[:]) + print("Last Message:", last_message) + + # Generate another response without JSON format + response = agent.generate_response() + print("Response:", response) diff --git a/llm_dialog_manager/chat_history.py b/llm_dialog_manager/chat_history.py index 45353c2..69fc19e 100644 --- a/llm_dialog_manager/chat_history.py +++ b/llm_dialog_manager/chat_history.py @@ -1,53 +1,122 @@ from typing import List, Dict, Optional, Union +from PIL import Image class ChatHistory: - def __init__(self, input_data: Union[str, List[Dict[str, str]]] = "") -> None: - self.messages: List[Dict[str, str]] = [] + def __init__(self, input_data: Union[str, List[Dict[str, Union[str, List[Union[str, Image.Image, Dict]]]]]] = "") -> None: + self.messages: List[Dict[str, Union[str, List[Union[str, Image.Image, Dict]]]]] = [] if isinstance(input_data, str) and input_data: self.add_message(input_data, "system") elif isinstance(input_data, list): self.load_messages(input_data) - self.last_role: str = "system" if not self.messages else self.messages[-1]["role"] + self.last_role: str = "system" if not self.messages else self.get_last_role() - def load_messages(self, messages: List[Dict[str, str]]) -> None: + def load_messages(self, messages: List[Dict[str, Union[str, List[Union[str, Image.Image, Dict]]]]]) -> None: for message in messages: if not ("role" in message and "content" in message): raise ValueError("Each message must have a 'role' and 'content'.") if message["role"] not in ["user", "assistant", "system"]: raise ValueError(f"Invalid role: {message['role']}") self.messages.append(message) - + self.last_role = self.get_last_role() + + def get_last_role(self): + return self.messages[-1]["role"] if self.messages else "system" + def pop(self): if not self.messages: return None - + popped_message = self.messages.pop() - + if self.messages: - self.last_role = self.messages[-1]["role"] + self.last_role = self.get_last_role() else: self.last_role = "system" - + return popped_message["content"] def __len__(self): return len(self.messages) def __str__(self): - return '\n'.join([f"Message {i} ({msg['role']}): {msg['content']}" for i, msg in enumerate(self.messages)]) + formatted_messages = [] + for i, msg in enumerate(self.messages): + role = msg['role'] + content = msg['content'] + if isinstance(content, str): + formatted_content = content + elif isinstance(content, list): + parts = [] + for block in content: + if isinstance(block, str): + parts.append(block) + elif isinstance(block, Image.Image): + parts.append(f"[Image Object: {block.filename}]") + elif isinstance(block, dict): + if block.get("type") == "image_url": + parts.append(f"[Image URL: {block.get('image_url', {}).get('url', '')}]") + elif block.get("type") == "image_base64": + parts.append(f"[Image Base64: {block.get('image_base64', {}).get('data', '')[:30]}...]") + formatted_content = "\n".join(parts) + else: + formatted_content = str(content) + formatted_messages.append(f"Message {i} ({role}): {formatted_content}") + return '\n'.join(formatted_messages) def __getitem__(self, key): if isinstance(key, slice): - # Handle slice object, no change needed here for negative indices as slices are handled by list itself - print('\n'.join([f"({msg['role']}): {msg['content']}" for i, msg in enumerate(self.messages[key])])) - return self.messages[key] + sliced_messages = self.messages[key] + formatted = [] + for msg in sliced_messages: + role = msg['role'] + content = msg['content'] + if isinstance(content, str): + formatted_content = content + elif isinstance(content, list): + parts = [] + for block in content: + if isinstance(block, str): + parts.append(block) + elif isinstance(block, Image.Image): + parts.append(f"[Image Object: {block.filename}]") + elif isinstance(block, dict): + if block.get("type") == "image_url": + parts.append(f"[Image URL: {block.get('image_url', {}).get('url', '')}]") + elif block.get("type") == "image_base64": + parts.append(f"[Image Base64: {block.get('image_base64', {}).get('data', '')[:30]}...]") + formatted_content = "\n".join(parts) + else: + formatted_content = str(content) + formatted.append(f"({role}): {formatted_content}") + print('\n'.join(formatted)) + return sliced_messages elif isinstance(key, int): # Adjust for negative indices if key < 0: - key += len(self.messages) # Convert negative index to positive + key += len(self.messages) if 0 <= key < len(self.messages): + msg = self.messages[key] + role = msg['role'] + content = msg['content'] + if isinstance(content, str): + formatted_content = content + elif isinstance(content, list): + parts = [] + for block in content: + if isinstance(block, str): + parts.append(block) + elif isinstance(block, Image.Image): + parts.append(f"[Image Object: {block.filename}]") + elif isinstance(block, dict): + if block.get("type") == "image_url": + parts.append(f"[Image URL: {block.get('image_url', {}).get('url', '')}]") + elif block.get("type") == "image_base64": + parts.append(f"[Image Base64: {block.get('image_base64', {}).get('data', '')[:30]}...]") + formatted_content = "\n".join(parts) + else: + formatted_content = str(content) snippet = self.get_conversation_snippet(key) - print('\n'.join([f"({v['role']}): {self.color_text(v['content'], 'green') if k == 'current' else v['content']}" for k, v in snippet.items() if v])) + print('\n'.join([f"({v['role']}): {v['content']}" for k, v in snippet.items() if v])) return self.messages[key] else: raise IndexError("Message index out of range.") @@ -55,8 +124,8 @@ def __getitem__(self, key): raise TypeError("Invalid argument type.") def __setitem__(self, index, value): - if not isinstance(value, str): - raise ValueError("Message content must be a string.") + if not isinstance(value, (str, list)): + raise ValueError("Message content must be a string or a list of content blocks.") role = "system" if index % 2 == 0 else "user" self.messages[index] = {"role": role, "content": value} @@ -68,19 +137,27 @@ def __add__(self, message): self.add_message(message, next_role) def __contains__(self, item): - return any(item in message['content'] for message in self.messages) + for message in self.messages: + content = message['content'] + if isinstance(content, str) and item in content: + return True + elif isinstance(content, list): + for block in content: + if isinstance(block, str) and item in block: + return True + return False - def add_message(self, content, role): + def add_message(self, content: Union[str, List[Union[str, Image.Image, Dict]]], role: str): self.messages.append({"role": role, "content": content}) self.last_role = role - def add_user_message(self, content): + def add_user_message(self, content: Union[str, List[Union[str, Image.Image, Dict]]]): if self.last_role in ["system", "assistant"]: self.add_message(content, "user") else: raise ValueError("A user message must follow a system or assistant message.") - def add_assistant_message(self, content): + def add_assistant_message(self, content: Union[str, List[Union[str, Image.Image, Dict]]]): if self.last_role == "user": self.add_message(content, "assistant") else: @@ -110,7 +187,17 @@ def display_conversation_status(self): print(f"Content of the last message: {status['last_message_content']}") def search_for_keyword(self, keyword): - return [msg for msg in self.messages if keyword.lower() in msg["content"].lower()] + results = [] + for msg in self.messages: + content = msg['content'] + if isinstance(content, str) and keyword.lower() in content.lower(): + results.append(msg) + elif isinstance(content, list): + for block in content: + if isinstance(block, str) and keyword.lower() in block.lower(): + results.append(msg) + break + return results def has_user_or_assistant_spoken_since_last_system(self): for msg in reversed(self.messages): @@ -143,4 +230,4 @@ def display_snippet(self, index): @staticmethod def color_text(text, color): colors = {"green": "\033[92m", "red": "\033[91m", "end": "\033[0m"} - return f"{colors[color]}{text}{colors['end']}" + return f"{colors.get(color, '')}{text}{colors.get('end', '')}"