Skip to content

Commit

Permalink
Add dolphin filter cli to run a long wavelgnth filter on a set of r…
Browse files Browse the repository at this point in the history
…asters (#382)

* add a `dolphin filter` CLI for long wavelegnth filtering

* optionally pass through cor/conncomp, add overviews

* fix cli test

* move filtering with io functions into `filtering` away from cli

* add test

* avoid set ctx

* add check for too large cutoff
  • Loading branch information
scottstanie authored Aug 2, 2024
1 parent 54ffd54 commit 5cc6f4a
Show file tree
Hide file tree
Showing 5 changed files with 285 additions and 7 deletions.
97 changes: 97 additions & 0 deletions src/dolphin/_cli_filter.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,97 @@
import argparse
from pathlib import Path
from typing import TYPE_CHECKING, Any

from dolphin.filtering import filter_rasters

if TYPE_CHECKING:
_SubparserType = argparse._SubParsersAction[argparse.ArgumentParser]
else:
_SubparserType = Any


def get_parser(subparser=None, subcommand_name="unwrap") -> argparse.ArgumentParser:
"""Set up the command line interface."""
metadata = {
"description": (
"Filter unwrapped interferograms using a long-wavelength filter."
),
"formatter_class": argparse.ArgumentDefaultsHelpFormatter,
# https://docs.python.org/3/library/argparse.html#fromfile-prefix-chars
"fromfile_prefix_chars": "@",
}
if subparser:
# Used by the subparser to make a nested command line interface
parser = subparser.add_parser(subcommand_name, **metadata)
else:
parser = argparse.ArgumentParser(**metadata) # type: ignore[arg-type]

# parser._action_groups.pop()
parser.add_argument(
"-o",
"--output-dir",
type=Path,
help=(
"Path to output directory to store results. None stores in same location as"
" inputs"
),
)
# Get Inputs from the command line
inputs = parser.add_argument_group("Input options")
inputs.add_argument(
"--unw-filenames",
nargs=argparse.ONE_OR_MORE,
type=Path,
help=(
"List the paths of unwrapped files to filter. Can pass a newline delimited"
" file with @ifg_filelist.txt"
),
)
inputs.add_argument(
"--temporal-coherence-filename",
type=Path,
help="Optionally, list the path of the temporal coherence to mask.",
)
inputs.add_argument(
"--cor-filenames",
nargs=argparse.ZERO_OR_MORE,
help="Optionally, list the paths of the correlation files to use for masking",
)
inputs.add_argument(
"--conncomp-filenames",
nargs=argparse.ZERO_OR_MORE,
help="Optionally, list the paths of the connected component labels for masking",
)
parser.add_argument(
"--wavelength-cutoff",
type=float,
default=50_000,
help="Spatial wavelength_cutoff (in meters) of filter to use.",
)

parser.add_argument(
"--max-workers",
type=int,
default=1,
help="Number of parallel files to filter.",
)

parser.set_defaults(run_func=_run_filter)

return parser


def _run_filter(*args, **kwargs):
"""Run `dolphin.filtering.filter_long_wavelength`."""
return filter_rasters(*args, **kwargs)


def main(args=None):
"""Get the command line arguments and filter files."""
parser = get_parser()
parsed_args = parser.parse_args(args)
return filter_rasters(**vars(parsed_args))


if __name__ == "__main__":
main()
2 changes: 2 additions & 0 deletions src/dolphin/cli.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import argparse
import sys

import dolphin._cli_filter
import dolphin._cli_timeseries
import dolphin._cli_unwrap
import dolphin.workflows._cli_config
Expand All @@ -22,6 +23,7 @@ def main(args=None):
dolphin.workflows._cli_config.get_parser(subparser, "config")
dolphin._cli_unwrap.get_parser(subparser, "unwrap")
dolphin._cli_timeseries.get_parser(subparser, "timeseries")
dolphin._cli_filter.get_parser(subparser, "filter")
parsed_args = parser.parse_args(args=args)

arg_dict = vars(parsed_args)
Expand Down
137 changes: 136 additions & 1 deletion src/dolphin/filtering.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,10 @@
import multiprocessing as mp
from concurrent.futures import ProcessPoolExecutor
from itertools import repeat
from pathlib import Path

import numpy as np
from numpy.typing import ArrayLike
from numpy.typing import ArrayLike, NDArray
from scipy import fft, ndimage


Expand Down Expand Up @@ -36,9 +41,15 @@ def filter_long_wavelength(
filtered interferogram that does not contain signals with spatial wavelength
longer than a threshold.
Raises
------
ValueError
If wavelength_cutoff too large for image size/pixel spacing.
"""
good_pixel_mask = ~bad_pixel_mask

rows, cols = unwrapped_phase.shape
unw0 = np.nan_to_num(unwrapped_phase)
# Take either nan or 0 pixels in `unwrapped_phase` to be nodata
nodata_mask = unw0 == 0
Expand All @@ -53,6 +64,10 @@ def filter_long_wavelength(
# Find the filter `sigma` which gives the correct cutoff in meters
sigma = _compute_filter_sigma(wavelength_cutoff, pixel_spacing, cutoff_value=0.5)

if sigma > unw0.shape[0] or sigma > unw0.shape[0]:
msg = f"{wavelength_cutoff = } too large for image."
msg += f"Shape = {(rows, cols)}, and {pixel_spacing = }"
raise ValueError(msg)
# Pad the array with edge values
# The padding extends further than the default "radius = 2*sigma + 1",
# which given specified in `gaussian_filter`
Expand Down Expand Up @@ -124,3 +139,123 @@ def fit_ramp_plane(unw_ifg: ArrayLike, mask: ArrayLike) -> np.ndarray:
plane = np.reshape(X_ @ theta, (nrow, ncol))

return plane


def filter_rasters(
unw_filenames: list[Path],
cor_filenames: list[Path] | None = None,
conncomp_filenames: list[Path] | None = None,
temporal_coherence_filename: Path | None = None,
wavelength_cutoff: float = 50_000,
correlation_cutoff: float = 0.5,
output_dir: Path | None = None,
max_workers: int = 4,
) -> list[Path]:
"""Filter a list of unwrapped interferogram files using a long-wavelength filter.
Remove long-wavelength components from each unwrapped interferogram.
It can optionally use temporal coherence, correlation, and connected component
information for masking.
Parameters
----------
unw_filenames : list[Path]
List of paths to unwrapped interferogram files to be filtered.
cor_filenames : list[Path] | None
List of paths to correlation files
Passing None skips filtering on correlation.
conncomp_filenames : list[Path] | None
List of paths to connected component files, filters any 0 labelled pixels.
Passing None skips filtering on connected component labels.
temporal_coherence_filename : Path | None
Path to the temporal coherence file for masking.
Passing None skips filtering on temporal coherence.
wavelength_cutoff : float, optional
Spatial wavelength cutoff (in meters) for the filter. Default is 50,000 meters.
correlation_cutoff : float, optional
Threshold of correlation (if passing `cor_filenames`) to use to ignore pixels
during filtering.
output_dir : Path | None, optional
Directory to save the filtered results.
If None, saves in the same location as inputs with .filt.tif extension.
max_workers : int, optional
Number of parallel images to process. Default is 4.
Returns
-------
list[Path]
Output filtered rasters.
Notes
-----
- If temporal_coherence_filename is provided, pixels with coherence < 0.5 are masked
"""
from dolphin import io

bad_pixel_mask = np.zeros(
io.get_raster_xysize(unw_filenames[0])[::-1], dtype="bool"
)
if temporal_coherence_filename:
bad_pixel_mask = bad_pixel_mask | (
io.load_gdal(temporal_coherence_filename) < 0.5
)

if output_dir is None:
assert unw_filenames
output_dir = unw_filenames[0].parent
output_dir.mkdir(exist_ok=True)
ctx = mp.get_context("spawn")

with ProcessPoolExecutor(max_workers, mp_context=ctx) as pool:
return list(
pool.map(
_filter_and_save,
unw_filenames,
cor_filenames or repeat(None),
conncomp_filenames or repeat(None),
repeat(output_dir),
repeat(wavelength_cutoff),
repeat(bad_pixel_mask),
repeat(correlation_cutoff),
)
)


def _filter_and_save(
unw_filename: Path,
cor_path: Path | None,
conncomp_path: Path | None,
output_dir: Path,
wavelength_cutoff: float,
bad_pixel_mask: NDArray[np.bool_],
correlation_cutoff: float = 0.5,
) -> Path:
"""Filter one interferogram (wrapper for multiprocessing)."""
from dolphin import io
from dolphin._overviews import Resampling, create_image_overviews

# Average for the pixel spacing for filtering
_, x_res, _, _, _, y_res = io.get_raster_gt(unw_filename)
pixel_spacing = (abs(x_res) + abs(y_res)) / 2

if cor_path is not None:
bad_pixel_mask |= io.load_gdal(cor_path) < correlation_cutoff
if conncomp_path is not None:
bad_pixel_mask |= io.load_gdal(conncomp_path, masked=True).astype(bool) == 0

unw = io.load_gdal(unw_filename)
filt_arr = filter_long_wavelength(
unwrapped_phase=unw,
wavelength_cutoff=wavelength_cutoff,
bad_pixel_mask=bad_pixel_mask,
pixel_spacing=pixel_spacing,
workers=1,
)
io.round_mantissa(filt_arr, keep_bits=9)
output_name = output_dir / Path(unw_filename).with_suffix(".filt.tif").name
io.write_arr(arr=filt_arr, like_filename=unw_filename, output_name=output_name)

create_image_overviews(output_name, resampling=Resampling.AVERAGE)

return output_name
6 changes: 3 additions & 3 deletions tests/test_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,17 +11,17 @@ def test_help(capsys, option):
with contextlib.suppress(SystemExit):
main([option])
output = capsys.readouterr().out
assert " dolphin [-h] [--version] {run,config,unwrap,timeseries}" in output
assert " dolphin [-h] [--version] {run,config,unwrap,timeseries,filter}" in output


def test_empty(capsys):
with contextlib.suppress(SystemExit):
main([])
output = capsys.readouterr().out
assert " dolphin [-h] [--version] {run,config,unwrap,timeseries}" in output
assert " dolphin [-h] [--version] {run,config,unwrap,timeseries,filter}" in output


@pytest.mark.parametrize("sub_cmd", ["run", "config", "unwrap", "timeseries"])
@pytest.mark.parametrize("sub_cmd", ["run", "config", "filter", "unwrap", "timeseries"])
@pytest.mark.parametrize("option", ["-h", "--help"])
def test_subcommand_help(capsys, sub_cmd, option):
with contextlib.suppress(SystemExit):
Expand Down
50 changes: 47 additions & 3 deletions tests/test_filtering.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,12 @@
from pathlib import Path

import numpy as np
import pytest

from dolphin import filtering
from dolphin import filtering, io


def test_filter_long_wavelegnth():
def test_filter_long_wavelength():
# Check filtering with ramp phase
y, x = np.ogrid[-3:3:512j, -3:3:512j]
unw_ifg = np.pi * (x + y)
Expand All @@ -12,10 +15,51 @@ def test_filter_long_wavelegnth():

# Filtering
filtered_ifg = filtering.filter_long_wavelength(
unw_ifg, bad_pixel_mask=bad_pixel_mask, pixel_spacing=300
unw_ifg, bad_pixel_mask=bad_pixel_mask, pixel_spacing=1000
)
np.testing.assert_allclose(
filtered_ifg[10:-10, 10:-10],
np.zeros(filtered_ifg[10:-10, 10:-10].shape),
atol=1.0,
)


def test_filter_long_wavelength_too_large_cutoff():
# Check filtering with ramp phase
y, x = np.ogrid[-3:3:512j, -3:3:512j]
unw_ifg = np.pi * (x + y)
bad_pixel_mask = np.zeros(unw_ifg.shape, dtype=bool)

with pytest.raises(ValueError):
filtering.filter_long_wavelength(
unw_ifg,
bad_pixel_mask=bad_pixel_mask,
pixel_spacing=1,
wavelength_cutoff=50_000,
)


@pytest.fixture()
def unw_files(tmp_path):
"""Make series of files offset in lat/lon."""
shape = (3, 9, 9)

y, x = np.ogrid[-3:3:512j, -3:3:512j]
file_list = []
for i in range(shape[0]):
unw_arr = (i + 1) * np.pi * (x + y)
fname = tmp_path / f"unw_{i}.tif"
io.write_arr(arr=unw_arr, output_name=fname)
file_list.append(Path(fname))

return file_list


def test_filter(tmp_path, unw_files):
output_dir = Path(tmp_path) / "filtered"
filtering.filter_rasters(
unw_filenames=unw_files,
output_dir=output_dir,
max_workers=1,
wavelength_cutoff=50,
)

0 comments on commit 5cc6f4a

Please sign in to comment.