Skip to content

Commit

Permalink
Add test to check that API matches networkx (signature, etc.)
Browse files Browse the repository at this point in the history
  • Loading branch information
eriknw committed Aug 15, 2023
1 parent 690b639 commit 267ac5b
Show file tree
Hide file tree
Showing 2 changed files with 41 additions and 5 deletions.
40 changes: 40 additions & 0 deletions python/cugraph-nx/cugraph_nx/tests/test_match_api.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
# Copyright (c) 2023, NVIDIA CORPORATION.
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import inspect

import networkx as nx

import cugraph_nx as cnx
from cugraph_nx.utils import networkx_algorithm


def test_match_signature_and_names():
"""Simple test to ensure our signatures and basic module layout match networkx."""
for name, func in vars(cnx.interface.BackendInterface).items():
if not isinstance(func, networkx_algorithm):
continue
dispatchable_func = nx.utils.backends._registered_algorithms[name]
orig_func = dispatchable_func.orig_func
# Matching signatures?
sig = inspect.signature(orig_func)
assert sig == inspect.signature(func)
# Matching function names?
assert func.__name__ == dispatchable_func.__name__ == orig_func.__name__
# Matching dispatch names?
assert func.name == dispatchable_func.name
# Matching modules (i.e., where function defined)?
assert (
"networkx." + func.__module__.split(".", 1)[1]
== dispatchable_func.__module__
== orig_func.__module__
)
6 changes: 1 addition & 5 deletions python/cugraph-nx/cugraph_nx/utils/decorators.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import inspect
from functools import partial, update_wrapper

from networkx.utils.decorators import not_implemented_for
Expand All @@ -33,6 +32,7 @@ def __new__(cls, func=None, *, name=None):
if func is None:
return partial(networkx_algorithm, name=name)
instance = object.__new__(cls)
# update_wrapper sets __wrapped__, which will be used for the signature
update_wrapper(instance, func)
instance.__defaults__ = func.__defaults__
instance.__kwdefaults__ = func.__kwdefaults__
Expand All @@ -41,10 +41,6 @@ def __new__(cls, func=None, *, name=None):
setattr(BackendInterface, instance.name, instance)
return instance

@property
def __signature__(self):
return inspect.signature(self.__wrapped__)

def _can_run(self, func):
"""Set the `can_run` attribute to the decorated function."""
self.can_run = func
Expand Down

0 comments on commit 267ac5b

Please sign in to comment.