Skip to content

Commit

Permalink
Solves #25
Browse files Browse the repository at this point in the history
  • Loading branch information
Norman Nabhan committed Mar 9, 2023
1 parent 20afd48 commit fc62f9a
Show file tree
Hide file tree
Showing 2 changed files with 131 additions and 11 deletions.
74 changes: 63 additions & 11 deletions cag/framework/component.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,14 @@

from pyArango.theExceptions import DocumentNotFoundError, SimpleQueryError

from cag.graph_elements.nodes import GenericOOSNode
from cag.graph_elements.relations import GenericEdge
from cag.utils.config import Config, configuration

from cag.graph_elements.base_graph import *
from pyArango.collection import Document, Collection
from pyArango.collection import Document, Collection, Collection_metaclass
import re
from typing import Any, Optional
from typing import Any, Optional, Union

from cag import logger

Expand Down Expand Up @@ -39,7 +41,6 @@ def __init__(self, conf: Config = None):
self.graph_name = conf.graph
if self.database.hasGraph(self.graph_name):
self.graph: BaseGraph = self.database.graphs[self.graph_name]

else:
edge_def_arr = []
for ed in edges:
Expand All @@ -48,14 +49,24 @@ def __init__(self, conf: Config = None):
+ ed["from_collections"]
+ ed["to_collections"]
):
if not self.database.hasCollection(col):
self.database.createCollection(col)

if not self.database.hasCollection(
self.__get_collection_name(col)
):
self.database.createCollection(
self.__get_collection_name(col)
)
edge_def_arr.append(
EdgeDefinition(
ed["relation"],
fromCollections=ed["from_collections"],
toCollections=ed["to_collections"],
self.__get_collection_name(ed["relation"]),
fromCollections=[
self.__get_collection_name(col)
for col in ed["from_collections"]
],
toCollections=[
self.__get_collection_name(col)
for col in ed["to_collections"]
],
)
)
if len(edge_def_arr) == 0:
Expand All @@ -75,12 +86,53 @@ def __init__(self, conf: Config = None):
# Setup graph structure
for ed in edges:
self.graph.update_graph_structure(
ed["relation"],
ed["from_collections"],
ed["to_collections"],
self.__get_collection_name(ed["relation"]),
[
self.__get_collection_name(col)
for col in ed["from_collections"]
],
[
self.__get_collection_name(col)
for col in ed["to_collections"]
],
create_collections=True,
)

def __get_collection_name(
self, collection: Union[str, Collection_metaclass]
) -> str:
"""
Returns the name of a collection based on the input collection. If the collection is a string,
it returns the same string. If the collection is an instance of Collection_metaclass, it tries
to return the '_name' attribute of the class. If '_name' is not available, it returns the class
name. Raises ValueError if the input collection is not a string or an instance of
Collection_metaclass.
Args:
collection (Union[str, Collection_metaclass]): The input collection, which can be a string
or an instance of Collection_metaclass.
Returns:
str: The name of the collection.
Raises:
ValueError: If the input collection is not a string or an instance of Collection_metaclass.
"""
if isinstance(collection, str):
# Backward compatibility, when strings are used in edge definition
return collection
if isinstance(collection, Collection_metaclass):
# When a class of GenericOOSNode gets passed, we take the _name if possible
if hasattr(collection, "_name"):
return collection._name # noqa
else:
# Otherwise just take the name of the class
return collection.__name__
raise ValueError(
f"{collection} is of incompatible type {type(collection)}"
f"Make sure it's a str, GenericOOSNode or GenericEdge!"
)

def get_document(
self,
collectionName: str,
Expand Down
68 changes: 68 additions & 0 deletions tests/test_graph_creator/test_improve _edge_definitions.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
# Tests for https://github.com/DLR-SC/corpus-annotation-graph-builder/issues/25

import os

import pytest

from cag.framework import GraphCreatorBase
from cag.graph_elements.nodes import GenericOOSNode, Field
from cag.graph_elements.relations import GenericEdge
from cag.utils.config import Config
from tests.test_graph_creator import config_factory


class CollectionA(GenericOOSNode):
_name = "CollectionA"
_fields = {"value": Field(), "value2": Field(), **GenericOOSNode._fields}


class CollectionB(GenericOOSNode):
_name = "CollectionB"
_fields = {"value": Field(), **GenericOOSNode._fields}


class CollectionC(GenericOOSNode):
_name = "CollectionC"
_fields = {"value": Field(), **GenericOOSNode._fields}


class HasRelation(GenericEdge):
_fields = GenericEdge._fields


class HasAnotherRelation(GenericEdge):
_fields = GenericEdge._fields


class SampleGraphCreator(GraphCreatorBase):
_name = "SampleGraphCreator"
_description = "Graph based on the DLR elib corpus"

_edge_definitions = [
{
"relation": HasRelation,
"from_collections": [CollectionA],
"to_collections": [CollectionB],
},
{
"relation": "HasAnotherRelation",
"from_collections": [CollectionC],
"to_collections": [CollectionC],
},
]

def init_graph(self):
pass


class TestGC25:
def test_arango_connection(self):
config = config_factory()
assert config.arango_db.name == config.database

def test_create_collection(self):
config = config_factory()
SampleGraphCreator("", config_factory())
assert config.arango_db.has_collection("CollectionA")
assert config.arango_db.has_collection("CollectionB")
assert config.arango_db.has_collection("HasRelation")

0 comments on commit fc62f9a

Please sign in to comment.