Skip to content

Commit

Permalink
fix tests
Browse files Browse the repository at this point in the history
  • Loading branch information
shihanwan committed Oct 16, 2024
1 parent 9677149 commit 9be03b5
Show file tree
Hide file tree
Showing 2 changed files with 28 additions and 13 deletions.
11 changes: 11 additions & 0 deletions memonto/stores/triple/base_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,17 @@ def get(self):
"""
pass

@abstractmethod
def get_all(self, graph_id: str = None) -> str:
"""
Get all memory data from the datastore.
:param graph_id: The id of the graph to get all memory data from.
:return: A string representation of the memory data.
"""
pass

@abstractmethod
def query(self):
"""
Expand Down
30 changes: 17 additions & 13 deletions tests/core/test_recall.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,15 @@
import pytest
from rdflib import Graph, Literal, URIRef
from unittest.mock import ANY, MagicMock, patch
from unittest.mock import MagicMock, patch

from memonto.core.recall import _recall
from memonto.memonto import Memonto
from memonto.stores.triple.jena import ApacheJena


@pytest.fixture
def jena():
return ApacheJena(connection_url="http://localhost:8080/test-dataset")


@pytest.fixture
Expand Down Expand Up @@ -62,16 +69,16 @@ def data_graph():
return g


@patch("memonto.core.recall._find_all")
def test_fetch_all_memory(mock_find_all, mock_llm, mock_store, id, data_graph):
@patch("memonto.stores.triple.jena.ApacheJena.get_all")
def test_fetch_all_memory(mock_get_all, jena, mock_llm, mock_store, id, data_graph):
all_memory = "all memory"
mock_find_all.return_value = all_memory
mock_get_all.return_value = all_memory

_recall(
data=data_graph,
llm=mock_llm,
vector_store=mock_store,
triple_store=mock_store,
triple_store=jena,
context=None,
id=id,
ephemeral=False,
Expand All @@ -84,26 +91,24 @@ def test_fetch_all_memory(mock_find_all, mock_llm, mock_store, id, data_graph):
)


@patch("memonto.core.recall._find_adjacent_triples")
@patch("memonto.core.recall._hydrate_triples")
@patch("memonto.stores.triple.jena.ApacheJena.get_context")
def test_fetch_some_memory(
mock_hydrate_triples,
mock_find_adjacent_triples,
mock_get_context,
jena,
mock_llm,
mock_store,
user_query,
id,
data_graph,
):
some_memory = "some memory"
mock_find_adjacent_triples.return_value = some_memory
mock_hydrate_triples.return_value = Graph()
mock_get_context.return_value = some_memory

_recall(
data=data_graph,
llm=mock_llm,
vector_store=mock_store,
triple_store=mock_store,
triple_store=jena,
context=user_query,
id=id,
ephemeral=False,
Expand All @@ -117,7 +122,6 @@ def test_fetch_some_memory(


def test_fetch_some_memory_ephemeral(mock_llm, data_graph):

mem = _recall(
data=data_graph,
llm=mock_llm,
Expand Down

0 comments on commit 9be03b5

Please sign in to comment.