Skip to content

Commit

Permalink
Refactor multimedia attributes
Browse files Browse the repository at this point in the history
Modified `CatMessage` and `UserMessage` classes to use singular  image and audio
  • Loading branch information
Pingdred committed Dec 15, 2024
1 parent 27aba7e commit 209dc34
Show file tree
Hide file tree
Showing 4 changed files with 53 additions and 85 deletions.
75 changes: 27 additions & 48 deletions core/cat/convo/messages.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,10 +132,10 @@ class CatMessage(BaseModelDict):
Deprecated. The text content of the message. Use `text` instead.
text : Optional[str], default=None
The text content of the message.
images : Optional[Union[List[str], str]], default=None
List of image file URLs or base64 data URIs that represent images associated with the message. A single string can also be provided and will be converted to a list.
audio : Optional[Union[List[str], str]], default=None
List of audio file URLs or base64 data URIs that represent audio associated with the message. A single string can also be provided and will be converted to a list.
image : Optional[str], default=None
Image file URLs or base64 data URIs that represent image associated with the message.
audio : Optional[str], default=None
Audio file URLs or base64 data URIs that represent audio associated with the message.
why : Optional[MessageWhy], default=None
Additional contextual information related to the message.
who : str, default="AI"
Expand All @@ -151,10 +151,10 @@ class CatMessage(BaseModelDict):
The name of the message author, by default AI.
text : Optional[str]
The text content of the message.
images : Optional[List[str]]
List of image URLs or paths associated with the message, if any.
audio : Optional[List[str]]
List of audio file URLs or paths associated with the message, if any.
image : Optional[str]
Image file URLs or base64 data URIs that represent image associated with the message.
audio : Optional[str]
Audio file URLs or base64 data URIs that represent audio associated with the message.
why : Optional[MessageWhy]
Additional contextual information related to the message.
Expand All @@ -167,32 +167,26 @@ class CatMessage(BaseModelDict):
user_id: str
who: str = "AI"
text: Optional[str] = None
images: Optional[List[str]] = None
audio: Optional[List[str]] = None
image: Optional[str] = None
audio: Optional[str] = None
why: Optional[MessageWhy] = None

def __init__(
self,
user_id: str,
content: Optional[str] = None,
text: Optional[str] = None,
images: Optional[Union[List[str], str]] = None,
audio: Optional[Union[List[str], str]] = None,
image: Optional[str] = None,
audio: Optional[str] = None,
why: Optional[MessageWhy] = None,
who: str = "AI",
**kwargs,
):
if isinstance(images, str):
images = [images]

if isinstance(audio, str):
audio = [audio]

if content:
deprecation_warning("The `content` parameter is deprecated. Use `text` instead.")
text = content # Map 'content' to 'text'

super().__init__(user_id=user_id, who=who, content=content, text=text, images=images, audio=audio, why=why, **kwargs)
super().__init__(user_id=user_id, who=who, text=text, image=image, audio=audio, why=why, **kwargs)

@computed_field
@property
Expand All @@ -218,18 +212,21 @@ def content(self, value):

class UserMessage(BaseModelDict):
"""
Represents a message from a user, containing text and optional multimedia content such as images and audio.
Represents a message from a user, containing text and optional multimedia content such as image and audio.
This class is used to encapsulate the details of a message sent by a user, including the user's identifier,
the text content of the message, and any associated multimedia content such as image or audio files.
Parameters
----------
user_id : str
Unique identifier for the user sending the message.
text : Optional[str], default=None
The text content of the message. Can be `None` if no text is provided.
images : Optional[Union[List[str], str]], default=None
List of image file URLs or base64 data URIs that represent images associated with the message. A single string can also be provided and will be converted to a list.
audio : Optional[Union[List[str], str]], default=None
List of audio file URLs or base64 data URIs that represent audio associated with the message. A single string can also be provided and will be converted to a list.
image : Optional[str], default=None
Image file URLs or base64 data URIs that represent image associated with the message.
audio : Optional[str], default=None
Audio file URLs or base64 data URIs that represent audio associated with the message.
who : str, default="Human"
The name of the message author, by default “Human”.
Expand All @@ -241,32 +238,14 @@ class UserMessage(BaseModelDict):
The name of the message author, by default “Human”.
text : Optional[str]
The text content of the message.
images : Optional[List[str]]
List of images associated with the message, if any.
audio : Optional[List[str]]
List of audio files associated with the message, if any.
image : Optional[str]
Image file URLs or base64 data URIs that represent image associated with the message.
audio : Optional[str]
Audio file URLs or base64 data URIs that represent audio associated with the message.
"""

user_id: str
who: str = "Human"
text: Optional[str] = None
images: Optional[List[str]] = None
audio: Optional[List[str]] = None

def __init__(
self,
user_id: str,
text: Optional[str] = None,
images: Optional[Union[List[str], str]] = None,
audio: Optional[Union[List[str], str]] = None,
who: str = "Human",
**kwargs,
):
if isinstance(images, str):
images = [images]

if isinstance(audio, str):
audio = [audio]

super().__init__(user_id=user_id, who=who, text=text, images=images, audio=audio, **kwargs)

image: Optional[str] = None
audio: Optional[str] = None
2 changes: 1 addition & 1 deletion core/cat/looking_glass/cheshire_cat.py
Original file line number Diff line number Diff line change
Expand Up @@ -168,7 +168,7 @@ def _check_image_support(llm, image_type: str, image_value: str) -> None:
llm.invoke([message])
setattr(self._llm_modalities, image_type, True)
except Exception as e:
log.warning(f"The LLM '{model_name}' does not support {image_type} as input images.")
log.warning(f"The LLM '{model_name}' does not support {image_type} as input image.")
log.debug(e)

image_url = "https://raw.githubusercontent.com/cheshire-cat-ai/core/refs/heads/main/readme/cheshire-cat.jpeg"
Expand Down
58 changes: 25 additions & 33 deletions core/cat/looking_glass/stray_cat.py
Original file line number Diff line number Diff line change
Expand Up @@ -587,39 +587,18 @@ def stringify_chat_history(self, latest_n: int = 5) -> str:
def langchainfy_chat_history(self, latest_n: int = 5) -> List[BaseMessage]:

def format_human_message(message: HumanMessage) -> HumanMessage:
"""Format a human message, including any text and images."""
"""Format a human message, including any text and image."""
content = [{"type": "text", "text": message.content.text}]

if message.content.images:
content.extend(format_images(message.content.images))

return HumanMessage(
name=message.content.who,
content=content
)

def format_ai_message(message) -> AIMessage:
"""Format an AI message with text content only."""
return AIMessage(
name=message.content.who,
content=message.content.text
)

def format_images(images: List[str]) -> List[dict]:
"""Format a list of images into the required structure for langchain messages,
downloading the image if the model only supports data URIs but not image URLs."""
def format_image(image:str) -> dict:

# Retrieve the supported modalities from the LLM
llm_modalities: LLMSupportedModalities = CheshireCat()._llm_modalities

# Retrieve the supported modalities from the LLM
llm_modalities: LLMSupportedModalities = CheshireCat()._llm_modalities

formatted_images = []

for image in images:
if image.startswith("http"):
if llm_modalities.image_url:
formatted_images.append({"type": "image_url", "image_url": {"url": image}})
continue

return {"type": "image_url", "image_url": {"url": image}}

response = requests.get(image)
if response.status_code == 200:
# Open the image using Pillow to determine its MIME type
Expand All @@ -631,18 +610,31 @@ def format_images(images: List[str]) -> List[dict]:
image_uri = f"data:image/{mime_type};base64,{encoded_image}"

# Add the image as a data URI with the correct MIME type
formatted_images.append({"type": "image_url", "image_url": {"url": image_uri}})
return {"type": "image_url", "image_url": {"url": image_uri}}
else:
error_message = f"Unexpected error with status code {response.status_code}"
if response.text:
error_message = response.text

log.error(f"Failed to process image {image}: {error_message}")
else:
if llm_modalities.imge_uri:
formatted_images.append({"type": "image_url", "image_url": {"url": image}})

if llm_modalities.imge_uri:
return {"type": "image_url", "image_url": {"url": image}}

if message.content.image:
content.append(format_image(message.content.image))

return formatted_images
return HumanMessage(
name=message.content.who,
content=content
)

def format_ai_message(message) -> AIMessage:
"""Format an AI message with text content only."""
return AIMessage(
name=message.content.who,
content=message.content.text
)

chat_history = self.working_memory.history[-latest_n:]
recent_history = chat_history[-latest_n:]
Expand Down
3 changes: 0 additions & 3 deletions core/cat/memory/working_memory.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,9 +40,6 @@ class HistoryEntry(BaseModelDict):
when: float
content: Union[UserMessage, CatMessage]

def __init__(self, role: Role, when: float, content: Union[UserMessage, CatMessage]):
super().__init__(role=role, when=when, content=content)

@computed_field
@property
def message(self) -> str:
Expand Down

0 comments on commit 209dc34

Please sign in to comment.