From 07e600a15c1fa94179f98aef760bc3bc079dfd6b Mon Sep 17 00:00:00 2001 From: Chaitanya-Keyal <66475772+Chaitanya-Keyal@users.noreply.github.com> Date: Sun, 19 Jan 2025 02:35:59 +0530 Subject: [PATCH] fix: update state messages structure and add agent depths --- src/agents.py | 10 +++++++++- src/nodes.py | 39 +++++++++++++++++++++++++-------------- src/state.py | 3 ++- 3 files changed, 36 insertions(+), 16 deletions(-) diff --git a/src/agents.py b/src/agents.py index a560113..d8874b8 100644 --- a/src/agents.py +++ b/src/agents.py @@ -31,6 +31,14 @@ def __init__(self): agent_name = agent_name.split(".")[0] self.prompts[agent_name] = f.read() + self.depths = { + "intent_classifier": 1, + "general_campus_query": 2, + "course_query": 2, + "not_related_query": 2, + "long_term_memory": 3, + } + def _get_prompt( self, agent_name: str, user_input: str, agent_scratchpad=False ) -> ChatPromptTemplate: @@ -63,7 +71,7 @@ def intent_classifier(self, query: str, chat_history: str) -> str: } ) - return result + return result.content def general_campus_query(self, query: str, chat_history: str) -> str: PROMPT_TEMPLATE = self._get_prompt( diff --git a/src/nodes.py b/src/nodes.py index ef1dc01..eeb22d7 100644 --- a/src/nodes.py +++ b/src/nodes.py @@ -9,44 +9,55 @@ def intent_classifier(state: State): - query = state["messages"][0].content - result = agents.intent_classifier(query, state.get("chat_history", "")) - + query = state["messages"][-agents.depths["intent_classifier"]].content + result = AIMessage(agents.intent_classifier(query, state.get("chat_history", ""))) return {"messages": [result]} def course_query(state: State): - query = state["messages"][0].content + query = state["messages"][-agents.depths["course_query"]].content result = AIMessage("Course query not implemented yet") # Add short term memory here once implemented. return {"messages": [result]} def general_campus_query(state: State): - query = state["messages"][0].content - result = agents.general_campus_query(query, state.get("chat_history", "")) - add_short_term_memory(state["user_id"], query, result, "general_campus_query") + query = state["messages"][-agents.depths["general_campus_query"]].content + result = AIMessage( + agents.general_campus_query(query, state.get("chat_history", "")) + ) + add_short_term_memory( + state["user_id"], + query, + result.content, + "general_campus_query", + ) return {"messages": [result]} def not_related_query(state: State): - query = state["messages"][0].content + query = state["messages"][-agents.depths["not_related_query"]].content result = AIMessage( "I'm sorry, I don't understand the question, if it relates to campus please rephrase." ) add_short_term_memory( - state["user_id"], query, not_related_query, "intent_classifier" + state["user_id"], + query, + "not_related_query", + "intent_classifier", ) return {"messages": [result]} def long_term_memory(state: State): - query = state["messages"][0].content + query = state["messages"][-agents.depths["long_term_memory"]].content user_id = state["user_id"] long_term_memories = parse_long_term_memory(state.get("long_term_memories", [])) - result = agents.long_term_memory( - user_id, - query, - long_term_memories, + result = AIMessage( + agents.long_term_memory( + user_id, + query, + long_term_memories, + ) ) return {"messages": [result]} diff --git a/src/state.py b/src/state.py index 51b34f1..a1934a7 100644 --- a/src/state.py +++ b/src/state.py @@ -1,5 +1,6 @@ from typing import Annotated, Optional +from langchain_core.messages import AIMessage, HumanMessage from langgraph.graph.message import add_messages from typing_extensions import TypedDict @@ -8,6 +9,6 @@ class State(TypedDict): user_id: str - messages: Annotated[list, add_messages] + messages: Annotated[list[AIMessage | HumanMessage], add_messages] chat_history: Optional[str] long_term_memories: Optional[list[LongTermMemory]]