diff --git a/src/agents.py b/src/agents.py index 29f8d7a..3608895 100644 --- a/src/agents.py +++ b/src/agents.py @@ -28,7 +28,7 @@ def __init__(self): agent_name = agent_name.split(".")[0] self.prompts[agent_name] = f.read() - def get_prompt( + def _get_prompt( self, agent_name: str, user_input: str, agent_scratchpad=False ) -> ChatPromptTemplate: @@ -47,7 +47,7 @@ def get_prompt( return ChatPromptTemplate.from_messages(prompt) def intent_classifier(self, query: str, chat_history: str) -> str: - prompt = self.get_prompt( + prompt = self._get_prompt( "INTENT_CLASSIFIER_AGENT", f"{query}\n\n{chat_history}", ) @@ -63,7 +63,7 @@ def intent_classifier(self, query: str, chat_history: str) -> str: return result def general_campus_query(self, query: str, chat_history: str) -> str: - prompt = self.get_prompt( + prompt = self._get_prompt( "GENERAL_CAMPUS_QUERY_AGENT", f"{query}\n\n{chat_history}", ) @@ -83,7 +83,9 @@ def course_query(self, query: str, chat_history: str) -> str: def long_term_memory(self, id: str, query: str, memories: str) -> str: tools = [tool_modify_memory] - prompt = self.get_prompt("LONG_TERM_MEMORY_AGENT", query, agent_scratchpad=True) + prompt = self._get_prompt( + "LONG_TERM_MEMORY_AGENT", query, agent_scratchpad=True + ) agent = create_tool_calling_agent(self.llm, tools, prompt) chain = AgentExecutor(agent=agent, tools=tools) result = chain.invoke({"user_id": id, "memories": memories}) diff --git a/src/memory/data.py b/src/memory/data.py index 893b0f9..24795c7 100644 --- a/src/memory/data.py +++ b/src/memory/data.py @@ -1,7 +1,8 @@ +from datetime import datetime from enum import Enum from typing import Optional -from langchain.pydantic_v1 import BaseModel, Field +from pydantic import BaseModel, Field class Category(str, Enum): @@ -18,7 +19,7 @@ class Action(str, Enum): Delete = "Delete" -class AddMemory(BaseModel): +class LongTermMemory(BaseModel): id: str = Field(..., description="The ID of the user") memory: str = Field( ..., @@ -35,5 +36,40 @@ class AddMemory(BaseModel): ) -def parse_memory(memory: AddMemory): - raise NotImplementedError("This function is not yet implemented") +class ShortTermMemory(BaseModel): + id: str = Field( + ..., + title="User ID", + description="Unique identifier for the user.", + examples=["123"], + ) + message_id: str = Field( + ..., + title="Message ID", + description="Unique identifier for the message.", + examples=["8b47dfe8-0960-4b80-b551-471b47a650a0"], + ) + created_at: datetime = Field( + ..., + title="Created At", + description="Timestamp of when the memory was created.", + examples=["2022-01-01T00:00:00"], + ) + query: str = Field( + ..., + title="Query", + description="User query.", + examples=["Where is the library?"], + ) + reply: str = Field( + ..., + title="Reply", + description="Agent's reply.", + examples=["The library is on the second floor."], + ) + agent: str = Field( + ..., + title="Agent", + description="Agent that replied.", + examples=["INTENT_CLASSIFIER_AGENT"], + ) diff --git a/src/memory/long_term_memory.py b/src/memory/long_term_memory.py new file mode 100644 index 0000000..54fe79b --- /dev/null +++ b/src/memory/long_term_memory.py @@ -0,0 +1,130 @@ +import os + +import psycopg2 +from dotenv import load_dotenv +from langchain.tools import StructuredTool +from langchain_core.tools import ToolException + +from src.memory.data import Category, LongTermMemory + +load_dotenv() + +conn = psycopg2.connect( + dbname=os.getenv("POSTGRES_DB"), + user=os.getenv("POSTGRES_USER"), + password=os.getenv("POSTGRES_PASS"), + host=os.getenv("POSTGRES_HOST"), + port=os.getenv("POSTGRES_PORT"), +) + + +def add_long_term_memory( + id: str, memory: str, category: str, action: str, memory_old: str = None +): + """ + Function to modify memory in the database + """ + print( + f"Modifying long term memory for {id} with action {action} and category {Category(category).value}" + ) + if Category(category).value not in [ + "Course Likes", + "Course Dislikes", + "Branch", + "Clubs", + "Person Attributes", + ]: + return "Invalid category choose from: Course Likes, Course Dislikes, Branch, Clubs, Person Attributes" + cur = conn.cursor() + cur.execute( + """ + CREATE TABLE IF NOT EXISTS longterm_memory ( + id VARCHAR(255) NOT NULL, + memory VARCHAR(255) NOT NULL, + category VARCHAR(255) NOT NULL + ); + """ + ) + if action == "Create": + cur.execute( + f""" + INSERT INTO longterm_memory (id, memory, category) + VALUES ('{id}', '{memory}', '{Category(category).value}'); + """ + ) + conn.commit() + return "Memory created successfully" + elif action == "Update": + cur.execute( + f""" + UPDATE longterm_memory + SET memory = '{memory}' + WHERE id = '{id}' AND memory = '{memory_old}'; + """ + ) + conn.commit() + return "Memory updated successfully" + elif action == "Delete": + cur.execute( + f""" + DELETE FROM longterm_memory + WHERE id = '{id}' AND memory = '{memory_old}' AND category = '{Category(category).value}'; + """ + ) + conn.commit() + return "Memory deleted successfully" + else: + return "Invalid action" + + +def fetch_long_term_memory(id: str) -> list[LongTermMemory]: + """ + Function to fetch long term memory from the database + """ + cur = conn.cursor() + cur.execute( + f""" + SELECT memory, category + FROM longterm_memory + WHERE id = '{id}'; + """ + ) + rows = cur.fetchall() + print(f"Fetched long term memory for {id} = {rows}") + return [ + LongTermMemory(id=id, memory=row[0], category=row[1], action="Create") + for row in rows + ] + + +def reset_long_term_memory(id: str) -> None: + """ + Function to reset long term memory from the database + """ + cur = conn.cursor() + cur.execute( + f""" + DELETE FROM longterm_memory + WHERE id = '{id}'; + """ + ) + conn.commit() + print(f"Reset long term memory for {id}") + cur.close() + + +def parse_long_term_memory(memory: list[LongTermMemory]) -> str: + """ + Function to parse memory from the database. + """ + memory_dict = {} + for mem in memory: + if Category(mem.category).value not in memory_dict: + memory_dict[Category(mem.category).value] = [] + memory_dict[Category(mem.category).value].append(mem.memory) + long_term_memory = "" + for category in memory_dict: + long_term_memory += f"{category}:\n" + for mem in memory_dict[category]: + long_term_memory += f"- {mem}\n" + return long_term_memory diff --git a/src/memory/short_term_memory.py b/src/memory/short_term_memory.py new file mode 100644 index 0000000..f537ed8 --- /dev/null +++ b/src/memory/short_term_memory.py @@ -0,0 +1,100 @@ +import os +import uuid +from datetime import datetime + +import psycopg2 +from dotenv import load_dotenv + +from src.memory.data import ShortTermMemory + +load_dotenv() + + +conn = psycopg2.connect( + dbname=os.getenv("POSTGRES_DB"), + user=os.getenv("POSTGRES_USER"), + password=os.getenv("POSTGRES_PASS"), + host=os.getenv("POSTGRES_HOST"), + port=os.getenv("POSTGRES_PORT"), +) + + +def add_short_term_memory(id: str, query: str, reply: str, agent: str) -> None: + """ + Function to add short term memory to the database. + """ + cur = conn.cursor() + cur.execute( + """ + CREATE TABLE IF NOT EXISTS short_term_memory (id VARCHAR(255) NOT NULL, message_id VARCHAR(255) NOT NULL, created_at TIMESTAMP NOT NULL, query VARCHAR(255) NOT NULL, reply VARCHAR(255) NOT NULL, agent VARCHAR(255) NOT NULL); + """ + ) + message_id = str(uuid.uuid4()) + created_at = datetime.now() + cur.execute( + f""" + INSERT INTO short_term_memory (id, message_id, created_at, query, reply, agent) + VALUES ('{id}', '{message_id}', '{created_at}', '{query}', '{reply}', '{agent}'); + """ + ) + conn.commit() + cur.close() + print(f"Short term memory added for {id}") + + +def fetch_short_term_memory(id: str) -> list[ShortTermMemory]: + """ + Function to fetch short term memory from the database. + """ + cur = conn.cursor() + cur.execute( + f""" + SELECT * FROM short_term_memory WHERE id = '{id}' ORDER BY created_at DESC LIMIT 5; + """ + ) + rows = cur.fetchall() + memory = [] + for row in rows: + memory.append( + ShortTermMemory( + id=row[0], + message_id=row[1], + created_at=row[2], + query=row[3], + reply=row[4], + agent=row[5], + ) + ) + if len(memory) > 5: + cur.execute( + f""" + DELETE FROM short_term_memory WHERE id = '{id}' AND created_at < '{memory[-1]["created_at"]}'; + """ + ) + conn.commit() + cur.close() + return memory + + +def reset_short_term_memory(id: str) -> None: + """ + Function to reset short term memory from the database. + """ + cur = conn.cursor() + cur.execute( + f""" + DELETE FROM short_term_memory WHERE id = '{id}'; + """ + ) + conn.commit() + cur.close() + print(f"Short term memory reset for {id}") + + +def parse_short_term_memory(short_term_memory: list[ShortTermMemory]) -> str: + chat_history = "" + for i, memory in enumerate(short_term_memory): + chat_history += ( + f"{i+1} User: {memory.query}\nAgent ({memory.agent}): {memory.reply}\n" + ) + return chat_history diff --git a/src/nodes.py b/src/nodes.py index 1e2feb8..367c1bd 100644 --- a/src/nodes.py +++ b/src/nodes.py @@ -1,6 +1,8 @@ from langchain_core.messages import AIMessage from .agents import Agents +from .memory.long_term_memory import parse_long_term_memory +from .memory.short_term_memory import add_short_term_memory from .state import State agents = Agents() @@ -16,12 +18,16 @@ def intent_classifier(state: State): def course_query(state: State): query = state["messages"][0].content result = AIMessage("Course query not implemented yet") + # Add short term memory here once implemented. return {"messages": [result]} def general_campus_query(state: State): query = state["messages"][0].content result = agents.general_campus_query(query, state.get("chat_history", "")) + add_short_term_memory( + state["user_id"], query, result.content, "general_campus_query" + ) return {"messages": [result]} @@ -30,14 +36,16 @@ def not_related_query(state: State): result = AIMessage( "I'm sorry, I don't understand the question, if it relates to campus please rephrase." ) + add_short_term_memory( + state["user_id"], query, not_related_query, "intent_classifier" + ) return {"messages": [result]} def long_term_memory(state: State): query = state["messages"][0].content user_id = state["user_id"] - # parse long term memory here. - long_term_memories = state.get("long_term_memories", "") + long_term_memories = parse_long_term_memory(state.get("long_term_memories", [])) result = agents.long_term_memory( user_id, query, diff --git a/src/prompts/LONG_TERM_MEMORY_AGENT.md b/src/prompts/LONG_TERM_MEMORY_AGENT.md index 81fab19..6e8db33 100644 --- a/src/prompts/LONG_TERM_MEMORY_AGENT.md +++ b/src/prompts/LONG_TERM_MEMORY_AGENT.md @@ -12,11 +12,11 @@ The users id is {user_id}. You are only interested in the following categories of information: -1. Course prefereces or likes - The user likes or is interested in a course. -2. Course dislikes - The user dislikes a course. +1. Course Likes- The user likes or is interested in a course. +2. Course Dislikes - The user dislikes a course. 3. Branch - The branch the user is pursuing in college including majors and minors. 4. Clubs - The clubs the user is part of on campus. -5. Personal attributes - Any personal information that the user provides. (e.g. Campus eating habits, Campus sports, Fests etc.). Keep this limited to the context of the campus. +5. Person Attributes - Any personal information that the user provides. (e.g. Campus eating habits, Campus sports, Fests etc.). Keep this limited to the context of the campus. When you receive a message, you perform a sequence of steps consisting of: @@ -25,6 +25,7 @@ When you receive a message, you perform a sequence of steps consisting of: 3. Determine if this is new knowledge, an update to old knowledge that now needs to change, or should result in deleting information that is not correct. It's possible that a product/brand you previously wrote as a dislike might now be a like, and other cases- those examples would require an update. 4. Never save the same information twice. If you see the same information in a message that you have already saved, ignore it. 5. Refer to the history for existing memories. +6. Categories must be from ['Course Likes', 'Course Dislikes', 'Branch', 'Clubs', 'Person Attributes']. Here are the existing bits of information that we have about the user. diff --git a/src/state.py b/src/state.py index 61f23da..51b34f1 100644 --- a/src/state.py +++ b/src/state.py @@ -3,11 +3,11 @@ from langgraph.graph.message import add_messages from typing_extensions import TypedDict -from src.memory.data import AddMemory +from src.memory.data import LongTermMemory class State(TypedDict): user_id: str messages: Annotated[list, add_messages] chat_history: Optional[str] - long_term_memories: Optional[list[AddMemory]] + long_term_memories: Optional[list[LongTermMemory]] diff --git a/src/tools/memory_tool.py b/src/tools/memory_tool.py index 3a40ccc..612d1a7 100644 --- a/src/tools/memory_tool.py +++ b/src/tools/memory_tool.py @@ -1,74 +1,18 @@ -import os - -import psycopg2 -from dotenv import load_dotenv from langchain.tools import StructuredTool +from langchain_core.tools import ToolException -from src.memory.data import AddMemory - -load_dotenv() - -conn = psycopg2.connect( - dbname=os.getenv("POSTGRES_DB"), - user=os.getenv("POSTGRES_USER"), - password=os.getenv("POSTGRES_PASS"), - host=os.getenv("POSTGRES_HOST"), - port=os.getenv("POSTGRES_PORT"), -) +from src.memory.data import LongTermMemory +from src.memory.long_term_memory import add_long_term_memory -def modify_memory( - id: str, memory: str, category: str, action: str, memory_old: str = None -): - """ - Function to modify memory in the database - """ - print(f"Modifying long term memory for {id} with action {action}") - cur = conn.cursor() - cur.execute( - """ - CREATE TABLE IF NOT EXISTS longterm_memory ( - id VARCHAR(255) NOT NULL, - memory VARCHAR(255) NOT NULL, - category VARCHAR(255) NOT NULL - ); - """ - ) - if action == "Create": - cur.execute( - f""" - INSERT INTO longterm_memory (id, memory, category) - VALUES ('{id}', '{memory}', '{category}'); - """ - ) - conn.commit() - return "Memory created successfully" - elif action == "Update": - cur.execute( - f""" - UPDATE longterm_memory - SET memory = '{memory}' - WHERE id = '{id}' AND memory = '{memory_old}'; - """ - ) - conn.commit() - return "Memory updated successfully" - elif action == "Delete": - cur.execute( - f""" - DELETE FROM longterm_memory - WHERE id = '{id}' AND memory = '{memory_old}' AND category = '{category}'; - """ - ) - conn.commit() - return "Memory deleted successfully" - else: - return "Invalid action" +def _handle_tool_error(error: ToolException) -> str: + return f"The following errors occurred during tool execution: `{error.args[0]}`" tool_modify_memory = StructuredTool.from_function( - func=modify_memory, + func=add_long_term_memory, name="modify_memory", description="Modify the long term memory of a user", - args_schema=AddMemory, + args_schema=LongTermMemory, + handle_tool_error=_handle_tool_error, )