diff --git a/.github/workflows/tag.yml b/.github/workflows/tag.yml index 78c129000..ce4c2c1c9 100644 --- a/.github/workflows/tag.yml +++ b/.github/workflows/tag.yml @@ -36,6 +36,17 @@ jobs: steps: - name: Checkout repository uses: actions/checkout@v3 + + # Uses the `docker/setup-qemu-action@v3` + - name: Set up QEMU + uses: docker/setup-qemu-action@v3 + + # Uses the `docker/setup-buildx-action@v3` + - name: Set up docker Buildx + uses: docker/setup-buildx-action@v3 + with: + platforms: linux/amd64,linux/arm64 + # Uses the `docker/login-action` # action to log in to the Container registry using the account and password that will publish the packages. # Once published, the packages are scoped to the account defined here. diff --git a/core/Dockerfile b/core/Dockerfile index 5b3dfe05d..1f687ef02 100644 --- a/core/Dockerfile +++ b/core/Dockerfile @@ -4,7 +4,7 @@ FROM python:3.10.11-slim-bullseye ### PREPARE BUILD WITH NECESSARY FILES AND FOLDERS ### RUN mkdir -p /app && mkdir -p /admin COPY ./pyproject.toml /app/pyproject.toml -COPY ./cat/plugins /app/cat/plugins +COPY ./cat /app/cat COPY ./install_plugin_dependencies.py /app/install_plugin_dependencies.py ### SYSTEM SETUP ### diff --git a/core/cat/factory/custom_embedder.py b/core/cat/factory/custom_embedder.py index 40b6b29a1..3e595f6ed 100644 --- a/core/cat/factory/custom_embedder.py +++ b/core/cat/factory/custom_embedder.py @@ -64,4 +64,25 @@ def embed_query(self, text: str) -> List[float]: ret = httpx.post(self.url, data=payload, timeout=None) ret.raise_for_status() return ret.json()['data'][0]['embedding'] + +class CustomFastembedEmbeddings(Embeddings): + """Use Fastembed for embedding. + """ + def __init__(self, url, model,max_length) -> None: + self.url = url + output = httpx.post(f"{url}/embeddings", json={"model": model, "max_length": max_length}, follow_redirects=True, timeout=None) + output.raise_for_status() + + + def embed_documents(self, texts: List[str]): + payload = json.dumps({"document": texts}) + ret = httpx.post(f"{self.url}/embeddings/document", data=payload, timeout=None) + ret.raise_for_status() + return ret.json() + + def embed_query(self, text: str) -> List[float]: + payload = json.dumps({"prompt": text}) + ret = httpx.post(f"{self.url}/embeddings/prompt", data=payload, timeout=None) + ret.raise_for_status() + return ret.json() \ No newline at end of file diff --git a/core/cat/factory/custom_llm.py b/core/cat/factory/custom_llm.py index 2a2d37b51..b909968db 100644 --- a/core/cat/factory/custom_llm.py +++ b/core/cat/factory/custom_llm.py @@ -3,6 +3,7 @@ import requests from langchain.llms.base import LLM from langchain.llms.openai import OpenAI +from langchain.llms.ollama import Ollama class LLMDefault(LLM): @@ -86,4 +87,3 @@ def __init__(self, **kwargs): self.url = kwargs['url'] self.openai_api_base = os.path.join(self.url, "v1") - \ No newline at end of file diff --git a/core/cat/factory/embedder.py b/core/cat/factory/embedder.py index cfd340958..cebaa2ec6 100644 --- a/core/cat/factory/embedder.py +++ b/core/cat/factory/embedder.py @@ -2,7 +2,7 @@ import langchain from pydantic import BaseModel, ConfigDict -from cat.factory.custom_embedder import DumbEmbedder, CustomOpenAIEmbeddings +from cat.factory.custom_embedder import CustomFastembedEmbeddings, DumbEmbedder, CustomOpenAIEmbeddings # Base class to manage LLM configuration. @@ -108,7 +108,7 @@ class EmbedderCohereConfig(EmbedderSettings): class EmbedderHuggingFaceHubConfig(EmbedderSettings): - repo_id: str = "sentence-transformers/all-MiniLM-L12-v2" + repo_id: str = "sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2" huggingfacehub_api_token: str _pyclass: Type = langchain.embeddings.HuggingFaceHubEmbeddings @@ -119,6 +119,20 @@ class EmbedderHuggingFaceHubConfig(EmbedderSettings): } ) +class EmbedderFastEmbedConfig(EmbedderSettings): + url: str + model: str = "intfloat/multilingual-e5-large" + max_length: int = 512 + + _pyclass: Type = CustomFastembedEmbeddings + + model_config = ConfigDict( + json_schema_extra = { + "humanReadableName": "Fast Embedder", + "description": "Configuration for Fast embeddings", + } + ) + SUPPORTED_EMDEDDING_MODELS = [ EmbedderDumbConfig, @@ -128,6 +142,7 @@ class EmbedderHuggingFaceHubConfig(EmbedderSettings): EmbedderAzureOpenAIConfig, EmbedderCohereConfig, EmbedderHuggingFaceHubConfig, + EmbedderFastEmbedConfig ] diff --git a/core/cat/factory/llm.py b/core/cat/factory/llm.py index ffbf66940..4794d5f4e 100644 --- a/core/cat/factory/llm.py +++ b/core/cat/factory/llm.py @@ -1,6 +1,7 @@ import langchain from langchain.chat_models import ChatOpenAI, AzureChatOpenAI from langchain.llms import OpenAI, AzureOpenAI +from langchain.llms.ollama import Ollama from typing import Dict, List, Type import json @@ -272,6 +273,24 @@ class LLMGooglePalmConfig(LLMSettings): } ) +class LLMOllamaConfig(LLMSettings): + base_url: str + model: str = "llama2" + num_ctx: int = 2048 + repeat_last_n: int = 64 + repeat_penalty: float = 1.1 + temperature: float = 0.8 + + _pyclass: Type = Ollama + + model_config = ConfigDict( + json_schema_extra = { + "humanReadableName": "Ollama", + "description": "Configuration for Ollama", + "link": "https://ollama.ai/library" + } + ) + SUPPORTED_LANGUAGE_MODELS = [ LLMDefaultConfig, @@ -286,7 +305,8 @@ class LLMGooglePalmConfig(LLMSettings): LLMAzureOpenAIConfig, LLMAzureChatOpenAIConfig, LLMAnthropicConfig, - LLMGooglePalmConfig + LLMGooglePalmConfig, + LLMOllamaConfig ] # LLM_SCHEMAS contains metadata to let any client know diff --git a/core/cat/api_auth.py b/core/cat/headers.py similarity index 89% rename from core/cat/api_auth.py rename to core/cat/headers.py index 1dac95a23..482e33ff0 100644 --- a/core/cat/api_auth.py +++ b/core/cat/headers.py @@ -51,3 +51,11 @@ def check_api_key(request: Request, api_key: str = Security(api_key_header)) -> status_code=403, detail={"error": "Invalid API Key"} ) + + +def check_user_id(request: Request) -> str: + user_id = request.headers.get("user_id") + if user_id: + return user_id + else: + return "user" \ No newline at end of file diff --git a/core/cat/looking_glass/agent_manager.py b/core/cat/looking_glass/agent_manager.py index 2d2acc33c..b6810de7e 100644 --- a/core/cat/looking_glass/agent_manager.py +++ b/core/cat/looking_glass/agent_manager.py @@ -11,6 +11,7 @@ from cat.looking_glass import prompts from cat.looking_glass.callbacks import NewTokenHandler from cat.looking_glass.output_parser import ToolOutputParser +from cat.memory.working_memory import WorkingMemory from cat.utils import verbal_timedelta from cat.log import log @@ -72,10 +73,9 @@ def execute_tool_agent(self, agent_input, allowed_tools): return out - def execute_memory_chain(self, agent_input, prompt_prefix, prompt_suffix): + def execute_memory_chain(self, agent_input, prompt_prefix, prompt_suffix, working_memory: WorkingMemory): input_variables = [i for i in agent_input.keys() if i in prompt_prefix + prompt_suffix] - # memory chain (second step) memory_prompt = PromptTemplate( template = prompt_prefix + prompt_suffix, @@ -88,13 +88,13 @@ def execute_memory_chain(self, agent_input, prompt_prefix, prompt_suffix): verbose=True ) - out = memory_chain(agent_input, callbacks=[NewTokenHandler(self.cat)]) + out = memory_chain(agent_input, callbacks=[NewTokenHandler(self.cat, working_memory)]) out["output"] = out["text"] del out["text"] return out - def execute_agent(self): + def execute_agent(self, working_memory): """Instantiate the Agent with tools. The method formats the main prompt and gather the allowed tools. It also instantiates a conversational Agent @@ -106,11 +106,10 @@ def execute_agent(self): Instance of the Agent provided with a set of tools. """ mad_hatter = self.cat.mad_hatter - working_memory = self.cat.working_memory # prepare input to be passed to the agent. # Info will be extracted from working memory - agent_input = self.format_agent_input() + agent_input = self.format_agent_input(working_memory) agent_input = mad_hatter.execute_hook("before_agent_starts", agent_input) # should we ran the default agent? fast_reply = {} @@ -161,7 +160,7 @@ def execute_agent(self): agent_input["tools_output"] = "## Tools output: \n" + tools_result["output"] if tools_result["output"] else "" # Execute the memory chain - out = self.execute_memory_chain(agent_input, prompt_prefix, prompt_suffix) + out = self.execute_memory_chain(agent_input, prompt_prefix, prompt_suffix, working_memory) # If some tools are used the intermediate step are added to the agent output out["intermediate_steps"] = used_tools @@ -178,11 +177,11 @@ def execute_agent(self): #Adding the tools_output key in agent input, needed by the memory chain agent_input["tools_output"] = "" # Execute the memory chain - out = self.execute_memory_chain(agent_input, prompt_prefix, prompt_suffix) + out = self.execute_memory_chain(agent_input, prompt_prefix, prompt_suffix, working_memory) return out - def format_agent_input(self): + def format_agent_input(self, working_memory): """Format the input for the Agent. The method formats the strings of recalled memories and chat history that will be provided to the Langchain @@ -206,7 +205,7 @@ def format_agent_input(self): agent_prompt_chat_history """ - working_memory = self.cat.working_memory + # format memories to be inserted in the prompt episodic_memory_formatted_content = self.agent_prompt_episodic_memories( diff --git a/core/cat/looking_glass/callbacks.py b/core/cat/looking_glass/callbacks.py index b224c9e5d..a50d5058b 100644 --- a/core/cat/looking_glass/callbacks.py +++ b/core/cat/looking_glass/callbacks.py @@ -4,8 +4,9 @@ class NewTokenHandler(BaseCallbackHandler): - def __init__(self, cat): + def __init__(self, cat, working_memory): self.cat = cat + self.working_memory = working_memory def on_llm_new_token(self, token: str, **kwargs) -> None: - self.cat.send_ws_message(token, "chat_token") + self.cat.send_ws_message(token, "chat_token", self.working_memory) diff --git a/core/cat/looking_glass/cheshire_cat.py b/core/cat/looking_glass/cheshire_cat.py index 3a284d4b2..a06689715 100644 --- a/core/cat/looking_glass/cheshire_cat.py +++ b/core/cat/looking_glass/cheshire_cat.py @@ -1,7 +1,8 @@ import time from copy import deepcopy import traceback -from typing import Literal, get_args +from typing import Literal, get_args, Dict +import langchain import os import asyncio import langchain @@ -9,12 +10,13 @@ from langchain.chat_models import ChatOpenAI, AzureChatOpenAI from langchain.base_language import BaseLanguageModel +import cat.utils as utils from cat.log import log from cat.db import crud from cat.db.database import Database from cat.rabbit_hole import RabbitHole from cat.mad_hatter.mad_hatter import MadHatter -from cat.memory.working_memory import WorkingMemoryList +from cat.memory.working_memory import WorkingMemoryList, WorkingMemory from cat.memory.long_term_memory import LongTermMemory from cat.looking_glass.agent_manager import AgentManager from cat.looking_glass.callbacks import NewTokenHandler @@ -26,7 +28,7 @@ MSG_TYPES = Literal["notification", "chat", "error", "chat_token"] # main class -class CheshireCat: +class CheshireCat(): """The Cheshire Cat. This is the main class that manages everything. @@ -38,6 +40,15 @@ class CheshireCat: """ + # CheshireCat is a singleton, this is the instance + _instance = None + + # get instance or create as the constructor is called + def __new__(cls): + if not cls._instance: + cls._instance = super().__new__(cls) + return cls._instance + def __init__(self): """Cat initialization. @@ -71,7 +82,7 @@ def __init__(self): # queue of cat messages not directly related to last user input # i.e. finished uploading a file - self.ws_messages = asyncio.Queue() + self.ws_messages: Dict[str, asyncio.Queue] = {} def load_natural_language(self): """Load Natural Language related objects. @@ -128,7 +139,6 @@ def get_language_model(self) -> BaseLanguageModel: return llm - def get_language_embedder(self) -> embedders.EmbedderSettings: """Hook into the embedder selection. @@ -237,7 +247,7 @@ def load_memory(self): # Load default shared working memory user self.working_memory = self.working_memory_list.get_working_memory() - def recall_relevant_memories_to_working_memory(self): + def recall_relevant_memories_to_working_memory(self, working_memory): """Retrieve context from memory. The method retrieves the relevant memories from the vector collections that are given as context to the LLM. @@ -257,8 +267,8 @@ def recall_relevant_memories_to_working_memory(self): before_cat_recalls_procedural_memories after_cat_recalls_memories """ - user_id = self.working_memory.get_user_id() - recall_query = self.working_memory["user_message_json"]["text"] + user_id = working_memory.get_user_id() + recall_query = working_memory["user_message_json"]["text"] # We may want to search in memory recall_query = self.mad_hatter.execute_hook("cat_recall_query", recall_query) @@ -266,7 +276,7 @@ def recall_relevant_memories_to_working_memory(self): # Embed recall query recall_query_embedding = self.embedder.embed_query(recall_query) - self.working_memory["recall_query"] = recall_query + working_memory["recall_query"] = recall_query # hook to do something before recall begins self.mad_hatter.execute_hook("before_cat_recalls_memories") @@ -297,8 +307,8 @@ def recall_relevant_memories_to_working_memory(self): # hooks to change recall configs for each memory recall_configs = [ self.mad_hatter.execute_hook("before_cat_recalls_episodic_memories", default_episodic_recall_config), - self.mad_hatter.execute_hook("before_cat_recalls_declarative_memories", default_procedural_recall_config), - self.mad_hatter.execute_hook("before_cat_recalls_procedural_memories", default_declarative_recall_config) + self.mad_hatter.execute_hook("before_cat_recalls_declarative_memories", default_declarative_recall_config), + self.mad_hatter.execute_hook("before_cat_recalls_procedural_memories", default_procedural_recall_config) ] memory_types = self.memory.vectors.collections.keys() @@ -310,7 +320,7 @@ def recall_relevant_memories_to_working_memory(self): vector_memory = getattr(self.memory.vectors, memory_type) memories = vector_memory.recall_memories_from_embedding(**config) - self.working_memory[memory_key] = memories + working_memory[memory_key] = memories # hook to modify/enrich retrieved memories self.mad_hatter.execute_hook("after_cat_recalls_memories") @@ -345,19 +355,24 @@ def llm(self, prompt: str, chat: bool = False, stream: bool = False) -> str: if isinstance(self._llm, langchain.chat_models.base.BaseChatModel): return self._llm.call_as_llm(prompt, callbacks=callbacks) - def send_ws_message(self, content: str, msg_type: MSG_TYPES = "notification"): + def send_ws_message(self, content: str, msg_type: MSG_TYPES = "notification", working_memory: WorkingMemory = None): """Send a message via websocket. This method is useful for sending a message via websocket directly without passing through the LLM Parameters ---------- + working_memory content : str The content of the message. msg_type : str The type of the message. Should be either `notification`, `chat` or `error` """ + # no working memory passed, send message to default user + if working_memory is None: + working_memory = self.working_memory_list.get_working_memory() + options = get_args(MSG_TYPES) if msg_type not in options: @@ -365,7 +380,7 @@ def send_ws_message(self, content: str, msg_type: MSG_TYPES = "notification"): if msg_type == "error": asyncio.run( - self.ws_messages.put( + working_memory.ws_messages.put( { "type": msg_type, "name": "GenericError", @@ -375,36 +390,13 @@ def send_ws_message(self, content: str, msg_type: MSG_TYPES = "notification"): ) else: asyncio.run( - self.ws_messages.put( + working_memory.ws_messages.put( { "type": msg_type, "content": content } ) - ) - - def get_base_url(self): - """Allows the Cat expose the base url.""" - secure = os.getenv('CORE_USE_SECURE_PROTOCOLS', '') - if secure != '': - secure = 's' - return f'http{secure}://{os.environ["CORE_HOST"]}:{os.environ["CORE_PORT"]}/' - - def get_base_path(self): - """Allows the Cat expose the base path.""" - return "cat/" - - def get_plugin_path(self): - """Allows the Cat expose the plugins path.""" - return os.path.join(self.get_base_path(), "plugins/") - - def get_static_url(self): - """Allows the Cat expose the static server url.""" - return self.get_base_url() + "static/" - - def get_static_path(self): - """Allows the Cat expose the static files path.""" - return os.path.join(self.get_base_path(), "static/") + ) def __call__(self, user_message_json): """Call the Cat instance. @@ -433,18 +425,20 @@ def __call__(self, user_message_json): # Change working memory based on received user_id user_id = user_message_json.get('user_id', 'user') user_message_json['user_id'] = user_id - self.working_memory = self.working_memory_list.get_working_memory(user_id) + # ccat class working memory is the default "user" working memory + # self.working_memory = self.working_memory_list.get_working_memory(user_id) + user_working_memory = self.working_memory_list.get_working_memory(user_id) # hook to modify/enrich user input user_message_json = self.mad_hatter.execute_hook("before_cat_reads_message", user_message_json) # store last message in working memory - self.working_memory["user_message_json"] = user_message_json + user_working_memory["user_message_json"] = user_message_json # recall episodic and declarative memories from vector collections # and store them in working_memory try: - self.recall_relevant_memories_to_working_memory() + self.recall_relevant_memories_to_working_memory(user_working_memory) except Exception as e: log.error(e) traceback.print_exc(e) @@ -462,7 +456,7 @@ def __call__(self, user_message_json): # reply with agent try: - cat_message = self.agent_manager.execute_agent() + cat_message = self.agent_manager.execute_agent(user_working_memory) except Exception as e: # This error happens when the LLM # does not respect prompt instructions. @@ -476,7 +470,7 @@ def __call__(self, user_message_json): unparsable_llm_output = error_description.replace("Could not parse LLM output: `", "").replace("`", "") cat_message = { - "input": self.working_memory["user_message_json"]["text"], + "input": user_working_memory["user_message_json"]["text"], "intermediate_steps": [], "output": unparsable_llm_output } @@ -485,9 +479,9 @@ def __call__(self, user_message_json): log.info(cat_message) # update conversation history - user_message = self.working_memory["user_message_json"]["text"] - self.working_memory.update_conversation_history(who="Human", message=user_message) - self.working_memory.update_conversation_history(who="AI", message=cat_message["output"]) + user_message = user_working_memory["user_message_json"]["text"] + user_working_memory.update_conversation_history(who="Human", message=user_message) + user_working_memory.update_conversation_history(who="AI", message=cat_message["output"]) # store user message in episodic memory # TODO: vectorize and store also conversation chunks @@ -498,9 +492,9 @@ def __call__(self, user_message_json): ) # build data structure for output (response and why with memories) - episodic_report = [dict(d[0]) | {"score": float(d[1]), "id": d[3]} for d in self.working_memory["episodic_memories"]] - declarative_report = [dict(d[0]) | {"score": float(d[1]), "id": d[3]} for d in self.working_memory["declarative_memories"]] - procedural_report = [dict(d[0]) | {"score": float(d[1]), "id": d[3]} for d in self.working_memory["procedural_memories"]] + episodic_report = [dict(d[0]) | {"score": float(d[1]), "id": d[3]} for d in user_working_memory["episodic_memories"]] + declarative_report = [dict(d[0]) | {"score": float(d[1]), "id": d[3]} for d in user_working_memory["declarative_memories"]] + procedural_report = [dict(d[0]) | {"score": float(d[1]), "id": d[3]} for d in user_working_memory["procedural_memories"]] final_output = { "type": "chat", @@ -520,3 +514,34 @@ def __call__(self, user_message_json): final_output = self.mad_hatter.execute_hook("before_cat_sends_message", final_output) return final_output + + + # TODO: remove this method in a few versions, current version 1.2.0 + def get_base_url(): + """Allows the Cat exposing the base url.""" + log.warning("This method will be removed, import cat.utils tu use it instead.") + return utils.get_base_url() + + # TODO: remove this method in a few versions, current version 1.2.0 + def get_base_path(): + """Allows the Cat exposing the base path.""" + log.warning("This method will be removed, import cat.utils tu use it instead.") + return utils.get_base_path() + + # TODO: remove this method in a few versions, current version 1.2.0 + def get_plugins_path(): + """Allows the Cat exposing the plugins path.""" + log.warning("This method will be removed, import cat.utils tu use it instead.") + return utils.get_plugins_path() + + # TODO: remove this method in a few versions, current version 1.2.0 + def get_static_url(): + """Allows the Cat exposing the static server url.""" + log.warning("This method will be removed, import cat.utils tu usit instead.") + return utils.get_static_url() + + # TODO: remove this method in a few versions, current version 1.2.0 + def get_static_path(): + """Allows the Cat exposing the static files path.""" + log.warning("This method will be removed, import cat.utils tu usit instead.") + return utils.get_static_path() diff --git a/core/cat/mad_hatter/mad_hatter.py b/core/cat/mad_hatter/mad_hatter.py index 34d872a8b..b03af1054 100644 --- a/core/cat/mad_hatter/mad_hatter.py +++ b/core/cat/mad_hatter/mad_hatter.py @@ -1,11 +1,10 @@ import glob -import json import time import shutil import os import traceback +import cat.utils as utils from copy import deepcopy - from cat.log import log from cat.db import crud from cat.db.models import Setting @@ -33,14 +32,15 @@ def __init__(self, ccat): self.active_plugins = [] + self.plugins_folder = utils.get_plugins_path() + self.find_plugins() def install_plugin(self, package_plugin): # extract zip/tar file into plugin folder - plugins_folder = self.ccat.get_plugin_path() extractor = PluginExtractor(package_plugin) - plugin_path = extractor.extract(plugins_folder) + plugin_path = extractor.extract(self.plugins_folder) # remove zip after extraction os.remove(package_plugin) @@ -82,10 +82,8 @@ def find_plugins(self): # plus the default core plugin s(where default hooks and tools are defined) core_plugin_folder = "cat/mad_hatter/core_plugin/" - # plugin folder is "cat/plugins/" in production, "tests/mocks/mock_plugin_folder/" during tests - plugins_folder = self.ccat.get_plugin_path() - - all_plugin_folders = [core_plugin_folder] + glob.glob(f"{plugins_folder}*/") + # plugin folder is "cat/plugins/" in production, "tests/mocks/mock_plugin_folder/" during tests + all_plugin_folders = [core_plugin_folder] + glob.glob(f"{self.plugins_folder}*/") log.info("ACTIVE PLUGINS:") log.info(self.active_plugins) diff --git a/core/cat/mad_hatter/plugin.py b/core/cat/mad_hatter/plugin.py index c68387bb6..27d487e61 100644 --- a/core/cat/mad_hatter/plugin.py +++ b/core/cat/mad_hatter/plugin.py @@ -2,15 +2,16 @@ import sys import json import glob -import importlib import traceback +import importlib +from importlib import machinery from typing import Dict from inspect import getmembers from pydantic import BaseModel from cat.mad_hatter.decorators import CatTool, CatHook, CatPluginOverride from cat.utils import to_camel_case -from cat.log import log, get_log_level +from cat.log import log # Empty class to represent basic plugin Settings model @@ -83,18 +84,35 @@ def deactivate(self): # get plugin settings JSON schema def settings_schema(self): - # is "plugin_settings_schema" hook defined in the plugin? + # is "settings_schema" hook defined in the plugin? for h in self._plugin_overrides: if h.name == "settings_schema": return h.function() + else: + # if the "settings_schema" is not defined but + # "settings_model" is it get the schema from the model + if h.name == "settings_model": + return h.function().model_json_schema() # default schema (empty) return PluginSettingsModel.model_json_schema() + # get plugin settings Pydantic model + def settings_model(self): + + # is "settings_model" hook defined in the plugin? + for h in self._plugin_overrides: + if h.name == "settings_model": + return h.function() + + # default schema (empty) + return PluginSettingsModel + + # load plugin settings def load_settings(self): - # is "plugin_settings_load" hook defined in the plugin? + # is "settings_load" hook defined in the plugin? for h in self._plugin_overrides: if h.name == "load_settings": return h.function() @@ -114,13 +132,14 @@ def load_settings(self): except Exception as e: log.error(f"Unable to load plugin {self._id} settings") log.error(e) + raise e return settings # save plugin settings def save_settings(self, settings: Dict): - # is "plugin_settings_save" hook defined in the plugin? + # is "settings_save" hook defined in the plugin? for h in self._plugin_overrides: if h.name == "save_settings": return h.function(settings) @@ -182,13 +201,31 @@ def _install_requirements(self): log.info(f"Installing requirements for: {self.id}") os.system(f'pip install --no-cache-dir -r "{req_file}"') - # lists of hooks and tools def _load_decorated_functions(self): hooks = [] tools = [] plugin_overrides = [] + """ + for py_file in self.py_files: + module_name = os.path.splitext(os.path.basename(py_file))[0] + + log.info(f"Import module {py_file}") + + # save a reference to decorated functions + try: + plugin_module = machinery.SourceFileLoader(module_name, py_file).load_module() + hooks += getmembers(plugin_module, self._is_cat_hook) + tools += getmembers(plugin_module, self._is_cat_tool) + plugin_overrides += getmembers(plugin_module, self._is_cat_plugin_override) + except Exception as e: + log.error(f"Error in {module_name}: {str(e)}") + traceback.print_exc() + raise Exception(f"Unable to load the plugin {self._id}") + """ + + for py_file in self.py_files: py_filename = py_file.replace("/", ".").replace(".py", "") # this is UGLY I know. I'm sorry @@ -203,8 +240,9 @@ def _load_decorated_functions(self): except Exception as e: log.error(f"Error in {py_filename}: {str(e)}") traceback.print_exc() - raise Exception(f"Unable to load the plugin {self._id}") - + raise Exception(f"Unable to load the plugin {self._id}") + + # clean and enrich instances hooks = list(map(self._clean_hook, hooks)) tools = list(map(self._clean_tool, tools)) diff --git a/core/cat/main.py b/core/cat/main.py index b93801fac..cf6a519bd 100644 --- a/core/cat/main.py +++ b/core/cat/main.py @@ -12,9 +12,9 @@ from cat.log import log from cat.routes import base, settings, llm, embedder, memory, plugins, upload, websocket from cat.routes.static import public, admin, static -from cat.api_auth import check_api_key +from cat.headers import check_api_key from cat.routes.openapi import get_openapi_configuration_function -from cat.looking_glass.cheshire_cat import CheshireCat +from cat.looking_glass.cheshire_cat import CheshireCat @asynccontextmanager @@ -41,7 +41,6 @@ def custom_generate_unique_id(route: APIRoute): # REST API cheshire_cat_api = FastAPI( lifespan=lifespan, - dependencies=[Depends(check_api_key)], generate_unique_id_function=custom_generate_unique_id ) @@ -57,14 +56,15 @@ def custom_generate_unique_id(route: APIRoute): ) # Add routers to the middleware stack. -cheshire_cat_api.include_router(base.router, tags=["Status"]) -cheshire_cat_api.include_router(settings.router, tags=["Settings"], prefix="/settings") -cheshire_cat_api.include_router(llm.router, tags=["Large Language Model"], prefix="/llm") -cheshire_cat_api.include_router(embedder.router, tags=["Embedder"], prefix="/embedder") -cheshire_cat_api.include_router(plugins.router, tags=["Plugins"], prefix="/plugins") -cheshire_cat_api.include_router(memory.router, tags=["Memory"], prefix="/memory") -cheshire_cat_api.include_router(upload.router, tags=["Rabbit Hole"], prefix="/rabbithole") -cheshire_cat_api.include_router(websocket.router, tags=["Websocket"]) +# TODO: To workaround the dependencies of the websocket, their are added manually in each router +cheshire_cat_api.include_router(base.router, tags=["Status"], dependencies=[Depends(check_api_key)]) +cheshire_cat_api.include_router(settings.router, tags=["Settings"], prefix="/settings", dependencies=[Depends(check_api_key)]) +cheshire_cat_api.include_router(llm.router, tags=["Large Language Model"], prefix="/llm", dependencies=[Depends(check_api_key)]) +cheshire_cat_api.include_router(embedder.router, tags=["Embedder"], prefix="/embedder", dependencies=[Depends(check_api_key)]) +cheshire_cat_api.include_router(plugins.router, tags=["Plugins"], prefix="/plugins", dependencies=[Depends(check_api_key)]) +cheshire_cat_api.include_router(memory.router, tags=["Memory"], prefix="/memory", dependencies=[Depends(check_api_key)]) +cheshire_cat_api.include_router(upload.router, tags=["Rabbit Hole"], prefix="/rabbithole", dependencies=[Depends(check_api_key)]) +cheshire_cat_api.include_router(websocket.router, tags=["WebSocket"]) # mount static files # this cannot be done via fastapi.APIrouter: diff --git a/core/cat/memory/working_memory.py b/core/cat/memory/working_memory.py index 5d1b1ed38..6bc236dda 100644 --- a/core/cat/memory/working_memory.py +++ b/core/cat/memory/working_memory.py @@ -1,3 +1,5 @@ +import asyncio + class WorkingMemory(dict): """Cat's volatile memory. @@ -16,6 +18,8 @@ class WorkingMemory(dict): def __init__(self): # The constructor instantiates a `dict` with a 'history' key to store conversation history + # and the asyncio queue to manage the session notifications + self.ws_messages = asyncio.Queue() super().__init__(history=[]) def get_user_id(self): diff --git a/core/cat/routes/embedder.py b/core/cat/routes/embedder.py index d7a69a90f..089d5c059 100644 --- a/core/cat/routes/embedder.py +++ b/core/cat/routes/embedder.py @@ -91,7 +91,7 @@ def get_embedder_settings(request: Request, languageEmbedderName: str) -> Dict: def upsert_embedder_setting( request: Request, languageEmbedderName: str, - payload: Dict = Body(example={"openai_api_key": "your-key-here"}), + payload: Dict = Body(examples={"openai_api_key": "your-key-here"}), ) -> Dict: """Upsert the Embedder setting""" diff --git a/core/cat/routes/llm.py b/core/cat/routes/llm.py index fbde5dd28..5dfe96b1d 100644 --- a/core/cat/routes/llm.py +++ b/core/cat/routes/llm.py @@ -85,7 +85,7 @@ def get_llm_settings(request: Request, languageModelName: str) -> Dict: def upsert_llm_setting( request: Request, languageModelName: str, - payload: Dict = Body(example={"openai_api_key": "your-key-here"}), + payload: Dict = Body(examples={"openai_api_key": "your-key-here"}), ) -> Dict: """Upsert the Large Language Model setting""" diff --git a/core/cat/routes/memory.py b/core/cat/routes/memory.py index e009b5eea..dffd7dbad 100644 --- a/core/cat/routes/memory.py +++ b/core/cat/routes/memory.py @@ -1,5 +1,6 @@ from typing import Dict -from fastapi import Query, Request, APIRouter, HTTPException +from cat.headers import check_user_id +from fastapi import Query, Request, APIRouter, HTTPException, Depends router = APIRouter() @@ -7,10 +8,10 @@ # GET memories from recall @router.get("/recall/") async def recall_memories_from_text( - request: Request, - text: str = Query(description="Find memories similar to this text."), - k: int = Query(default=100, description="How many memories to return."), - user_id: str = Query(default="user", description="User id."), + request: Request, + text: str = Query(description="Find memories similar to this text."), + k: int = Query(default=100, description="How many memories to return."), + user_id = Depends(check_user_id) ) -> Dict: """Search k memories similar to given text.""" @@ -87,7 +88,7 @@ async def get_collections(request: Request) -> Dict: # DELETE all collections @router.delete("/collections/") async def wipe_collections( - request: Request, + request: Request, ) -> Dict: """Delete and create all collections""" @@ -111,8 +112,7 @@ async def wipe_collections( # DELETE one collection @router.delete("/collections/{collection_id}/") -async def wipe_single_collection(request: Request, - collection_id: str) -> Dict: +async def wipe_single_collection(request: Request, collection_id: str) -> Dict: """Delete and recreate a collection""" ccat = request.app.state.ccat @@ -143,9 +143,9 @@ async def wipe_single_collection(request: Request, # DELETE memories @router.delete("/collections/{collection_id}/points/{memory_id}/") async def wipe_memory_point( - request: Request, - collection_id: str, - memory_id: str + request: Request, + collection_id: str, + memory_id: str ) -> Dict: """Delete a specific point in memory""" @@ -181,9 +181,9 @@ async def wipe_memory_point( @router.delete("/collections/{collection_id}/points") async def wipe_memory_points_by_metadata( - request: Request, - collection_id: str, - metadata: Dict = {}, + request: Request, + collection_id: str, + metadata: Dict = {}, ) -> Dict: """Delete points in memory by filter""" @@ -201,9 +201,12 @@ async def wipe_memory_points_by_metadata( # DELETE conversation history from working memory @router.delete("/conversation_history/") async def wipe_conversation_history( - request: Request, + request: Request, + user_id = Depends(check_user_id), ) -> Dict: - """Delete conversation history from working memory""" + """Delete the specified user's conversation history from working memory""" + + # TODO: Add possibility to wipe the working memory of specified user id ccat = request.app.state.ccat ccat.working_memory["history"] = [] @@ -211,3 +214,21 @@ async def wipe_conversation_history( return { "deleted": True, } + + +# GET conversation history from working memory +@router.get("/conversation_history/") +async def get_conversation_history( + request: Request, + user_id = Depends(check_user_id), +) -> Dict: + """Get the specified user's conversation history from working memory""" + + # TODO: Add possibility to get the working memory of specified user id + + ccat = request.app.state.ccat + history = ccat.working_memory["history"] + + return { + "history": history + } \ No newline at end of file diff --git a/core/cat/routes/plugins.py b/core/cat/routes/plugins.py index 409e7d0f4..4be087570 100644 --- a/core/cat/routes/plugins.py +++ b/core/cat/routes/plugins.py @@ -8,6 +8,8 @@ from urllib.parse import urlparse import requests +from pydantic import ValidationError + router = APIRouter() @@ -65,9 +67,9 @@ async def get_available_plugins( @router.post("/upload/") async def install_plugin( - request: Request, - file: UploadFile, - background_tasks: BackgroundTasks + request: Request, + file: UploadFile, + background_tasks: BackgroundTasks ) -> Dict: """Install a new plugin from a zip file""" @@ -104,8 +106,8 @@ async def install_plugin( async def install_plugin_from_registry( request: Request, background_tasks: BackgroundTasks, - payload: Dict = Body(example={"url": "https://github.com/plugin-dev-account/plugin-repo"}) - ) -> Dict: + payload: Dict = Body(examples={"url": "https://github.com/plugin-dev-account/plugin-repo"}) +) -> Dict: """Install a new plugin from registry""" # access cat instance @@ -219,15 +221,18 @@ async def get_plugins_settings(request: Request) -> Dict: # plugins are managed by the MadHatter class for plugin in ccat.mad_hatter.plugins.values(): - plugin_settings = plugin.load_settings() - plugin_schema = plugin.settings_schema() - if plugin_schema['properties'] == {}: - plugin_schema = {} - settings.append({ - "name": plugin.id, - "value": plugin_settings, - "schema": plugin_schema - }) + try: + plugin_settings = plugin.load_settings() + plugin_schema = plugin.settings_schema() + if plugin_schema['properties'] == {}: + plugin_schema = {} + settings.append({ + "name": plugin.id, + "value": plugin_settings, + "schema": plugin_schema + }) + except: + log.error(f"Error loading {plugin} settings") return { "settings": settings, @@ -264,7 +269,7 @@ async def get_plugin_settings(request: Request, plugin_id: str) -> Dict: async def upsert_plugin_settings( request: Request, plugin_id: str, - payload: Dict = Body(example={"setting_a": "some value", "setting_b": "another value"}), + payload: Dict = Body(examples={"setting_a": "some value", "setting_b": "another value"}), ) -> Dict: """Updates the settings of a specific plugin""" @@ -277,7 +282,21 @@ async def upsert_plugin_settings( detail = { "error": "Plugin not found" } ) - final_settings = ccat.mad_hatter.plugins[plugin_id].save_settings(payload) + # Get the plugin object + plugin = ccat.mad_hatter.plugins[plugin_id] + + try: + # Load the plugin settings Pydantic model + PluginSettingsModel = plugin.settings_model() + # Validate the settings + PluginSettingsModel.model_validate(payload) + except ValidationError as e: + raise HTTPException( + status_code = 400, + detail = { "error": "\n".join(list( map((lambda x: x["msg"]), e.errors())))} + ) + + final_settings = plugin.save_settings(payload) return { "name": plugin_id, diff --git a/core/cat/routes/static/auth_static.py b/core/cat/routes/static/auth_static.py index 42fb8355f..a6492df07 100644 --- a/core/cat/routes/static/auth_static.py +++ b/core/cat/routes/static/auth_static.py @@ -1,12 +1,12 @@ -from fastapi.staticfiles import StaticFiles -from fastapi import Request -from cat.api_auth import check_api_key - -class AuthStatic(StaticFiles): - def __init__(self, *args, **kwargs) -> None: - super().__init__(*args, **kwargs) - - async def __call__(self, scope, receive, send) -> None: - reqeust = Request(scope, receive=receive) - check_api_key(reqeust.headers.get("access_token")) +from fastapi.staticfiles import StaticFiles +from fastapi import Request +from cat.headers import check_api_key + +class AuthStatic(StaticFiles): + def __init__(self, *args, **kwargs) -> None: + super().__init__(*args, **kwargs) + + async def __call__(self, scope, receive, send) -> None: + request = Request(scope, receive=receive) + check_api_key(request.headers.get("access_token")) await super().__call__(scope, receive, send) \ No newline at end of file diff --git a/core/cat/routes/upload.py b/core/cat/routes/upload.py index 14a35322f..da4ef3b42 100644 --- a/core/cat/routes/upload.py +++ b/core/cat/routes/upload.py @@ -10,17 +10,16 @@ # receive files via http endpoint -# TODO: should we receive files also via websocket? @router.post("/") async def upload_file( - request: Request, - file: UploadFile, - background_tasks: BackgroundTasks, - chunk_size: int = Body( - default=400, - description="Maximum length of each chunk after the document is split (in characters)", - ), - chunk_overlap: int = Body(default=100, description="Chunk overlap (in characters)") + request: Request, + file: UploadFile, + background_tasks: BackgroundTasks, + chunk_size: int = Body( + default=400, + description="Maximum length of each chunk after the document is split (in characters)", + ), + chunk_overlap: int = Body(default=100, description="Chunk overlap (in characters)") ) -> Dict: """Upload a file containing text (.txt, .md, .pdf, etc.). File content will be extracted and segmented into chunks. Chunks will be then vectorized and stored into documents memory. @@ -58,16 +57,16 @@ async def upload_file( @router.post("/web/") async def upload_url( - request: Request, - background_tasks: BackgroundTasks, - url: str = Body( - description="URL of the website to which you want to save the content" - ), - chunk_size: int = Body( - default=400, - description="Maximum length of each chunk after the document is split (in characters)", - ), - chunk_overlap: int = Body(default=100, description="Chunk overlap (in characters)") + request: Request, + background_tasks: BackgroundTasks, + url: str = Body( + description="URL of the website to which you want to save the content" + ), + chunk_size: int = Body( + default=400, + description="Maximum length of each chunk after the document is split (in characters)", + ), + chunk_overlap: int = Body(default=100, description="Chunk overlap (in characters)") ): """Upload a url. Website content will be extracted and segmented into chunks. Chunks will be then vectorized and stored into documents memory.""" @@ -111,9 +110,9 @@ async def upload_url( @router.post("/memory/") async def upload_memory( - request: Request, - file: UploadFile, - background_tasks: BackgroundTasks + request: Request, + file: UploadFile, + background_tasks: BackgroundTasks ) -> Dict: """Upload a memory json file to the cat memory""" diff --git a/core/cat/routes/websocket.py b/core/cat/routes/websocket.py index 21e02b5e5..bfbff7430 100644 --- a/core/cat/routes/websocket.py +++ b/core/cat/routes/websocket.py @@ -1,16 +1,13 @@ import traceback import asyncio - +from cat.looking_glass.cheshire_cat import CheshireCat +from typing import Dict, Optional from fastapi import APIRouter, WebSocketDisconnect, WebSocket from cat.log import log from fastapi.concurrency import run_in_threadpool router = APIRouter() -# This constant sets the interval (in seconds) at which the system checks for notifications. -QUEUE_CHECK_INTERVAL = 1 # seconds - - class ConnectionManager: """ Manages active WebSocket connections. @@ -18,81 +15,90 @@ class ConnectionManager: def __init__(self): # List to store all active WebSocket connections. - self.active_connections: list[WebSocket] = [] + self.active_connections: Dict[str, WebSocket] = {} - async def connect(self, websocket: WebSocket): + async def connect(self, websocket: WebSocket, user_id: str = "user"): """ Accept the incoming WebSocket connection and add it to the active connections list. """ - await websocket.accept() - self.active_connections.append(websocket) + self.active_connections[user_id] = websocket - def disconnect(self, websocket: WebSocket): + def disconnect(self, ccat: CheshireCat, user_id: str = "user"): """ Remove the given WebSocket from the active connections list. """ - self.active_connections.remove(websocket) + del self.active_connections[user_id] + if user_id in ccat.ws_messages: + del ccat.ws_messages[user_id] - async def send_personal_message(self, message: str, websocket: WebSocket): + async def send_personal_message(self, message: str, user_id: str = "user"): """ Send a personal message (in JSON format) to the specified WebSocket. """ - await websocket.send_json(message) + if user_id in self.active_connections: + await self.active_connections[user_id].send_json(message) async def broadcast(self, message: str): """ Send a message to all active WebSocket connections. """ - for connection in self.active_connections: + for connection in self.active_connections.values(): await connection.send_json(message) manager = ConnectionManager() -async def receive_message(websocket: WebSocket, ccat: object): +async def receive_message(ccat: CheshireCat, user_id: str = "user"): """ Continuously receive messages from the WebSocket and forward them to the `ccat` object for processing. """ + while True: - user_message = await websocket.receive_json() + # Receive the next message from the WebSocket. + user_message = await manager.active_connections[user_id].receive_json() + user_message["user_id"] = user_id # Run the `ccat` object's method in a threadpool since it might be a CPU-bound operation. cat_message = await run_in_threadpool(ccat, user_message) # Send the response message back to the user. - await manager.send_personal_message(cat_message, websocket) + await manager.send_personal_message(cat_message, user_id) -async def check_messages(websocket: WebSocket, ccat): +async def check_messages(ccat, user_id='user'): """ Periodically check if there are any new notifications from the `ccat` instance and send them to the user. """ - while True: + while True: # extract from FIFO list websocket notification - notification = await ccat.ws_messages.get() - await manager.send_personal_message(notification, websocket) - + notification = await ccat.working_memory_list.get_working_memory(user_id).ws_messages.get() + await manager.send_personal_message(notification, user_id) + -@router.websocket_route("/ws") -async def websocket_endpoint(websocket: WebSocket): +@router.websocket("/ws") +@router.websocket("/ws/{user_id}") +async def websocket_endpoint(websocket: WebSocket, user_id: str = "user"): """ - Endpoint to handle incoming WebSocket connections, process messages, and check for messages. + Endpoint to handle incoming WebSocket connections by user id, process messages, and check for messages. """ # Retrieve the `ccat` instance from the application's state. ccat = websocket.app.state.ccat - # Add the new WebSocket connection to the manager. - await manager.connect(websocket) + if user_id in manager.active_connections: + # Skip the coroutine if the same user is already connected via WebSocket. + return + # Add the new WebSocket connection to the manager. + await manager.connect(websocket, user_id) try: # Process messages and check for notifications concurrently. await asyncio.gather( - receive_message(websocket, ccat), - check_messages(websocket, ccat) + receive_message(ccat, user_id), + check_messages(ccat, user_id) ) except WebSocketDisconnect: # Handle the event where the user disconnects their WebSocket. @@ -105,7 +111,7 @@ async def websocket_endpoint(websocket: WebSocket): "type": "error", "name": type(e).__name__, "description": str(e), - }, websocket) + }, user_id) finally: - # Always ensure the WebSocket is removed from the manager, regardless of how the above block exits. - manager.disconnect(websocket) + # Remove the WebSocket from the manager when the user disconnects. + manager.disconnect(ccat, user_id) diff --git a/core/cat/utils.py b/core/cat/utils.py index 8792a99e0..13b9cea66 100644 --- a/core/cat/utils.py +++ b/core/cat/utils.py @@ -1,9 +1,10 @@ """Various utiles used from the projects.""" - +import os +import inspect from datetime import timedelta -def to_camel_case(text :str ) -> str: +def to_camel_case(text: str) -> str: """Format string to camel case. Takes a string of words separated by either hyphens or underscores and returns a string of words in camel case. @@ -67,3 +68,49 @@ def verbal_timedelta(td: timedelta) -> str: return "{} ago".format(abs_delta) else: return "{} ago".format(abs_delta) + + +def get_base_url(): + """Allows exposing the base url.""" + secure = os.getenv('CORE_USE_SECURE_PROTOCOLS', '') + if secure != '': + secure = 's' + return f'http{secure}://{os.environ["CORE_HOST"]}:{os.environ["CORE_PORT"]}/' + + +def get_base_path(): + """Allows exposing the base path.""" + return "cat/" + + +def get_plugins_path(): + """Allows exposing the plugins' path.""" + return os.path.join(get_base_path(), "plugins/") + + +def get_static_url(): + """Allows exposing the static server url.""" + return get_base_url() + "static/" + + +def get_static_path(): + """Allows exposing the static files' path.""" + return os.path.join(get_base_path(), "static/") + + +def get_current_plugin_path(): + """Allows accessing the current plugin path.""" + # Get the current execution frame of the calling module, + # then the previous frame in the call stack + frame = inspect.currentframe().f_back + # Get the module associated with the frame + module = inspect.getmodule(frame) + # Get the absolute and then relative path of the calling module's file + abs_path = inspect.getabsfile(module) + rel_path = os.path.relpath(abs_path) + # Replace the root and get only the current plugin folder + plugin_suffix = rel_path.replace(get_plugins_path(), "") + # Plugin's folder + folder_name = plugin_suffix.split("/")[0] + # Get current plugin's folder + return os.path.join(get_plugins_path(), folder_name) diff --git a/core/pyproject.toml b/core/pyproject.toml index 157c6abdc..6c66c2316 100644 --- a/core/pyproject.toml +++ b/core/pyproject.toml @@ -1,7 +1,7 @@ [project] name = "Cheshire-Cat" description = "Production ready AI assistant framework" -version = "1.2.0" +version = "1.3.0" requires-python = ">=3.10" license = { file="LICENSE" } authors = [ diff --git a/core/tests/conftest.py b/core/tests/conftest.py index d3e899817..4f4c94fca 100644 --- a/core/tests/conftest.py +++ b/core/tests/conftest.py @@ -8,12 +8,12 @@ from fastapi import FastAPI from fastapi.testclient import TestClient -from tinydb import Query -from cat.db import models + from cat.db.database import Database from cat.log import log -from cat.looking_glass.cheshire_cat import CheshireCat +import cat.utils as utils + from qdrant_client import QdrantClient from cat.memory.vector_memory import VectorMemory @@ -46,10 +46,10 @@ def app(monkeypatch) -> Generator[FastAPI, Any, None]: Create a new setup on each test case, with new mocks for both Qdrant and TinyDB """ - # Use mock plugin folder - def mock_plugin_folder(self, *args, **kwargs): + # Use mock utils plugin folder + def get_test_plugin_folder(): return "tests/mocks/mock_plugin_folder/" - monkeypatch.setattr(CheshireCat, "get_plugin_path", mock_plugin_folder) + utils.get_plugins_path = get_test_plugin_folder # Use in memory vector db def mock_connect_to_vector_memory(self, *args, **kwargs): diff --git a/core/tests/mad_hatter/test_mad_hatter.py b/core/tests/mad_hatter/test_mad_hatter.py index eddff27dd..bb1909107 100644 --- a/core/tests/mad_hatter/test_mad_hatter.py +++ b/core/tests/mad_hatter/test_mad_hatter.py @@ -1,8 +1,9 @@ import os -import shutil import pytest from inspect import isfunction +import cat.utils as utils + from cat.mad_hatter.mad_hatter import MadHatter, Plugin from cat.mad_hatter.decorators import CatHook, CatTool from cat.looking_glass.cheshire_cat import CheshireCat @@ -69,7 +70,7 @@ def test_plugin_install(mad_hatter: MadHatter, plugin_is_flat): # archive extracted assert os.path.exists( - os.path.join(mad_hatter.ccat.get_plugin_path(), "mock_plugin") + os.path.join(utils.get_plugins_path(), "mock_plugin") ) # plugins list updated @@ -127,7 +128,7 @@ def test_plugin_uninstall(mad_hatter: MadHatter, plugin_is_flat): # directory removed assert not os.path.exists( - os.path.join(mad_hatter.ccat.get_plugin_path(), "mock_plugin") + os.path.join(utils.get_plugins_path(), "mock_plugin") ) # plugins list updated diff --git a/core/tests/routes/memory/test_memory_by_user.py b/core/tests/routes/memory/test_memory_by_user.py index 4c74f109b..dff37eee7 100644 --- a/core/tests/routes/memory/test_memory_by_user.py +++ b/core/tests/routes/memory/test_memory_by_user.py @@ -1,14 +1,12 @@ - from tests.utils import send_websocket_message # episodic memories are saved having the correct user def test_episodic_memory_by_user(client): - # send websocket message from user A + # send websocket message from user C send_websocket_message({ - "text": "I am user A", - "user_id": "A" - }, client) + "text": "I am user C", + }, client, user_id="C") # episodic recall (no user) params = { @@ -22,24 +20,22 @@ def test_episodic_memory_by_user(client): # episodic recall (memories from non existing user) params = { - "text": "I am user", - "user_id": "H" + "text": "I am user not existing" } - response = client.get(f"/memory/recall/", params=params) + response = client.get(f"/memory/recall/", params=params, headers={"user_id": "not_existing"}) json = response.json() assert response.status_code == 200 episodic_memories = json["vectors"]["collections"]["episodic"] assert len(episodic_memories) == 0 - # episodic recall (memories from user A) + # episodic recall (memories from user C) params = { - "text": "I am user", - "user_id": "A" + "text": "I am user C" } - response = client.get(f"/memory/recall/", params=params) + response = client.get(f"/memory/recall/", params=params, headers={"user_id": "C"}) json = response.json() assert response.status_code == 200 episodic_memories = json["vectors"]["collections"]["episodic"] assert len(episodic_memories) == 1 - assert episodic_memories[0]["metadata"]["source"] == "A" + assert episodic_memories[0]["metadata"]["source"] == "C" diff --git a/core/tests/routes/memory/test_memory_recall.py b/core/tests/routes/memory/test_memory_recall.py index 350260341..14040355d 100644 --- a/core/tests/routes/memory/test_memory_recall.py +++ b/core/tests/routes/memory/test_memory_recall.py @@ -54,7 +54,6 @@ def test_memory_recall_success(client): episodic_memories = json["vectors"]["collections"]["episodic"] assert len(episodic_memories) == num_messages # all 3 retrieved - # search with query and k def test_memory_recall_with_k_success(client): diff --git a/core/tests/utils.py b/core/tests/utils.py index 643d4618b..7eb855c7f 100644 --- a/core/tests/utils.py +++ b/core/tests/utils.py @@ -3,13 +3,11 @@ # utility function to communicate with the cat via websocket -def send_websocket_message(msg, client): +def send_websocket_message(msg, client, user_id="user"): - with client.websocket_connect("/ws") as websocket: - + with client.websocket_connect(f"/ws/{user_id}") as websocket: # sed ws message websocket.send_json(msg) - # get reply reply = websocket.receive_json() @@ -20,15 +18,20 @@ def send_websocket_message(msg, client): def send_n_websocket_messages(num_messages, client): responses = [] - for m in range(num_messages): - message = { - "text": f"Red Queen {m}" - } - res = send_websocket_message(message, client) - responses.append(res) + + with client.websocket_connect(f"/ws") as websocket: + for m in range(num_messages): + message = { + "text": f"Red Queen {m}" + } + # sed ws message + websocket.send_json(message) + # get reply + reply = websocket.receive_json() + responses.append(reply) return responses - + def key_in_json(key, json): return key in json.keys() diff --git a/docker-compose.yml b/docker-compose.yml index 24be86b8e..08dc0557e 100644 --- a/docker-compose.yml +++ b/docker-compose.yml @@ -31,7 +31,7 @@ services: restart: unless-stopped cheshire-cat-vector-memory: - image: qdrant/qdrant:v1.1.1 + image: qdrant/qdrant:v1.6.1 container_name: cheshire_cat_vector_memory expose: - 6333