diff --git a/core/cat/agents/base_agent.py b/core/cat/agents/base_agent.py index f94155caa..d9116d508 100644 --- a/core/cat/agents/base_agent.py +++ b/core/cat/agents/base_agent.py @@ -13,5 +13,5 @@ class AgentOutput(BaseModelDict): class BaseAgent(ABC): @abstractmethod - async def execute(*args, **kwargs) -> AgentOutput: + def execute(*args, **kwargs) -> AgentOutput: pass \ No newline at end of file diff --git a/core/cat/agents/form_agent.py b/core/cat/agents/form_agent.py index d393c8fdf..8be84fd3d 100644 --- a/core/cat/agents/form_agent.py +++ b/core/cat/agents/form_agent.py @@ -5,7 +5,7 @@ class FormAgent(BaseAgent): - async def execute(self, stray) -> AgentOutput: + def execute(self, stray) -> AgentOutput: # get active form from working memory active_form = stray.working_memory.active_form diff --git a/core/cat/agents/main_agent.py b/core/cat/agents/main_agent.py index 7ea5a0226..e52389bb7 100644 --- a/core/cat/agents/main_agent.py +++ b/core/cat/agents/main_agent.py @@ -26,7 +26,7 @@ def __init__(self): else: self.verbose = False - async def execute(self, stray) -> AgentOutput: + def execute(self, stray) -> AgentOutput: """Execute the agents. Returns @@ -66,7 +66,7 @@ async def execute(self, stray) -> AgentOutput: # run tools and forms procedures_agent = ProceduresAgent() - procedures_agent_out : AgentOutput = await procedures_agent.execute(stray) + procedures_agent_out : AgentOutput = procedures_agent.execute(stray) if procedures_agent_out.return_direct: return procedures_agent_out @@ -74,7 +74,7 @@ async def execute(self, stray) -> AgentOutput: # - no procedures were recalled or selected or # - procedures have all return_direct=False memory_agent = MemoryAgent() - memory_agent_out : AgentOutput = await memory_agent.execute( + memory_agent_out : AgentOutput = memory_agent.execute( # TODO: should all agents only receive stray? stray, prompt_prefix, prompt_suffix ) diff --git a/core/cat/agents/memory_agent.py b/core/cat/agents/memory_agent.py index 2308c483c..22c34f954 100644 --- a/core/cat/agents/memory_agent.py +++ b/core/cat/agents/memory_agent.py @@ -11,7 +11,7 @@ class MemoryAgent(BaseAgent): - async def execute(self, stray, prompt_prefix, prompt_suffix) -> AgentOutput: + def execute(self, stray, prompt_prefix, prompt_suffix) -> AgentOutput: prompt_variables = stray.working_memory.agent_input.model_dump() sys_prompt = prompt_prefix + prompt_suffix diff --git a/core/cat/agents/procedures_agent.py b/core/cat/agents/procedures_agent.py index 86baa1b39..9c28a77a1 100644 --- a/core/cat/agents/procedures_agent.py +++ b/core/cat/agents/procedures_agent.py @@ -24,10 +24,10 @@ class ProceduresAgent(BaseAgent): form_agent = FormAgent() allowed_procedures: Dict[str, CatTool | CatForm] = {} - async def execute(self, stray) -> AgentOutput: + def execute(self, stray) -> AgentOutput: # Run active form if present - form_output: AgentOutput = await self.form_agent.execute(stray) + form_output: AgentOutput = self.form_agent.execute(stray) if form_output.return_direct: return form_output @@ -38,7 +38,7 @@ async def execute(self, stray) -> AgentOutput: log.debug(f"Procedural memories retrived: {len(procedural_memories)}.") try: - procedures_result: AgentOutput = await self.execute_procedures(stray) + procedures_result: AgentOutput = self.execute_procedures(stray) if procedures_result.return_direct: # exit agent if a return_direct procedure was executed return procedures_result @@ -64,7 +64,7 @@ async def execute(self, stray) -> AgentOutput: return AgentOutput() - async def execute_procedures(self, stray): + def execute_procedures(self, stray): # using some hooks mad_hatter = MadHatter() @@ -87,13 +87,13 @@ async def execute_procedures(self, stray): ) # Execute chain and obtain a choice of procedure from the LLM - llm_action: LLMAction = await self.execute_chain(stray, procedures_prompt_template, allowed_procedures) + llm_action: LLMAction = self.execute_chain(stray, procedures_prompt_template, allowed_procedures) # route execution to subagents - return await self.execute_subagents(stray, llm_action, allowed_procedures) + return self.execute_subagents(stray, llm_action, allowed_procedures) - async def execute_chain(self, stray, procedures_prompt_template, allowed_procedures) -> LLMAction: + def execute_chain(self, stray, procedures_prompt_template, allowed_procedures) -> LLMAction: # Prepare info to fill up the prompt prompt_variables = { @@ -136,7 +136,7 @@ async def execute_chain(self, stray, procedures_prompt_template, allowed_procedu return llm_action - async def execute_subagents(self, stray, llm_action, allowed_procedures): + def execute_subagents(self, stray, llm_action, allowed_procedures): # execute chosen tool / form # loop over allowed tools and forms if llm_action.action: @@ -144,7 +144,7 @@ async def execute_subagents(self, stray, llm_action, allowed_procedures): try: if Plugin._is_cat_tool(chosen_procedure): # execute tool - tool_output = await chosen_procedure._arun(llm_action.action_input, stray=stray) + tool_output = chosen_procedure.run(llm_action.action_input, stray=stray) return AgentOutput( output=tool_output, return_direct=chosen_procedure.return_direct, @@ -158,7 +158,7 @@ async def execute_subagents(self, stray, llm_action, allowed_procedures): # store active form in working memory stray.working_memory.active_form = form_instance # execute form - return await self.form_agent.execute(stray) + return self.form_agent.execute(stray) except Exception as e: log.error(f"Error executing {chosen_procedure.procedure_type} `{chosen_procedure.name}`") diff --git a/core/cat/looking_glass/stray_cat.py b/core/cat/looking_glass/stray_cat.py index 5108c4fc5..968159fb7 100644 --- a/core/cat/looking_glass/stray_cat.py +++ b/core/cat/looking_glass/stray_cat.py @@ -43,8 +43,6 @@ def __init__( self.__main_loop = main_loop - self.__loop = asyncio.new_event_loop() - def __repr__(self): return f"StrayCat(user_id={self.user_id})" @@ -342,7 +340,7 @@ def llm(self, prompt: str, stream: bool = False) -> str: return output - async def __call__(self, message_dict): + def __call__(self, message_dict): """Call the Cat instance. This method is called on the user's message received from the client. @@ -408,7 +406,7 @@ async def __call__(self, message_dict): # reply with agent try: - agent_output: AgentOutput = await self.main_agent.execute(self) + agent_output: AgentOutput = self.main_agent.execute(self) except Exception as e: # This error happens when the LLM # does not respect prompt instructions. @@ -472,7 +470,7 @@ async def __call__(self, message_dict): def run(self, user_message_json, return_message=False): try: - cat_message = self.loop.run_until_complete(self.__call__(user_message_json)) + cat_message = self.__call__(user_message_json) if return_message: # return the message for HTTP usage return cat_message @@ -648,7 +646,3 @@ def main_agent(self): @property def white_rabbit(self): return CheshireCat().white_rabbit - - @property - def loop(self): - return self.__loop diff --git a/core/cat/mad_hatter/decorators/tool.py b/core/cat/mad_hatter/decorators/tool.py index b9cfd16eb..c97f62953 100644 --- a/core/cat/mad_hatter/decorators/tool.py +++ b/core/cat/mad_hatter/decorators/tool.py @@ -45,20 +45,12 @@ def __repr__(self) -> str: return f"CatTool(name={self.name}, return_direct={self.return_direct}, description={self.description})" # we run tools always async, even if they are not defined so in a plugin - def _run(self, input_by_llm: str) -> str: - pass # do nothing - + def _run(self, input_by_llm: str, stray) -> str: + return self.func(input_by_llm, cat=stray) + # we run tools always async, even if they are not defined so in a plugin async def _arun(self, input_by_llm, stray): - - # await if the tool is async - if inspect.iscoroutinefunction(self.func): - return await self.func(input_by_llm, cat=stray) - - # run in executor if the tool is not async - return await stray.loop.run_in_executor( - None, self.func, input_by_llm, stray - ) + pass # override `extra = 'forbid'` for Tool pydantic model in langchain class Config: