From 63649c3765c02abe51bf9b3abbc28c9034584683 Mon Sep 17 00:00:00 2001 From: theofficialvedantjoshi Date: Sat, 18 Jan 2025 11:48:27 +0530 Subject: [PATCH 1/4] Feat: Fetching and resetting long term memory. --- src/prompts/LONG_TERM_MEMORY_AGENT.md | 7 ++-- src/tools/memory_tool.py | 59 +++++++++++++++++++++++++-- 2 files changed, 59 insertions(+), 7 deletions(-) diff --git a/src/prompts/LONG_TERM_MEMORY_AGENT.md b/src/prompts/LONG_TERM_MEMORY_AGENT.md index 81fab19..6e8db33 100644 --- a/src/prompts/LONG_TERM_MEMORY_AGENT.md +++ b/src/prompts/LONG_TERM_MEMORY_AGENT.md @@ -12,11 +12,11 @@ The users id is {user_id}. You are only interested in the following categories of information: -1. Course prefereces or likes - The user likes or is interested in a course. -2. Course dislikes - The user dislikes a course. +1. Course Likes- The user likes or is interested in a course. +2. Course Dislikes - The user dislikes a course. 3. Branch - The branch the user is pursuing in college including majors and minors. 4. Clubs - The clubs the user is part of on campus. -5. Personal attributes - Any personal information that the user provides. (e.g. Campus eating habits, Campus sports, Fests etc.). Keep this limited to the context of the campus. +5. Person Attributes - Any personal information that the user provides. (e.g. Campus eating habits, Campus sports, Fests etc.). Keep this limited to the context of the campus. When you receive a message, you perform a sequence of steps consisting of: @@ -25,6 +25,7 @@ When you receive a message, you perform a sequence of steps consisting of: 3. Determine if this is new knowledge, an update to old knowledge that now needs to change, or should result in deleting information that is not correct. It's possible that a product/brand you previously wrote as a dislike might now be a like, and other cases- those examples would require an update. 4. Never save the same information twice. If you see the same information in a message that you have already saved, ignore it. 5. Refer to the history for existing memories. +6. Categories must be from ['Course Likes', 'Course Dislikes', 'Branch', 'Clubs', 'Person Attributes']. Here are the existing bits of information that we have about the user. diff --git a/src/tools/memory_tool.py b/src/tools/memory_tool.py index 3a40ccc..93f052b 100644 --- a/src/tools/memory_tool.py +++ b/src/tools/memory_tool.py @@ -3,8 +3,9 @@ import psycopg2 from dotenv import load_dotenv from langchain.tools import StructuredTool +from langchain_core.tools import ToolException -from src.memory.data import AddMemory +from src.memory.data import AddMemory, Category load_dotenv() @@ -23,7 +24,17 @@ def modify_memory( """ Function to modify memory in the database """ - print(f"Modifying long term memory for {id} with action {action}") + print( + f"Modifying long term memory for {id} with action {action} and category {Category(category).value}" + ) + if Category(category).value not in [ + "Course Likes", + "Course Dislikes", + "Branch", + "Clubs", + "Person Attributes", + ]: + return "Invalid category choose from: Course Likes, Course Dislikes, Branch, Clubs, Person Attributes" cur = conn.cursor() cur.execute( """ @@ -38,7 +49,7 @@ def modify_memory( cur.execute( f""" INSERT INTO longterm_memory (id, memory, category) - VALUES ('{id}', '{memory}', '{category}'); + VALUES ('{id}', '{memory}', '{Category(category).value}'); """ ) conn.commit() @@ -57,7 +68,7 @@ def modify_memory( cur.execute( f""" DELETE FROM longterm_memory - WHERE id = '{id}' AND memory = '{memory_old}' AND category = '{category}'; + WHERE id = '{id}' AND memory = '{memory_old}' AND category = '{Category(category).value}'; """ ) conn.commit() @@ -66,9 +77,49 @@ def modify_memory( return "Invalid action" +def fetch_long_term_memory(id: str) -> list[AddMemory]: + """ + Function to fetch long term memory from the database + """ + cur = conn.cursor() + cur.execute( + f""" + SELECT memory, category + FROM longterm_memory + WHERE id = '{id}'; + """ + ) + rows = cur.fetchall() + print(f"Fetched long term memory for {id} = {rows}") + return [ + AddMemory(id=id, memory=row[0], category=row[1], action="Create") + for row in rows + ] + + +def reset_long_term_memory(id: str) -> None: + """ + Function to reset long term memory from the database + """ + cur = conn.cursor() + cur.execute( + f""" + DELETE FROM longterm_memory + WHERE id = '{id}'; + """ + ) + conn.commit() + print(f"Reset long term memory for {id}") + + +def _handle_tool_error(error: ToolException) -> str: + return f"The following errors occurred during tool execution: `{error.args[0]}`" + + tool_modify_memory = StructuredTool.from_function( func=modify_memory, name="modify_memory", description="Modify the long term memory of a user", args_schema=AddMemory, + handle_tool_error=_handle_tool_error, ) From 012468e2f2e966c7e0db5ef2d56a2d16d64c6e36 Mon Sep 17 00:00:00 2001 From: theofficialvedantjoshi Date: Sat, 18 Jan 2025 11:48:52 +0530 Subject: [PATCH 2/4] Feat: Parsing long term memories. --- src/memory/data.py | 19 ++++++++++++++++--- src/nodes.py | 4 ++-- 2 files changed, 18 insertions(+), 5 deletions(-) diff --git a/src/memory/data.py b/src/memory/data.py index 893b0f9..6d57729 100644 --- a/src/memory/data.py +++ b/src/memory/data.py @@ -1,7 +1,7 @@ from enum import Enum from typing import Optional -from langchain.pydantic_v1 import BaseModel, Field +from pydantic import BaseModel, Field class Category(str, Enum): @@ -35,5 +35,18 @@ class AddMemory(BaseModel): ) -def parse_memory(memory: AddMemory): - raise NotImplementedError("This function is not yet implemented") +def parse_memory(memory: list[AddMemory]) -> str: + """ + Function to parse memory from the database. + """ + memory_dict = {} + for mem in memory: + if Category(mem.category).value not in memory_dict: + memory_dict[Category(mem.category).value] = [] + memory_dict[Category(mem.category).value].append(mem.memory) + long_term_memory = "" + for category in memory_dict: + long_term_memory += f"{category}:\n" + for mem in memory_dict[category]: + long_term_memory += f"- {mem}\n" + return long_term_memory diff --git a/src/nodes.py b/src/nodes.py index 1e2feb8..16ab42a 100644 --- a/src/nodes.py +++ b/src/nodes.py @@ -1,6 +1,7 @@ from langchain_core.messages import AIMessage from .agents import Agents +from .memory.data import parse_memory from .state import State agents = Agents() @@ -36,8 +37,7 @@ def not_related_query(state: State): def long_term_memory(state: State): query = state["messages"][0].content user_id = state["user_id"] - # parse long term memory here. - long_term_memories = state.get("long_term_memories", "") + long_term_memories = parse_memory(state.get("long_term_memories", [])) result = agents.long_term_memory( user_id, query, From e05d9828099476d83d566711d52028c501c03164 Mon Sep 17 00:00:00 2001 From: theofficialvedantjoshi Date: Sat, 18 Jan 2025 13:52:02 +0530 Subject: [PATCH 3/4] Feat: Short term memory for agents. --- src/agents.py | 10 ++- src/memory/data.py | 55 ++++++++++---- src/memory/long_term_memory.py | 130 ++++++++++++++++++++++++++++++++ src/memory/short_term_memory.py | 100 ++++++++++++++++++++++++ src/nodes.py | 4 +- src/state.py | 4 +- src/tools/memory_tool.py | 115 +--------------------------- 7 files changed, 283 insertions(+), 135 deletions(-) create mode 100644 src/memory/long_term_memory.py create mode 100644 src/memory/short_term_memory.py diff --git a/src/agents.py b/src/agents.py index 29f8d7a..3608895 100644 --- a/src/agents.py +++ b/src/agents.py @@ -28,7 +28,7 @@ def __init__(self): agent_name = agent_name.split(".")[0] self.prompts[agent_name] = f.read() - def get_prompt( + def _get_prompt( self, agent_name: str, user_input: str, agent_scratchpad=False ) -> ChatPromptTemplate: @@ -47,7 +47,7 @@ def get_prompt( return ChatPromptTemplate.from_messages(prompt) def intent_classifier(self, query: str, chat_history: str) -> str: - prompt = self.get_prompt( + prompt = self._get_prompt( "INTENT_CLASSIFIER_AGENT", f"{query}\n\n{chat_history}", ) @@ -63,7 +63,7 @@ def intent_classifier(self, query: str, chat_history: str) -> str: return result def general_campus_query(self, query: str, chat_history: str) -> str: - prompt = self.get_prompt( + prompt = self._get_prompt( "GENERAL_CAMPUS_QUERY_AGENT", f"{query}\n\n{chat_history}", ) @@ -83,7 +83,9 @@ def course_query(self, query: str, chat_history: str) -> str: def long_term_memory(self, id: str, query: str, memories: str) -> str: tools = [tool_modify_memory] - prompt = self.get_prompt("LONG_TERM_MEMORY_AGENT", query, agent_scratchpad=True) + prompt = self._get_prompt( + "LONG_TERM_MEMORY_AGENT", query, agent_scratchpad=True + ) agent = create_tool_calling_agent(self.llm, tools, prompt) chain = AgentExecutor(agent=agent, tools=tools) result = chain.invoke({"user_id": id, "memories": memories}) diff --git a/src/memory/data.py b/src/memory/data.py index 6d57729..24795c7 100644 --- a/src/memory/data.py +++ b/src/memory/data.py @@ -1,3 +1,4 @@ +from datetime import datetime from enum import Enum from typing import Optional @@ -18,7 +19,7 @@ class Action(str, Enum): Delete = "Delete" -class AddMemory(BaseModel): +class LongTermMemory(BaseModel): id: str = Field(..., description="The ID of the user") memory: str = Field( ..., @@ -35,18 +36,40 @@ class AddMemory(BaseModel): ) -def parse_memory(memory: list[AddMemory]) -> str: - """ - Function to parse memory from the database. - """ - memory_dict = {} - for mem in memory: - if Category(mem.category).value not in memory_dict: - memory_dict[Category(mem.category).value] = [] - memory_dict[Category(mem.category).value].append(mem.memory) - long_term_memory = "" - for category in memory_dict: - long_term_memory += f"{category}:\n" - for mem in memory_dict[category]: - long_term_memory += f"- {mem}\n" - return long_term_memory +class ShortTermMemory(BaseModel): + id: str = Field( + ..., + title="User ID", + description="Unique identifier for the user.", + examples=["123"], + ) + message_id: str = Field( + ..., + title="Message ID", + description="Unique identifier for the message.", + examples=["8b47dfe8-0960-4b80-b551-471b47a650a0"], + ) + created_at: datetime = Field( + ..., + title="Created At", + description="Timestamp of when the memory was created.", + examples=["2022-01-01T00:00:00"], + ) + query: str = Field( + ..., + title="Query", + description="User query.", + examples=["Where is the library?"], + ) + reply: str = Field( + ..., + title="Reply", + description="Agent's reply.", + examples=["The library is on the second floor."], + ) + agent: str = Field( + ..., + title="Agent", + description="Agent that replied.", + examples=["INTENT_CLASSIFIER_AGENT"], + ) diff --git a/src/memory/long_term_memory.py b/src/memory/long_term_memory.py new file mode 100644 index 0000000..54fe79b --- /dev/null +++ b/src/memory/long_term_memory.py @@ -0,0 +1,130 @@ +import os + +import psycopg2 +from dotenv import load_dotenv +from langchain.tools import StructuredTool +from langchain_core.tools import ToolException + +from src.memory.data import Category, LongTermMemory + +load_dotenv() + +conn = psycopg2.connect( + dbname=os.getenv("POSTGRES_DB"), + user=os.getenv("POSTGRES_USER"), + password=os.getenv("POSTGRES_PASS"), + host=os.getenv("POSTGRES_HOST"), + port=os.getenv("POSTGRES_PORT"), +) + + +def add_long_term_memory( + id: str, memory: str, category: str, action: str, memory_old: str = None +): + """ + Function to modify memory in the database + """ + print( + f"Modifying long term memory for {id} with action {action} and category {Category(category).value}" + ) + if Category(category).value not in [ + "Course Likes", + "Course Dislikes", + "Branch", + "Clubs", + "Person Attributes", + ]: + return "Invalid category choose from: Course Likes, Course Dislikes, Branch, Clubs, Person Attributes" + cur = conn.cursor() + cur.execute( + """ + CREATE TABLE IF NOT EXISTS longterm_memory ( + id VARCHAR(255) NOT NULL, + memory VARCHAR(255) NOT NULL, + category VARCHAR(255) NOT NULL + ); + """ + ) + if action == "Create": + cur.execute( + f""" + INSERT INTO longterm_memory (id, memory, category) + VALUES ('{id}', '{memory}', '{Category(category).value}'); + """ + ) + conn.commit() + return "Memory created successfully" + elif action == "Update": + cur.execute( + f""" + UPDATE longterm_memory + SET memory = '{memory}' + WHERE id = '{id}' AND memory = '{memory_old}'; + """ + ) + conn.commit() + return "Memory updated successfully" + elif action == "Delete": + cur.execute( + f""" + DELETE FROM longterm_memory + WHERE id = '{id}' AND memory = '{memory_old}' AND category = '{Category(category).value}'; + """ + ) + conn.commit() + return "Memory deleted successfully" + else: + return "Invalid action" + + +def fetch_long_term_memory(id: str) -> list[LongTermMemory]: + """ + Function to fetch long term memory from the database + """ + cur = conn.cursor() + cur.execute( + f""" + SELECT memory, category + FROM longterm_memory + WHERE id = '{id}'; + """ + ) + rows = cur.fetchall() + print(f"Fetched long term memory for {id} = {rows}") + return [ + LongTermMemory(id=id, memory=row[0], category=row[1], action="Create") + for row in rows + ] + + +def reset_long_term_memory(id: str) -> None: + """ + Function to reset long term memory from the database + """ + cur = conn.cursor() + cur.execute( + f""" + DELETE FROM longterm_memory + WHERE id = '{id}'; + """ + ) + conn.commit() + print(f"Reset long term memory for {id}") + cur.close() + + +def parse_long_term_memory(memory: list[LongTermMemory]) -> str: + """ + Function to parse memory from the database. + """ + memory_dict = {} + for mem in memory: + if Category(mem.category).value not in memory_dict: + memory_dict[Category(mem.category).value] = [] + memory_dict[Category(mem.category).value].append(mem.memory) + long_term_memory = "" + for category in memory_dict: + long_term_memory += f"{category}:\n" + for mem in memory_dict[category]: + long_term_memory += f"- {mem}\n" + return long_term_memory diff --git a/src/memory/short_term_memory.py b/src/memory/short_term_memory.py new file mode 100644 index 0000000..f537ed8 --- /dev/null +++ b/src/memory/short_term_memory.py @@ -0,0 +1,100 @@ +import os +import uuid +from datetime import datetime + +import psycopg2 +from dotenv import load_dotenv + +from src.memory.data import ShortTermMemory + +load_dotenv() + + +conn = psycopg2.connect( + dbname=os.getenv("POSTGRES_DB"), + user=os.getenv("POSTGRES_USER"), + password=os.getenv("POSTGRES_PASS"), + host=os.getenv("POSTGRES_HOST"), + port=os.getenv("POSTGRES_PORT"), +) + + +def add_short_term_memory(id: str, query: str, reply: str, agent: str) -> None: + """ + Function to add short term memory to the database. + """ + cur = conn.cursor() + cur.execute( + """ + CREATE TABLE IF NOT EXISTS short_term_memory (id VARCHAR(255) NOT NULL, message_id VARCHAR(255) NOT NULL, created_at TIMESTAMP NOT NULL, query VARCHAR(255) NOT NULL, reply VARCHAR(255) NOT NULL, agent VARCHAR(255) NOT NULL); + """ + ) + message_id = str(uuid.uuid4()) + created_at = datetime.now() + cur.execute( + f""" + INSERT INTO short_term_memory (id, message_id, created_at, query, reply, agent) + VALUES ('{id}', '{message_id}', '{created_at}', '{query}', '{reply}', '{agent}'); + """ + ) + conn.commit() + cur.close() + print(f"Short term memory added for {id}") + + +def fetch_short_term_memory(id: str) -> list[ShortTermMemory]: + """ + Function to fetch short term memory from the database. + """ + cur = conn.cursor() + cur.execute( + f""" + SELECT * FROM short_term_memory WHERE id = '{id}' ORDER BY created_at DESC LIMIT 5; + """ + ) + rows = cur.fetchall() + memory = [] + for row in rows: + memory.append( + ShortTermMemory( + id=row[0], + message_id=row[1], + created_at=row[2], + query=row[3], + reply=row[4], + agent=row[5], + ) + ) + if len(memory) > 5: + cur.execute( + f""" + DELETE FROM short_term_memory WHERE id = '{id}' AND created_at < '{memory[-1]["created_at"]}'; + """ + ) + conn.commit() + cur.close() + return memory + + +def reset_short_term_memory(id: str) -> None: + """ + Function to reset short term memory from the database. + """ + cur = conn.cursor() + cur.execute( + f""" + DELETE FROM short_term_memory WHERE id = '{id}'; + """ + ) + conn.commit() + cur.close() + print(f"Short term memory reset for {id}") + + +def parse_short_term_memory(short_term_memory: list[ShortTermMemory]) -> str: + chat_history = "" + for i, memory in enumerate(short_term_memory): + chat_history += ( + f"{i+1} User: {memory.query}\nAgent ({memory.agent}): {memory.reply}\n" + ) + return chat_history diff --git a/src/nodes.py b/src/nodes.py index 16ab42a..c6ea769 100644 --- a/src/nodes.py +++ b/src/nodes.py @@ -1,7 +1,7 @@ from langchain_core.messages import AIMessage from .agents import Agents -from .memory.data import parse_memory +from .memory.long_term_memory import parse_long_term_memory from .state import State agents = Agents() @@ -37,7 +37,7 @@ def not_related_query(state: State): def long_term_memory(state: State): query = state["messages"][0].content user_id = state["user_id"] - long_term_memories = parse_memory(state.get("long_term_memories", [])) + long_term_memories = parse_long_term_memory(state.get("long_term_memories", [])) result = agents.long_term_memory( user_id, query, diff --git a/src/state.py b/src/state.py index 61f23da..51b34f1 100644 --- a/src/state.py +++ b/src/state.py @@ -3,11 +3,11 @@ from langgraph.graph.message import add_messages from typing_extensions import TypedDict -from src.memory.data import AddMemory +from src.memory.data import LongTermMemory class State(TypedDict): user_id: str messages: Annotated[list, add_messages] chat_history: Optional[str] - long_term_memories: Optional[list[AddMemory]] + long_term_memories: Optional[list[LongTermMemory]] diff --git a/src/tools/memory_tool.py b/src/tools/memory_tool.py index 93f052b..612d1a7 100644 --- a/src/tools/memory_tool.py +++ b/src/tools/memory_tool.py @@ -1,115 +1,8 @@ -import os - -import psycopg2 -from dotenv import load_dotenv from langchain.tools import StructuredTool from langchain_core.tools import ToolException -from src.memory.data import AddMemory, Category - -load_dotenv() - -conn = psycopg2.connect( - dbname=os.getenv("POSTGRES_DB"), - user=os.getenv("POSTGRES_USER"), - password=os.getenv("POSTGRES_PASS"), - host=os.getenv("POSTGRES_HOST"), - port=os.getenv("POSTGRES_PORT"), -) - - -def modify_memory( - id: str, memory: str, category: str, action: str, memory_old: str = None -): - """ - Function to modify memory in the database - """ - print( - f"Modifying long term memory for {id} with action {action} and category {Category(category).value}" - ) - if Category(category).value not in [ - "Course Likes", - "Course Dislikes", - "Branch", - "Clubs", - "Person Attributes", - ]: - return "Invalid category choose from: Course Likes, Course Dislikes, Branch, Clubs, Person Attributes" - cur = conn.cursor() - cur.execute( - """ - CREATE TABLE IF NOT EXISTS longterm_memory ( - id VARCHAR(255) NOT NULL, - memory VARCHAR(255) NOT NULL, - category VARCHAR(255) NOT NULL - ); - """ - ) - if action == "Create": - cur.execute( - f""" - INSERT INTO longterm_memory (id, memory, category) - VALUES ('{id}', '{memory}', '{Category(category).value}'); - """ - ) - conn.commit() - return "Memory created successfully" - elif action == "Update": - cur.execute( - f""" - UPDATE longterm_memory - SET memory = '{memory}' - WHERE id = '{id}' AND memory = '{memory_old}'; - """ - ) - conn.commit() - return "Memory updated successfully" - elif action == "Delete": - cur.execute( - f""" - DELETE FROM longterm_memory - WHERE id = '{id}' AND memory = '{memory_old}' AND category = '{Category(category).value}'; - """ - ) - conn.commit() - return "Memory deleted successfully" - else: - return "Invalid action" - - -def fetch_long_term_memory(id: str) -> list[AddMemory]: - """ - Function to fetch long term memory from the database - """ - cur = conn.cursor() - cur.execute( - f""" - SELECT memory, category - FROM longterm_memory - WHERE id = '{id}'; - """ - ) - rows = cur.fetchall() - print(f"Fetched long term memory for {id} = {rows}") - return [ - AddMemory(id=id, memory=row[0], category=row[1], action="Create") - for row in rows - ] - - -def reset_long_term_memory(id: str) -> None: - """ - Function to reset long term memory from the database - """ - cur = conn.cursor() - cur.execute( - f""" - DELETE FROM longterm_memory - WHERE id = '{id}'; - """ - ) - conn.commit() - print(f"Reset long term memory for {id}") +from src.memory.data import LongTermMemory +from src.memory.long_term_memory import add_long_term_memory def _handle_tool_error(error: ToolException) -> str: @@ -117,9 +10,9 @@ def _handle_tool_error(error: ToolException) -> str: tool_modify_memory = StructuredTool.from_function( - func=modify_memory, + func=add_long_term_memory, name="modify_memory", description="Modify the long term memory of a user", - args_schema=AddMemory, + args_schema=LongTermMemory, handle_tool_error=_handle_tool_error, ) From 67b87ff2f11c2088580b9f16eaddcd4825777edb Mon Sep 17 00:00:00 2001 From: theofficialvedantjoshi Date: Sat, 18 Jan 2025 14:06:11 +0530 Subject: [PATCH 4/4] Feat: Add short term memory for end nodes. --- src/nodes.py | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/src/nodes.py b/src/nodes.py index c6ea769..367c1bd 100644 --- a/src/nodes.py +++ b/src/nodes.py @@ -2,6 +2,7 @@ from .agents import Agents from .memory.long_term_memory import parse_long_term_memory +from .memory.short_term_memory import add_short_term_memory from .state import State agents = Agents() @@ -17,12 +18,16 @@ def intent_classifier(state: State): def course_query(state: State): query = state["messages"][0].content result = AIMessage("Course query not implemented yet") + # Add short term memory here once implemented. return {"messages": [result]} def general_campus_query(state: State): query = state["messages"][0].content result = agents.general_campus_query(query, state.get("chat_history", "")) + add_short_term_memory( + state["user_id"], query, result.content, "general_campus_query" + ) return {"messages": [result]} @@ -31,6 +36,9 @@ def not_related_query(state: State): result = AIMessage( "I'm sorry, I don't understand the question, if it relates to campus please rephrase." ) + add_short_term_memory( + state["user_id"], query, not_related_query, "intent_classifier" + ) return {"messages": [result]}