Skip to content

Commit

Permalink
feat: long term memory in main graph.
Browse files Browse the repository at this point in the history
  • Loading branch information
theofficialvedantjoshi committed Jan 6, 2025
1 parent 9bb1412 commit 5a87362
Show file tree
Hide file tree
Showing 6 changed files with 62 additions and 13 deletions.
5 changes: 5 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,11 @@ LANGCHAIN_TRACING_V2=true
LANGCHAIN_ENDPOINT="https://api.smith.langchain.com"
LANGCHAIN_API_KEY=...
LANGCHAIN_PROJECT="bitsgpt-rewrite"
POSTGRES_DB=
POSTGRES_USER=
POSTGRES_PASS=
POSTGRES_HOST=
POSTGRES_PORT=
```

Replace the keys with the appropriate values.
Expand Down
26 changes: 20 additions & 6 deletions src/agents.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@
from dotenv import load_dotenv
from langchain_core.prompts import ChatPromptTemplate
from langchain_groq import ChatGroq
from src.tools.memory_tool import tool_modify_memory
from langchain.agents import create_tool_calling_agent, AgentExecutor

load_dotenv()

Expand All @@ -26,7 +28,7 @@ def __init__(self):
self.prompts[agent_name] = f.read()

def get_prompt(
self, agent_name: str, query: str, chat_history: str, agent_scratchpad=False
self, agent_name: str, user_input: str, agent_scratchpad=False
) -> ChatPromptTemplate:

prompt = [
Expand All @@ -36,17 +38,18 @@ def get_prompt(
),
(
"user",
textwrap.dedent(
f"<query>{query}</query>\n\n<history>{chat_history}</history>"
),
textwrap.dedent(user_input),
),
]
if agent_scratchpad:
prompt.append(("placeholder", "{agent_scratchpad}"))
return ChatPromptTemplate.from_messages(prompt)

def intent_classifier(self, query: str, chat_history: str) -> str:
prompt = self.get_prompt("INTENT_CLASSIFIER_AGENT", query, chat_history)
prompt = self.get_prompt(
"INTENT_CLASSIFIER_AGENT",
f"<query>{query}</query>\n\n<history>{chat_history}</history>",
)

chain = prompt | self.llm

Expand All @@ -59,7 +62,10 @@ def intent_classifier(self, query: str, chat_history: str) -> str:
return result

def general_campus_query(self, query: str, chat_history: str) -> str:
prompt = self.get_prompt("GENERAL_CAMPUS_QUERY_AGENT", query, chat_history)
prompt = self.get_prompt(
"GENERAL_CAMPUS_QUERY_AGENT",
f"<query>{query}</query>\n\n<history>{chat_history}</history>",
)

chain = prompt | self.llm

Expand All @@ -73,3 +79,11 @@ def general_campus_query(self, query: str, chat_history: str) -> str:

def course_query(self, query: str, chat_history: str) -> str:
raise NotImplementedError("Course query not implemented yet")

def long_term_memory(self, id: str, query: str, memories: str) -> str:
tools = [tool_modify_memory]
prompt = self.get_prompt("LONG_TERM_MEMORY_AGENT", query, agent_scratchpad=True)
agent = create_tool_calling_agent(self.llm, tools, prompt)
chain = AgentExecutor(agent=agent, tools=tools)
result = chain.invoke({"user_id": id, "memories": memories})
return result["output"]
22 changes: 18 additions & 4 deletions src/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
general_campus_query,
intent_classifier,
not_related_query,
long_term_memory,
)
from .state import State

Expand All @@ -21,6 +22,7 @@ def create_graph(self) -> StateGraph:
graph.add_node("course_query", course_query)
graph.add_node("general_campus_query", general_campus_query)
graph.add_node("not_related_query", not_related_query)
graph.add_node("long_term_memory", long_term_memory)

graph.set_entry_point("intent_classifer")

Expand All @@ -34,8 +36,20 @@ def intent_router(state):

graph.add_conditional_edges("intent_classifer", intent_router)

graph.add_edge("course_query", END)
graph.add_edge("general_campus_query", END)
graph.add_edge("not_related_query", END)

graph.add_edge(
"course_query",
"long_term_memory",
)
graph.add_edge(
"general_campus_query",
"long_term_memory",
)
graph.add_edge(
"not_related_query",
"long_term_memory",
)
graph.add_edge(
"long_term_memory",
END,
)
return graph
13 changes: 13 additions & 0 deletions src/nodes.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,3 +31,16 @@ def not_related_query(state: State):
"I'm sorry, I don't understand the question, if it relates to campus please rephrase."
)
return {"messages": [result]}


def long_term_memory(state: State):
query = state["messages"][0].content
user_id = state["user_id"]
# parse long term memory here.
long_term_memories = state.get("long_term_memories", "")
result = agents.long_term_memory(
user_id,
query,
long_term_memories,
)
return {"messages": [result]}
3 changes: 3 additions & 0 deletions src/state.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,12 @@
from typing import Annotated, Optional

from langgraph.graph.message import add_messages
from src.memory.data import AddMemory
from typing_extensions import TypedDict


class State(TypedDict):
user_id: str
messages: Annotated[list, add_messages]
chat_history: Optional[str]
long_term_memories: Optional[list[AddMemory]]
6 changes: 3 additions & 3 deletions src/tools/memory_tool.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,11 +15,11 @@ def modify_memory(
"""
print(f"Modifying memory for {id} with action {action}")
conn = psycopg2.connect(
dbname="postgres",
dbname=os.getenv("POSTGRES_DB"),
user=os.getenv("POSTGRES_USER"),
password=os.getenv("POSTGRES_PASS"),
host="localhost",
port="5432",
host=os.getenv("POSTGRES_HOST"),
port=os.getenv("POSTGRES_PORT"),
)
cur = conn.cursor()
cur.execute(
Expand Down

0 comments on commit 5a87362

Please sign in to comment.