diff --git a/src/agents.py b/src/agents.py
index 29f8d7a..3608895 100644
--- a/src/agents.py
+++ b/src/agents.py
@@ -28,7 +28,7 @@ def __init__(self):
agent_name = agent_name.split(".")[0]
self.prompts[agent_name] = f.read()
- def get_prompt(
+ def _get_prompt(
self, agent_name: str, user_input: str, agent_scratchpad=False
) -> ChatPromptTemplate:
@@ -47,7 +47,7 @@ def get_prompt(
return ChatPromptTemplate.from_messages(prompt)
def intent_classifier(self, query: str, chat_history: str) -> str:
- prompt = self.get_prompt(
+ prompt = self._get_prompt(
"INTENT_CLASSIFIER_AGENT",
f"{query}\n\n{chat_history}",
)
@@ -63,7 +63,7 @@ def intent_classifier(self, query: str, chat_history: str) -> str:
return result
def general_campus_query(self, query: str, chat_history: str) -> str:
- prompt = self.get_prompt(
+ prompt = self._get_prompt(
"GENERAL_CAMPUS_QUERY_AGENT",
f"{query}\n\n{chat_history}",
)
@@ -83,7 +83,9 @@ def course_query(self, query: str, chat_history: str) -> str:
def long_term_memory(self, id: str, query: str, memories: str) -> str:
tools = [tool_modify_memory]
- prompt = self.get_prompt("LONG_TERM_MEMORY_AGENT", query, agent_scratchpad=True)
+ prompt = self._get_prompt(
+ "LONG_TERM_MEMORY_AGENT", query, agent_scratchpad=True
+ )
agent = create_tool_calling_agent(self.llm, tools, prompt)
chain = AgentExecutor(agent=agent, tools=tools)
result = chain.invoke({"user_id": id, "memories": memories})
diff --git a/src/memory/data.py b/src/memory/data.py
index 893b0f9..24795c7 100644
--- a/src/memory/data.py
+++ b/src/memory/data.py
@@ -1,7 +1,8 @@
+from datetime import datetime
from enum import Enum
from typing import Optional
-from langchain.pydantic_v1 import BaseModel, Field
+from pydantic import BaseModel, Field
class Category(str, Enum):
@@ -18,7 +19,7 @@ class Action(str, Enum):
Delete = "Delete"
-class AddMemory(BaseModel):
+class LongTermMemory(BaseModel):
id: str = Field(..., description="The ID of the user")
memory: str = Field(
...,
@@ -35,5 +36,40 @@ class AddMemory(BaseModel):
)
-def parse_memory(memory: AddMemory):
- raise NotImplementedError("This function is not yet implemented")
+class ShortTermMemory(BaseModel):
+ id: str = Field(
+ ...,
+ title="User ID",
+ description="Unique identifier for the user.",
+ examples=["123"],
+ )
+ message_id: str = Field(
+ ...,
+ title="Message ID",
+ description="Unique identifier for the message.",
+ examples=["8b47dfe8-0960-4b80-b551-471b47a650a0"],
+ )
+ created_at: datetime = Field(
+ ...,
+ title="Created At",
+ description="Timestamp of when the memory was created.",
+ examples=["2022-01-01T00:00:00"],
+ )
+ query: str = Field(
+ ...,
+ title="Query",
+ description="User query.",
+ examples=["Where is the library?"],
+ )
+ reply: str = Field(
+ ...,
+ title="Reply",
+ description="Agent's reply.",
+ examples=["The library is on the second floor."],
+ )
+ agent: str = Field(
+ ...,
+ title="Agent",
+ description="Agent that replied.",
+ examples=["INTENT_CLASSIFIER_AGENT"],
+ )
diff --git a/src/memory/long_term_memory.py b/src/memory/long_term_memory.py
new file mode 100644
index 0000000..54fe79b
--- /dev/null
+++ b/src/memory/long_term_memory.py
@@ -0,0 +1,130 @@
+import os
+
+import psycopg2
+from dotenv import load_dotenv
+from langchain.tools import StructuredTool
+from langchain_core.tools import ToolException
+
+from src.memory.data import Category, LongTermMemory
+
+load_dotenv()
+
+conn = psycopg2.connect(
+ dbname=os.getenv("POSTGRES_DB"),
+ user=os.getenv("POSTGRES_USER"),
+ password=os.getenv("POSTGRES_PASS"),
+ host=os.getenv("POSTGRES_HOST"),
+ port=os.getenv("POSTGRES_PORT"),
+)
+
+
+def add_long_term_memory(
+ id: str, memory: str, category: str, action: str, memory_old: str = None
+):
+ """
+ Function to modify memory in the database
+ """
+ print(
+ f"Modifying long term memory for {id} with action {action} and category {Category(category).value}"
+ )
+ if Category(category).value not in [
+ "Course Likes",
+ "Course Dislikes",
+ "Branch",
+ "Clubs",
+ "Person Attributes",
+ ]:
+ return "Invalid category choose from: Course Likes, Course Dislikes, Branch, Clubs, Person Attributes"
+ cur = conn.cursor()
+ cur.execute(
+ """
+ CREATE TABLE IF NOT EXISTS longterm_memory (
+ id VARCHAR(255) NOT NULL,
+ memory VARCHAR(255) NOT NULL,
+ category VARCHAR(255) NOT NULL
+ );
+ """
+ )
+ if action == "Create":
+ cur.execute(
+ f"""
+ INSERT INTO longterm_memory (id, memory, category)
+ VALUES ('{id}', '{memory}', '{Category(category).value}');
+ """
+ )
+ conn.commit()
+ return "Memory created successfully"
+ elif action == "Update":
+ cur.execute(
+ f"""
+ UPDATE longterm_memory
+ SET memory = '{memory}'
+ WHERE id = '{id}' AND memory = '{memory_old}';
+ """
+ )
+ conn.commit()
+ return "Memory updated successfully"
+ elif action == "Delete":
+ cur.execute(
+ f"""
+ DELETE FROM longterm_memory
+ WHERE id = '{id}' AND memory = '{memory_old}' AND category = '{Category(category).value}';
+ """
+ )
+ conn.commit()
+ return "Memory deleted successfully"
+ else:
+ return "Invalid action"
+
+
+def fetch_long_term_memory(id: str) -> list[LongTermMemory]:
+ """
+ Function to fetch long term memory from the database
+ """
+ cur = conn.cursor()
+ cur.execute(
+ f"""
+ SELECT memory, category
+ FROM longterm_memory
+ WHERE id = '{id}';
+ """
+ )
+ rows = cur.fetchall()
+ print(f"Fetched long term memory for {id} = {rows}")
+ return [
+ LongTermMemory(id=id, memory=row[0], category=row[1], action="Create")
+ for row in rows
+ ]
+
+
+def reset_long_term_memory(id: str) -> None:
+ """
+ Function to reset long term memory from the database
+ """
+ cur = conn.cursor()
+ cur.execute(
+ f"""
+ DELETE FROM longterm_memory
+ WHERE id = '{id}';
+ """
+ )
+ conn.commit()
+ print(f"Reset long term memory for {id}")
+ cur.close()
+
+
+def parse_long_term_memory(memory: list[LongTermMemory]) -> str:
+ """
+ Function to parse memory from the database.
+ """
+ memory_dict = {}
+ for mem in memory:
+ if Category(mem.category).value not in memory_dict:
+ memory_dict[Category(mem.category).value] = []
+ memory_dict[Category(mem.category).value].append(mem.memory)
+ long_term_memory = ""
+ for category in memory_dict:
+ long_term_memory += f"{category}:\n"
+ for mem in memory_dict[category]:
+ long_term_memory += f"- {mem}\n"
+ return long_term_memory
diff --git a/src/memory/short_term_memory.py b/src/memory/short_term_memory.py
new file mode 100644
index 0000000..f537ed8
--- /dev/null
+++ b/src/memory/short_term_memory.py
@@ -0,0 +1,100 @@
+import os
+import uuid
+from datetime import datetime
+
+import psycopg2
+from dotenv import load_dotenv
+
+from src.memory.data import ShortTermMemory
+
+load_dotenv()
+
+
+conn = psycopg2.connect(
+ dbname=os.getenv("POSTGRES_DB"),
+ user=os.getenv("POSTGRES_USER"),
+ password=os.getenv("POSTGRES_PASS"),
+ host=os.getenv("POSTGRES_HOST"),
+ port=os.getenv("POSTGRES_PORT"),
+)
+
+
+def add_short_term_memory(id: str, query: str, reply: str, agent: str) -> None:
+ """
+ Function to add short term memory to the database.
+ """
+ cur = conn.cursor()
+ cur.execute(
+ """
+ CREATE TABLE IF NOT EXISTS short_term_memory (id VARCHAR(255) NOT NULL, message_id VARCHAR(255) NOT NULL, created_at TIMESTAMP NOT NULL, query VARCHAR(255) NOT NULL, reply VARCHAR(255) NOT NULL, agent VARCHAR(255) NOT NULL);
+ """
+ )
+ message_id = str(uuid.uuid4())
+ created_at = datetime.now()
+ cur.execute(
+ f"""
+ INSERT INTO short_term_memory (id, message_id, created_at, query, reply, agent)
+ VALUES ('{id}', '{message_id}', '{created_at}', '{query}', '{reply}', '{agent}');
+ """
+ )
+ conn.commit()
+ cur.close()
+ print(f"Short term memory added for {id}")
+
+
+def fetch_short_term_memory(id: str) -> list[ShortTermMemory]:
+ """
+ Function to fetch short term memory from the database.
+ """
+ cur = conn.cursor()
+ cur.execute(
+ f"""
+ SELECT * FROM short_term_memory WHERE id = '{id}' ORDER BY created_at DESC LIMIT 5;
+ """
+ )
+ rows = cur.fetchall()
+ memory = []
+ for row in rows:
+ memory.append(
+ ShortTermMemory(
+ id=row[0],
+ message_id=row[1],
+ created_at=row[2],
+ query=row[3],
+ reply=row[4],
+ agent=row[5],
+ )
+ )
+ if len(memory) > 5:
+ cur.execute(
+ f"""
+ DELETE FROM short_term_memory WHERE id = '{id}' AND created_at < '{memory[-1]["created_at"]}';
+ """
+ )
+ conn.commit()
+ cur.close()
+ return memory
+
+
+def reset_short_term_memory(id: str) -> None:
+ """
+ Function to reset short term memory from the database.
+ """
+ cur = conn.cursor()
+ cur.execute(
+ f"""
+ DELETE FROM short_term_memory WHERE id = '{id}';
+ """
+ )
+ conn.commit()
+ cur.close()
+ print(f"Short term memory reset for {id}")
+
+
+def parse_short_term_memory(short_term_memory: list[ShortTermMemory]) -> str:
+ chat_history = ""
+ for i, memory in enumerate(short_term_memory):
+ chat_history += (
+ f"{i+1} User: {memory.query}\nAgent ({memory.agent}): {memory.reply}\n"
+ )
+ return chat_history
diff --git a/src/nodes.py b/src/nodes.py
index 1e2feb8..367c1bd 100644
--- a/src/nodes.py
+++ b/src/nodes.py
@@ -1,6 +1,8 @@
from langchain_core.messages import AIMessage
from .agents import Agents
+from .memory.long_term_memory import parse_long_term_memory
+from .memory.short_term_memory import add_short_term_memory
from .state import State
agents = Agents()
@@ -16,12 +18,16 @@ def intent_classifier(state: State):
def course_query(state: State):
query = state["messages"][0].content
result = AIMessage("Course query not implemented yet")
+ # Add short term memory here once implemented.
return {"messages": [result]}
def general_campus_query(state: State):
query = state["messages"][0].content
result = agents.general_campus_query(query, state.get("chat_history", ""))
+ add_short_term_memory(
+ state["user_id"], query, result.content, "general_campus_query"
+ )
return {"messages": [result]}
@@ -30,14 +36,16 @@ def not_related_query(state: State):
result = AIMessage(
"I'm sorry, I don't understand the question, if it relates to campus please rephrase."
)
+ add_short_term_memory(
+ state["user_id"], query, not_related_query, "intent_classifier"
+ )
return {"messages": [result]}
def long_term_memory(state: State):
query = state["messages"][0].content
user_id = state["user_id"]
- # parse long term memory here.
- long_term_memories = state.get("long_term_memories", "")
+ long_term_memories = parse_long_term_memory(state.get("long_term_memories", []))
result = agents.long_term_memory(
user_id,
query,
diff --git a/src/prompts/LONG_TERM_MEMORY_AGENT.md b/src/prompts/LONG_TERM_MEMORY_AGENT.md
index 81fab19..6e8db33 100644
--- a/src/prompts/LONG_TERM_MEMORY_AGENT.md
+++ b/src/prompts/LONG_TERM_MEMORY_AGENT.md
@@ -12,11 +12,11 @@ The users id is {user_id}.
You are only interested in the following categories of information:
-1. Course prefereces or likes - The user likes or is interested in a course.
-2. Course dislikes - The user dislikes a course.
+1. Course Likes- The user likes or is interested in a course.
+2. Course Dislikes - The user dislikes a course.
3. Branch - The branch the user is pursuing in college including majors and minors.
4. Clubs - The clubs the user is part of on campus.
-5. Personal attributes - Any personal information that the user provides. (e.g. Campus eating habits, Campus sports, Fests etc.). Keep this limited to the context of the campus.
+5. Person Attributes - Any personal information that the user provides. (e.g. Campus eating habits, Campus sports, Fests etc.). Keep this limited to the context of the campus.
When you receive a message, you perform a sequence of steps consisting of:
@@ -25,6 +25,7 @@ When you receive a message, you perform a sequence of steps consisting of:
3. Determine if this is new knowledge, an update to old knowledge that now needs to change, or should result in deleting information that is not correct. It's possible that a product/brand you previously wrote as a dislike might now be a like, and other cases- those examples would require an update.
4. Never save the same information twice. If you see the same information in a message that you have already saved, ignore it.
5. Refer to the history for existing memories.
+6. Categories must be from ['Course Likes', 'Course Dislikes', 'Branch', 'Clubs', 'Person Attributes'].
Here are the existing bits of information that we have about the user.
diff --git a/src/state.py b/src/state.py
index 61f23da..51b34f1 100644
--- a/src/state.py
+++ b/src/state.py
@@ -3,11 +3,11 @@
from langgraph.graph.message import add_messages
from typing_extensions import TypedDict
-from src.memory.data import AddMemory
+from src.memory.data import LongTermMemory
class State(TypedDict):
user_id: str
messages: Annotated[list, add_messages]
chat_history: Optional[str]
- long_term_memories: Optional[list[AddMemory]]
+ long_term_memories: Optional[list[LongTermMemory]]
diff --git a/src/tools/memory_tool.py b/src/tools/memory_tool.py
index 3a40ccc..612d1a7 100644
--- a/src/tools/memory_tool.py
+++ b/src/tools/memory_tool.py
@@ -1,74 +1,18 @@
-import os
-
-import psycopg2
-from dotenv import load_dotenv
from langchain.tools import StructuredTool
+from langchain_core.tools import ToolException
-from src.memory.data import AddMemory
-
-load_dotenv()
-
-conn = psycopg2.connect(
- dbname=os.getenv("POSTGRES_DB"),
- user=os.getenv("POSTGRES_USER"),
- password=os.getenv("POSTGRES_PASS"),
- host=os.getenv("POSTGRES_HOST"),
- port=os.getenv("POSTGRES_PORT"),
-)
+from src.memory.data import LongTermMemory
+from src.memory.long_term_memory import add_long_term_memory
-def modify_memory(
- id: str, memory: str, category: str, action: str, memory_old: str = None
-):
- """
- Function to modify memory in the database
- """
- print(f"Modifying long term memory for {id} with action {action}")
- cur = conn.cursor()
- cur.execute(
- """
- CREATE TABLE IF NOT EXISTS longterm_memory (
- id VARCHAR(255) NOT NULL,
- memory VARCHAR(255) NOT NULL,
- category VARCHAR(255) NOT NULL
- );
- """
- )
- if action == "Create":
- cur.execute(
- f"""
- INSERT INTO longterm_memory (id, memory, category)
- VALUES ('{id}', '{memory}', '{category}');
- """
- )
- conn.commit()
- return "Memory created successfully"
- elif action == "Update":
- cur.execute(
- f"""
- UPDATE longterm_memory
- SET memory = '{memory}'
- WHERE id = '{id}' AND memory = '{memory_old}';
- """
- )
- conn.commit()
- return "Memory updated successfully"
- elif action == "Delete":
- cur.execute(
- f"""
- DELETE FROM longterm_memory
- WHERE id = '{id}' AND memory = '{memory_old}' AND category = '{category}';
- """
- )
- conn.commit()
- return "Memory deleted successfully"
- else:
- return "Invalid action"
+def _handle_tool_error(error: ToolException) -> str:
+ return f"The following errors occurred during tool execution: `{error.args[0]}`"
tool_modify_memory = StructuredTool.from_function(
- func=modify_memory,
+ func=add_long_term_memory,
name="modify_memory",
description="Modify the long term memory of a user",
- args_schema=AddMemory,
+ args_schema=LongTermMemory,
+ handle_tool_error=_handle_tool_error,
)