diff --git a/src/dolphin/_cli_filter.py b/src/dolphin/_cli_filter.py new file mode 100644 index 00000000..0ee51e51 --- /dev/null +++ b/src/dolphin/_cli_filter.py @@ -0,0 +1,97 @@ +import argparse +from pathlib import Path +from typing import TYPE_CHECKING, Any + +from dolphin.filtering import filter_rasters + +if TYPE_CHECKING: + _SubparserType = argparse._SubParsersAction[argparse.ArgumentParser] +else: + _SubparserType = Any + + +def get_parser(subparser=None, subcommand_name="unwrap") -> argparse.ArgumentParser: + """Set up the command line interface.""" + metadata = { + "description": ( + "Filter unwrapped interferograms using a long-wavelength filter." + ), + "formatter_class": argparse.ArgumentDefaultsHelpFormatter, + # https://docs.python.org/3/library/argparse.html#fromfile-prefix-chars + "fromfile_prefix_chars": "@", + } + if subparser: + # Used by the subparser to make a nested command line interface + parser = subparser.add_parser(subcommand_name, **metadata) + else: + parser = argparse.ArgumentParser(**metadata) # type: ignore[arg-type] + + # parser._action_groups.pop() + parser.add_argument( + "-o", + "--output-dir", + type=Path, + help=( + "Path to output directory to store results. None stores in same location as" + " inputs" + ), + ) + # Get Inputs from the command line + inputs = parser.add_argument_group("Input options") + inputs.add_argument( + "--unw-filenames", + nargs=argparse.ONE_OR_MORE, + type=Path, + help=( + "List the paths of unwrapped files to filter. Can pass a newline delimited" + " file with @ifg_filelist.txt" + ), + ) + inputs.add_argument( + "--temporal-coherence-filename", + type=Path, + help="Optionally, list the path of the temporal coherence to mask.", + ) + inputs.add_argument( + "--cor-filenames", + nargs=argparse.ZERO_OR_MORE, + help="Optionally, list the paths of the correlation files to use for masking", + ) + inputs.add_argument( + "--conncomp-filenames", + nargs=argparse.ZERO_OR_MORE, + help="Optionally, list the paths of the connected component labels for masking", + ) + parser.add_argument( + "--wavelength-cutoff", + type=float, + default=50_000, + help="Spatial wavelength_cutoff (in meters) of filter to use.", + ) + + parser.add_argument( + "--max-workers", + type=int, + default=1, + help="Number of parallel files to filter.", + ) + + parser.set_defaults(run_func=_run_filter) + + return parser + + +def _run_filter(*args, **kwargs): + """Run `dolphin.filtering.filter_long_wavelength`.""" + return filter_rasters(*args, **kwargs) + + +def main(args=None): + """Get the command line arguments and filter files.""" + parser = get_parser() + parsed_args = parser.parse_args(args) + return filter_rasters(**vars(parsed_args)) + + +if __name__ == "__main__": + main() diff --git a/src/dolphin/cli.py b/src/dolphin/cli.py index 1772c669..36245a54 100644 --- a/src/dolphin/cli.py +++ b/src/dolphin/cli.py @@ -1,6 +1,7 @@ import argparse import sys +import dolphin._cli_filter import dolphin._cli_timeseries import dolphin._cli_unwrap import dolphin.workflows._cli_config @@ -22,6 +23,7 @@ def main(args=None): dolphin.workflows._cli_config.get_parser(subparser, "config") dolphin._cli_unwrap.get_parser(subparser, "unwrap") dolphin._cli_timeseries.get_parser(subparser, "timeseries") + dolphin._cli_filter.get_parser(subparser, "filter") parsed_args = parser.parse_args(args=args) arg_dict = vars(parsed_args) diff --git a/src/dolphin/filtering.py b/src/dolphin/filtering.py index c8d5d3a2..f8b8f601 100644 --- a/src/dolphin/filtering.py +++ b/src/dolphin/filtering.py @@ -1,5 +1,10 @@ +import multiprocessing as mp +from concurrent.futures import ProcessPoolExecutor +from itertools import repeat +from pathlib import Path + import numpy as np -from numpy.typing import ArrayLike +from numpy.typing import ArrayLike, NDArray from scipy import fft, ndimage @@ -36,9 +41,15 @@ def filter_long_wavelength( filtered interferogram that does not contain signals with spatial wavelength longer than a threshold. + Raises + ------ + ValueError + If wavelength_cutoff too large for image size/pixel spacing. + """ good_pixel_mask = ~bad_pixel_mask + rows, cols = unwrapped_phase.shape unw0 = np.nan_to_num(unwrapped_phase) # Take either nan or 0 pixels in `unwrapped_phase` to be nodata nodata_mask = unw0 == 0 @@ -53,6 +64,10 @@ def filter_long_wavelength( # Find the filter `sigma` which gives the correct cutoff in meters sigma = _compute_filter_sigma(wavelength_cutoff, pixel_spacing, cutoff_value=0.5) + if sigma > unw0.shape[0] or sigma > unw0.shape[0]: + msg = f"{wavelength_cutoff = } too large for image." + msg += f"Shape = {(rows, cols)}, and {pixel_spacing = }" + raise ValueError(msg) # Pad the array with edge values # The padding extends further than the default "radius = 2*sigma + 1", # which given specified in `gaussian_filter` @@ -124,3 +139,123 @@ def fit_ramp_plane(unw_ifg: ArrayLike, mask: ArrayLike) -> np.ndarray: plane = np.reshape(X_ @ theta, (nrow, ncol)) return plane + + +def filter_rasters( + unw_filenames: list[Path], + cor_filenames: list[Path] | None = None, + conncomp_filenames: list[Path] | None = None, + temporal_coherence_filename: Path | None = None, + wavelength_cutoff: float = 50_000, + correlation_cutoff: float = 0.5, + output_dir: Path | None = None, + max_workers: int = 4, +) -> list[Path]: + """Filter a list of unwrapped interferogram files using a long-wavelength filter. + + Remove long-wavelength components from each unwrapped interferogram. + It can optionally use temporal coherence, correlation, and connected component + information for masking. + + Parameters + ---------- + unw_filenames : list[Path] + List of paths to unwrapped interferogram files to be filtered. + cor_filenames : list[Path] | None + List of paths to correlation files + Passing None skips filtering on correlation. + conncomp_filenames : list[Path] | None + List of paths to connected component files, filters any 0 labelled pixels. + Passing None skips filtering on connected component labels. + temporal_coherence_filename : Path | None + Path to the temporal coherence file for masking. + Passing None skips filtering on temporal coherence. + wavelength_cutoff : float, optional + Spatial wavelength cutoff (in meters) for the filter. Default is 50,000 meters. + correlation_cutoff : float, optional + Threshold of correlation (if passing `cor_filenames`) to use to ignore pixels + during filtering. + output_dir : Path | None, optional + Directory to save the filtered results. + If None, saves in the same location as inputs with .filt.tif extension. + max_workers : int, optional + Number of parallel images to process. Default is 4. + + Returns + ------- + list[Path] + Output filtered rasters. + + Notes + ----- + - If temporal_coherence_filename is provided, pixels with coherence < 0.5 are masked + + """ + from dolphin import io + + bad_pixel_mask = np.zeros( + io.get_raster_xysize(unw_filenames[0])[::-1], dtype="bool" + ) + if temporal_coherence_filename: + bad_pixel_mask = bad_pixel_mask | ( + io.load_gdal(temporal_coherence_filename) < 0.5 + ) + + if output_dir is None: + assert unw_filenames + output_dir = unw_filenames[0].parent + output_dir.mkdir(exist_ok=True) + ctx = mp.get_context("spawn") + + with ProcessPoolExecutor(max_workers, mp_context=ctx) as pool: + return list( + pool.map( + _filter_and_save, + unw_filenames, + cor_filenames or repeat(None), + conncomp_filenames or repeat(None), + repeat(output_dir), + repeat(wavelength_cutoff), + repeat(bad_pixel_mask), + repeat(correlation_cutoff), + ) + ) + + +def _filter_and_save( + unw_filename: Path, + cor_path: Path | None, + conncomp_path: Path | None, + output_dir: Path, + wavelength_cutoff: float, + bad_pixel_mask: NDArray[np.bool_], + correlation_cutoff: float = 0.5, +) -> Path: + """Filter one interferogram (wrapper for multiprocessing).""" + from dolphin import io + from dolphin._overviews import Resampling, create_image_overviews + + # Average for the pixel spacing for filtering + _, x_res, _, _, _, y_res = io.get_raster_gt(unw_filename) + pixel_spacing = (abs(x_res) + abs(y_res)) / 2 + + if cor_path is not None: + bad_pixel_mask |= io.load_gdal(cor_path) < correlation_cutoff + if conncomp_path is not None: + bad_pixel_mask |= io.load_gdal(conncomp_path, masked=True).astype(bool) == 0 + + unw = io.load_gdal(unw_filename) + filt_arr = filter_long_wavelength( + unwrapped_phase=unw, + wavelength_cutoff=wavelength_cutoff, + bad_pixel_mask=bad_pixel_mask, + pixel_spacing=pixel_spacing, + workers=1, + ) + io.round_mantissa(filt_arr, keep_bits=9) + output_name = output_dir / Path(unw_filename).with_suffix(".filt.tif").name + io.write_arr(arr=filt_arr, like_filename=unw_filename, output_name=output_name) + + create_image_overviews(output_name, resampling=Resampling.AVERAGE) + + return output_name diff --git a/tests/test_cli.py b/tests/test_cli.py index 26025e37..0c1facab 100644 --- a/tests/test_cli.py +++ b/tests/test_cli.py @@ -11,17 +11,17 @@ def test_help(capsys, option): with contextlib.suppress(SystemExit): main([option]) output = capsys.readouterr().out - assert " dolphin [-h] [--version] {run,config,unwrap,timeseries}" in output + assert " dolphin [-h] [--version] {run,config,unwrap,timeseries,filter}" in output def test_empty(capsys): with contextlib.suppress(SystemExit): main([]) output = capsys.readouterr().out - assert " dolphin [-h] [--version] {run,config,unwrap,timeseries}" in output + assert " dolphin [-h] [--version] {run,config,unwrap,timeseries,filter}" in output -@pytest.mark.parametrize("sub_cmd", ["run", "config", "unwrap", "timeseries"]) +@pytest.mark.parametrize("sub_cmd", ["run", "config", "filter", "unwrap", "timeseries"]) @pytest.mark.parametrize("option", ["-h", "--help"]) def test_subcommand_help(capsys, sub_cmd, option): with contextlib.suppress(SystemExit): diff --git a/tests/test_filtering.py b/tests/test_filtering.py index 09d15a61..34b693fc 100644 --- a/tests/test_filtering.py +++ b/tests/test_filtering.py @@ -1,9 +1,12 @@ +from pathlib import Path + import numpy as np +import pytest -from dolphin import filtering +from dolphin import filtering, io -def test_filter_long_wavelegnth(): +def test_filter_long_wavelength(): # Check filtering with ramp phase y, x = np.ogrid[-3:3:512j, -3:3:512j] unw_ifg = np.pi * (x + y) @@ -12,10 +15,51 @@ def test_filter_long_wavelegnth(): # Filtering filtered_ifg = filtering.filter_long_wavelength( - unw_ifg, bad_pixel_mask=bad_pixel_mask, pixel_spacing=300 + unw_ifg, bad_pixel_mask=bad_pixel_mask, pixel_spacing=1000 ) np.testing.assert_allclose( filtered_ifg[10:-10, 10:-10], np.zeros(filtered_ifg[10:-10, 10:-10].shape), atol=1.0, ) + + +def test_filter_long_wavelength_too_large_cutoff(): + # Check filtering with ramp phase + y, x = np.ogrid[-3:3:512j, -3:3:512j] + unw_ifg = np.pi * (x + y) + bad_pixel_mask = np.zeros(unw_ifg.shape, dtype=bool) + + with pytest.raises(ValueError): + filtering.filter_long_wavelength( + unw_ifg, + bad_pixel_mask=bad_pixel_mask, + pixel_spacing=1, + wavelength_cutoff=50_000, + ) + + +@pytest.fixture() +def unw_files(tmp_path): + """Make series of files offset in lat/lon.""" + shape = (3, 9, 9) + + y, x = np.ogrid[-3:3:512j, -3:3:512j] + file_list = [] + for i in range(shape[0]): + unw_arr = (i + 1) * np.pi * (x + y) + fname = tmp_path / f"unw_{i}.tif" + io.write_arr(arr=unw_arr, output_name=fname) + file_list.append(Path(fname)) + + return file_list + + +def test_filter(tmp_path, unw_files): + output_dir = Path(tmp_path) / "filtered" + filtering.filter_rasters( + unw_filenames=unw_files, + output_dir=output_dir, + max_workers=1, + wavelength_cutoff=50, + )