From b6b9ce371864c4da01b03955858e13dda69797fc Mon Sep 17 00:00:00 2001 From: Kirill Kouzoubov Date: Tue, 11 Jun 2024 18:56:48 +1000 Subject: [PATCH 1/2] fix: Dask reproject from GCP based source - remove incorrect type assert - make dask_rio_reproject "public" --- odc/geo/_dask.py | 3 +-- odc/geo/_xr_interop.py | 4 ++-- 2 files changed, 3 insertions(+), 4 deletions(-) diff --git a/odc/geo/_dask.py b/odc/geo/_dask.py index a8d7d4e..68ca299 100644 --- a/odc/geo/_dask.py +++ b/odc/geo/_dask.py @@ -70,7 +70,7 @@ def _do_chunked_reproject( return dst -def _dask_rio_reproject( +def dask_rio_reproject( src: da.Array, s_gbox: Union[GeoBox, GCPGeoBox], d_gbox: GeoBox, @@ -94,7 +94,6 @@ def with_yx(a, yx): name: str = kwargs.pop("name", "reproject") - assert isinstance(s_gbox, GeoBox) gbt_src = GeoboxTiles(s_gbox, src.chunks[ydim : ydim + 2]) gbt_dst = GeoboxTiles(d_gbox, chunks) d2s_idx = gbt_dst.grid_intersect(gbt_src) diff --git a/odc/geo/_xr_interop.py b/odc/geo/_xr_interop.py index b149eca..a1637c3 100644 --- a/odc/geo/_xr_interop.py +++ b/odc/geo/_xr_interop.py @@ -821,9 +821,9 @@ def _xr_reproject_da( dst_nodata = src_nodata if is_dask_collection(src): - from ._dask import _dask_rio_reproject + from ._dask import dask_rio_reproject - dst: Any = _dask_rio_reproject( + dst: Any = dask_rio_reproject( src.data, src_gbox, dst_geobox, From d9bd9c96b47a555e2d9840a0b68c982f2d3a7a7c Mon Sep 17 00:00:00 2001 From: Kirill Kouzoubov Date: Sun, 16 Jun 2024 11:04:06 +1000 Subject: [PATCH 2/2] refactor: reproject dtype and nodata handling - allow dtype change as part of reprojection - more consistent handling of destination fill value defaults between different implementations --- odc/geo/_dask.py | 38 ++++++++++++++++++----------------- odc/geo/_xr_interop.py | 43 +++++++++++++++++++++++++++------------- odc/geo/warp.py | 32 ++++++++++++++++++++++++------ tests/test_xr_interop.py | 3 +++ 4 files changed, 78 insertions(+), 38 deletions(-) diff --git a/odc/geo/_dask.py b/odc/geo/_dask.py index 68ca299..004583c 100644 --- a/odc/geo/_dask.py +++ b/odc/geo/_dask.py @@ -9,19 +9,13 @@ from ._blocks import BlockAssembler from .gcp import GCPGeoBox from .geobox import GeoBox, GeoboxTiles -from .warp import Nodata, Resampling, _rio_reproject, resampling_s2rio - - -def resolve_fill_value(dst_nodata, src_nodata, dtype): - dtype = np.dtype(dtype) - - if dst_nodata is not None: - return dtype.type(dst_nodata) - if src_nodata is not None: - return dtype.type(src_nodata) - if np.issubdtype(dtype, np.floating): - return dtype.type("nan") - return dtype.type(0) +from .warp import ( + Nodata, + Resampling, + _rio_reproject, + resampling_s2rio, + resolve_fill_value, +) def _do_chunked_reproject( @@ -50,7 +44,11 @@ def _do_chunked_reproject( dtype = ba.dtype dst_shape = ba.with_yx(ba.shape, dst_gbox.shape) - dst = np.zeros(dst_shape, dtype=dtype) + dst = np.full( + dst_shape, + resolve_fill_value(dst_nodata, src_nodata, dtype), + dtype=dtype, + ) for src_roi in ba.planes_yx(): src = ba.extract(src_nodata, dtype=dtype, casting=casting, roi=src_roi) @@ -79,6 +77,7 @@ def dask_rio_reproject( dst_nodata: Nodata = None, ydim: int = 0, chunks: Optional[Tuple[int, int]] = None, + dtype=None, **kwargs, ) -> da.Array: # pylint: disable=too-many-arguments, too-many-locals @@ -92,6 +91,9 @@ def dask_rio_reproject( def with_yx(a, yx): return (*a[:ydim], *yx, *a[ydim + 2 :]) + if dtype is None: + dtype = src.dtype + name: str = kwargs.pop("name", "reproject") gbt_src = GeoboxTiles(s_gbox, src.chunks[ydim : ydim + 2]) @@ -100,6 +102,7 @@ def with_yx(a, yx): dst_shape = with_yx(src.shape, d_gbox.shape.yx) dst_chunks: Tuple[Tuple[int, ...], ...] = with_yx(src.chunks, gbt_dst.chunks) + fill_value = resolve_fill_value(dst_nodata, src_nodata, dtype) tk = uuid4().hex name = f"{name}-{tk}" @@ -111,15 +114,14 @@ def with_yx(a, yx): gbt_src, gbt_dst, src_nodata=src_nodata, - dst_nodata=dst_nodata, + dst_nodata=fill_value, axis=ydim, resampling=resampling, + dtype=dtype, **kwargs, ) src_block_keys = src.__dask_keys__() - fill_value = resolve_fill_value(dst_nodata, src_nodata, src.dtype) - def _src(idx): a = src_block_keys for i in idx: @@ -141,4 +143,4 @@ def _src(idx): dsk = HighLevelGraph.from_collections(name, dsk, dependencies=(src,)) - return da.Array(dsk, name, chunks=dst_chunks, dtype=src.dtype, shape=dst_shape) + return da.Array(dsk, name, chunks=dst_chunks, dtype=dtype, shape=dst_shape) diff --git a/odc/geo/_xr_interop.py b/odc/geo/_xr_interop.py index a1637c3..45ffeec 100644 --- a/odc/geo/_xr_interop.py +++ b/odc/geo/_xr_interop.py @@ -45,6 +45,7 @@ from .overlap import compute_output_geobox from .roi import roi_is_empty from .types import Resolution, SomeResolution, SomeShape, xy_ +from .warp import resolve_fill_value # pylint: disable=import-outside-toplevel # pylint: disable=too-many-lines @@ -63,6 +64,8 @@ # these attributes are pruned during reproject SPATIAL_ATTRIBUTES = ("crs", "crs_wkt", "grid_mapping", "gcps", "epsg") +NODATA_ATTRIBUTES = ("nodata", "_FillValue") +REPROJECT_SKIP_ATTRS: set[str] = set(SPATIAL_ATTRIBUTES + NODATA_ATTRIBUTES) # dimensions with these names are considered spatial STANDARD_SPATIAL_DIMS = [ @@ -654,6 +657,7 @@ def xr_reproject( *, resampling: Union[str, int] = "nearest", dst_nodata: Optional[float] = None, + dtype=None, resolution: Union[SomeResolution, Literal["auto", "fit", "same"]] = "auto", shape: Union[SomeShape, int, None] = None, tight: bool = False, @@ -728,10 +732,10 @@ def xr_reproject( } if isinstance(src, xarray.DataArray): return _xr_reproject_da( - src, how, resampling=resampling, dst_nodata=dst_nodata, **kw + src, how, resampling=resampling, dst_nodata=dst_nodata, dtype=dtype, **kw ) return _xr_reproject_ds( - src, how, resampling=resampling, dst_nodata=dst_nodata, **kw + src, how, resampling=resampling, dst_nodata=dst_nodata, dtype=dtype, **kw ) @@ -750,6 +754,7 @@ def _xr_reproject_ds( *, resampling: Union[str, int] = "nearest", dst_nodata: Optional[float] = None, + dtype=None, **kw, ) -> xarray.Dataset: assert isinstance(src, xarray.Dataset) @@ -776,7 +781,12 @@ def _maybe_reproject(dv: xarray.DataArray): dv = dv.drop_vars(strip_coords) return dv return _xr_reproject_da( - dv, how=dst_geobox, resampling=resampling, dst_nodata=dst_nodata, **kw + dv, + how=dst_geobox, + resampling=resampling, + dst_nodata=dst_nodata, + dtype=dtype, + **kw, ) return src.map(_maybe_reproject) @@ -788,6 +798,7 @@ def _xr_reproject_da( *, resampling: Union[str, int] = "nearest", dst_nodata: Optional[float] = None, + dtype=None, **kw, ) -> xarray.DataArray: # pylint: disable=too-many-locals @@ -809,6 +820,9 @@ def _xr_reproject_da( else: dst_geobox = src.odc.output_geobox(how, **kw_gbox) + if dtype is None: + dtype = src.dtype + # compute destination shape by replacing spatial dimensions shape ydim = src.odc.ydim assert ydim + 1 == src.odc.xdim @@ -817,8 +831,8 @@ def _xr_reproject_da( src_nodata = kw.pop("src_nodata", None) if src_nodata is None: src_nodata = src.odc.nodata - if dst_nodata is None: - dst_nodata = src_nodata + + fill_value = resolve_fill_value(dst_nodata, src_nodata, dtype) if is_dask_collection(src): from ._dask import dask_rio_reproject @@ -829,12 +843,13 @@ def _xr_reproject_da( dst_geobox, resampling=resampling, src_nodata=src_nodata, - dst_nodata=dst_nodata, + dst_nodata=fill_value, ydim=ydim, + dtype=dtype, **kw, ) else: - dst = numpy.empty(dst_shape, dtype=src.dtype) + dst = numpy.full(dst_shape, fill_value, dtype=dtype) dst = rio_reproject( src.values, @@ -843,17 +858,17 @@ def _xr_reproject_da( dst_geobox, resampling=resampling, src_nodata=src_nodata, - dst_nodata=dst_nodata, + dst_nodata=fill_value, ydim=ydim, + dtype=dtype, **kw, ) - attrs = {k: v for k, v in src.attrs.items() if k not in SPATIAL_ATTRIBUTES} - if dst_nodata is None: - attrs.pop("nodata", None) - attrs.pop("_FillValue", None) - else: - attrs.update(nodata=maybe_int(dst_nodata, 1e-6)) + attrs = {k: v for k, v in src.attrs.items() if k not in REPROJECT_SKIP_ATTRS} + if numpy.isfinite(fill_value) and ( + dst_nodata is not None or src_nodata is not None + ): + attrs.update({k: maybe_int(float(fill_value), 1e-6) for k in NODATA_ATTRIBUTES}) # new set of coords (replace x,y dims) # discard all coords that reference spatial dimensions diff --git a/odc/geo/warp.py b/odc/geo/warp.py index e9f50d7..b441d01 100644 --- a/odc/geo/warp.py +++ b/odc/geo/warp.py @@ -17,6 +17,15 @@ Nodata = Optional[Union[int, float]] _WRP_CRS = "epsg:3857" +__all__ = [ + "resampling_s2rio", + "is_resampling_nn", + "resolve_fill_value", + "warp_affine", + "warp_affine_rio", + "rio_reproject", +] + def resampling_s2rio(name: str) -> rasterio.warp.Resampling: """ @@ -38,6 +47,18 @@ def is_resampling_nn(resampling: Resampling) -> bool: return resampling == rasterio.warp.Resampling.nearest +def resolve_fill_value(dst_nodata, src_nodata, dtype): + dtype = np.dtype(dtype) + + if dst_nodata is not None: + return dtype.type(dst_nodata) + if np.issubdtype(dtype, np.floating): + return dtype.type("nan") + if src_nodata is not None: + return dtype.type(src_nodata) + return dtype.type(0) + + def warp_affine_rio( src: np.ndarray, dst: np.ndarray, @@ -129,15 +150,14 @@ def rio_reproject( :returns: dst """ assert src.ndim == dst.ndim - if dst_nodata is None: - if dst.dtype.kind == "f": - dst_nodata = np.nan if src.ndim == 2: return _rio_reproject( src, dst, s_gbox, d_gbox, resampling, src_nodata, dst_nodata, **kwargs ) + fill_value = resolve_fill_value(dst_nodata, src_nodata, dst.dtype) + if ydim is None: # Assume last two dimensions are Y/X ydim = src.ndim - 2 @@ -154,9 +174,9 @@ def rio_reproject( dst[roi], s_gbox, d_gbox, - resampling, - src_nodata, - dst_nodata, + resampling=resampling, + src_nodata=src_nodata, + dst_nodata=fill_value, **kwargs, ) return dst diff --git a/tests/test_xr_interop.py b/tests/test_xr_interop.py index 10e74ca..1b8272b 100644 --- a/tests/test_xr_interop.py +++ b/tests/test_xr_interop.py @@ -425,6 +425,9 @@ def test_xr_reproject(xx_epsg4326: xr.DataArray): assert xx.odc.geobox == dst_gbox assert xx.encoding["grid_mapping"] == "spatial_ref" assert "crs" not in xx.attrs + assert xx.dtype == xx_epsg4326.dtype + + assert xx_epsg4326.odc.reproject(3857, dtype="float32").dtype == "float32" yy = xr.Dataset({"a": xx0, "b": xx0 + 1, "c": xr.DataArray([2, 3, 4])}) assert isinstance(yy.odc, ODCExtensionDs)