Skip to content

Commit

Permalink
add finer CLI control of concatenation method and args passed to xarray
Browse files Browse the repository at this point in the history
  • Loading branch information
danielfromearth committed Nov 8, 2023
1 parent aa51cf0 commit 5ee77ca
Show file tree
Hide file tree
Showing 3 changed files with 126 additions and 45 deletions.
1 change: 0 additions & 1 deletion concatenator/group_handling.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,7 +174,6 @@ def regroup_flattened_dataset(
output_file : str
Name of the output file to write the resulting NetCDF file to.
"""

with nc.Dataset(output_file, mode="w", format="NETCDF4") as base_dataset:
# Copy global attributes
base_dataset.setncatts(dataset.attrs)
Expand Down
145 changes: 108 additions & 37 deletions concatenator/run_stitchee.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from concatenator.stitchee import stitchee


def parse_args(args: list) -> tuple[list[str], str, str, bool, str | None]:
def parse_args(args: list) -> tuple[list[str], str, str, bool, str | None, str, dict]:
"""
Parse args for this script.
Expand Down Expand Up @@ -40,12 +40,6 @@ def parse_args(args: list) -> tuple[list[str], str, str, bool, str | None]:
required=True,
help="The output filename for the merged output.",
)
req_grp.add_argument(
"--concat_dim",
nargs=1,
required=True,
help="Dimension to concatenate along, if possible.",
)

# Optional arguments
parser.add_argument(
Expand All @@ -62,6 +56,29 @@ def parse_args(args: list) -> tuple[list[str], str, str, bool, str | None]:
"(1) the flattened concatenated file and "
"(2) the input directory copy if created by '--make_dir_copy'.",
)
parser.add_argument(
"--concat_method",
choices=["xarray-concat", "xarray-combine"],
default="xarray-concat",
help="Whether to use the xarray concat method or the combine-by-coords method.",
)
parser.add_argument(
"--concat_dim",
help="Dimension to concatenate along, if possible. "
"This is required if using the 'xarray-concat' method",
)
parser.add_argument(
"--concat_arg_compat",
help="'compat' argument passed to xarray.concat.",
)
parser.add_argument(
"--concat_arg_combine_attrs",
help="'combine_attrs' argument passed to xarray.concat.",
)
parser.add_argument(
"--concat_arg_join",
help="'join' argument passed to xarray.concat.",
)
parser.add_argument(
"-O", "--overwrite", action="store_true", help="Overwrite output file if it already exists."
)
Expand All @@ -77,6 +94,71 @@ def parse_args(args: list) -> tuple[list[str], str, str, bool, str | None]:
if parsed.verbose:
logging.basicConfig(level=logging.DEBUG)

# Validate the input and output paths
output_path = _validate_output_path(parsed)
input_files = _validate_input_path(parsed)

print(f"CONCAT METHOD === {parsed.concat_method}")
print(f"CONCAT DIM === {parsed.concat_dim}")
if parsed.concat_method == "xarray-concat":
if not parsed.concat_dim:
raise ValueError(
"If using the xarray-concat method, then 'concat_dim' must be specified."
)
elif parsed.concat_method == "xarray-combine":
if parsed.concat_dim:
raise ValueError(
"If using the xarray-combine method, then 'concat_dim' cannot be specified."
)

# Gather the concatenation arguments that will be passed to xarray.
concat_kwargs = {}
if parsed.concat_arg_compat:
concat_kwargs["compat"] = parsed.concat_arg_compat
if parsed.concat_arg_combine_attrs:
concat_kwargs["combine_attrs"] = parsed.concat_arg_combine_attrs
if parsed.concat_arg_join:
concat_kwargs["join"] = parsed.concat_arg_join

# If requested, make a temporary directory with new copies of the original input files
temporary_dir_to_remove = None
if not parsed.no_input_file_copies:
input_files, temporary_dir_to_remove = _make_temp_dir_with_input_file_copies(
input_files, output_path
)

return (
input_files,
str(output_path),
parsed.concat_dim,
bool(parsed.keep_tmp_files),
temporary_dir_to_remove,
parsed.concat_method,
concat_kwargs,
)


def _make_temp_dir_with_input_file_copies(input_files, output_path):
new_data_dir = Path(
add_label_to_path(str(output_path.parent / "temp_copy"), label=str(uuid.uuid4()))
).resolve()
os.makedirs(new_data_dir, exist_ok=True)
print("Created temporary directory: %s", str(new_data_dir))

new_input_files = []
for file in input_files:
new_path = new_data_dir / Path(file).name
shutil.copyfile(file, new_path)
new_input_files.append(str(new_path))

input_files = new_input_files
print("Copied files to temporary directory: %s", new_data_dir)
temporary_dir_to_remove = str(new_data_dir)

