Skip to content

Commit

Permalink
Merge pull request #8 from crux-bphc:state-messages-fix
Browse files Browse the repository at this point in the history
Update state messages and Add agent depths
  • Loading branch information
Chaitanya-Keyal authored Jan 18, 2025
2 parents 5413a09 + 07e600a commit 03d9c25
Show file tree
Hide file tree
Showing 3 changed files with 36 additions and 16 deletions.
10 changes: 9 additions & 1 deletion src/agents.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,14 @@ def __init__(self):
agent_name = agent_name.split(".")[0]
self.prompts[agent_name] = f.read()

self.depths = {
"intent_classifier": 1,
"general_campus_query": 2,
"course_query": 2,
"not_related_query": 2,
"long_term_memory": 3,
}

def _get_prompt(
self, agent_name: str, user_input: str, agent_scratchpad=False
) -> ChatPromptTemplate:
Expand Down Expand Up @@ -63,7 +71,7 @@ def intent_classifier(self, query: str, chat_history: str) -> str:
}
)

return result
return result.content

def general_campus_query(self, query: str, chat_history: str) -> str:
PROMPT_TEMPLATE = self._get_prompt(
Expand Down
39 changes: 25 additions & 14 deletions src/nodes.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,44 +9,55 @@


def intent_classifier(state: State):
query = state["messages"][0].content
result = agents.intent_classifier(query, state.get("chat_history", ""))

query = state["messages"][-agents.depths["intent_classifier"]].content
result = AIMessage(agents.intent_classifier(query, state.get("chat_history", "")))
return {"messages": [result]}


def course_query(state: State):
query = state["messages"][0].content
query = state["messages"][-agents.depths["course_query"]].content
result = AIMessage("Course query not implemented yet")
# Add short term memory here once implemented.
return {"messages": [result]}


def general_campus_query(state: State):
query = state["messages"][0].content
result = agents.general_campus_query(query, state.get("chat_history", ""))
add_short_term_memory(state["user_id"], query, result, "general_campus_query")
query = state["messages"][-agents.depths["general_campus_query"]].content
result = AIMessage(
agents.general_campus_query(query, state.get("chat_history", ""))
)
add_short_term_memory(
state["user_id"],
query,
result.content,
"general_campus_query",
)
return {"messages": [result]}


def not_related_query(state: State):
query = state["messages"][0].content
query = state["messages"][-agents.depths["not_related_query"]].content
result = AIMessage(
"I'm sorry, I don't understand the question, if it relates to campus please rephrase."
)
add_short_term_memory(
state["user_id"], query, not_related_query, "intent_classifier"
state["user_id"],
query,
"not_related_query",
"intent_classifier",
)
return {"messages": [result]}


def long_term_memory(state: State):
query = state["messages"][0].content
query = state["messages"][-agents.depths["long_term_memory"]].content
user_id = state["user_id"]
long_term_memories = parse_long_term_memory(state.get("long_term_memories", []))
result = agents.long_term_memory(
user_id,
query,
long_term_memories,
result = AIMessage(
agents.long_term_memory(
user_id,
query,
long_term_memories,
)
)
return {"messages": [result]}
3 changes: 2 additions & 1 deletion src/state.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from typing import Annotated, Optional

from langchain_core.messages import AIMessage, HumanMessage
from langgraph.graph.message import add_messages
from typing_extensions import TypedDict

Expand All @@ -8,6 +9,6 @@

class State(TypedDict):
user_id: str
messages: Annotated[list, add_messages]
messages: Annotated[list[AIMessage | HumanMessage], add_messages]
chat_history: Optional[str]
long_term_memories: Optional[list[LongTermMemory]]

0 comments on commit 03d9c25

Please sign in to comment.