Skip to content

Commit

Permalink
Merge pull request #2559 from satra/fix/config
Browse files Browse the repository at this point in the history
fix: propagate explicit workflow config to nodes in a workflow and allow nodes to overwrite
  • Loading branch information
effigies authored Apr 27, 2018
2 parents c9315ca + fea03b8 commit 9333ff8
Show file tree
Hide file tree
Showing 2 changed files with 45 additions and 0 deletions.
1 change: 1 addition & 0 deletions nipype/pipeline/engine/nodes.py
Original file line number Diff line number Diff line change
Expand Up @@ -183,6 +183,7 @@ def __init__(self,
self._hashed_inputs = None
self._needed_outputs = []
self.needed_outputs = needed_outputs
self.config = None

@property
def interface(self):
Expand Down
44 changes: 44 additions & 0 deletions nipype/pipeline/engine/tests/test_workflows.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
# vi: set ft=python sts=4 ts=4 sw=4 et:
"""Tests for the engine workflows module
"""
from glob import glob
import os
from shutil import rmtree
from itertools import product
Expand Down Expand Up @@ -229,3 +230,46 @@ def pick_first(l):
assert os.path.exists(
os.path.join(wf.base_dir, wf.name, n4.name,
'file1.txt')) is keep_inputs


def _test_function4():
raise FileNotFoundError('Generic error')


def test_config_setting(tmpdir):
tmpdir.chdir()
wf = pe.Workflow('config')
wf.base_dir = os.getcwd()

crashdir = os.path.join(os.getcwd(), 'crashdir')
os.mkdir(crashdir)
wf.config = {"execution": {"crashdump_dir": crashdir}}

n1 = pe.Node(niu.Function(function=_test_function4),
name='errorfunc')
wf.add_nodes([n1])
try:
wf.run()
except RuntimeError:
pass

fl = glob(os.path.join(crashdir, 'crash*'))
assert len(fl) == 1

# Now test node overwrite
crashdir2 = os.path.join(os.getcwd(), 'crashdir2')
os.mkdir(crashdir2)
crashdir3 = os.path.join(os.getcwd(), 'crashdir3')
os.mkdir(crashdir3)
wf.config = {"execution": {"crashdump_dir": crashdir3}}
n1.config = {"execution": {"crashdump_dir": crashdir2}}

try:
wf.run()
except RuntimeError:
pass

fl = glob(os.path.join(crashdir2, 'crash*'))
assert len(fl) == 1
fl = glob(os.path.join(crashdir3, 'crash*'))
assert len(fl) == 0

0 comments on commit 9333ff8

Please sign in to comment.