Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Short term memory #5

Merged
merged 4 commits into from
Jan 18, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 6 additions & 4 deletions src/agents.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ def __init__(self):
agent_name = agent_name.split(".")[0]
self.prompts[agent_name] = f.read()

def get_prompt(
def _get_prompt(
self, agent_name: str, user_input: str, agent_scratchpad=False
) -> ChatPromptTemplate:

Expand All @@ -47,7 +47,7 @@ def get_prompt(
return ChatPromptTemplate.from_messages(prompt)

def intent_classifier(self, query: str, chat_history: str) -> str:
prompt = self.get_prompt(
prompt = self._get_prompt(
"INTENT_CLASSIFIER_AGENT",
f"<query>{query}</query>\n\n<history>{chat_history}</history>",
)
Expand All @@ -63,7 +63,7 @@ def intent_classifier(self, query: str, chat_history: str) -> str:
return result

def general_campus_query(self, query: str, chat_history: str) -> str:
prompt = self.get_prompt(
prompt = self._get_prompt(
"GENERAL_CAMPUS_QUERY_AGENT",
f"<query>{query}</query>\n\n<history>{chat_history}</history>",
)
Expand All @@ -83,7 +83,9 @@ def course_query(self, query: str, chat_history: str) -> str:

def long_term_memory(self, id: str, query: str, memories: str) -> str:
tools = [tool_modify_memory]
prompt = self.get_prompt("LONG_TERM_MEMORY_AGENT", query, agent_scratchpad=True)
prompt = self._get_prompt(
"LONG_TERM_MEMORY_AGENT", query, agent_scratchpad=True
)
agent = create_tool_calling_agent(self.llm, tools, prompt)
chain = AgentExecutor(agent=agent, tools=tools)
result = chain.invoke({"user_id": id, "memories": memories})
Expand Down
44 changes: 40 additions & 4 deletions src/memory/data.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
from datetime import datetime
from enum import Enum
from typing import Optional

from langchain.pydantic_v1 import BaseModel, Field
from pydantic import BaseModel, Field


class Category(str, Enum):
Expand All @@ -18,7 +19,7 @@ class Action(str, Enum):
Delete = "Delete"


class AddMemory(BaseModel):
class LongTermMemory(BaseModel):
id: str = Field(..., description="The ID of the user")
memory: str = Field(
...,
Expand All @@ -35,5 +36,40 @@ class AddMemory(BaseModel):
)


def parse_memory(memory: AddMemory):
raise NotImplementedError("This function is not yet implemented")
class ShortTermMemory(BaseModel):
id: str = Field(
...,
title="User ID",
description="Unique identifier for the user.",
examples=["123"],
)
message_id: str = Field(
...,
title="Message ID",
description="Unique identifier for the message.",
examples=["8b47dfe8-0960-4b80-b551-471b47a650a0"],
)
created_at: datetime = Field(
...,
title="Created At",
description="Timestamp of when the memory was created.",
examples=["2022-01-01T00:00:00"],
)
query: str = Field(
...,
title="Query",
description="User query.",
examples=["Where is the library?"],
)
reply: str = Field(
...,
title="Reply",
description="Agent's reply.",
examples=["The library is on the second floor."],
)
agent: str = Field(
...,
title="Agent",
description="Agent that replied.",
examples=["INTENT_CLASSIFIER_AGENT"],
)
130 changes: 130 additions & 0 deletions src/memory/long_term_memory.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,130 @@
import os

import psycopg2
from dotenv import load_dotenv
from langchain.tools import StructuredTool
from langchain_core.tools import ToolException

from src.memory.data import Category, LongTermMemory

load_dotenv()

conn = psycopg2.connect(
dbname=os.getenv("POSTGRES_DB"),
user=os.getenv("POSTGRES_USER"),
password=os.getenv("POSTGRES_PASS"),
host=os.getenv("POSTGRES_HOST"),
port=os.getenv("POSTGRES_PORT"),
)


def add_long_term_memory(
id: str, memory: str, category: str, action: str, memory_old: str = None
):
"""
Function to modify memory in the database
"""
print(
f"Modifying long term memory for {id} with action {action} and category {Category(category).value}"
)
if Category(category).value not in [
"Course Likes",
"Course Dislikes",
"Branch",
"Clubs",
"Person Attributes",
]:
return "Invalid category choose from: Course Likes, Course Dislikes, Branch, Clubs, Person Attributes"
cur = conn.cursor()
cur.execute(
"""
CREATE TABLE IF NOT EXISTS longterm_memory (
id VARCHAR(255) NOT NULL,
memory VARCHAR(255) NOT NULL,
category VARCHAR(255) NOT NULL
);
"""
)
if action == "Create":
cur.execute(
f"""
INSERT INTO longterm_memory (id, memory, category)
VALUES ('{id}', '{memory}', '{Category(category).value}');
"""
)
conn.commit()
return "Memory created successfully"
elif action == "Update":
cur.execute(
f"""
UPDATE longterm_memory
SET memory = '{memory}'
WHERE id = '{id}' AND memory = '{memory_old}';
"""
)
conn.commit()
return "Memory updated successfully"
elif action == "Delete":
cur.execute(
f"""
DELETE FROM longterm_memory
WHERE id = '{id}' AND memory = '{memory_old}' AND category = '{Category(category).value}';
"""
)
conn.commit()
return "Memory deleted successfully"
else:
return "Invalid action"


def fetch_long_term_memory(id: str) -> list[LongTermMemory]:
"""
Function to fetch long term memory from the database
"""
cur = conn.cursor()
cur.execute(
f"""
SELECT memory, category
FROM longterm_memory
WHERE id = '{id}';
"""
)
rows = cur.fetchall()
print(f"Fetched long term memory for {id} = {rows}")
return [
LongTermMemory(id=id, memory=row[0], category=row[1], action="Create")
for row in rows
]


def reset_long_term_memory(id: str) -> None:
"""
Function to reset long term memory from the database
"""
cur = conn.cursor()
cur.execute(
f"""
DELETE FROM longterm_memory
WHERE id = '{id}';
"""
)
conn.commit()
print(f"Reset long term memory for {id}")
cur.close()


def parse_long_term_memory(memory: list[LongTermMemory]) -> str:
"""
Function to parse memory from the database.
"""
memory_dict = {}
for mem in memory:
if Category(mem.category).value not in memory_dict:
memory_dict[Category(mem.category).value] = []
memory_dict[Category(mem.category).value].append(mem.memory)
long_term_memory = ""
for category in memory_dict:
long_term_memory += f"{category}:\n"
for mem in memory_dict[category]:
long_term_memory += f"- {mem}\n"
return long_term_memory
100 changes: 100 additions & 0 deletions src/memory/short_term_memory.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,100 @@
import os
import uuid
from datetime import datetime

import psycopg2
from dotenv import load_dotenv

from src.memory.data import ShortTermMemory

load_dotenv()


conn = psycopg2.connect(
dbname=os.getenv("POSTGRES_DB"),
user=os.getenv("POSTGRES_USER"),
password=os.getenv("POSTGRES_PASS"),
host=os.getenv("POSTGRES_HOST"),
port=os.getenv("POSTGRES_PORT"),
)


def add_short_term_memory(id: str, query: str, reply: str, agent: str) -> None:
"""
Function to add short term memory to the database.
"""
cur = conn.cursor()
cur.execute(
"""
CREATE TABLE IF NOT EXISTS short_term_memory (id VARCHAR(255) NOT NULL, message_id VARCHAR(255) NOT NULL, created_at TIMESTAMP NOT NULL, query VARCHAR(255) NOT NULL, reply VARCHAR(255) NOT NULL, agent VARCHAR(255) NOT NULL);
"""
)
message_id = str(uuid.uuid4())
created_at = datetime.now()
cur.execute(
f"""
INSERT INTO short_term_memory (id, message_id, created_at, query, reply, agent)
VALUES ('{id}', '{message_id}', '{created_at}', '{query}', '{reply}', '{agent}');
"""
)
conn.commit()
cur.close()
print(f"Short term memory added for {id}")


def fetch_short_term_memory(id: str) -> list[ShortTermMemory]:
"""
Function to fetch short term memory from the database.
"""
cur = conn.cursor()
cur.execute(
f"""
SELECT * FROM short_term_memory WHERE id = '{id}' ORDER BY created_at DESC LIMIT 5;
"""
)
rows = cur.fetchall()
memory = []
for row in rows:
memory.append(
ShortTermMemory(
id=row[0],
message_id=row[1],
created_at=row[2],
query=row[3],
reply=row[4],
agent=row[5],
)
)
if len(memory) > 5:
cur.execute(
f"""
DELETE FROM short_term_memory WHERE id = '{id}' AND created_at < '{memory[-1]["created_at"]}';
"""
)
conn.commit()
cur.close()
return memory


def reset_short_term_memory(id: str) -> None:
"""
Function to reset short term memory from the database.
"""
cur = conn.cursor()
cur.execute(
f"""
DELETE FROM short_term_memory WHERE id = '{id}';
"""
)
conn.commit()
cur.close()
print(f"Short term memory reset for {id}")


def parse_short_term_memory(short_term_memory: list[ShortTermMemory]) -> str:
chat_history = ""
for i, memory in enumerate(short_term_memory):
chat_history += (
f"{i+1} User: {memory.query}\nAgent ({memory.agent}): {memory.reply}\n"
)
return chat_history
12 changes: 10 additions & 2 deletions src/nodes.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
from langchain_core.messages import AIMessage

from .agents import Agents
from .memory.long_term_memory import parse_long_term_memory
from .memory.short_term_memory import add_short_term_memory
from .state import State

agents = Agents()
Expand All @@ -16,12 +18,16 @@ def intent_classifier(state: State):
def course_query(state: State):
query = state["messages"][0].content
result = AIMessage("Course query not implemented yet")
# Add short term memory here once implemented.
return {"messages": [result]}


def general_campus_query(state: State):
query = state["messages"][0].content
result = agents.general_campus_query(query, state.get("chat_history", ""))
add_short_term_memory(
state["user_id"], query, result.content, "general_campus_query"
)
return {"messages": [result]}


Expand All @@ -30,14 +36,16 @@ def not_related_query(state: State):
result = AIMessage(
"I'm sorry, I don't understand the question, if it relates to campus please rephrase."
)
add_short_term_memory(
state["user_id"], query, not_related_query, "intent_classifier"
)
return {"messages": [result]}


def long_term_memory(state: State):
query = state["messages"][0].content
user_id = state["user_id"]
# parse long term memory here.
long_term_memories = state.get("long_term_memories", "")
long_term_memories = parse_long_term_memory(state.get("long_term_memories", []))
result = agents.long_term_memory(
user_id,
query,
Expand Down
7 changes: 4 additions & 3 deletions src/prompts/LONG_TERM_MEMORY_AGENT.md
Original file line number Diff line number Diff line change
Expand Up @@ -12,11 +12,11 @@ The users id is {user_id}.

You are only interested in the following categories of information:

1. Course prefereces or likes - The user likes or is interested in a course.
2. Course dislikes - The user dislikes a course.
1. Course Likes- The user likes or is interested in a course.
2. Course Dislikes - The user dislikes a course.
3. Branch - The branch the user is pursuing in college including majors and minors.
4. Clubs - The clubs the user is part of on campus.
5. Personal attributes - Any personal information that the user provides. (e.g. Campus eating habits, Campus sports, Fests etc.). Keep this limited to the context of the campus.
5. Person Attributes - Any personal information that the user provides. (e.g. Campus eating habits, Campus sports, Fests etc.). Keep this limited to the context of the campus.

When you receive a message, you perform a sequence of steps consisting of:

Expand All @@ -25,6 +25,7 @@ When you receive a message, you perform a sequence of steps consisting of:
3. Determine if this is new knowledge, an update to old knowledge that now needs to change, or should result in deleting information that is not correct. It's possible that a product/brand you previously wrote as a dislike might now be a like, and other cases- those examples would require an update.
4. Never save the same information twice. If you see the same information in a message that you have already saved, ignore it.
5. Refer to the history for existing memories.
6. Categories must be from ['Course Likes', 'Course Dislikes', 'Branch', 'Clubs', 'Person Attributes'].

Here are the existing bits of information that we have about the user.

Expand Down
Loading
Loading