return input_files, temporary_dir_to_remove


def _validate_output_path(parsed):
# The output file path is validated.
output_path = Path(parsed.output_path).resolve()
if output_path.is_file(): # the file already exists
Expand All @@ -86,7 +168,12 @@ def parse_args(args: list) -> tuple[list[str], str, str, bool, str | None]:
raise FileExistsError(
f"File already exists at <{output_path}>. Run again with option '-O' to overwrite."
)
if output_path.is_dir(): # the specified path is an existing directory
raise TypeError("Output path cannot be a directory. Please specify a new filepath.")
return output_path


def _validate_input_path(parsed):
# The input directory or file is validated.
print(f"parsed_input === {parsed.input}")
if len(parsed.input) > 1:
Expand All @@ -99,38 +186,12 @@ def parse_args(args: list) -> tuple[list[str], str, str, bool, str | None]:
input_files = _get_list_of_filepaths_from_file(directory_or_path)
else:
raise TypeError(
"if one path is provided for 'data_dir_or_file_or_filepaths', "
"If one path is provided for 'data_dir_or_file_or_filepaths', "
"then it must be an existing directory or file."
)
else:
raise TypeError("input argument must be one path/directory or a list of paths.")

# If requested, make a temporary directory with copies of the original input files
temporary_dir_to_remove = None
if not parsed.no_input_file_copies:
new_data_dir = Path(
add_label_to_path(str(output_path.parent / "temp_copy"), label=str(uuid.uuid4()))
).resolve()
os.makedirs(new_data_dir, exist_ok=True)
print("Created temporary directory: %s", str(new_data_dir))

new_input_files = []
for file in input_files:
new_path = new_data_dir / Path(file).name
shutil.copyfile(file, new_path)
new_input_files.append(str(new_path))

input_files = new_input_files
print("Copied files to temporary directory: %s", new_data_dir)
temporary_dir_to_remove = str(new_data_dir)

return (
input_files,
str(output_path),
parsed.concat_dim[0],
bool(parsed.keep_tmp_files),
temporary_dir_to_remove,
)
return input_files


def _get_list_of_filepaths_from_file(file_with_paths: Path):
Expand All @@ -144,7 +205,7 @@ def _get_list_of_filepaths_from_file(file_with_paths: Path):


def _get_list_of_filepaths_from_dir(data_dir: Path):
# Get list of files (ignoring hidden files) in directory.
# Get a list of files (ignoring hidden files) in directory.
input_files = [str(f) for f in data_dir.iterdir() if not f.name.startswith(".")]
return input_files

Expand All @@ -153,7 +214,15 @@ def run_stitchee(args: list) -> None:
"""
Parse arguments and run subsetter on the specified input file
"""
input_files, output_path, concat_dim, keep_tmp_files, temporary_dir_to_remove = parse_args(args)
(
input_files,
output_path,
concat_dim,
keep_tmp_files,
temporary_dir_to_remove,
concat_method,
concat_kwargs,
) = parse_args(args)
num_inputs = len(input_files)

logging.info("Executing stitchee concatenation on %d files...", num_inputs)
Expand All @@ -162,7 +231,9 @@ def run_stitchee(args: list) -> None:
output_path,
write_tmp_flat_concatenated=keep_tmp_files,
keep_tmp_files=keep_tmp_files,
concat_method=concat_method,
concat_dim=concat_dim,
concat_kwargs=concat_kwargs,
)
logging.info("STITCHEE complete. Result in %s", output_path)

Expand Down
25 changes: 18 additions & 7 deletions concatenator/stitchee.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ def stitchee(
output_file: str,
write_tmp_flat_concatenated: bool = False,
keep_tmp_files: bool = True,
concat_method: str = "xarray-concat",
concat_dim: str = "",
concat_kwargs: dict | None = None,
logger: Logger = default_logger,
Expand Down Expand Up @@ -97,13 +98,23 @@ def stitchee(
if concat_kwargs is None:
concat_kwargs = {}

combined_ds = xr.concat(
xrdataset_list,
dim=GROUP_DELIM + concat_dim,
data_vars="minimal",
coords="minimal",
**concat_kwargs,
)
if concat_method == "xarray-concat":
combined_ds = xr.concat(
xrdataset_list,
dim=GROUP_DELIM + concat_dim,
data_vars="minimal",
coords="minimal",
**concat_kwargs,
)
elif concat_method == "xarray-combine":
combined_ds = xr.combine_by_coords(
xrdataset_list,
data_vars="minimal",
coords="minimal",
**concat_kwargs,
)
else:
raise ValueError("Unexpected concatenation method, <%s>." % concat_method)

benchmark_log["concatenating"] = time.time() - start_time

Expand Down

0 comments on commit 5ee77ca

Please sign in to comment.