From 267ac5bc83ee83ac22db37ed1d52d88804aa5f22 Mon Sep 17 00:00:00 2001 From: Erik Welch Date: Tue, 15 Aug 2023 02:03:32 -0500 Subject: [PATCH] Add test to check that API matches networkx (signature, etc.) --- .../cugraph_nx/tests/test_match_api.py | 40 +++++++++++++++++++ .../cugraph-nx/cugraph_nx/utils/decorators.py | 6 +-- 2 files changed, 41 insertions(+), 5 deletions(-) create mode 100644 python/cugraph-nx/cugraph_nx/tests/test_match_api.py diff --git a/python/cugraph-nx/cugraph_nx/tests/test_match_api.py b/python/cugraph-nx/cugraph_nx/tests/test_match_api.py new file mode 100644 index 00000000000..f2b88c7f137 --- /dev/null +++ b/python/cugraph-nx/cugraph_nx/tests/test_match_api.py @@ -0,0 +1,40 @@ +# Copyright (c) 2023, NVIDIA CORPORATION. +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import inspect + +import networkx as nx + +import cugraph_nx as cnx +from cugraph_nx.utils import networkx_algorithm + + +def test_match_signature_and_names(): + """Simple test to ensure our signatures and basic module layout match networkx.""" + for name, func in vars(cnx.interface.BackendInterface).items(): + if not isinstance(func, networkx_algorithm): + continue + dispatchable_func = nx.utils.backends._registered_algorithms[name] + orig_func = dispatchable_func.orig_func + # Matching signatures? + sig = inspect.signature(orig_func) + assert sig == inspect.signature(func) + # Matching function names? + assert func.__name__ == dispatchable_func.__name__ == orig_func.__name__ + # Matching dispatch names? + assert func.name == dispatchable_func.name + # Matching modules (i.e., where function defined)? + assert ( + "networkx." + func.__module__.split(".", 1)[1] + == dispatchable_func.__module__ + == orig_func.__module__ + ) diff --git a/python/cugraph-nx/cugraph_nx/utils/decorators.py b/python/cugraph-nx/cugraph_nx/utils/decorators.py index 65909e32485..7bda3e58b6b 100644 --- a/python/cugraph-nx/cugraph_nx/utils/decorators.py +++ b/python/cugraph-nx/cugraph_nx/utils/decorators.py @@ -10,7 +10,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -import inspect from functools import partial, update_wrapper from networkx.utils.decorators import not_implemented_for @@ -33,6 +32,7 @@ def __new__(cls, func=None, *, name=None): if func is None: return partial(networkx_algorithm, name=name) instance = object.__new__(cls) + # update_wrapper sets __wrapped__, which will be used for the signature update_wrapper(instance, func) instance.__defaults__ = func.__defaults__ instance.__kwdefaults__ = func.__kwdefaults__ @@ -41,10 +41,6 @@ def __new__(cls, func=None, *, name=None): setattr(BackendInterface, instance.name, instance) return instance - @property - def __signature__(self): - return inspect.signature(self.__wrapped__) - def _can_run(self, func): """Set the `can_run` attribute to the decorated function.""" self.can_run = func