Skip to content

Commit

Permalink
Make targets optional and return all variables if targets=None. (#10
Browse files Browse the repository at this point in the history
)
  • Loading branch information
tobiasraabe authored Jun 30, 2022
1 parent 96d9dce commit 8794c21
Show file tree
Hide file tree
Showing 4 changed files with 36 additions and 15 deletions.
8 changes: 5 additions & 3 deletions CHANGES.rst
Original file line number Diff line number Diff line change
Expand Up @@ -7,11 +7,13 @@ releases are available on `Anaconda.org
<https://anaconda.org/OpenSourceEconomics/dags>`_.



0.2.2
-----
0.2.2 - 2022-xx-xx
------------------

- :gh:`5` Updates examples used in tests (:ghuser:`janosg`)
- :gh:`7` improves the examples in the test cases.
- :gh:`10` turns ``targets`` into an optional argument. All variables in the DAG are
returned by default.


0.2.1 - 2022-03-29
Expand Down
16 changes: 9 additions & 7 deletions src/dags/dag.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@

def concatenate_functions(
functions,
targets,
targets=None,
return_type="tuple",
aggregator=None,
enforce_signature=True,
Expand All @@ -31,8 +31,8 @@ def concatenate_functions(
functions (dict or list): Dict or list of functions. If a list, the function
name is inferred from the __name__ attribute of the entries. If a dict,
the name of the function is set to the dictionary key.
targets (str): Name of the function that produces the target or list of such
function names.
targets (str | None): Name of the function that produces the target or list of
such function names. If the value is `None`, all variables are returned.
return_type (str): One of "tuple", "list", "dict". This is ignored if the
targets are a single string or if an aggregator is provided.
aggregator (callable or None): Binary reduction function that is used to
Expand All @@ -45,8 +45,8 @@ def concatenate_functions(
function: A function that produces targets when called with suitable arguments.
"""
_targets = _harmonize_targets(targets)
_functions = _harmonize_functions(functions)
_targets = _harmonize_targets(targets, list(_functions))
_fail_if_targets_have_wrong_types(_targets)
_fail_if_functions_are_missing(_functions, _targets)

Expand Down Expand Up @@ -91,8 +91,8 @@ def get_ancestors(functions, targets, include_targets=False):
set: The ancestors
"""
_targets = _harmonize_targets(targets)
_functions = _harmonize_functions(functions)
_targets = _harmonize_targets(targets, list(_functions))
_fail_if_targets_have_wrong_types(_targets)
_fail_if_functions_are_missing(_functions, _targets)

Expand All @@ -113,8 +113,10 @@ def _harmonize_functions(functions):
return functions


def _harmonize_targets(targets):
if isinstance(targets, str):
def _harmonize_targets(targets, function_names):
if targets is None:
targets = function_names
elif isinstance(targets, str):
targets = [targets]
return targets

Expand Down
21 changes: 21 additions & 0 deletions tests/test_dag.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,27 @@ def _complete_utility(wage, working_hours, leisure_weight):
return util


def test_concatenate_functions_no_target():
concatenated = concatenate_functions(functions=[_utility, _leisure, _consumption])

calculated_result = concatenated(wage=5, working_hours=8, leisure_weight=2)

expected_utility = _complete_utility(wage=5, working_hours=8, leisure_weight=2)
expected_leisure = _leisure(working_hours=8)
expected_consumption = _consumption(working_hours=8, wage=5)

assert calculated_result == (
expected_utility,
expected_leisure,
expected_consumption,
)

calculated_args = set(inspect.signature(concatenated).parameters)
expected_args = {"leisure_weight", "wage", "working_hours"}

assert calculated_args == expected_args


def test_concatenate_functions_single_target():
concatenated = concatenate_functions(
functions=[_utility, _unrelated, _leisure, _consumption],
Expand Down
6 changes: 1 addition & 5 deletions tox.ini
Original file line number Diff line number Diff line change
@@ -1,14 +1,10 @@
[tox]
envlist = pytest, sphinx
skipsdist = True
skip_missing_interpreters = True

[testenv]
basepython = python
usedevelop = true

[testenv:pytest]
setenv =
CONDA_DLL_SEARCH_MODIFICATION_ENABLE = 1
conda_channels =
conda-forge
nodefaults
Expand Down

0 comments on commit 8794c21

Please sign in to comment.