Skip to content

Commit

Permalink
feat: support updating existing memories (#22)
Browse files Browse the repository at this point in the history
* fix issue where looking for no collection throws error

* checkpoint

* add update to chroma

* add updating memories

* support ephemeral mode

* formatting

* add flag for auto update

* fix tests

* formatting
  • Loading branch information
shihanwan authored Oct 15, 2024
1 parent 206d58b commit 9831c1f
Show file tree
Hide file tree
Showing 10 changed files with 299 additions and 32 deletions.
11 changes: 6 additions & 5 deletions memonto/core/recall.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,11 +9,12 @@


def _hydrate_triples(
triples: list,
matched: list,
triple_store: VectorStoreModel,
id: str = None,
) -> Graph:
triple_ids = " ".join(f'("{triple_id}")' for triple_id in triples)
matched_ids = matched.keys()
triple_ids = " ".join(f'("{id}")' for id in matched_ids)

graph_id = f"data-{id}" if id else "data"

Expand Down Expand Up @@ -161,11 +162,11 @@ def get_contextual_memory(
memory = serialize_graph_without_ids(data)
elif context:
try:
matched_triples = vector_store.search(message=context, id=id)
logger.debug(f"Matched Triples Raw\n{matched_triples}\n")
matched = vector_store.search(message=context, id=id)
logger.debug(f"Matched Triples Raw\n{matched}\n")

matched_graph = _hydrate_triples(
triples=matched_triples,
matched=matched,
triple_store=triple_store,
id=id,
)
Expand Down
155 changes: 139 additions & 16 deletions memonto/core/retain.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,19 @@
import ast
from rdflib import Graph, Namespace

from memonto.llms.base_llm import LLMModel
from memonto.stores.triple.base_store import TripleStoreModel
from memonto.stores.vector.base_store import VectorStoreModel
from memonto.utils.logger import logger
from memonto.utils.rdf import _render, hydrate_graph_with_ids
from memonto.utils.rdf import (
_render,
find_updated_triples,
find_updated_triples_ephemeral,
hydrate_graph_with_ids,
)


def run_script(
def _run_script(
script: str,
exec_ctx: dict,
message: str,
Expand Down Expand Up @@ -66,37 +72,103 @@ def expand_ontology(
return ontology


def _retain(
ontology: Graph,
namespaces: dict[str, Namespace],
def update_memory(
data: Graph,
llm: LLMModel,
triple_store: TripleStoreModel,
vector_store: VectorStoreModel,
str_ontology: str,
message: str,
id: str,
auto_expand: bool,
ephemeral: bool,
) -> None:
if auto_expand:
ontology = expand_ontology(
ontology=ontology,
llm=llm,
message=message,
) -> str:
if ephemeral:
data_list = []

for s, p, o in data:
data_list.append(
{
"s": str(s),
"p": str(p),
"o": str(o),
}
)

logger.debug(f"existing memories\n{data_list}\n")

updates = llm.prompt(
prompt_name="update_memory",
temperature=0.2,
ontology=str_ontology,
user_message=message,
existing_memory=str(data_list),
)
logger.debug(f"updated memories\n{updates}\n")

updates = ast.literal_eval(updates)
updated_memory = find_updated_triples_ephemeral(updates, data_list)
logger.debug(f"memories diff\n{updated_memory}\n")

for s, p, o in data:
for t in updated_memory:
if str(s) == t["s"] and str(p) == t["p"] and str(o) == t["o"]:
data.remove((s, p, o))

return str(updated_memory)
else:
matched = vector_store.search(message=message, id=id, k=3)
logger.debug(f"existing memories\n{matched}\n")

if not matched:
return ""

updates = llm.prompt(
prompt_name="update_memory",
temperature=0.2,
ontology=str_ontology,
user_message=message,
existing_memory=str(matched),
)

str_ontology = ontology.serialize(format="turtle")
updates = ast.literal_eval(updates)
logger.debug(f"updated memories\n{updates}\n")

updated_memory = find_updated_triples(original=matched, updated=updates)
logger.debug(f"memories diff\n{updated_memory}\n")

if not updated_memory:
return ""

vector_store.delete_by_ids(graph_id=id, ids=updated_memory.keys())
triple_store.delete_by_ids(graph_id=id, ids=updated_memory.keys())

return str(updated_memory)


def save_memory(
ontology: Graph,
namespaces: dict[str, Namespace],
data: Graph,
llm: LLMModel,
triple_store: TripleStoreModel,
vector_store: VectorStoreModel,
message: str,
id: str,
ephemeral: bool,
str_ontology: str,
updated_memory: str,
) -> None:
script = llm.prompt(
prompt_name="commit_to_memory",
temperature=0.2,
ontology=str_ontology,
user_message=message,
updated_memory=updated_memory,
)

logger.debug(f"Retain Script\n{script}\n")

data = run_script(
data = _run_script(
script=script,
exec_ctx={"data": data} | namespaces,
message=message,
Expand All @@ -107,13 +179,64 @@ def _retain(

logger.debug(f"Data Graph\n{data.serialize(format='turtle')}\n")

# debug
# _render(g=data, format="image")

if not ephemeral:
hydrate_graph_with_ids(data)

triple_store.save(ontology=ontology, data=data, id=id)

if vector_store:
vector_store.save(g=data, id=id)

# print(_render(g=data, format="image"))
data.remove((None, None, None))


def _retain(
ontology: Graph,
namespaces: dict[str, Namespace],
data: Graph,
llm: LLMModel,
triple_store: TripleStoreModel,
vector_store: VectorStoreModel,
message: str,
id: str,
auto_expand: bool,
auto_update: bool,
ephemeral: bool,
) -> None:
str_ontology = ontology.serialize(format="turtle")
updated_memory = ""

if auto_expand:
ontology = expand_ontology(
ontology=ontology,
llm=llm,
message=message,
)

if auto_update:
updated_memory = update_memory(
data=data,
llm=llm,
vector_store=vector_store,
triple_store=triple_store,
str_ontology=str_ontology,
message=message,
id=id,
ephemeral=ephemeral,
)

save_memory(
ontology=ontology,
namespaces=namespaces,
data=data,
llm=llm,
triple_store=triple_store,
vector_store=vector_store,
message=message,
id=id,
ephemeral=ephemeral,
str_ontology=str_ontology,
updated_memory=updated_memory,
)
2 changes: 2 additions & 0 deletions memonto/memonto.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ class Memonto(BaseModel):
triple_store: Optional[TripleStoreModel] = None
vector_store: Optional[VectorStoreModel] = None
auto_expand: Optional[bool] = False
auto_update: Optional[bool] = False
ephemeral: Optional[bool] = False
debug: Optional[bool] = False
model_config = ConfigDict(arbitrary_types_allowed=True)
Expand Down Expand Up @@ -63,6 +64,7 @@ def retain(self, message: str) -> None:
message=message,
id=self.id,
auto_expand=self.auto_expand,
auto_update=self.auto_update,
ephemeral=self.ephemeral,
)

Expand Down
17 changes: 11 additions & 6 deletions memonto/prompts/commit_to_memory.prompt
Original file line number Diff line number Diff line change
@@ -1,19 +1,24 @@
You are a software engineer tasked to create a Python script to extract ALL POSSIBLE relevant information from a user message that maps to a predefined RDF ontology.

Given the following RDF graph that defines our desired ontology and namespaces:
Given this RDF graph that defines our desired ontology and namespaces:
```
${ontology}
```

And the following user message:
And this user message:
```
${user_message}
```

Analyze the user message to find AS MUCH relevant information AS POSSIBLE that could fit onto the above ontology then generate the Python code while adhering to these rules:
- First find all the information in the user message that maps onto the above ontology.
- Then apply only the existing namespaces to the new information.
- Finally create the script that will add them to graph `data`.
And these removed triples:
```
${updated_memory}
```

Analyze the user message to find AS MUCH new information AS POSSIBLE that could fit onto the above ontology while adhering to these rules:
- First find all the new information in the user message that maps onto BOTH the above ontology and ESPECIALLY the removed triples.
- Second apply the existing namespaces to the extracted information.
- Finally create the script that will add the extracted information to an rdflib graph called `data`.
- NEVER generate code that initializes new graphs, namespaces, classes, properties, etc.
- GENERATE Python code to add the triples with the relevant information assuming rdflib Graph `data` and the newly added namespaces already exists.
- GENERATE all necessary rdflib and rdflib.namespace imports for the code to run.
Expand Down
23 changes: 23 additions & 0 deletions memonto/prompts/update_memory.prompt
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
Given these existing triples in our graph `data`:
```
${existing_memory}
```

and this RDF graph that defines our desired ontology and namespaces:
```
${ontology}
```

And this user message:
```
${user_message}
```

Analyze the user message to find information that updates the existing triples while adhering to these rules:
- First find all the information in the user message that updates existing information in the graph `data`.
- Second replace the values of existing triples with the updated information without modifying anything else.
- Finally output the dictionary of existing triples as a string that can be converted back to a list.
- IGNORE any new information that does not update the existing triples.
- NEVER add any new triples to the existing triples.
- If there are no information to update then return the existing triples as they are.
- Please return only the string of the dictionary without using ``` or any other code formatting symbols. Return only plain text.
36 changes: 36 additions & 0 deletions memonto/stores/triple/jena.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
import json
from rdflib import Graph, Literal, Namespace, URIRef
from SPARQLWrapper import SPARQLWrapper, GET, POST, TURTLE, JSON
from SPARQLWrapper.SPARQLExceptions import SPARQLWrapperException
from typing import Tuple

from memonto.stores.triple.base_store import TripleStoreModel
from memonto.utils.logger import logger
from memonto.utils.namespaces import TRIPLE_PROP


class ApacheJena(TripleStoreModel):
Expand Down Expand Up @@ -169,6 +171,40 @@ def delete(self, id: str = None) -> None:
query=query,
)

def delete_by_ids(self, ids: list[str], graph_id: str = None) -> None:
g_id = f"data-{graph_id}" if graph_id else "data"
t_ids = " ".join(f'"{id}"' for id in ids)

query = f"""
PREFIX rdf: <http://www.w3.org/1999/02/22-rdf-syntax-ns#>
DELETE {{
GRAPH <{g_id}> {{
?triple_node <{TRIPLE_PROP.uuid}> ?uuid .
?triple_node rdf:subject ?s ;
rdf:predicate ?p ;
rdf:object ?o .
?s ?p ?o .
}}
}}
WHERE {{
VALUES ?uuid {{ {t_ids} }}
GRAPH <{g_id}> {{
?triple_node <{TRIPLE_PROP.uuid}> ?uuid .
?triple_node rdf:subject ?s ;
rdf:predicate ?p ;
rdf:object ?o .
?s ?p ?o .
}}
}}
"""

self._query(
url=f"{self.connection_url}/update",
method=POST,
query=query,
)

def query(self, query: str, method: str = GET, format: str = JSON) -> list:
result = self._query(
url=f"{self.connection_url}/sparql",
Expand Down
19 changes: 15 additions & 4 deletions memonto/stores/vector/chroma.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,21 +84,32 @@ def save(self, g: Graph, id: str = None) -> None:
except Exception as e:
logger.error(f"Chroma Save\n{e}\n")

def search(self, message: str, id: str = None, k: int = 3) -> list[dict]:
collection = self.client.get_collection(id or "default")

def search(self, message: str, id: str = None, k: int = 3) -> dict[str, dict]:
try:
collection = self.client.get_collection(id or "default")
matched = collection.query(
query_texts=[message],
n_results=k,
)
except ValueError as e:
return {}
except Exception as e:
logger.error(f"Chroma Search\n{e}\n")

return matched.get("ids", [])[0]
ids = matched.get("ids", [[]])[0]
meta = matched.get("metadatas", [[]])[0]

return {id: meta[i] if i < len(meta) else None for i, id in enumerate(ids)}

def delete(self, id: str) -> None:
try:
self.client.delete_collection(id)
except Exception as e:
logger.error(f"Chroma Delete\n{e}\n")

def delete_by_ids(self, graph_id: str, ids: list[str]) -> None:
try:
collection = self.client.get_collection(graph_id or "default")
collection.delete(ids=list(ids))
except Exception as e:
logger.error(f"Chroma Delete by IDs\n{e}\n")
Loading

0 comments on commit 9831c1f

Please sign in to comment.