Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fixes and additions for .odc.reproject(..) #161

Merged
merged 2 commits into from
Jun 16, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
41 changes: 21 additions & 20 deletions odc/geo/_dask.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,19 +9,13 @@
from ._blocks import BlockAssembler
from .gcp import GCPGeoBox
from .geobox import GeoBox, GeoboxTiles
from .warp import Nodata, Resampling, _rio_reproject, resampling_s2rio


def resolve_fill_value(dst_nodata, src_nodata, dtype):
dtype = np.dtype(dtype)

if dst_nodata is not None:
return dtype.type(dst_nodata)
if src_nodata is not None:
return dtype.type(src_nodata)
if np.issubdtype(dtype, np.floating):
return dtype.type("nan")
return dtype.type(0)
from .warp import (
Nodata,
Resampling,
_rio_reproject,
resampling_s2rio,
resolve_fill_value,
)


def _do_chunked_reproject(
Expand Down Expand Up @@ -50,7 +44,11 @@ def _do_chunked_reproject(
dtype = ba.dtype

dst_shape = ba.with_yx(ba.shape, dst_gbox.shape)
dst = np.zeros(dst_shape, dtype=dtype)
dst = np.full(
dst_shape,
resolve_fill_value(dst_nodata, src_nodata, dtype),
dtype=dtype,
)

for src_roi in ba.planes_yx():
src = ba.extract(src_nodata, dtype=dtype, casting=casting, roi=src_roi)
Expand All @@ -70,7 +68,7 @@ def _do_chunked_reproject(
return dst


def _dask_rio_reproject(
def dask_rio_reproject(
src: da.Array,
s_gbox: Union[GeoBox, GCPGeoBox],
d_gbox: GeoBox,
Expand All @@ -79,6 +77,7 @@ def _dask_rio_reproject(
dst_nodata: Nodata = None,
ydim: int = 0,
chunks: Optional[Tuple[int, int]] = None,
dtype=None,
**kwargs,
) -> da.Array:
# pylint: disable=too-many-arguments, too-many-locals
Expand All @@ -92,15 +91,18 @@ def _dask_rio_reproject(
def with_yx(a, yx):
return (*a[:ydim], *yx, *a[ydim + 2 :])

if dtype is None:
dtype = src.dtype

name: str = kwargs.pop("name", "reproject")

assert isinstance(s_gbox, GeoBox)
gbt_src = GeoboxTiles(s_gbox, src.chunks[ydim : ydim + 2])
gbt_dst = GeoboxTiles(d_gbox, chunks)
d2s_idx = gbt_dst.grid_intersect(gbt_src)

dst_shape = with_yx(src.shape, d_gbox.shape.yx)
dst_chunks: Tuple[Tuple[int, ...], ...] = with_yx(src.chunks, gbt_dst.chunks)
fill_value = resolve_fill_value(dst_nodata, src_nodata, dtype)

tk = uuid4().hex
name = f"{name}-{tk}"
Expand All @@ -112,15 +114,14 @@ def with_yx(a, yx):
gbt_src,
gbt_dst,
src_nodata=src_nodata,
dst_nodata=dst_nodata,
dst_nodata=fill_value,
axis=ydim,
resampling=resampling,
dtype=dtype,
**kwargs,
)
src_block_keys = src.__dask_keys__()

fill_value = resolve_fill_value(dst_nodata, src_nodata, src.dtype)

def _src(idx):
a = src_block_keys
for i in idx:
Expand All @@ -142,4 +143,4 @@ def _src(idx):

dsk = HighLevelGraph.from_collections(name, dsk, dependencies=(src,))

return da.Array(dsk, name, chunks=dst_chunks, dtype=src.dtype, shape=dst_shape)
return da.Array(dsk, name, chunks=dst_chunks, dtype=dtype, shape=dst_shape)
47 changes: 31 additions & 16 deletions odc/geo/_xr_interop.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@
from .overlap import compute_output_geobox
from .roi import roi_is_empty
from .types import Resolution, SomeResolution, SomeShape, xy_
from .warp import resolve_fill_value

# pylint: disable=import-outside-toplevel
# pylint: disable=too-many-lines
Expand All @@ -63,6 +64,8 @@

# these attributes are pruned during reproject
SPATIAL_ATTRIBUTES = ("crs", "crs_wkt", "grid_mapping", "gcps", "epsg")
NODATA_ATTRIBUTES = ("nodata", "_FillValue")
REPROJECT_SKIP_ATTRS: set[str] = set(SPATIAL_ATTRIBUTES + NODATA_ATTRIBUTES)

# dimensions with these names are considered spatial
STANDARD_SPATIAL_DIMS = [
Expand Down Expand Up @@ -654,6 +657,7 @@ def xr_reproject(
*,
resampling: Union[str, int] = "nearest",
dst_nodata: Optional[float] = None,
dtype=None,
resolution: Union[SomeResolution, Literal["auto", "fit", "same"]] = "auto",
shape: Union[SomeShape, int, None] = None,
tight: bool = False,
Expand Down Expand Up @@ -728,10 +732,10 @@ def xr_reproject(
}
if isinstance(src, xarray.DataArray):
return _xr_reproject_da(
src, how, resampling=resampling, dst_nodata=dst_nodata, **kw
src, how, resampling=resampling, dst_nodata=dst_nodata, dtype=dtype, **kw
)
return _xr_reproject_ds(
src, how, resampling=resampling, dst_nodata=dst_nodata, **kw
src, how, resampling=resampling, dst_nodata=dst_nodata, dtype=dtype, **kw
)


Expand All @@ -750,6 +754,7 @@ def _xr_reproject_ds(
*,
resampling: Union[str, int] = "nearest",
dst_nodata: Optional[float] = None,
dtype=None,
**kw,
) -> xarray.Dataset:
assert isinstance(src, xarray.Dataset)
Expand All @@ -776,7 +781,12 @@ def _maybe_reproject(dv: xarray.DataArray):
dv = dv.drop_vars(strip_coords)
return dv
return _xr_reproject_da(
dv, how=dst_geobox, resampling=resampling, dst_nodata=dst_nodata, **kw
dv,
how=dst_geobox,
resampling=resampling,
dst_nodata=dst_nodata,
dtype=dtype,
**kw,
)

return src.map(_maybe_reproject)
Expand All @@ -788,6 +798,7 @@ def _xr_reproject_da(
*,
resampling: Union[str, int] = "nearest",
dst_nodata: Optional[float] = None,
dtype=None,
**kw,
) -> xarray.DataArray:
# pylint: disable=too-many-locals
Expand All @@ -809,6 +820,9 @@ def _xr_reproject_da(
else:
dst_geobox = src.odc.output_geobox(how, **kw_gbox)

if dtype is None:
dtype = src.dtype

# compute destination shape by replacing spatial dimensions shape
ydim = src.odc.ydim
assert ydim + 1 == src.odc.xdim
Expand All @@ -817,24 +831,25 @@ def _xr_reproject_da(
src_nodata = kw.pop("src_nodata", None)
if src_nodata is None:
src_nodata = src.odc.nodata
if dst_nodata is None:
dst_nodata = src_nodata

fill_value = resolve_fill_value(dst_nodata, src_nodata, dtype)

if is_dask_collection(src):
from ._dask import _dask_rio_reproject
from ._dask import dask_rio_reproject

dst: Any = _dask_rio_reproject(
dst: Any = dask_rio_reproject(
src.data,
src_gbox,
dst_geobox,
resampling=resampling,
src_nodata=src_nodata,
dst_nodata=dst_nodata,
dst_nodata=fill_value,
ydim=ydim,
dtype=dtype,
**kw,
)
else:
dst = numpy.empty(dst_shape, dtype=src.dtype)
dst = numpy.full(dst_shape, fill_value, dtype=dtype)

dst = rio_reproject(
src.values,
Expand All @@ -843,17 +858,17 @@ def _xr_reproject_da(
dst_geobox,
resampling=resampling,
src_nodata=src_nodata,
dst_nodata=dst_nodata,
dst_nodata=fill_value,
ydim=ydim,
dtype=dtype,
**kw,
)

attrs = {k: v for k, v in src.attrs.items() if k not in SPATIAL_ATTRIBUTES}
if dst_nodata is None:
attrs.pop("nodata", None)
attrs.pop("_FillValue", None)
else:
attrs.update(nodata=maybe_int(dst_nodata, 1e-6))
attrs = {k: v for k, v in src.attrs.items() if k not in REPROJECT_SKIP_ATTRS}
if numpy.isfinite(fill_value) and (
dst_nodata is not None or src_nodata is not None
):
attrs.update({k: maybe_int(float(fill_value), 1e-6) for k in NODATA_ATTRIBUTES})

# new set of coords (replace x,y dims)
# discard all coords that reference spatial dimensions
Expand Down
32 changes: 26 additions & 6 deletions odc/geo/warp.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,15 @@
Nodata = Optional[Union[int, float]]
_WRP_CRS = "epsg:3857"

__all__ = [
"resampling_s2rio",
"is_resampling_nn",
"resolve_fill_value",
"warp_affine",
"warp_affine_rio",
"rio_reproject",
]


def resampling_s2rio(name: str) -> rasterio.warp.Resampling:
"""
Expand All @@ -38,6 +47,18 @@ def is_resampling_nn(resampling: Resampling) -> bool:
return resampling == rasterio.warp.Resampling.nearest


def resolve_fill_value(dst_nodata, src_nodata, dtype):
dtype = np.dtype(dtype)

if dst_nodata is not None:
return dtype.type(dst_nodata)
if np.issubdtype(dtype, np.floating):
return dtype.type("nan")
if src_nodata is not None:
return dtype.type(src_nodata)
return dtype.type(0)


def warp_affine_rio(
src: np.ndarray,
dst: np.ndarray,
Expand Down Expand Up @@ -129,15 +150,14 @@ def rio_reproject(
:returns: dst
"""
assert src.ndim == dst.ndim
if dst_nodata is None:
if dst.dtype.kind == "f":
dst_nodata = np.nan

if src.ndim == 2:
return _rio_reproject(
src, dst, s_gbox, d_gbox, resampling, src_nodata, dst_nodata, **kwargs
)

fill_value = resolve_fill_value(dst_nodata, src_nodata, dst.dtype)

if ydim is None:
# Assume last two dimensions are Y/X
ydim = src.ndim - 2
Expand All @@ -154,9 +174,9 @@ def rio_reproject(
dst[roi],
s_gbox,
d_gbox,
resampling,
src_nodata,
dst_nodata,
resampling=resampling,
src_nodata=src_nodata,
dst_nodata=fill_value,
**kwargs,
)
return dst
Expand Down
3 changes: 3 additions & 0 deletions tests/test_xr_interop.py
Original file line number Diff line number Diff line change
Expand Up @@ -425,6 +425,9 @@ def test_xr_reproject(xx_epsg4326: xr.DataArray):
assert xx.odc.geobox == dst_gbox
assert xx.encoding["grid_mapping"] == "spatial_ref"
assert "crs" not in xx.attrs
assert xx.dtype == xx_epsg4326.dtype

assert xx_epsg4326.odc.reproject(3857, dtype="float32").dtype == "float32"

yy = xr.Dataset({"a": xx0, "b": xx0 + 1, "c": xr.DataArray([2, 3, 4])})
assert isinstance(yy.odc, ODCExtensionDs)
Expand Down
Loading