Skip to content

Commit

Permalink
Merge branch 'main' into campus-rag
Browse files Browse the repository at this point in the history
  • Loading branch information
theofficialvedantjoshi authored Jan 18, 2025
2 parents ed6aabd + 999502f commit 0167d96
Show file tree
Hide file tree
Showing 8 changed files with 300 additions and 79 deletions.
10 changes: 6 additions & 4 deletions src/agents.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ def __init__(self):
agent_name = agent_name.split(".")[0]
self.prompts[agent_name] = f.read()

def get_prompt(
def _get_prompt(
self, agent_name: str, user_input: str, agent_scratchpad=False
) -> ChatPromptTemplate:

Expand All @@ -50,7 +50,7 @@ def get_prompt(
return ChatPromptTemplate.from_messages(prompt)

def intent_classifier(self, query: str, chat_history: str) -> str:
prompt = self.get_prompt(
prompt = self._get_prompt(
"INTENT_CLASSIFIER_AGENT",
f"<query>{query}</query>\n\n<history>{chat_history}</history>",
)
Expand All @@ -66,7 +66,7 @@ def intent_classifier(self, query: str, chat_history: str) -> str:
return result

def general_campus_query(self, query: str, chat_history: str) -> str:
PROMPT_TEMPLATE = self.get_prompt(
PROMPT_TEMPLATE = self._get_prompt(
"GENERAL_CAMPUS_QUERY_AGENT",
f"<query>{query}</query>\n\n<history>{chat_history}</history>",
)
Expand All @@ -87,7 +87,9 @@ def course_query(self, query: str, chat_history: str) -> str:

def long_term_memory(self, id: str, query: str, memories: str) -> str:
tools = [tool_modify_memory]
prompt = self.get_prompt("LONG_TERM_MEMORY_AGENT", query, agent_scratchpad=True)
prompt = self._get_prompt(
"LONG_TERM_MEMORY_AGENT", query, agent_scratchpad=True
)
agent = create_tool_calling_agent(self.llm, tools, prompt)
chain = AgentExecutor(agent=agent, tools=tools)
result = chain.invoke({"user_id": id, "memories": memories})
Expand Down
44 changes: 40 additions & 4 deletions src/memory/data.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
from datetime import datetime
from enum import Enum
from typing import Optional

from langchain.pydantic_v1 import BaseModel, Field
from pydantic import BaseModel, Field


class Category(str, Enum):
Expand All @@ -18,7 +19,7 @@ class Action(str, Enum):
Delete = "Delete"


class AddMemory(BaseModel):
class LongTermMemory(BaseModel):
id: str = Field(..., description="The ID of the user")
memory: str = Field(
...,
Expand All @@ -35,5 +36,40 @@ class AddMemory(BaseModel):
)


def parse_memory(memory: AddMemory):
raise NotImplementedError("This function is not yet implemented")
class ShortTermMemory(BaseModel):
id: str = Field(
...,
title="User ID",
description="Unique identifier for the user.",
examples=["123"],
)
message_id: str = Field(
...,
title="Message ID",
description="Unique identifier for the message.",
examples=["8b47dfe8-0960-4b80-b551-471b47a650a0"],
)
created_at: datetime = Field(
...,
title="Created At",
description="Timestamp of when the memory was created.",
examples=["2022-01-01T00:00:00"],
)
query: str = Field(
...,
title="Query",
description="User query.",
examples=["Where is the library?"],
)
reply: str = Field(
...,
title="Reply",
description="Agent's reply.",
examples=["The library is on the second floor."],
)
agent: str = Field(
...,
title="Agent",
description="Agent that replied.",
examples=["INTENT_CLASSIFIER_AGENT"],
)
130 changes: 130 additions & 0 deletions src/memory/long_term_memory.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,130 @@
import os

import psycopg2
from dotenv import load_dotenv
from langchain.tools import StructuredTool
from langchain_core.tools import ToolException

from src.memory.data import Category, LongTermMemory

load_dotenv()

conn = psycopg2.connect(
dbname=os.getenv("POSTGRES_DB"),
user=os.getenv("POSTGRES_USER"),
password=os.getenv("POSTGRES_PASS"),
host=os.getenv("POSTGRES_HOST"),
port=os.getenv("POSTGRES_PORT"),
)


def add_long_term_memory(
id: str, memory: str, category: str, action: str, memory_old: str = None
):
"""
Function to modify memory in the database
"""
print(
f"Modifying long term memory for {id} with action {action} and category {Category(category).value}"
)
if Category(category).value not in [
"Course Likes",
"Course Dislikes",
"Branch",
"Clubs",
"Person Attributes",
]:
return "Invalid category choose from: Course Likes, Course Dislikes, Branch, Clubs, Person Attributes"
cur = conn.cursor()
cur.execute(
"""
CREATE TABLE IF NOT EXISTS longterm_memory (
id VARCHAR(255) NOT NULL,
memory VARCHAR(255) NOT NULL,
category VARCHAR(255) NOT NULL
);
"""
)
if action == "Create":
cur.execute(
f"""
INSERT INTO longterm_memory (id, memory, category)
VALUES ('{id}', '{memory}', '{Category(category).value}');
"""
)
conn.commit()
return "Memory created successfully"
elif action == "Update":
cur.execute(
f"""
UPDATE longterm_memory
SET memory = '{memory}'
WHERE id = '{id}' AND memory = '{memory_old}';
"""
)
conn.commit()
return "Memory updated successfully"
elif action == "Delete":
cur.execute(
f"""
DELETE FROM longterm_memory
WHERE id = '{id}' AND memory = '{memory_old}' AND category = '{Category(category).value}';
"""
)
conn.commit()
return "Memory deleted successfully"
else:
return "Invalid action"


def fetch_long_term_memory(id: str) -> list[LongTermMemory]:
"""
Function to fetch long term memory from the database
"""
cur = conn.cursor()
cur.execute(
f"""
SELECT memory, category
FROM longterm_memory
WHERE id = '{id}';
"""
)
rows = cur.fetchall()
print(f"Fetched long term memory for {id} = {rows}")
return [
LongTermMemory(id=id, memory=row[0], category=row[1], action="Create")
for row in rows
]


def reset_long_term_memory(id: str) -> None:
"""
Function to reset long term memory from the database
"""
cur = conn.cursor()
cur.execute(
f"""
DELETE FROM longterm_memory
WHERE id = '{id}';
"""
)
conn.commit()
print(f"Reset long term memory for {id}")
cur.close()


def parse_long_term_memory(memory: list[LongTermMemory]) -> str:
"""
Function to parse memory from the database.
"""
memory_dict = {}
for mem in memory:
if Category(mem.category).value not in memory_dict:
memory_dict[Category(mem.category).value] = []
memory_dict[Category(mem.category).value].append(mem.memory)
long_term_memory = ""
for category in memory_dict:
long_term_memory += f"{category}:\n"
for mem in memory_dict[category]:
long_term_memory += f"- {mem}\n"
return long_term_memory
100 changes: 100 additions & 0 deletions src/memory/short_term_memory.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,100 @@
import os
import uuid
from datetime import datetime

import psycopg2
from dotenv import load_dotenv

from src.memory.data import ShortTermMemory

load_dotenv()


conn = psycopg2.connect(
dbname=os.getenv("POSTGRES_DB"),
user=os.getenv("POSTGRES_USER"),
password=os.getenv("POSTGRES_PASS"),
host=os.getenv("POSTGRES_HOST"),
port=os.getenv("POSTGRES_PORT"),
)


def add_short_term_memory(id: str, query: str, reply: str, agent: str) -> None:
"""
Function to add short term memory to the database.
"""
cur = conn.cursor()
cur.execute(
"""
CREATE TABLE IF NOT EXISTS short_term_memory (id VARCHAR(255) NOT NULL, message_id VARCHAR(255) NOT NULL, created_at TIMESTAMP NOT NULL, query VARCHAR(255) NOT NULL, reply VARCHAR(255) NOT NULL, agent VARCHAR(255) NOT NULL);
"""
)
message_id = str(uuid.uuid4())
created_at = datetime.now()
cur.execute(
f"""
INSERT INTO short_term_memory (id, message_id, created_at, query, reply, agent)
VALUES ('{id}', '{message_id}', '{created_at}', '{query}', '{reply}', '{agent}');
"""
)
conn.commit()
cur.close()
print(f"Short term memory added for {id}")


def fetch_short_term_memory(id: str) -> list[ShortTermMemory]:
"""
Function to fetch short term memory from the database.
"""
cur = conn.cursor()
cur.execute(
f"""
SELECT * FROM short_term_memory WHERE id = '{id}' ORDER BY created_at DESC LIMIT 5;
"""
)
rows = cur.fetchall()
memory = []
for row in rows:
memory.append(
ShortTermMemory(
id=row[0],
message_id=row[1],
created_at=row[2],
query=row[3],
reply=row[4],
agent=row[5],
)
)
if len(memory) > 5:
cur.execute(
f"""
DELETE FROM short_term_memory WHERE id = '{id}' AND created_at < '{memory[-1]["created_at"]}';
"""
)
conn.commit()
cur.close()
return memory


def reset_short_term_memory(id: str) -> None:
"""
Function to reset short term memory from the database.
"""
cur = conn.cursor()
cur.execute(
f"""
DELETE FROM short_term_memory WHERE id = '{id}';
"""
)
conn.commit()
cur.close()
print(f"Short term memory reset for {id}")


def parse_short_term_memory(short_term_memory: list[ShortTermMemory]) -> str:
chat_history = ""
for i, memory in enumerate(short_term_memory):
chat_history += (
f"{i+1} User: {memory.query}\nAgent ({memory.agent}): {memory.reply}\n"
)
return chat_history
12 changes: 10 additions & 2 deletions src/nodes.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
from langchain_core.messages import AIMessage

from .agents import Agents
from .memory.long_term_memory import parse_long_term_memory
from .memory.short_term_memory import add_short_term_memory
from .state import State

agents = Agents()
Expand All @@ -16,12 +18,16 @@ def intent_classifier(state: State):
def course_query(state: State):
query = state["messages"][0].content
result = AIMessage("Course query not implemented yet")
# Add short term memory here once implemented.
return {"messages": [result]}


def general_campus_query(state: State):
query = state["messages"][0].content
result = agents.general_campus_query(query, state.get("chat_history", ""))
add_short_term_memory(
state["user_id"], query, result.content, "general_campus_query"
)
return {"messages": [result]}


Expand All @@ -30,14 +36,16 @@ def not_related_query(state: State):
result = AIMessage(
"I'm sorry, I don't understand the question, if it relates to campus please rephrase."
)
add_short_term_memory(
state["user_id"], query, not_related_query, "intent_classifier"
)
return {"messages": [result]}


def long_term_memory(state: State):
query = state["messages"][0].content
user_id = state["user_id"]
# parse long term memory here.
long_term_memories = state.get("long_term_memories", "")
long_term_memories = parse_long_term_memory(state.get("long_term_memories", []))
result = agents.long_term_memory(
user_id,
query,
Expand Down
7 changes: 4 additions & 3 deletions src/prompts/LONG_TERM_MEMORY_AGENT.md
Original file line number Diff line number Diff line change
Expand Up @@ -12,11 +12,11 @@ The users id is {user_id}.

You are only interested in the following categories of information:

1. Course prefereces or likes - The user likes or is interested in a course.
2. Course dislikes - The user dislikes a course.
1. Course Likes- The user likes or is interested in a course.
2. Course Dislikes - The user dislikes a course.
3. Branch - The branch the user is pursuing in college including majors and minors.
4. Clubs - The clubs the user is part of on campus.
5. Personal attributes - Any personal information that the user provides. (e.g. Campus eating habits, Campus sports, Fests etc.). Keep this limited to the context of the campus.
5. Person Attributes - Any personal information that the user provides. (e.g. Campus eating habits, Campus sports, Fests etc.). Keep this limited to the context of the campus.

When you receive a message, you perform a sequence of steps consisting of:

Expand All @@ -25,6 +25,7 @@ When you receive a message, you perform a sequence of steps consisting of:
3. Determine if this is new knowledge, an update to old knowledge that now needs to change, or should result in deleting information that is not correct. It's possible that a product/brand you previously wrote as a dislike might now be a like, and other cases- those examples would require an update.
4. Never save the same information twice. If you see the same information in a message that you have already saved, ignore it.
5. Refer to the history for existing memories.
6. Categories must be from ['Course Likes', 'Course Dislikes', 'Branch', 'Clubs', 'Person Attributes'].

Here are the existing bits of information that we have about the user.

Expand Down
Loading

0 comments on commit 0167d96

Please sign in to comment.