Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Xarray percentile - support non-dask input and percentiles in a new #248

Merged
merged 12 commits into from
Sep 21, 2021
3 changes: 2 additions & 1 deletion libs/algo/odc/algo/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,7 @@

from ._tiff import save_cog

from ._percentile import xr_quantile
from ._broadcast import (
pool_broadcast,
)
Expand All @@ -72,7 +73,6 @@
seq_to_bags,
)


__all__ = (
"apply_numexpr",
"safe_div",
Expand Down Expand Up @@ -117,6 +117,7 @@
"colorize",
"xr_reproject",
"save_cog",
"xr_quantile",
"pool_broadcast",
"dask_compute_stream",
"seq_to_bags",
Expand Down
92 changes: 77 additions & 15 deletions libs/algo/odc/algo/_percentile.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import numpy as np
from ._masking import keep_good_np
from dask.base import tokenize
import dask
from functools import partial


Expand Down Expand Up @@ -34,9 +35,9 @@ def np_percentile(xx, percentile, nodata):
return keep_good_np(xx, (valid_counts >= 3), nodata)


def xr_percentile(
def xr_quantile_bands(
src: xr.Dataset,
percentiles: Sequence,
quantiles: Sequence,
nodata,
) -> xr.Dataset:

Expand All @@ -49,7 +50,7 @@ def xr_percentile(
float or integer with `nodata` values to indicate gaps in data.
`nodata` must be the largest or smallest values in the dataset or NaN.

:param percentiles: A sequence of percentiles in the [0.0, 1.0] range
:param percentiles: A sequence of quantiles in the [0.0, 1.0] range

:param nodata: The `nodata` value
"""
Expand All @@ -58,20 +59,81 @@ def xr_percentile(
for band, xx in src.data_vars.items():

xx_data = xx.data
if len(xx.chunks[0]) > 1:
xx_data = xx_data.rechunk({0: -1})

if dask.is_dask_collection(xx_data):
if len(xx.chunks[0]) > 1:
xx_data = xx_data.rechunk({0: -1})

tk = tokenize(xx_data, percentiles, nodata)
for percentile in percentiles:
name = f"{band}_pc_{int(100 * percentile)}"
yy = da.map_blocks(
partial(np_percentile, percentile=percentile, nodata=nodata),
xx_data,
drop_axis=0,
meta=np.array([], dtype=xx.dtype),
name=f"{name}-{tk}",
)
tk = tokenize(xx_data, quantiles, nodata)
for quantile in quantiles:
name = f"{band}_pc_{int(100 * quantile)}"
if dask.is_dask_collection(xx_data):
yy = da.map_blocks(
partial(np_percentile, percentile=quantile, nodata=nodata),
xx_data,
drop_axis=0,
meta=np.array([], dtype=xx.dtype),
name=f"{name}-{tk}",
)
else:
yy = np_percentile(xx_data, percentile=quantile, nodata=nodata)
data_vars[name] = xr.DataArray(yy, dims=xx.dims[1:], attrs=xx.attrs)

coords = dict((dim, src.coords[dim]) for dim in xx.dims[1:])
return xr.Dataset(data_vars=data_vars, coords=coords, attrs=src.attrs)


def xr_quantile(
src: xr.Dataset,
quantiles: Sequence,
nodata,
) -> xr.Dataset:

"""
Calculates the percentiles of the input data along the time dimension.

This approach is approximately 700x faster than the `numpy` and `xarray` nanpercentile functions.

:param src: xr.Dataset, bands can be either
float or integer with `nodata` values to indicate gaps in data.
`nodata` must be the largest or smallest values in the dataset or NaN.

:param percentiles: A sequence of quantiles in the [0.0, 1.0] range

:param nodata: The `nodata` value
"""

data_vars = {}
for band, xx in src.data_vars.items():

xx_data = xx.data
out_dims = ('quantile',) + xx.dims[1:]

if dask.is_dask_collection(xx_data):
if len(xx.chunks[0]) > 1:
xx_data = xx_data.rechunk({0: -1})

tk = tokenize(xx_data, quantiles, nodata)
data = []
for quantile in quantiles:
name = f"{band}_pc_{int(100 * quantile)}"
if dask.is_dask_collection(xx_data):
yy = da.map_blocks(
partial(np_percentile, percentile=quantile, nodata=nodata),
xx_data,
drop_axis=0,
meta=np.array([], dtype=xx.dtype),
name=f"{name}-{tk}",
)
else:
yy = np_percentile(xx_data, percentile=quantile, nodata=nodata)
data.append(yy)

if dask.is_dask_collection(yy):
data_vars[band] = (out_dims, da.stack(data, axis=0))
else:
data_vars[band] = (out_dims, np.stack(data, axis=0))

coords = dict((dim, src.coords[dim]) for dim in xx.dims[1:])
coords['quantile'] = np.array(quantiles)
return xr.Dataset(data_vars=data_vars, coords=coords, attrs=src.attrs)
59 changes: 45 additions & 14 deletions libs/algo/tests/test_percentile.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from odc.algo._percentile import np_percentile, xr_percentile
from odc.algo._percentile import np_percentile, xr_quantile_bands, xr_quantile
import numpy as np
import pytest
import dask.array as da
Expand Down Expand Up @@ -55,8 +55,9 @@ def test_np_percentile_bad_data(nodata):
np.testing.assert_equal(np_percentile(arr, 0.0, nodata), np.array([nodata, 3]))


@pytest.mark.parametrize("nodata", [255, 200, np.nan, -1]) #should do -1
def test_xr_percentile(nodata):
@pytest.mark.parametrize("nodata", [255, 200, np.nan, -1])
@pytest.mark.parametrize("use_dask", [False, True])
def test_xr_quantile_bands(nodata, use_dask):
band_1 = np.random.randint(0, 100, size=(10, 100, 200)).astype(type(nodata))
band_2 = np.random.randint(0, 100, size=(10, 100, 200)).astype(type(nodata))

Expand All @@ -69,8 +70,42 @@ def test_xr_percentile(nodata):
true_results["band_1_pc_60"] = np_percentile(band_1, 0.6, nodata)
true_results["band_2_pc_60"] = np_percentile(band_2, 0.6, nodata)

band_1 = da.from_array(band_1, chunks=(2, 20, 20))
band_2 = da.from_array(band_2, chunks=(2, 20, 20))
if use_dask:
band_1 = da.from_array(band_1, chunks=(2, 20, 20))
band_2 = da.from_array(band_2, chunks=(2, 20, 20))

attrs = {"test": "attrs"}
coords = {
"x": np.linspace(10, 20, band_1.shape[2]),
"y": np.linspace(0, 5, band_1.shape[1]),
"t": np.linspace(0, 5, band_1.shape[0])
}

data_vars = {"band_1": (("t", "y", "x"), band_1), "band_2": (("t", "y", "x"), band_2)}

dataset = xr.Dataset(data_vars=data_vars, coords=coords, attrs=attrs)
output = xr_quantile_bands(dataset, [0.2, 0.6], nodata).compute()

for band in output.keys():
np.testing.assert_equal(output[band], true_results[band])


@pytest.mark.parametrize("nodata", [255, 200, np.nan, -1])
@pytest.mark.parametrize("use_dask", [False, True])
def test_xr_quantile(nodata, use_dask):
band_1 = np.random.randint(0, 100, size=(10, 100, 200)).astype(type(nodata))
band_2 = np.random.randint(0, 100, size=(10, 100, 200)).astype(type(nodata))

band_1[np.random.random(size=band_1.shape) > 0.5] = nodata
band_2[np.random.random(size=band_1.shape) > 0.5] = nodata

true_results = dict()
true_results["band_1"] = np.stack([np_percentile(band_1, 0.2, nodata), np_percentile(band_1, 0.6, nodata)], axis=0)
true_results["band_2"] = np.stack([np_percentile(band_2, 0.2, nodata), np_percentile(band_2, 0.6, nodata)], axis=0)

if use_dask:
band_1 = da.from_array(band_1, chunks=(2, 20, 20))
band_2 = da.from_array(band_2, chunks=(2, 20, 20))

attrs = {"test": "attrs"}
coords = {
Expand All @@ -85,12 +120,8 @@ def test_xr_percentile(nodata):
}

dataset = xr.Dataset(data_vars=data_vars, coords=coords, attrs=attrs)
output = xr_percentile(dataset, [0.2, 0.6], nodata).compute()

for key in output.keys():
np.testing.assert_equal(output[key], true_results[key])

assert output["band_1_pc_20"].attrs["test_attr"] == 1
assert output["band_1_pc_60"].attrs["test_attr"] == 1
assert output["band_2_pc_20"].attrs["test_attr"] == 2
assert output["band_2_pc_20"].attrs["test_attr"] == 2
output = xr_quantile(dataset, [0.2, 0.6], nodata).compute()

for band in output.keys():
np.testing.assert_equal(output[band], true_results[band])

4 changes: 2 additions & 2 deletions libs/stats/odc/stats/_fc_percentiles.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from odc.stats.model import Task
from odc.algo.io import load_with_native_transform
from odc.algo import keep_good_only
from odc.algo._percentile import xr_percentile
from odc.algo._percentile import xr_quantile_bands
from odc.algo._masking import _xr_fuse, _or_fuser, _fuse_mean_np, _fuse_or_np, _fuse_and_np
from .model import StatsPluginInterface
from . import _plugins
Expand Down Expand Up @@ -97,7 +97,7 @@ def reduce(xx: xr.Dataset) -> xr.Dataset:
wet = xx["wet"]
xx = xx.drop_vars(["wet"])

yy = xr_percentile(xx, [0.1, 0.5, 0.9], nodata=NODATA)
yy = xr_quantile_bands(xx, [0.1, 0.5, 0.9], nodata=NODATA)
is_ever_wet = _or_fuser(wet).squeeze(wet.dims[0], drop=True)

band, *bands = yy.data_vars.keys()
Expand Down