Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: add basic forget functionality #6

Merged
merged 3 commits into from
Sep 27, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 14 additions & 0 deletions memonto/core/forget.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
from memonto.stores.triple.base_store import TripleStoreModel
from memonto.stores.vector.base_store import VectorStoreModel


def forget_memory(
id: str,
triple_store: TripleStoreModel,
vector_store: VectorStoreModel,
) -> None:
if vector_store:
vector_store.delete(id)

if triple_store:
triple_store.delete(id)
30 changes: 17 additions & 13 deletions memonto/core/recall.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,19 +88,23 @@ def recall_memory(
id: str,
) -> str:
if message:
matched_triples = vector_store.search(message=message, id=id)
triples = _hydrate_triples(
triples=matched_triples,
triple_store=triple_store,
id=id,
)
contextual_memory = _find_adjacent_triples(
triples=triples,
triple_store=triple_store,
id=id,
)

logger.debug(f"Matched Triples\n{json.dumps(triples, indent=2)}\n")
try:
matched_triples = vector_store.search(message=message, id=id)
triples = _hydrate_triples(
triples=matched_triples,
triple_store=triple_store,
id=id,
)
contextual_memory = _find_adjacent_triples(
triples=triples,
triple_store=triple_store,
id=id,
)

logger.debug(f"Matched Triples\n{json.dumps(triples, indent=2)}\n")
except ValueError as e:
logger.debug(f"Recall Exception\n{e}\n")
contextual_memory = ""
else:
contextual_memory = _find_all(triple_store=triple_store, id=id)

Expand Down
10 changes: 8 additions & 2 deletions memonto/memonto.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

from memonto.core.configure import configure
from memonto.core.init import init
from memonto.core.forget import forget_memory
from memonto.core.query import query_memory_data
from memonto.core.recall import recall_memory
from memonto.core.remember import load_memory
Expand Down Expand Up @@ -109,6 +110,7 @@ def recall(self, message: str = None) -> str:
id=self.id,
)

# TODO: no longer needed, can be deprecated or removed
def remember(self) -> None:
"""
Load existing memories from the memory store to a memonto instance.
Expand All @@ -123,11 +125,15 @@ def remember(self) -> None:
id=self.id,
)

def forget(self):
def forget(self) -> None:
"""
Remove memories from the memory store.
"""
pass
forget_memory(
id=self.id,
triple_store=self.triple_store,
vector_store=self.vector_store,
)

def query(self, uri: URIRef = None, query: str = None) -> list:
"""
Expand Down
1 change: 1 addition & 0 deletions memonto/prompts/summarize_memory.prompt
Original file line number Diff line number Diff line change
Expand Up @@ -9,4 +9,5 @@ Describe the RDF graph in one paragraph and make sure to follow these rules:
- FOCUS on the telling a story about the who, what, where, how, etc.
- LEAVE OUT anything not explicitly defined, do not make assumptions.
- DO NOT describe the RDF graph schema and DO NOT mention the RDF graph at all.
- If the RDF graph is empty then just return that there are currently no stored memory.
- Make sure to use plain and simple English.
9 changes: 9 additions & 0 deletions memonto/stores/triple/jena.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,6 +160,15 @@ def get(

return result["results"]["bindings"]

def delete(self, id: str = None) -> None:
query = f"""DROP GRAPH <ontology-{id}> ; DROP GRAPH <data-{id}> ;"""

self._query(
url=f"{self.connection_url}/update",
method=POST,
query=query,
)

def query(self, query: str, method: str = GET, format: str = JSON) -> list:
result = self._query(
url=f"{self.connection_url}/sparql",
Expand Down
3 changes: 3 additions & 0 deletions memonto/stores/vector/chroma.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,3 +78,6 @@ def search(self, message: str, id: str = None, k: int = 3) -> list[dict]:
)

return [json.loads(t.get("triple", "{}")) for t in matched["metadatas"][0]]

def delete(self, id: str) -> None:
self.client.delete_collection(id)
Loading