Skip to content

Commit

Permalink
Merge pull request #32 from singnet/senna-20-2
Browse files Browse the repository at this point in the history
[das-query-engine #20]  Implement a query method in DAS that expect a MeTTa expression rather than a PM query
  • Loading branch information
andre-senna authored Nov 14, 2023
2 parents 4c8e97c + f059aa3 commit d826f60
Show file tree
Hide file tree
Showing 7 changed files with 632 additions and 39 deletions.
160 changes: 129 additions & 31 deletions hyperon_das/api.py
Original file line number Diff line number Diff line change
@@ -1,21 +1,24 @@
import json
from typing import Any, Dict, List, Optional, Tuple, Union
from itertools import product
from typing import Any, Dict, List, Optional, Set, Tuple, Union

from hyperon_das_atomdb import WILDCARD

from hyperon_das.exceptions import (
DatabaseTypeException,
InitializeServerException,
MethodNotAllowed,
UnexpectedQueryFormat,
QueryParametersException,
)
from hyperon_das.factory import DatabaseFactory, DatabaseType, database_factory
from hyperon_das.logger import logger
from hyperon_das.cache import LazyQueryEvaluator, ListIterator, QueryAnswerIterator
from hyperon_das.pattern_matcher import (
LogicalExpression,
PatternMatchingAnswer,
)
from hyperon_das.utils import QueryOutputFormat, QueryParameters
from hyperon_das.utils import Assignment, QueryAnswer, QueryOutputFormat, QueryParameters


class DistributedAtomSpace:
Expand All @@ -30,10 +33,10 @@ def __init__(
try:
DatabaseType(database)
except ValueError as e:
raise DatabaseTypeException(
self._error(DatabaseTypeException(
message=str(e),
details=f'possible values {DatabaseType.values()}',
)
))

if database == DatabaseType.SERVER.value and not host:
raise InitializeServerException(
Expand Down Expand Up @@ -104,6 +107,40 @@ def _turn_into_deep_representation(self, assignments) -> list:
results.append(result)
return results

def _error(self, exception: Exception):
logger().error(str(exception))
raise exception

def _recursive_query(
self,
query: Dict[str, Any],
mappings: Set[Assignment] = None,
extra_parameters: Optional[Dict[str, Any]] = None,
) -> QueryAnswerIterator:
if query["atom_type"] == "node":
atom_handle = self.db.get_node_handle(query["type"], query["name"])
return ListIterator([QueryAnswer(self.db.get_atom_as_dict(atom_handle), None)])
elif query["atom_type"] == "link":
matched_targets = []
for target in query["targets"]:
if target["atom_type"] == "node" or target["atom_type"] == "link":
matched = self._recursive_query(target, mappings, extra_parameters)
if matched:
matched_targets.append(matched)
elif target["atom_type"] == "variable":
matched_targets.append(ListIterator([QueryAnswer(target, None)]))
else:
self._error(UnexpectedQueryFormat(
message="Query processing reached an unexpected state",
details=f'link: {str(query)} link target: {str(query)}',
))
return LazyQueryEvaluator(query["type"], matched_targets, self, extra_parameters)
else:
self._error(UnexpectedQueryFormat(
message="Query processing reached an unexpected state",
details=f'query: {str(query)}',
))

def clear_database(self) -> None:
"""Clear all data"""
return self.db.clear_database()
Expand Down Expand Up @@ -166,7 +203,7 @@ def get_atom(
answer = self.db.get_atom_as_deep_representation(handle)
return json.dumps(answer, sort_keys=False, indent=4)
else:
raise ValueError(f"Invalid output format: '{output_format}'")
self._error(ValueError(f"Invalid output format: '{output_format}'"))

def get_node(
self,
Expand Down Expand Up @@ -231,7 +268,7 @@ def get_node(
answer = self.db.get_atom_as_deep_representation(node_handle)
return json.dumps(answer, sort_keys=False, indent=4)
else:
raise ValueError(f"Invalid output format: '{output_format}'")
self._error(ValueError(f"Invalid output format: '{output_format}'"))

def get_nodes(
self,
Expand Down Expand Up @@ -298,7 +335,7 @@ def get_nodes(
]
return json.dumps(answer, sort_keys=False, indent=4)
else:
raise ValueError(f"Invalid output format: '{output_format}'")
self._error(ValueError(f"Invalid output format: '{output_format}'"))

def get_link(
self,
Expand Down Expand Up @@ -347,7 +384,7 @@ def get_link(
try:
link_handle = self.db.get_link_handle(link_type, targets)
except Exception as e:
raise e
self._error(e)

if output_format == QueryOutputFormat.HANDLE or link_handle is None:
return link_handle
Expand All @@ -359,7 +396,7 @@ def get_link(
)
return json.dumps(answer, sort_keys=False, indent=4)
else:
raise ValueError(f"Invalid output format: '{output_format}'")
self._error(ValueError(f"Invalid output format: '{output_format}'"))

def get_links(
self,
Expand Down Expand Up @@ -431,7 +468,7 @@ def get_links(
db_answer = self.db.get_matched_type(link_type)
else:
# TODO: Improve this message error. What is invalid?
raise ValueError("Invalid parameters")
self._error(ValueError("Invalid parameters"))

if output_format == QueryOutputFormat.HANDLE:
return self._to_handle_list(db_answer)
Expand All @@ -440,7 +477,7 @@ def get_links(
elif output_format == QueryOutputFormat.JSON:
return self._to_json(db_answer)
else:
raise ValueError(f"Invalid output format: '{output_format}'")
self.error(ValueError(f"Invalid output format: '{output_format}'"))

def get_link_type(self, link_handle: str) -> str:
"""
Expand All @@ -467,10 +504,7 @@ def get_link_type(self, link_handle: str) -> str:
return resp
# TODO: Find out what specific exceptions might happen
except Exception as e:
logger().warning(
f"An error occurred during the query. Detail:'{str(e)}'"
)
raise e
self._error(e)

def get_link_targets(self, link_handle: str) -> List[str]:
"""
Expand Down Expand Up @@ -500,10 +534,7 @@ def get_link_targets(self, link_handle: str) -> List[str]:
return resp
# TODO: Find out what specific exceptions might happen
except Exception as e:
logger().warning(
f"An error occurred during the query. Detail:'{str(e)}'"
)
raise e
self._error(e)

def get_node_type(self, node_handle: str) -> str:
"""
Expand All @@ -528,10 +559,7 @@ def get_node_type(self, node_handle: str) -> str:
return resp
# TODO: Find out what specific exceptions might happen
except Exception as e:
logger().warning(
f"An error occurred during the query. Detail:'{str(e)}'"
)
raise e
self._error(e)

def get_node_name(self, node_handle: str) -> str:
"""
Expand All @@ -556,12 +584,82 @@ def get_node_name(self, node_handle: str) -> str:
return resp
# TODO: Find out what specific exceptions might happen
except Exception as e:
logger().warning(
f"An error occurred during the query. Detail:'{str(e)}'"
)
raise e
self._error(e)

def query(
self,
query: Dict[str, Any],
extra_parameters: Optional[Dict[str, Any]] = None,
) -> List[Dict[str, Any]]:
"""
Perform a query on the knowledge base using a dict as input. Returns a
list of dicts as result.
The input dict is a link, used as a pattern to make the query.
Variables can be used as link targets as well as nodes. Nested links are
allowed as well.
Args:
query (Dict[str, Any]): A pattern described as a link (possibly with nested links) with
nodes and variables used to query the knoeledge base.
extra_paramaters (Dict[str, Any], optional): query optional parameters
Returns:
List[Dict[str, Any]]: a list of dicts with the matching subgraphs
Raises:
UnexpectedQueryFormat: If query resolution lead to an invalid state
Notes:
- No logical connectors (AND, OR, NOT) are allowed
- If no match is found for the query, an empty list is returned.
Example:
>>> hash_table_api.add_link({
"type": "Expression",
"targets": [
{"type": "Symbol", "name": "Test"},
{
"type": "Expression",
"targets": [
{"type": "Symbol", "name": "Test"},
{"type": "Symbol", "name": "2"}
]
}
]
})
>>> query_params = {
"toplevel_only": False,
"return_type": QueryOutputFormat.ATOM_INFO,
}
>>> q1 = {
"atom_type": "link",
"type": "Expression",
"targets": [
{"atom_type": "variable", "name": "v1"},
{
"atom_type": "link",
"type": "Expression",
"targets": [
{"atom_type": "variable", "name": "v2"},
{"atom_type": "node", "type": "Symbol", "name": "2"},
]
}
]
}
>>> result = hash_table_api.query(q1, query_params)
>>> print(result)
[{'handle': 'dbcf1c7b610a5adea335bf08f6509978', 'type': 'Expression', 'template': ['Expression', 'Symbol', ['Expression', 'Symbol', 'Symbol']], 'targets': [{'handle': '963d66edfb77236054125e3eb866c8b5', 'type': 'Symbol', 'name': 'Test'}, {'handle': '233d9a6da7d49d4164d863569e9ab7b6', 'type': 'Expression', 'template': ['Expression', 'Symbol', 'Symbol'], 'targets': [{'handle': '963d66edfb77236054125e3eb866c8b5', 'type': 'Symbol', 'name': 'Test'}, {'handle': '9f27a331633c8bc3c49435ffabb9110e', 'type': 'Symbol', 'name': '2'}]}]}]
"""
query_results = self._recursive_query(query, extra_parameters)
logger().debug(f"query: {query} result: {str(query_results)}")
answer = []
for result in query_results:
answer.append(result.grounded_atom)
return answer

def pattern_matcher_query(
self,
query: LogicalExpression,
extra_parameters: Optional[Dict[str, Any]] = None,
Expand Down Expand Up @@ -668,16 +766,16 @@ def add_node(self, node_params: Dict[str, Any]) -> Dict[str, Any]:
if self._db_type == DatabaseType.RAM_ONLY.value:
return self.db.add_node(node_params)
else:
raise MethodNotAllowed(
self._error(MethodNotAllowed(
message='This method is permited only in memory database',
details='Instantiate the class sent the database type as `ram_only`',
)
))

def add_link(self, link_params: Dict[str, Any]) -> Dict[str, Any]:
if self._db_type == DatabaseType.RAM_ONLY.value:
return self.db.add_link(link_params)
else:
raise MethodNotAllowed(
self._error(MethodNotAllowed(
message='This method is permited only in memory database',
details='Instantiate the class sent the database type as `ram_only`',
)
))
Loading

0 comments on commit d826f60

Please sign in to comment.