Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[#153] Refactor the constructor of the TraverseEngine and the request method in the FunctionsClient #157

Merged
merged 6 commits into from
Feb 28, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion CHANGELOG
Original file line number Diff line number Diff line change
@@ -1 +1 @@
[#154] Refactor remote tests to use a single declaration of remote host/port
[#153] Refactor the constructor of the TraverseEngine and the request method in the FunctionsClient
32 changes: 26 additions & 6 deletions hyperon_das/client.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
import contextlib
import json
from typing import Any, Dict, List, Optional, Tuple, Union

import requests
from hyperon_das_atomdb import AtomDoesNotExist, LinkDoesNotExist, NodeDoesNotExist
from requests import exceptions, sessions

from hyperon_das.logger import logger

Expand All @@ -15,13 +16,32 @@ def __init__(self, url: str, server_count: int = 0, name: Optional[str] = None):

def _send_request(self, payload) -> Any:
try:
response = requests.request('POST', url=self.url, data=json.dumps(payload))
with sessions.Session() as session:
response = session.request(method='POST', url=self.url, data=json.dumps(payload))

response.raise_for_status()

try:
response_data = response.json()
except exceptions.JSONDecodeError as e:
raise Exception(f"JSON decode error: {str(e)}")

if response.status_code == 200:
return response.json()
return response_data
else:
return response.json()['error']
except requests.exceptions.RequestException as e:
raise e
return response_data.get(
'error', f'Unknown error with status code {response.status_code}'
)
except exceptions.ConnectionError as e:
raise Exception(f"Connection error: {str(e)}")
except exceptions.Timeout as e:
raise Exception(f"Request timed out: {str(e)}")
except exceptions.HTTPError as e:
with contextlib.suppress(exceptions.JSONDecodeError):
return response.json().get('error')
raise Exception(f"HTTP error occurred: {str(e)}")
except exceptions.RequestException as e:
raise Exception(f"Request exception occurred: {str(e)}")

def get_atom(self, handle: str, **kwargs) -> Union[str, Dict]:
payload = {
Expand Down
4 changes: 1 addition & 3 deletions hyperon_das/das.py
Original file line number Diff line number Diff line change
Expand Up @@ -496,8 +496,6 @@ def get_traversal_cursor(self, handle: str, **kwargs) -> TraverseEngine:
TraverseEngine: The object that allows traversal of the hypergraph
"""
try:
self.get_atom(handle)
return TraverseEngine(handle, das=self, **kwargs)
except AtomDoesNotExist:
raise GetTraversalCursorException(message="Cannot start Traversal. Atom does not exist")

return TraverseEngine(handle, das=self, **kwargs)
15 changes: 8 additions & 7 deletions hyperon_das/query_engines.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,9 @@
from abc import ABC, abstractmethod
from typing import Any, Dict, Iterator, List, Optional, Set, Tuple, Union

import requests
from hyperon_das_atomdb import WILDCARD
from hyperon_das_atomdb.exceptions import AtomDoesNotExist, LinkDoesNotExist, NodeDoesNotExist
from requests import sessions

from hyperon_das.cache import (
AndEvaluator,
Expand Down Expand Up @@ -279,12 +279,13 @@ def _connect_server(self, host: str, port: Optional[str] = None):
def _is_server_connect(self, url: str) -> bool:
logger().debug(f'connecting to remote Das {url}')
try:
response = requests.request(
'POST',
url=url,
data=json.dumps({"action": "ping", "input": {}}),
timeout=10,
)
with sessions.Session() as session:
response = session.request(
method='POST',
url=url,
data=json.dumps({"action": "ping", "input": {}}),
timeout=10,
)
except Exception:
return False
if response.status_code == 200:
Expand Down
10 changes: 8 additions & 2 deletions hyperon_das/traverse_engines.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,16 @@
class TraverseEngine:
def __init__(self, handle: str, **kwargs) -> None:
self.das: DistributedAtomSpace = kwargs['das']
self._cursor = self.das.get_atom(handle)

try:
atom = self.das.get_atom(handle)
except AtomDoesNotExist as e:
raise e

self._cursor = atom

def get(self) -> Dict[str, Any]:
return self.das.get_atom(self._cursor['handle'])
return self._cursor

def get_links(self, **kwargs) -> QueryAnswerIterator:
incoming_links = self.das.get_incoming_links(
Expand Down
12 changes: 6 additions & 6 deletions tests/unit/test_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
class TestFunctionsClient:
@pytest.fixture
def mock_request(self):
with patch('requests.request') as mock_request:
with patch('requests.sessions.Session.request') as mock_request:
yield mock_request

def test_get_atom_success(self, mock_request):
Expand All @@ -26,7 +26,7 @@ def test_get_atom_success(self, mock_request):
result = client.get_atom(handle='123')

mock_request.assert_called_with(
'POST',
method='POST',
url='http://example.com',
data='{"action": "get_atom", "input": {"handle": "123"}}',
)
Expand All @@ -48,7 +48,7 @@ def test_get_node_success(self, mock_request):
result = client.get_node(node_type='Concept', node_name='human')

mock_request.assert_called_with(
'POST',
method='POST',
url='http://example.com',
data='{"action": "get_node", "input": {"node_type": "Concept", "node_name": "human"}}',
)
Expand Down Expand Up @@ -80,7 +80,7 @@ def test_get_link_success(self, mock_request):
)

mock_request.assert_called_with(
'POST',
method='POST',
url='http://example.com',
data='{"action": "get_link", "input": {"link_type": "Similarity", "link_targets": ["af12f10f9ae2002a1607ba0b47ba8407", "1cdffc6b0b89ff41d68bec237481d1e1"]}}',
)
Expand All @@ -107,7 +107,7 @@ def test_get_links_success(self, mock_request):
)

mock_request.assert_called_with(
'POST',
method='POST',
url='http://example.com',
data='{"action": "get_links", "input": {"link_type": "Inheritance", "kwargs": {}, "link_targets": ["4e8e26e3276af8a5c2ac2cc2dc95c6d2", "80aff30094874e75028033a38ce677bb"]}}',
)
Expand Down Expand Up @@ -162,7 +162,7 @@ def test_count_atoms_success(self, mock_request):
result = client.count_atoms()

mock_request.assert_called_once_with(
'POST', url='http://example.com', data='{"action": "count_atoms", "input": {}}'
method='POST', url='http://example.com', data='{"action": "count_atoms", "input": {}}'
)

assert result == tuple(expected_response)
Loading