From ab7f8a2e4581935091a1c2a7a26643aaacfdfcbb Mon Sep 17 00:00:00 2001 From: Christian Zimpelmann Date: Thu, 6 Oct 2022 09:43:23 +0200 Subject: [PATCH] Add option to get the DAG. (#9) --- .gitignore | 3 + .pre-commit-config.yaml | 16 ++--- CHANGES.rst | 3 +- setup.cfg | 4 -- src/dags/dag.py | 156 ++++++++++++++++++++++++++++++++++------ tests/test_dag.py | 35 +++++++++ 6 files changed, 183 insertions(+), 34 deletions(-) diff --git a/.gitignore b/.gitignore index db3addb..ed0ea34 100644 --- a/.gitignore +++ b/.gitignore @@ -81,6 +81,9 @@ target/ profile_default/ ipython_config.py +# VS Code +.vscode + # pyenv .python-version diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index d9dca30..cb7cd16 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -6,7 +6,7 @@ repos: - id: debug-statements - id: end-of-file-fixer - repo: https://github.com/asottile/reorder_python_imports - rev: v3.1.0 + rev: v3.8.3 hooks: - id: reorder-python-imports types: [python] @@ -45,12 +45,12 @@ repos: additional_dependencies: [black==22.3.0] types: [rst] - repo: https://github.com/psf/black - rev: 22.3.0 + rev: 22.8.0 hooks: - id: black types: [python] - repo: https://github.com/PyCQA/flake8 - rev: 4.0.1 + rev: 5.0.4 hooks: - id: flake8 types: [python] @@ -71,7 +71,7 @@ repos: Pygments, ] - repo: https://github.com/PyCQA/doc8 - rev: 0.11.2 + rev: v1.0.0 hooks: - id: doc8 - repo: meta @@ -86,11 +86,11 @@ repos: args: [--no-build-isolation] additional_dependencies: [setuptools-scm, toml] - repo: https://github.com/PyCQA/doc8 - rev: 0.11.2 + rev: v1.0.0 hooks: - id: doc8 - repo: https://github.com/asottile/setup-cfg-fmt - rev: v1.20.1 + rev: v2.0.0 hooks: - id: setup-cfg-fmt - repo: https://github.com/econchick/interrogate @@ -100,11 +100,11 @@ repos: args: [-v, --fail-under=20] exclude: ^(tests|docs|setup\.py) - repo: https://github.com/codespell-project/codespell - rev: v2.1.0 + rev: v2.2.1 hooks: - id: codespell - repo: https://github.com/asottile/pyupgrade - rev: v2.34.0 + rev: v2.38.2 hooks: - id: pyupgrade args: [--py37-plus] diff --git a/CHANGES.rst b/CHANGES.rst index 5f4595d..ff6240b 100644 --- a/CHANGES.rst +++ b/CHANGES.rst @@ -14,7 +14,8 @@ releases are available on `Anaconda.org - :gh:`7` improves the examples in the test cases. - :gh:`10` turns ``targets`` into an optional argument. All variables in the DAG are returned by default. - +- :gh:`9` Add function to return the DAG. Check for cycles in DAG. + (:ghuser:`ChristianZimpelmann`) 0.2.1 - 2022-03-29 ------------------ diff --git a/setup.cfg b/setup.cfg index 47c3fe6..e4530d8 100644 --- a/setup.cfg +++ b/setup.cfg @@ -17,10 +17,6 @@ classifiers = Operating System :: POSIX Programming Language :: Python :: 3 Programming Language :: Python :: 3 :: Only - Programming Language :: Python :: 3.7 - Programming Language :: Python :: 3.8 - Programming Language :: Python :: 3.9 - Programming Language :: Python :: 3.10 Topic :: Utilities [options] diff --git a/src/dags/dag.py b/src/dags/dag.py index 8c77eed..29f1b3d 100644 --- a/src/dags/dag.py +++ b/src/dags/dag.py @@ -24,15 +24,16 @@ def concatenate_functions( Functions that are not required to produce the targets will simply be ignored. - The arguments of the combined function are all arguments of relevant functions - that are not themselves function names, in alphabetical order. + The arguments of the combined function are all arguments of relevant functions that + are not themselves function names, in alphabetical order. Args: functions (dict or list): Dict or list of functions. If a list, the function - name is inferred from the __name__ attribute of the entries. If a dict, - the name of the function is set to the dictionary key. - targets (str | None): Name of the function that produces the target or list of - such function names. If the value is `None`, all variables are returned. + name is inferred from the __name__ attribute of the entries. If a dict, the + name of the function is set to the dictionary key. + targets (str or list or None): Name of the function that produces the target or + list of such function names. If the value is `None`, all variables are + returned. return_type (str): One of "tuple", "list", "dict". This is ignored if the targets are a single string or if an aggregator is provided. aggregator (callable or None): Binary reduction function that is used to @@ -45,19 +46,99 @@ def concatenate_functions( function: A function that produces targets when called with suitable arguments. """ - _functions = _harmonize_functions(functions) - _targets = _harmonize_targets(targets, list(_functions)) - _fail_if_targets_have_wrong_types(_targets) - _fail_if_functions_are_missing(_functions, _targets) + # Create the DAG. + dag = create_dag(functions, targets) + + # Build combined function. + out = _create_combined_function_from_dag( + dag, functions, targets, return_type, aggregator, enforce_signature + ) + + return out + + +def create_dag(functions, targets): + """Build a directed acyclic graph (DAG) from functions. + + Functions can depend on the output of other functions as inputs, as long as the + dependencies can be described by a directed acyclic graph (DAG). + + Functions that are not required to produce the targets will simply be ignored. + + Args: + functions (dict or list): Dict or list of functions. If a list, the function + name is inferred from the __name__ attribute of the entries. If a dict, the + name of the function is set to the dictionary key. + targets (str or list or None): Name of the function that produces the target or + list of such function names. If the value is `None`, all variables are + returned. + + Returns: + dag: the DAG (as networkx.DiGraph object) + + """ + # Harmonize and check arguments. + _functions, _targets = _harmonize_and_check_functions_and_targets( + functions, targets + ) + + # Create the DAG _raw_dag = _create_complete_dag(_functions) - _dag = _limit_dag_to_targets_and_their_ancestors(_raw_dag, _targets) - _arglist = _create_arguments_of_concatenated_function(_functions, _dag) - _exec_info = _create_execution_info(_functions, _dag) + dag = _limit_dag_to_targets_and_their_ancestors(_raw_dag, _targets) + + # Check if there are cycles in the DAG + _fail_if_dag_contains_cycle(dag) + + return dag + + +def _create_combined_function_from_dag( + dag, + functions, + targets, + return_type="tuple", + aggregator=None, + enforce_signature=True, +): + """Create combined function which allows to execute a complete directed acyclic + graph (DAG) in one function call. + + The arguments of the combined function are all arguments of relevant functions that + are not themselves function names, in alphabetical order. + + Args: + dag (networkx.DiGraph): a DAG of functions + functions (dict or list): Dict or list of functions. If a list, the function + name is inferred from the __name__ attribute of the entries. If a dict, the + name of the function is set to the dictionary key. + targets (str or list or None): Name of the function that produces the target or + list of such function names. If the value is `None`, all variables are + returned. + return_type (str): One of "tuple", "list", "dict". This is ignored if the + targets are a single string or if an aggregator is provided. + aggregator (callable or None): Binary reduction function that is used to + aggregate the targets into a single target. + enforce_signature (bool): If True, the signature of the concatenated function + is enforced. Otherwise it is only provided for introspection purposes. + Enforcing the signature has a small runtime overhead. + + Returns: + function: A function that produces targets when called with suitable arguments. + + """ + # Harmonize and check arguments. + _functions, _targets = _harmonize_and_check_functions_and_targets( + functions, targets + ) + + _arglist = _create_arguments_of_concatenated_function(_functions, dag) + _exec_info = _create_execution_info(_functions, dag) _concatenated = _create_concatenated_function( _exec_info, _arglist, _targets, enforce_signature ) + # Return function in specified format. if isinstance(targets, str) or (aggregator is not None and len(_targets) == 1): out = single_output(_concatenated) elif aggregator is not None: @@ -70,7 +151,7 @@ def concatenate_functions( out = dict_output(_concatenated, keys=_targets) else: raise ValueError( - f"Invalid return type {return_type}. Must be 'list', 'tuple', or 'dict'. " + f"Invalid return type {return_type}. Must be 'list', 'tuple', or 'dict'. " f"You provided {return_type}." ) @@ -91,13 +172,14 @@ def get_ancestors(functions, targets, include_targets=False): set: The ancestors """ - _functions = _harmonize_functions(functions) - _targets = _harmonize_targets(targets, list(_functions)) - _fail_if_targets_have_wrong_types(_targets) - _fail_if_functions_are_missing(_functions, _targets) - raw_dag = _create_complete_dag(_functions) - dag = _limit_dag_to_targets_and_their_ancestors(raw_dag, _targets) + # Harmonize and check arguments. + _functions, _targets = _harmonize_and_check_functions_and_targets( + functions, targets + ) + + # Create the DAG. + dag = create_dag(functions, targets) ancestors = set() for target in _targets: @@ -107,6 +189,29 @@ def get_ancestors(functions, targets, include_targets=False): return ancestors +def _harmonize_and_check_functions_and_targets(functions, targets): + """Harmonize the type of specified functions and targets and do some checks. + + Args: + functions (dict or list): Dict or list of functions. If a list, the function + name is inferred from the __name__ attribute of the entries. If a dict, the + name of the function is set to the dictionary key. + targets (str or list): Name of the function that produces the target or list of + such function names. + + Returns: + functions_harmonized: harmonized functions + targets_harmonized: harmonized targets + + """ + functions_harmonized = _harmonize_functions(functions) + targets_harmonized = _harmonize_targets(targets, list(functions_harmonized)) + _fail_if_targets_have_wrong_types(targets_harmonized) + _fail_if_functions_are_missing(functions_harmonized, targets_harmonized) + + return functions_harmonized, targets_harmonized + + def _harmonize_functions(functions): if isinstance(functions, (list, tuple)): functions = {func.__name__: func for func in functions} @@ -141,6 +246,15 @@ def _fail_if_functions_are_missing(functions, targets): return functions, targets +def _fail_if_dag_contains_cycle(dag): + """Check for cycles in DAG""" + cycles = list(nx.simple_cycles(dag)) + + if len(cycles) > 0: + formatted = _format_list_linewise(cycles) + raise ValueError(f"The DAG contains one or more cycles:\n{formatted}") + + def _create_complete_dag(functions): """Create the complete DAG. @@ -275,7 +389,7 @@ def concatenated(*args, **kwargs): def _format_list_linewise(list_): - formatted_list = '",\n "'.join(list_) + formatted_list = '",\n "'.join([str(c) for c in list_]) return textwrap.dedent( """ [ diff --git a/tests/test_dag.py b/tests/test_dag.py index b1c4c93..fcd25f2 100644 --- a/tests/test_dag.py +++ b/tests/test_dag.py @@ -3,6 +3,7 @@ import pytest from dags.dag import concatenate_functions +from dags.dag import create_dag from dags.dag import get_ancestors @@ -22,6 +23,14 @@ def _unrelated(working_hours): # noqa: U100 raise NotImplementedError() +def _leisure_cycle(working_hours, _utility): + return 24 - working_hours + _utility + + +def _consumption_cycle(working_hours, wage, _utility): + return wage * working_hours + _utility + + def _complete_utility(wage, working_hours, leisure_weight): """The function that we try to generate dynamically.""" leis = _leisure(working_hours) @@ -157,3 +166,29 @@ def g(f, d): assert list(inspect.signature(concatenated).parameters) == ["c", "d"] assert concatenated(3, 4) == 10 + + +@pytest.mark.parametrize( + "funcs", + [ + { + "_utility": _utility, + "_leisure": _leisure, + "_consumption": _consumption_cycle, + }, + { + "_utility": _utility, + "_leisure": _leisure_cycle, + "_consumption": _consumption_cycle, + }, + ], +) +def test_fail_if_cycle_in_dag(funcs): + with pytest.raises( + ValueError, + match="The DAG contains one or more cycles:", + ): + create_dag( + functions=funcs, + targets=["_utility"], + )