diff --git a/core/cat/convo/messages.py b/core/cat/convo/messages.py index c81a30450..b61f0c045 100644 --- a/core/cat/convo/messages.py +++ b/core/cat/convo/messages.py @@ -132,10 +132,10 @@ class CatMessage(BaseModelDict): Deprecated. The text content of the message. Use `text` instead. text : Optional[str], default=None The text content of the message. - images : Optional[Union[List[str], str]], default=None - List of image file URLs or base64 data URIs that represent images associated with the message. A single string can also be provided and will be converted to a list. - audio : Optional[Union[List[str], str]], default=None - List of audio file URLs or base64 data URIs that represent audio associated with the message. A single string can also be provided and will be converted to a list. + image : Optional[str], default=None + Image file URLs or base64 data URIs that represent image associated with the message. + audio : Optional[str], default=None + Audio file URLs or base64 data URIs that represent audio associated with the message. why : Optional[MessageWhy], default=None Additional contextual information related to the message. who : str, default="AI" @@ -151,10 +151,10 @@ class CatMessage(BaseModelDict): The name of the message author, by default AI. text : Optional[str] The text content of the message. - images : Optional[List[str]] - List of image URLs or paths associated with the message, if any. - audio : Optional[List[str]] - List of audio file URLs or paths associated with the message, if any. + image : Optional[str] + Image file URLs or base64 data URIs that represent image associated with the message. + audio : Optional[str] + Audio file URLs or base64 data URIs that represent audio associated with the message. why : Optional[MessageWhy] Additional contextual information related to the message. @@ -167,8 +167,8 @@ class CatMessage(BaseModelDict): user_id: str who: str = "AI" text: Optional[str] = None - images: Optional[List[str]] = None - audio: Optional[List[str]] = None + image: Optional[str] = None + audio: Optional[str] = None why: Optional[MessageWhy] = None def __init__( @@ -176,23 +176,17 @@ def __init__( user_id: str, content: Optional[str] = None, text: Optional[str] = None, - images: Optional[Union[List[str], str]] = None, - audio: Optional[Union[List[str], str]] = None, + image: Optional[str] = None, + audio: Optional[str] = None, why: Optional[MessageWhy] = None, who: str = "AI", **kwargs, ): - if isinstance(images, str): - images = [images] - - if isinstance(audio, str): - audio = [audio] - if content: deprecation_warning("The `content` parameter is deprecated. Use `text` instead.") text = content # Map 'content' to 'text' - super().__init__(user_id=user_id, who=who, content=content, text=text, images=images, audio=audio, why=why, **kwargs) + super().__init__(user_id=user_id, who=who, text=text, image=image, audio=audio, why=why, **kwargs) @computed_field @property @@ -218,7 +212,10 @@ def content(self, value): class UserMessage(BaseModelDict): """ - Represents a message from a user, containing text and optional multimedia content such as images and audio. + Represents a message from a user, containing text and optional multimedia content such as image and audio. + + This class is used to encapsulate the details of a message sent by a user, including the user's identifier, + the text content of the message, and any associated multimedia content such as image or audio files. Parameters ---------- @@ -226,10 +223,10 @@ class UserMessage(BaseModelDict): Unique identifier for the user sending the message. text : Optional[str], default=None The text content of the message. Can be `None` if no text is provided. - images : Optional[Union[List[str], str]], default=None - List of image file URLs or base64 data URIs that represent images associated with the message. A single string can also be provided and will be converted to a list. - audio : Optional[Union[List[str], str]], default=None - List of audio file URLs or base64 data URIs that represent audio associated with the message. A single string can also be provided and will be converted to a list. + image : Optional[str], default=None + Image file URLs or base64 data URIs that represent image associated with the message. + audio : Optional[str], default=None + Audio file URLs or base64 data URIs that represent audio associated with the message. who : str, default="Human" The name of the message author, by default “Human”. @@ -241,32 +238,14 @@ class UserMessage(BaseModelDict): The name of the message author, by default “Human”. text : Optional[str] The text content of the message. - images : Optional[List[str]] - List of images associated with the message, if any. - audio : Optional[List[str]] - List of audio files associated with the message, if any. + image : Optional[str] + Image file URLs or base64 data URIs that represent image associated with the message. + audio : Optional[str] + Audio file URLs or base64 data URIs that represent audio associated with the message. """ user_id: str who: str = "Human" text: Optional[str] = None - images: Optional[List[str]] = None - audio: Optional[List[str]] = None - - def __init__( - self, - user_id: str, - text: Optional[str] = None, - images: Optional[Union[List[str], str]] = None, - audio: Optional[Union[List[str], str]] = None, - who: str = "Human", - **kwargs, - ): - if isinstance(images, str): - images = [images] - - if isinstance(audio, str): - audio = [audio] - - super().__init__(user_id=user_id, who=who, text=text, images=images, audio=audio, **kwargs) - + image: Optional[str] = None + audio: Optional[str] = None \ No newline at end of file diff --git a/core/cat/looking_glass/cheshire_cat.py b/core/cat/looking_glass/cheshire_cat.py index bb088e942..65151d81c 100644 --- a/core/cat/looking_glass/cheshire_cat.py +++ b/core/cat/looking_glass/cheshire_cat.py @@ -168,7 +168,7 @@ def _check_image_support(llm, image_type: str, image_value: str) -> None: llm.invoke([message]) setattr(self._llm_modalities, image_type, True) except Exception as e: - log.warning(f"The LLM '{model_name}' does not support {image_type} as input images.") + log.warning(f"The LLM '{model_name}' does not support {image_type} as input image.") log.debug(e) image_url = "https://raw.githubusercontent.com/cheshire-cat-ai/core/refs/heads/main/readme/cheshire-cat.jpeg" diff --git a/core/cat/looking_glass/stray_cat.py b/core/cat/looking_glass/stray_cat.py index d60a47dba..615884b0c 100644 --- a/core/cat/looking_glass/stray_cat.py +++ b/core/cat/looking_glass/stray_cat.py @@ -587,39 +587,18 @@ def stringify_chat_history(self, latest_n: int = 5) -> str: def langchainfy_chat_history(self, latest_n: int = 5) -> List[BaseMessage]: def format_human_message(message: HumanMessage) -> HumanMessage: - """Format a human message, including any text and images.""" + """Format a human message, including any text and image.""" content = [{"type": "text", "text": message.content.text}] - - if message.content.images: - content.extend(format_images(message.content.images)) - - return HumanMessage( - name=message.content.who, - content=content - ) - - def format_ai_message(message) -> AIMessage: - """Format an AI message with text content only.""" - return AIMessage( - name=message.content.who, - content=message.content.text - ) - def format_images(images: List[str]) -> List[dict]: - """Format a list of images into the required structure for langchain messages, - downloading the image if the model only supports data URIs but not image URLs.""" + def format_image(image:str) -> dict: + + # Retrieve the supported modalities from the LLM + llm_modalities: LLMSupportedModalities = CheshireCat()._llm_modalities - # Retrieve the supported modalities from the LLM - llm_modalities: LLMSupportedModalities = CheshireCat()._llm_modalities - - formatted_images = [] - - for image in images: if image.startswith("http"): if llm_modalities.image_url: - formatted_images.append({"type": "image_url", "image_url": {"url": image}}) - continue - + return {"type": "image_url", "image_url": {"url": image}} + response = requests.get(image) if response.status_code == 200: # Open the image using Pillow to determine its MIME type @@ -631,18 +610,31 @@ def format_images(images: List[str]) -> List[dict]: image_uri = f"data:image/{mime_type};base64,{encoded_image}" # Add the image as a data URI with the correct MIME type - formatted_images.append({"type": "image_url", "image_url": {"url": image_uri}}) + return {"type": "image_url", "image_url": {"url": image_uri}} else: error_message = f"Unexpected error with status code {response.status_code}" if response.text: error_message = response.text log.error(f"Failed to process image {image}: {error_message}") - else: - if llm_modalities.imge_uri: - formatted_images.append({"type": "image_url", "image_url": {"url": image}}) + + if llm_modalities.imge_uri: + return {"type": "image_url", "image_url": {"url": image}} + + if message.content.image: + content.append(format_image(message.content.image)) - return formatted_images + return HumanMessage( + name=message.content.who, + content=content + ) + + def format_ai_message(message) -> AIMessage: + """Format an AI message with text content only.""" + return AIMessage( + name=message.content.who, + content=message.content.text + ) chat_history = self.working_memory.history[-latest_n:] recent_history = chat_history[-latest_n:] diff --git a/core/cat/memory/working_memory.py b/core/cat/memory/working_memory.py index 24f0b7cdb..5b3b11d6e 100644 --- a/core/cat/memory/working_memory.py +++ b/core/cat/memory/working_memory.py @@ -40,9 +40,6 @@ class HistoryEntry(BaseModelDict): when: float content: Union[UserMessage, CatMessage] - def __init__(self, role: Role, when: float, content: Union[UserMessage, CatMessage]): - super().__init__(role=role, when=when, content=content) - @computed_field @property def message(self) -> str: