Skip to content

Commit

Permalink
Refactor: Simplified multimodal LLM check
Browse files Browse the repository at this point in the history
- Removed LLMSupportedModalities for cleaner architecture.
- Simplified image handling: image URLs are now always converted to base64.
- If an LLM that does not support images is selected, an error is shown as usual.
  • Loading branch information
Pingdred committed Dec 15, 2024
1 parent 209dc34 commit 8cd6220
Show file tree
Hide file tree
Showing 2 changed files with 31 additions and 104 deletions.
110 changes: 19 additions & 91 deletions core/cat/looking_glass/cheshire_cat.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,9 @@
import time
import base64
from typing import List, Dict
import requests
from typing_extensions import Protocol

from pydantic import BaseModel

from langchain.base_language import BaseLanguageModel
from langchain_core.messages import SystemMessage, HumanMessage
from langchain_core.messages import SystemMessage
from langchain_core.runnables import RunnableLambda
from langchain_core.prompts import ChatPromptTemplate
from langchain_core.output_parsers.string import StrOutputParser
Expand Down Expand Up @@ -42,10 +38,6 @@ class Procedure(Protocol):
# }
triggers_map: Dict[str, List[str]]

class LLMSupportedModalities(BaseModel):
image_url: bool = False
image_uri: bool = False


# main class
@singleton
Expand Down Expand Up @@ -135,88 +127,24 @@ def load_language_model(self) -> BaseLanguageModel:
"""

selected_llm = crud.get_setting_by_name(name="llm_selected")
self._llm_modalities: LLMSupportedModalities = LLMSupportedModalities()

def _prepare_content(image: str) -> List[Dict[str, Dict[str, str]]]:
"""Prepare the content structure for the message based on image type and value."""
content = []

content.append({
"type": "image_url",
"image_url": {"url": image}
})

# Add a text instruction for model response
content.append({
"type": "text",
"text": "Respond with `MEOW`."
})
return content

def _check_image_support(llm, image_type: str, image_value: str) -> None:
"""Check if the specified language model supports a given image input type."""
content = _prepare_content(image=image_value)
message = HumanMessage(content=content)

# Retrieve model information
selected_llm_class = selected_llm["value"]["name"]
selected_llm_config = crud.get_setting_by_name(name=selected_llm_class)
model_name = selected_llm_config["value"].get("model_name") or selected_llm_config["value"].get("model")

# Perform the image support check
try:
llm.invoke([message])
setattr(self._llm_modalities, image_type, True)
except Exception as e:
log.warning(f"The LLM '{model_name}' does not support {image_type} as input image.")
log.debug(e)

image_url = "https://raw.githubusercontent.com/cheshire-cat-ai/core/refs/heads/main/readme/cheshire-cat.jpeg"

def _check_image_uri_support(llm) -> None:
"""Check LLM support for base64-encoded image input."""
response = requests.get(image_url)
if response.status_code == 200:
encoded_image = base64.b64encode(response.content).decode('utf-8')
return _check_image_support(llm, "image_uri", f"data:image/jpeg;base64,{encoded_image}")
else:
error_message = f"Unexpected error with status code {response.status_code}"
if response.text:
error_message = response.text
log.error(f"Failed to process image {image_url}: {error_message}")

def _check_image_url_support(llm) -> None:
"""Check LLM support for URL-based image input."""
_check_image_support(llm, "image_url", image_url)

def _initialize_llm(selected_llm):
"""Initialize the LLM based on the selected settings."""
if selected_llm is None:
# Return default LLM
return LLMDefaultConfig.get_llm_from_config({})
else:
# Get LLM factory class
selected_llm_class = selected_llm["value"]["name"]
FactoryClass = get_llm_from_name(selected_llm_class)

# Obtain configuration and instantiate LLM
selected_llm_config = crud.get_setting_by_name(name=selected_llm_class)
model_name = selected_llm_config["value"].get("model_name") or selected_llm_config["value"].get("model") or None
try:
llm = FactoryClass.get_llm_from_config(selected_llm_config["value"])
_check_image_uri_support(llm)
_check_image_url_support(llm)
log.info(f"LLM {model_name} Supported modalities:")
log.info(self._llm_modalities.__dict__)
return llm
except Exception:
import traceback
traceback.print_exc()
return LLMDefaultConfig.get_llm_from_config({})

llm = _initialize_llm(selected_llm)

return llm

if selected_llm is None:
# Return default LLM
return LLMDefaultConfig.get_llm_from_config({})

# Get LLM factory class
selected_llm_class = selected_llm["value"]["name"]
FactoryClass = get_llm_from_name(selected_llm_class)

# Obtain configuration and instantiate LLM
selected_llm_config = crud.get_setting_by_name(name=selected_llm_class)
try:
llm = FactoryClass.get_llm_from_config(selected_llm_config["value"])
return llm
except Exception:
import traceback
traceback.print_exc()
return LLMDefaultConfig.get_llm_from_config({})

def load_language_embedder(self) -> embedders.EmbedderSettings:
"""Hook into the embedder selection.
Expand Down
25 changes: 12 additions & 13 deletions core/cat/looking_glass/stray_cat.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
from fastapi import WebSocket

from cat.log import log
from cat.looking_glass.cheshire_cat import CheshireCat, LLMSupportedModalities
from cat.looking_glass.cheshire_cat import CheshireCat
from cat.looking_glass.callbacks import NewTokenHandler, ModelInteractionHandler
from cat.memory.working_memory import WorkingMemory
from cat.convo.messages import CatMessage, UserMessage, MessageWhy, Role, EmbedderModelInteraction
Expand Down Expand Up @@ -591,19 +591,15 @@ def format_human_message(message: HumanMessage) -> HumanMessage:
content = [{"type": "text", "text": message.content.text}]

def format_image(image:str) -> dict:

# Retrieve the supported modalities from the LLM
llm_modalities: LLMSupportedModalities = CheshireCat()._llm_modalities

if image.startswith("http"):
if llm_modalities.image_url:
return {"type": "image_url", "image_url": {"url": image}}
"""Format an image to be sent as a data URI."""

# If the image is a URL, download it and encode it as a data URI
if image.startswith("http"):
response = requests.get(image)
if response.status_code == 200:
# Open the image using Pillow to determine its MIME type
img = Image.open(BytesIO(response.content))
mime_type = img.format.lower() # Get MIME type (e.g., jpeg, png)
mime_type = img.format.lower() # Get MIME type

# Encode the image to base64
encoded_image = base64.b64encode(response.content).decode('utf-8')
Expand All @@ -616,13 +612,16 @@ def format_image(image:str) -> dict:
if response.text:
error_message = response.text

log.error(f"Failed to process image {image}: {error_message}")
log.error(f"Failed to download image: {error_message} from {image}")

return None

if llm_modalities.imge_uri:
return {"type": "image_url", "image_url": {"url": image}}
return {"type": "image_url", "image_url": {"url": image}}

if message.content.image:
content.append(format_image(message.content.image))
formatted_image = format_image(message.content.image)
if formatted_image:
content.append(formatted_image)

return HumanMessage(
name=message.content.who,
Expand Down

0 comments on commit 8cd6220

Please sign in to comment.