-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
8 changed files
with
300 additions
and
79 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,130 @@ | ||
import os | ||
|
||
import psycopg2 | ||
from dotenv import load_dotenv | ||
from langchain.tools import StructuredTool | ||
from langchain_core.tools import ToolException | ||
|
||
from src.memory.data import Category, LongTermMemory | ||
|
||
load_dotenv() | ||
|
||
conn = psycopg2.connect( | ||
dbname=os.getenv("POSTGRES_DB"), | ||
user=os.getenv("POSTGRES_USER"), | ||
password=os.getenv("POSTGRES_PASS"), | ||
host=os.getenv("POSTGRES_HOST"), | ||
port=os.getenv("POSTGRES_PORT"), | ||
) | ||
|
||
|
||
def add_long_term_memory( | ||
id: str, memory: str, category: str, action: str, memory_old: str = None | ||
): | ||
""" | ||
Function to modify memory in the database | ||
""" | ||
print( | ||
f"Modifying long term memory for {id} with action {action} and category {Category(category).value}" | ||
) | ||
if Category(category).value not in [ | ||
"Course Likes", | ||
"Course Dislikes", | ||
"Branch", | ||
"Clubs", | ||
"Person Attributes", | ||
]: | ||
return "Invalid category choose from: Course Likes, Course Dislikes, Branch, Clubs, Person Attributes" | ||
cur = conn.cursor() | ||
cur.execute( | ||
""" | ||
CREATE TABLE IF NOT EXISTS longterm_memory ( | ||
id VARCHAR(255) NOT NULL, | ||
memory VARCHAR(255) NOT NULL, | ||
category VARCHAR(255) NOT NULL | ||
); | ||
""" | ||
) | ||
if action == "Create": | ||
cur.execute( | ||
f""" | ||
INSERT INTO longterm_memory (id, memory, category) | ||
VALUES ('{id}', '{memory}', '{Category(category).value}'); | ||
""" | ||
) | ||
conn.commit() | ||
return "Memory created successfully" | ||
elif action == "Update": | ||
cur.execute( | ||
f""" | ||
UPDATE longterm_memory | ||
SET memory = '{memory}' | ||
WHERE id = '{id}' AND memory = '{memory_old}'; | ||
""" | ||
) | ||
conn.commit() | ||
return "Memory updated successfully" | ||
elif action == "Delete": | ||
cur.execute( | ||
f""" | ||
DELETE FROM longterm_memory | ||
WHERE id = '{id}' AND memory = '{memory_old}' AND category = '{Category(category).value}'; | ||
""" | ||
) | ||
conn.commit() | ||
return "Memory deleted successfully" | ||
else: | ||
return "Invalid action" | ||
|
||
|
||
def fetch_long_term_memory(id: str) -> list[LongTermMemory]: | ||
""" | ||
Function to fetch long term memory from the database | ||
""" | ||
cur = conn.cursor() | ||
cur.execute( | ||
f""" | ||
SELECT memory, category | ||
FROM longterm_memory | ||
WHERE id = '{id}'; | ||
""" | ||
) | ||
rows = cur.fetchall() | ||
print(f"Fetched long term memory for {id} = {rows}") | ||
return [ | ||
LongTermMemory(id=id, memory=row[0], category=row[1], action="Create") | ||
for row in rows | ||
] | ||
|
||
|
||
def reset_long_term_memory(id: str) -> None: | ||
""" | ||
Function to reset long term memory from the database | ||
""" | ||
cur = conn.cursor() | ||
cur.execute( | ||
f""" | ||
DELETE FROM longterm_memory | ||
WHERE id = '{id}'; | ||
""" | ||
) | ||
conn.commit() | ||
print(f"Reset long term memory for {id}") | ||
cur.close() | ||
|
||
|
||
def parse_long_term_memory(memory: list[LongTermMemory]) -> str: | ||
""" | ||
Function to parse memory from the database. | ||
""" | ||
memory_dict = {} | ||
for mem in memory: | ||
if Category(mem.category).value not in memory_dict: | ||
memory_dict[Category(mem.category).value] = [] | ||
memory_dict[Category(mem.category).value].append(mem.memory) | ||
long_term_memory = "" | ||
for category in memory_dict: | ||
long_term_memory += f"{category}:\n" | ||
for mem in memory_dict[category]: | ||
long_term_memory += f"- {mem}\n" | ||
return long_term_memory |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,100 @@ | ||
import os | ||
import uuid | ||
from datetime import datetime | ||
|
||
import psycopg2 | ||
from dotenv import load_dotenv | ||
|
||
from src.memory.data import ShortTermMemory | ||
|
||
load_dotenv() | ||
|
||
|
||
conn = psycopg2.connect( | ||
dbname=os.getenv("POSTGRES_DB"), | ||
user=os.getenv("POSTGRES_USER"), | ||
password=os.getenv("POSTGRES_PASS"), | ||
host=os.getenv("POSTGRES_HOST"), | ||
port=os.getenv("POSTGRES_PORT"), | ||
) | ||
|
||
|
||
def add_short_term_memory(id: str, query: str, reply: str, agent: str) -> None: | ||
""" | ||
Function to add short term memory to the database. | ||
""" | ||
cur = conn.cursor() | ||
cur.execute( | ||
""" | ||
CREATE TABLE IF NOT EXISTS short_term_memory (id VARCHAR(255) NOT NULL, message_id VARCHAR(255) NOT NULL, created_at TIMESTAMP NOT NULL, query VARCHAR(255) NOT NULL, reply VARCHAR(255) NOT NULL, agent VARCHAR(255) NOT NULL); | ||
""" | ||
) | ||
message_id = str(uuid.uuid4()) | ||
created_at = datetime.now() | ||
cur.execute( | ||
f""" | ||
INSERT INTO short_term_memory (id, message_id, created_at, query, reply, agent) | ||
VALUES ('{id}', '{message_id}', '{created_at}', '{query}', '{reply}', '{agent}'); | ||
""" | ||
) | ||
conn.commit() | ||
cur.close() | ||
print(f"Short term memory added for {id}") | ||
|
||
|
||
def fetch_short_term_memory(id: str) -> list[ShortTermMemory]: | ||
""" | ||
Function to fetch short term memory from the database. | ||
""" | ||
cur = conn.cursor() | ||
cur.execute( | ||
f""" | ||
SELECT * FROM short_term_memory WHERE id = '{id}' ORDER BY created_at DESC LIMIT 5; | ||
""" | ||
) | ||
rows = cur.fetchall() | ||
memory = [] | ||
for row in rows: | ||
memory.append( | ||
ShortTermMemory( | ||
id=row[0], | ||
message_id=row[1], | ||
created_at=row[2], | ||
query=row[3], | ||
reply=row[4], | ||
agent=row[5], | ||
) | ||
) | ||
if len(memory) > 5: | ||
cur.execute( | ||
f""" | ||
DELETE FROM short_term_memory WHERE id = '{id}' AND created_at < '{memory[-1]["created_at"]}'; | ||
""" | ||
) | ||
conn.commit() | ||
cur.close() | ||
return memory | ||
|
||
|
||
def reset_short_term_memory(id: str) -> None: | ||
""" | ||
Function to reset short term memory from the database. | ||
""" | ||
cur = conn.cursor() | ||
cur.execute( | ||
f""" | ||
DELETE FROM short_term_memory WHERE id = '{id}'; | ||
""" | ||
) | ||
conn.commit() | ||
cur.close() | ||
print(f"Short term memory reset for {id}") | ||
|
||
|
||
def parse_short_term_memory(short_term_memory: list[ShortTermMemory]) -> str: | ||
chat_history = "" | ||
for i, memory in enumerate(short_term_memory): | ||
chat_history += ( | ||
f"{i+1} User: {memory.query}\nAgent ({memory.agent}): {memory.reply}\n" | ||
) | ||
return chat_history |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.