-
Notifications
You must be signed in to change notification settings - Fork 45
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
* first pass at refactoring detrend * remove comments * fix chunking * add 3D case * add error checking; implement 2D detrend * add more 2D test cases * black * refactor xrft and update tests * fix typo in detrend test * doc updates
- Loading branch information
Showing
6 changed files
with
241 additions
and
412 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -4,3 +4,4 @@ | |
del get_versions | ||
|
||
from .xrft import * # noqa | ||
from .detrend import detrend |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,96 @@ | ||
""" | ||
Functions for detrending xarray data. | ||
""" | ||
|
||
import numpy as np | ||
import xarray as xr | ||
import scipy.signal as sps | ||
import scipy.linalg as spl | ||
|
||
|
||
def detrend(da, dim, detrend_type="constant"): | ||
"""Detrend a DataArray | ||
Parameters | ||
---------- | ||
da : xarray.DataArray | ||
The data to detrend | ||
dim : str or list | ||
Dimensions along which to apply detrend. | ||
Can be either one dimension or a list with two dimensions. | ||
Higher-dimensional detrending is not supported. | ||
If dask data are passed, the data must be chunked along dim. | ||
detrend_type : {'constant', 'linear'} | ||
If ``constant``, a constant offset will be removed from each dim. | ||
If ``linear``, a linear least-squares fit will be estimated and removed | ||
from the data. | ||
Returns | ||
------- | ||
da : xarray.DataArray | ||
The detrended data. | ||
Notes | ||
----- | ||
This function will act lazily in the presence of dask arrays on the | ||
input. | ||
""" | ||
|
||
if detrend_type not in ["constant", "linear", None]: | ||
raise NotImplementedError( | ||
"%s is not a valid detrending option. Valid " | ||
"options are: 'constant','linear', or None." % detrend_type | ||
) | ||
|
||
if detrend_type is None: | ||
return da | ||
elif detrend_type == "constant": | ||
return da - da.mean(dim=dim) | ||
elif detrend_type == "linear": | ||
data = da.data | ||
axis_num = [da.get_axis_num(d) for d in dim] | ||
chunks = getattr(data, "chunks", None) | ||
if chunks: | ||
axis_chunks = [data.chunks[a] for a in axis_num] | ||
if not all([len(ac) == 1 for ac in axis_chunks]): | ||
raise ValueError("Contiguous chunks required for detrending.") | ||
if len(dim) == 1: | ||
dt = xr.apply_ufunc( | ||
sps.detrend, | ||
da, | ||
axis_num[0], | ||
output_dtypes=[da.dtype], | ||
dask="parallelized", | ||
) | ||
elif len(dim) == 2: | ||
dt = xr.apply_ufunc( | ||
_detrend_2d_ufunc, | ||
da, | ||
input_core_dims=[dim], | ||
output_core_dims=[dim], | ||
output_dtypes=[da.dtype], | ||
vectorize=True, | ||
dask="parallelized", | ||
) | ||
else: # pragma: no cover | ||
raise NotImplementedError( | ||
"Only 1D and 2D detrending are implemented so far." | ||
) | ||
|
||
return dt | ||
|
||
|
||
def _detrend_2d_ufunc(arr): | ||
assert arr.ndim == 2 | ||
N = arr.shape | ||
|
||
col0 = np.ones(N[0] * N[1]) | ||
col1 = np.repeat(np.arange(N[0]), N[1]) + 1 | ||
col2 = np.tile(np.arange(N[1]), N[0]) + 1 | ||
G = np.stack([col0, col1, col2]).transpose() | ||
|
||
d_obs = np.reshape(arr, (N[0] * N[1], 1)) | ||
m_est = np.dot(np.dot(spl.inv(np.dot(G.T, G)), G.T), d_obs) | ||
d_est = np.dot(G, m_est) | ||
linear_fit = np.reshape(d_est, N) | ||
return arr - linear_fit |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,106 @@ | ||
import numpy as np | ||
import xarray as xr | ||
import scipy.signal as sps | ||
|
||
import pytest | ||
import numpy.testing as npt | ||
import xarray.testing as xrt | ||
|
||
import xrft | ||
from xrft.detrend import detrend | ||
|
||
|
||
def detrended_noise(N, amplitude=1.0): | ||
return sps.detrend(amplitude * np.random.rand(N)) | ||
|
||
|
||
def noise(dims, shape): | ||
assert len(dims) == len(shape) | ||
coords = {d: (d, np.arange(n)) for d, n in zip(dims, shape)} | ||
data = np.random.rand(*shape) | ||
for n in range(len(shape)): | ||
data = sps.detrend(data, n) | ||
da = xr.DataArray(data, dims=dims, coords=coords) | ||
return da | ||
|
||
|
||
@pytest.mark.parametrize( | ||
"array_dims, array_shape, detrend_dim, chunks, linear_error", | ||
( | ||
(["x"], [16], "x", None, None), | ||
(["y", "x"], [32, 16], "x", None, None), | ||
(["y", "x"], [32, 16], "x", {"y": 4}, None), | ||
(["y", "x"], [32, 16], "y", None, None), | ||
(["y", "x"], [32, 16], "y", {"x": 4}, None), | ||
(["time", "y", "x"], [4, 32, 16], "x", None, None), | ||
(["time", "y", "x"], [4, 32, 16], "x", {"y": 4}, None), | ||
(["time", "y", "x"], [4, 32, 16], "x", {"time": 1, "y": 4}, None), | ||
# error cases for linear detrending | ||
(["x"], [16], "x", {"x": 1}, ValueError), | ||
(["y", "x"], [32, 16], "x", {"x": 4}, ValueError), | ||
), | ||
) | ||
@pytest.mark.parametrize("detrend_type", [None, "constant", "linear"]) | ||
@pytest.mark.parametrize("trend_amplitude", [0.01, 100]) | ||
def test_detrend_1D( | ||
array_dims, | ||
array_shape, | ||
detrend_dim, | ||
chunks, | ||
detrend_type, | ||
trend_amplitude, | ||
linear_error, | ||
): | ||
da_original = noise(array_dims, array_shape) | ||
da_trend = da_original + trend_amplitude * da_original[detrend_dim] | ||
if chunks: | ||
da_trend = da_trend.chunk(chunks) | ||
|
||
# bail out if we are expecting an error | ||
if detrend_type == "linear" and linear_error: | ||
with pytest.raises(linear_error): | ||
detrend(da_trend, detrend_dim, detrend_type=detrend_type) | ||
return | ||
|
||
detrended = detrend(da_trend, detrend_dim, detrend_type=detrend_type) | ||
assert detrended.chunks == da_trend.chunks | ||
if detrend_type is None: | ||
xrt.assert_equal(detrended, da_trend) | ||
elif detrend_type == "constant": | ||
xrt.assert_allclose(detrended, da_trend - da_trend.mean(dim=detrend_dim)) | ||
elif detrend_type == "linear": | ||
xrt.assert_allclose(detrended, da_original) | ||
|
||
|
||
# always detrend on x y dims | ||
@pytest.mark.parametrize( | ||
"array_dims, array_shape, chunks", | ||
( | ||
(["y", "x"], [32, 16], None), | ||
(["z", "y", "x"], [2, 32, 16], None), | ||
(["z", "y", "x"], [2, 32, 16], {"z": 1}), | ||
), | ||
) | ||
@pytest.mark.parametrize("detrend_type", [None, "constant", "linear"]) | ||
@pytest.mark.parametrize( | ||
"trend_amplitude", [{"x": 0.1, "y": 0.1}, {"x": 10.0, "y": 0.01}] | ||
) | ||
def test_detrend_2D(array_dims, array_shape, chunks, detrend_type, trend_amplitude): | ||
da_original = noise(array_dims, array_shape) | ||
da_trend = ( | ||
da_original | ||
+ trend_amplitude["x"] * da_original["x"] | ||
+ trend_amplitude["y"] * da_original["y"] | ||
) | ||
if chunks: | ||
da_trend = da_trend.chunk(chunks) | ||
|
||
detrend_dim = ["y", "x"] | ||
detrended = detrend(da_trend, detrend_dim, detrend_type=detrend_type) | ||
assert detrended.chunks == da_trend.chunks | ||
if detrend_type is None: | ||
xrt.assert_equal(detrended, da_trend) | ||
elif detrend_type == "constant": | ||
xrt.assert_allclose(detrended, da_trend - da_trend.mean(dim=detrend_dim)) | ||
elif detrend_type == "linear": | ||
xrt.assert_allclose(detrended, da_original) |
Oops, something went wrong.