Skip to content

Commit

Permalink
remove bnodes from recall
Browse files Browse the repository at this point in the history
  • Loading branch information
shihanwan committed Oct 16, 2024
1 parent 81e5a31 commit 9677149
Show file tree
Hide file tree
Showing 5 changed files with 163 additions and 159 deletions.
2 changes: 1 addition & 1 deletion memonto/core/forget.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ def _forget(
vector_store.delete(id)

if triple_store:
triple_store.delete(id)
triple_store.delete_all(id)
except ValueError as e:
logger.warning(e)
except Exception as e:
Expand Down
156 changes: 3 additions & 153 deletions memonto/core/recall.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,146 +8,6 @@
from memonto.utils.rdf import serialize_graph_without_ids


def _hydrate_triples(
matched: list,
triple_store: VectorStoreModel,
id: str = None,
) -> Graph:
matched_ids = matched.keys()
triple_ids = " ".join(f'("{id}")' for id in matched_ids)

graph_id = f"data-{id}" if id else "data"

query = f"""
PREFIX rdf: <http://www.w3.org/1999/02/22-rdf-syntax-ns#>
CONSTRUCT {{
?s ?p ?o .
}}
WHERE {{
GRAPH <{graph_id}> {{
VALUES (?uuid) {{ {triple_ids} }}
?triple_node <{TRIPLE_PROP.uuid}> ?uuid .
?triple_node rdf:subject ?s ;
rdf:predicate ?p ;
rdf:object ?o .
}}
}}
"""

result = triple_store.query(query=query, format="turtle")

g = Graph()
g.parse(data=result, format="turtle")

return g


def _get_formatted_node(node: URIRef | Literal | BNode) -> str:
if isinstance(node, URIRef):
return f"<{str(node)}>"
elif isinstance(node, Literal):
return f'"{str(node)}"'
elif isinstance(node, BNode):
return f"_:{str(node)}"
else:
return f'"{str(node)}"'


def _find_adjacent_triples(
triples: Graph,
triple_store: VectorStoreModel,
id: str = None,
depth: int = 1,
) -> str:
nodes_set = set()

for s, p, o in triples:
nodes_set.add(_get_formatted_node(s))
nodes_set.add(_get_formatted_node(o))

explored_nodes = set(nodes_set)
new_nodes_set = nodes_set.copy()

query = None

for _ in range(depth):
if not new_nodes_set:
break

node_list = ", ".join(new_nodes_set)
graph_id = f"data-{id}" if id else "data"

query = f"""
CONSTRUCT {{
?s ?p ?o .
}}
WHERE {{
GRAPH <{graph_id}> {{
?s ?p ?o .
FILTER (?s IN ({node_list}) || ?o IN ({node_list}))
}}
}}
"""

logger.debug(f"Find adjacent triples SPARQL query\n{query}\n")

try:
result_triples = triple_store.query(query=query, format="turtle")
except Exception as e:
raise ValueError(f"SPARQL Query Error: {e}")

if result_triples is None:
raise ValueError("SPARQL query returned no results")

graph = Graph()
graph.parse(data=result_triples, format="turtle")

temp_new_nodes_set = set()
for s, p, o in graph:
formatted_subject = _get_formatted_node(s)
formatted_object = _get_formatted_node(o)

if formatted_subject not in explored_nodes:
temp_new_nodes_set.add(formatted_subject)
explored_nodes.add(formatted_subject)

if formatted_object not in explored_nodes:
temp_new_nodes_set.add(formatted_object)
explored_nodes.add(formatted_object)

new_nodes_set = temp_new_nodes_set

if query is None:
return ""

return triple_store.query(query=query, format="turtle")


def _find_all(triple_store: TripleStoreModel, id: str) -> str:
result = triple_store.query(
query=f"""
CONSTRUCT {{
?s ?p ?o .
}} WHERE {{
GRAPH <data-{id}> {{
?s ?p ?o .
FILTER NOT EXISTS {{ ?s <{TRIPLE_PROP.uuid}> ?uuid }}
}}
}}
""",
format="turtle",
)

if isinstance(result, bytes):
result = result.decode("utf-8")

if not result:
return ""

return str(result)


def get_contextual_memory(
data: Graph,
vector_store: VectorStoreModel,
Expand All @@ -165,25 +25,15 @@ def get_contextual_memory(
matched = vector_store.search(message=context, id=id)
logger.debug(f"Matched Triples Raw\n{matched}\n")

matched_graph = _hydrate_triples(
memory = triple_store.get_context(
matched=matched,
triple_store=triple_store,
id=id,
)
matched_triples = matched_graph.serialize(format="turtle")
logger.debug(f"Matched Triples\n{matched_triples}\n")

memory = _find_adjacent_triples(
triples=matched_graph,
triple_store=triple_store,
id=id,
graph_id=id,
depth=1,
)
logger.debug(f"Adjacent Triples\n{memory}\n")
except ValueError as e:
logger.debug(f"Recall Exception\n{e}\n")
else:
memory = _find_all(triple_store=triple_store, id=id)
memory = triple_store.get_all(graph_id=id)

logger.debug(f"Contextual Memory\n{memory}\n")
return memory
Expand Down
149 changes: 147 additions & 2 deletions memonto/stores/triple/jena.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from memonto.stores.triple.base_store import TripleStoreModel
from memonto.utils.logger import logger
from memonto.utils.namespaces import TRIPLE_PROP
from memonto.utils.rdf import format_node


class ApacheJena(TripleStoreModel):
Expand Down Expand Up @@ -53,6 +54,43 @@ def _get_prefixes(self, g: Graph) -> list[str]:
gt = g.serialize(format="turtle")
return [line for line in gt.splitlines() if line.startswith("@prefix")]

def _hydrate_triples(
self,
matched: list,
graph_id: str = None,
) -> Graph:
matched_ids = matched.keys()
triple_ids = " ".join(f'("{id}")' for id in matched_ids)
g_id = f"data-{graph_id}" if graph_id else "data"

query = f"""
PREFIX rdf: <http://www.w3.org/1999/02/22-rdf-syntax-ns#>
CONSTRUCT {{
?s ?p ?o .
}}
WHERE {{
GRAPH <{g_id}> {{
VALUES (?uuid) {{ {triple_ids} }}
?triple_node <{TRIPLE_PROP.uuid}> ?uuid .
?triple_node rdf:subject ?s ;
rdf:predicate ?p ;
rdf:object ?o .
}}
}}
"""

result = self._query(
url=f"{self.connection_url}/sparql",
method=GET,
query=query,
)

g = Graph()
g.parse(data=result, format="turtle")

return g

def _load(
self,
g: Graph,
Expand Down Expand Up @@ -162,8 +200,115 @@ def get(

return result["results"]["bindings"]

def delete(self, id: str = None) -> None:
query = f"""DROP GRAPH <ontology-{id}> ; DROP GRAPH <data-{id}> ;"""
def get_all(self, graph_id: str = None) -> str:
g_id = f"data-{graph_id}" if graph_id else "data"

query=f"""
CONSTRUCT {{
?s ?p ?o .
}} WHERE {{
GRAPH <{g_id}> {{
?s ?p ?o .
FILTER NOT EXISTS {{ ?s <{TRIPLE_PROP.uuid}> ?uuid }}
}}
}}
"""

result = self._query(
url=f"{self.connection_url}/sparql",
method=GET,
query=query,
)

if isinstance(result, bytes):
result = result.decode("utf-8")

if not result:
return ""

return str(result)

def get_context(self, matched: dict[str, dict], graph_id: str, depth: int = 1) -> str:
g_id = f"data-{graph_id}" if graph_id else "data"
nodes_set = set()

matched_graph = self._hydrate_triples(
matched=matched,
graph_id=graph_id,
)
logger.debug(f"Matched Triples\n{matched_graph.serialize(format='turtle')}\n")

for s, p, o in matched_graph:
nodes_set.add(format_node(s))
nodes_set.add(format_node(o))

explored_nodes = set(nodes_set)
new_nodes_set = nodes_set.copy()

query = None

for _ in range(depth):
if not new_nodes_set:
break

node_list = ", ".join(new_nodes_set)

query = f"""
CONSTRUCT {{
?s ?p ?o .
}}
WHERE {{
GRAPH <{g_id}> {{
?s ?p ?o .
FILTER (
(?s IN ({node_list}) || ?o IN ({node_list})) &&
NOT EXISTS {{ ?s <{TRIPLE_PROP.uuid}> ?uuid }}
)
}}
}}
"""

logger.debug(f"Find adjacent triples SPARQL query\n{query}\n")

try:
result_triples = self.query(query=query, format="turtle")
except Exception as e:
raise ValueError(f"SPARQL Query Error: {e}")

if result_triples is None:
raise ValueError("SPARQL query returned no results")

graph = Graph()
graph.parse(data=result_triples, format="turtle")

temp_new_nodes_set = set()
for s, p, o in graph:
formatted_subject = format_node(s)
formatted_object = format_node(o)

if formatted_subject not in explored_nodes:
temp_new_nodes_set.add(formatted_subject)
explored_nodes.add(formatted_subject)

if formatted_object not in explored_nodes:
temp_new_nodes_set.add(formatted_object)
explored_nodes.add(formatted_object)

new_nodes_set = temp_new_nodes_set

if query is None:
return ""

result = self.query(query=query, format="turtle")
logger.debug(f"Adjacent Triples\n{result}\n")

return result

def delete_all(self, graph_id: str = None) -> None:
d_id = f"data-{graph_id}" if graph_id else "data"
o_id = f"ontology-{graph_id}" if graph_id else "ontology"

query = f"""DROP GRAPH <{o_id}> ; DROP GRAPH <{d_id}> ;"""

self._query(
url=f"{self.connection_url}/update",
Expand Down
2 changes: 0 additions & 2 deletions memonto/stores/vector/chroma.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,8 +65,6 @@ def save(self, g: Graph, ns: dict[str, Namespace], id: str = None) -> None:
_p = to_human_readable(str(p), ns)
_o = to_human_readable(str(o), ns)

print(_s, _p, _o)

id = ""
for bnode in g.subjects(RDF.subject, s):
if (bnode, RDF.predicate, p) in g and (bnode, RDF.object, o) in g:
Expand Down
13 changes: 12 additions & 1 deletion memonto/utils/rdf.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import re
import uuid
from collections import defaultdict
from rdflib import Graph, Literal, BNode, Namespace
from rdflib import Graph, Literal, BNode, Namespace, URIRef
from rdflib.namespace import RDF, RDFS, OWL
from typing import Union

Expand All @@ -31,6 +31,17 @@ def to_human_readable(c: str, ns: dict[str, Namespace]) -> str:
return c


def format_node(node: URIRef | Literal | BNode) -> str:
if isinstance(node, URIRef):
return f"<{str(node)}>"
elif isinstance(node, Literal):
return f'"{str(node)}"'
elif isinstance(node, BNode):
return f"_:{str(node)}"
else:
return f'"{str(node)}"'


def serialize_graph_without_ids(g: Graph, format: str = "turtle") -> Graph:
graph = Graph()

Expand Down

0 comments on commit 9677149

Please sign in to comment.