Skip to content

Commit

Permalink
Merge pull request #3130 from AKSoo/patch-1
Browse files Browse the repository at this point in the history
FIX: DataSink to S3 buckets
  • Loading branch information
effigies authored Jan 18, 2020
2 parents e5664a2 + 847553e commit 3468d9f
Showing 1 changed file with 7 additions and 21 deletions.
28 changes: 7 additions & 21 deletions nipype/interfaces/io.py
Original file line number Diff line number Diff line change
Expand Up @@ -212,12 +212,12 @@ class DataSinkInputSpec(DynamicTraitedSpec, BaseInterfaceInputSpec):
"""

# Init inputspec data attributes
base_directory = Directory(desc="Path to the base directory for storing data.")
base_directory = Str(desc="Path to the base directory for storing data.")
container = Str(desc="Folder within base directory in which to store output")
parameterization = traits.Bool(
True, usedefault=True, desc="store output in parametrized structure"
)
strip_dir = Directory(desc="path to strip out of filename")
strip_dir = Str(desc="path to strip out of filename")
substitutions = InputMultiPath(
traits.Tuple(Str, Str),
desc=(
Expand Down Expand Up @@ -440,7 +440,6 @@ def _check_s3_base_dir(self):
is not a valid S3 path, defaults to '<N/A>'
"""

# Init variables
s3_str = "s3://"
bucket_name = "<N/A>"
base_directory = self.inputs.base_directory
Expand All @@ -449,22 +448,10 @@ def _check_s3_base_dir(self):
s3_flag = False
return s3_flag, bucket_name

# Explicitly lower-case the "s3"
if base_directory.lower().startswith(s3_str):
base_dir_sp = base_directory.split("/")
base_dir_sp[0] = base_dir_sp[0].lower()
base_directory = "/".join(base_dir_sp)

# Check if 's3://' in base dir
if base_directory.startswith(s3_str):
# Expects bucket name to be 's3://bucket_name/base_dir/..'
bucket_name = base_directory.split(s3_str)[1].split("/")[0]
s3_flag = True
# Otherwise it's just a normal datasink
else:
s3_flag = False
s3_flag = base_directory.lower().startswith(s3_str)
if s3_flag:
bucket_name = base_directory[len(s3_str):].partition('/')[0]

# Return s3_flag
return s3_flag, bucket_name

# Function to return AWS secure environment variables
Expand Down Expand Up @@ -618,13 +605,12 @@ def _upload_to_s3(self, bucket, src, dst):

from botocore.exceptions import ClientError

# Init variables
s3_str = "s3://"
s3_prefix = s3_str + bucket.name

# Explicitly lower-case the "s3"
if dst[: len(s3_str)].lower() == s3_str:
dst = s3_str + dst[len(s3_str) :]
if dst.lower().startswith(s3_str):
dst = s3_str + dst[len(s3_str):]

# If src is a directory, collect files (this assumes dst is a dir too)
if os.path.isdir(src):
Expand Down

0 comments on commit 3468d9f

Please sign in to comment.