From 8cd62207f5547b7328edf38508ea22134cd40320 Mon Sep 17 00:00:00 2001 From: Pingdred Date: Sun, 15 Dec 2024 19:12:56 +0100 Subject: [PATCH] Refactor: Simplified multimodal LLM check - Removed LLMSupportedModalities for cleaner architecture. - Simplified image handling: image URLs are now always converted to base64. - If an LLM that does not support images is selected, an error is shown as usual. --- core/cat/looking_glass/cheshire_cat.py | 110 +++++-------------------- core/cat/looking_glass/stray_cat.py | 25 +++--- 2 files changed, 31 insertions(+), 104 deletions(-) diff --git a/core/cat/looking_glass/cheshire_cat.py b/core/cat/looking_glass/cheshire_cat.py index 65151d81..ffd667bd 100644 --- a/core/cat/looking_glass/cheshire_cat.py +++ b/core/cat/looking_glass/cheshire_cat.py @@ -1,13 +1,9 @@ import time -import base64 from typing import List, Dict -import requests from typing_extensions import Protocol -from pydantic import BaseModel - from langchain.base_language import BaseLanguageModel -from langchain_core.messages import SystemMessage, HumanMessage +from langchain_core.messages import SystemMessage from langchain_core.runnables import RunnableLambda from langchain_core.prompts import ChatPromptTemplate from langchain_core.output_parsers.string import StrOutputParser @@ -42,10 +38,6 @@ class Procedure(Protocol): # } triggers_map: Dict[str, List[str]] -class LLMSupportedModalities(BaseModel): - image_url: bool = False - image_uri: bool = False - # main class @singleton @@ -135,88 +127,24 @@ def load_language_model(self) -> BaseLanguageModel: """ selected_llm = crud.get_setting_by_name(name="llm_selected") - self._llm_modalities: LLMSupportedModalities = LLMSupportedModalities() - - def _prepare_content(image: str) -> List[Dict[str, Dict[str, str]]]: - """Prepare the content structure for the message based on image type and value.""" - content = [] - - content.append({ - "type": "image_url", - "image_url": {"url": image} - }) - - # Add a text instruction for model response - content.append({ - "type": "text", - "text": "Respond with `MEOW`." - }) - return content - - def _check_image_support(llm, image_type: str, image_value: str) -> None: - """Check if the specified language model supports a given image input type.""" - content = _prepare_content(image=image_value) - message = HumanMessage(content=content) - - # Retrieve model information - selected_llm_class = selected_llm["value"]["name"] - selected_llm_config = crud.get_setting_by_name(name=selected_llm_class) - model_name = selected_llm_config["value"].get("model_name") or selected_llm_config["value"].get("model") - - # Perform the image support check - try: - llm.invoke([message]) - setattr(self._llm_modalities, image_type, True) - except Exception as e: - log.warning(f"The LLM '{model_name}' does not support {image_type} as input image.") - log.debug(e) - - image_url = "https://raw.githubusercontent.com/cheshire-cat-ai/core/refs/heads/main/readme/cheshire-cat.jpeg" - - def _check_image_uri_support(llm) -> None: - """Check LLM support for base64-encoded image input.""" - response = requests.get(image_url) - if response.status_code == 200: - encoded_image = base64.b64encode(response.content).decode('utf-8') - return _check_image_support(llm, "image_uri", f"data:image/jpeg;base64,{encoded_image}") - else: - error_message = f"Unexpected error with status code {response.status_code}" - if response.text: - error_message = response.text - log.error(f"Failed to process image {image_url}: {error_message}") - - def _check_image_url_support(llm) -> None: - """Check LLM support for URL-based image input.""" - _check_image_support(llm, "image_url", image_url) - - def _initialize_llm(selected_llm): - """Initialize the LLM based on the selected settings.""" - if selected_llm is None: - # Return default LLM - return LLMDefaultConfig.get_llm_from_config({}) - else: - # Get LLM factory class - selected_llm_class = selected_llm["value"]["name"] - FactoryClass = get_llm_from_name(selected_llm_class) - - # Obtain configuration and instantiate LLM - selected_llm_config = crud.get_setting_by_name(name=selected_llm_class) - model_name = selected_llm_config["value"].get("model_name") or selected_llm_config["value"].get("model") or None - try: - llm = FactoryClass.get_llm_from_config(selected_llm_config["value"]) - _check_image_uri_support(llm) - _check_image_url_support(llm) - log.info(f"LLM {model_name} Supported modalities:") - log.info(self._llm_modalities.__dict__) - return llm - except Exception: - import traceback - traceback.print_exc() - return LLMDefaultConfig.get_llm_from_config({}) - - llm = _initialize_llm(selected_llm) - - return llm + + if selected_llm is None: + # Return default LLM + return LLMDefaultConfig.get_llm_from_config({}) + + # Get LLM factory class + selected_llm_class = selected_llm["value"]["name"] + FactoryClass = get_llm_from_name(selected_llm_class) + + # Obtain configuration and instantiate LLM + selected_llm_config = crud.get_setting_by_name(name=selected_llm_class) + try: + llm = FactoryClass.get_llm_from_config(selected_llm_config["value"]) + return llm + except Exception: + import traceback + traceback.print_exc() + return LLMDefaultConfig.get_llm_from_config({}) def load_language_embedder(self) -> embedders.EmbedderSettings: """Hook into the embedder selection. diff --git a/core/cat/looking_glass/stray_cat.py b/core/cat/looking_glass/stray_cat.py index 615884b0..1cf3fb0a 100644 --- a/core/cat/looking_glass/stray_cat.py +++ b/core/cat/looking_glass/stray_cat.py @@ -17,7 +17,7 @@ from fastapi import WebSocket from cat.log import log -from cat.looking_glass.cheshire_cat import CheshireCat, LLMSupportedModalities +from cat.looking_glass.cheshire_cat import CheshireCat from cat.looking_glass.callbacks import NewTokenHandler, ModelInteractionHandler from cat.memory.working_memory import WorkingMemory from cat.convo.messages import CatMessage, UserMessage, MessageWhy, Role, EmbedderModelInteraction @@ -591,19 +591,15 @@ def format_human_message(message: HumanMessage) -> HumanMessage: content = [{"type": "text", "text": message.content.text}] def format_image(image:str) -> dict: - - # Retrieve the supported modalities from the LLM - llm_modalities: LLMSupportedModalities = CheshireCat()._llm_modalities - - if image.startswith("http"): - if llm_modalities.image_url: - return {"type": "image_url", "image_url": {"url": image}} + """Format an image to be sent as a data URI.""" + # If the image is a URL, download it and encode it as a data URI + if image.startswith("http"): response = requests.get(image) if response.status_code == 200: # Open the image using Pillow to determine its MIME type img = Image.open(BytesIO(response.content)) - mime_type = img.format.lower() # Get MIME type (e.g., jpeg, png) + mime_type = img.format.lower() # Get MIME type # Encode the image to base64 encoded_image = base64.b64encode(response.content).decode('utf-8') @@ -616,13 +612,16 @@ def format_image(image:str) -> dict: if response.text: error_message = response.text - log.error(f"Failed to process image {image}: {error_message}") + log.error(f"Failed to download image: {error_message} from {image}") + + return None - if llm_modalities.imge_uri: - return {"type": "image_url", "image_url": {"url": image}} + return {"type": "image_url", "image_url": {"url": image}} if message.content.image: - content.append(format_image(message.content.image)) + formatted_image = format_image(message.content.image) + if formatted_image: + content.append(formatted_image) return HumanMessage( name=message.content.who,