Skip to content

Commit

Permalink
fix mock
Browse files Browse the repository at this point in the history
  • Loading branch information
malmans2 committed Apr 19, 2024
2 parents 4f06e98 + 356e280 commit 109c936
Show file tree
Hide file tree
Showing 10 changed files with 46 additions and 35 deletions.
2 changes: 1 addition & 1 deletion .cruft.json
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
{
"template": "https://github.com/ecmwf-projects/cookiecutter-conda-package",
"commit": "54a28e5611040c398c74454de4e6147c79cb5a39",
"commit": "280417918dc86a2f2be417c8877a590c07cf2a1c",
"checkout": null,
"context": {
"cookiecutter": {
Expand Down
4 changes: 2 additions & 2 deletions .github/workflows/on-push.yml
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ jobs:
- uses: actions/setup-python@v5
with:
python-version: 3.x
- uses: pre-commit/[email protected].0
- uses: pre-commit/[email protected].1

combine-environments:
runs-on: ubuntu-latest
Expand Down Expand Up @@ -230,6 +230,6 @@ jobs:
with:
name: distribution
path: dist
- uses: pypa/[email protected].11
- uses: pypa/[email protected].14
with:
verbose: true
6 changes: 3 additions & 3 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
repos:
- repo: https://github.com/pre-commit/pre-commit-hooks
rev: v4.5.0
rev: v4.6.0
hooks:
- id: trailing-whitespace
- id: end-of-file-fixer
Expand All @@ -17,7 +17,7 @@ repos:
- id: blackdoc
additional_dependencies: [black==23.11.0]
- repo: https://github.com/astral-sh/ruff-pre-commit
rev: v0.2.0
rev: v0.3.7
hooks:
- id: ruff
args: [--fix, --show-fixes]
Expand All @@ -27,7 +27,7 @@ repos:
hooks:
- id: mdformat
- repo: https://github.com/macisamuele/language-formatters-pre-commit-hooks
rev: v2.12.0
rev: v2.13.0
hooks:
- id: pretty-format-yaml
args: [--autofix, --preserve-quotes]
Expand Down
10 changes: 5 additions & 5 deletions c3s_eqc_automatic_quality_control/_time_weighted.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,9 @@ def groupby_reduce(
obj = self.obj
if isinstance(self.obj_weighted, (DataArrayWeighted | DatasetWeighted)):
obj = obj.assign_coords(__weights__=self.obj_weighted.weights)
return obj.groupby(group).map(map_func, (time_name, func_name), **kwargs)
return obj.groupby(group, squeeze=False).map(
map_func, (time_name, func_name), **kwargs
)

@utils.keep_attrs
def reduce(
Expand Down Expand Up @@ -143,8 +145,7 @@ def map_func(
time_name: str,
func: str,
**kwargs: Any,
) -> xr.DataArray:
...
) -> xr.DataArray: ...


@overload
Expand All @@ -153,8 +154,7 @@ def map_func(
time_name: str,
func: str,
**kwargs: Any,
) -> xr.Dataset:
...
) -> xr.Dataset: ...


def map_func(
Expand Down
17 changes: 10 additions & 7 deletions c3s_eqc_automatic_quality_control/plot.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
import plotly.graph_objs as go
import xarray as xr
from cartopy.mpl.geocollection import GeoQuadMesh
from cartopy.mpl.gridliner import Gridliner
from matplotlib.typing import ColorType
from xarray.plot.facetgrid import FacetGrid

Expand Down Expand Up @@ -84,9 +85,9 @@ def shaded_std(
colors = iter(colors)

if hue_dim:
_, means = zip(*ds_mean.groupby(hue_dim))
_, means = zip(*ds_mean.groupby(hue_dim, squeeze=False))
if ds_std:
_, stds = zip(*ds_std.groupby(hue_dim))
_, stds = zip(*ds_std.groupby(hue_dim, squeeze=False))
else:
stds = tuple(xr.Dataset() for _ in range(len(means)))
else:
Expand Down Expand Up @@ -218,11 +219,13 @@ def projected_map(
ax.gridlines(draw_labels=False)

for ax in plot_obj.axs[-1, :]:
for gl in ax._gridliners:
gridliners = [a for a in ax.artists if isinstance(a, Gridliner)]
for gl in gridliners:
gl.bottom_labels = True

for ax in plot_obj.axs[:, 0]:
for gl in ax._gridliners:
gridliners = [a for a in ax.artists if isinstance(a, Gridliner)]
for gl in gridliners:
gl.left_labels = True

if show_stats:
Expand All @@ -231,7 +234,8 @@ def projected_map(
)
else:
plot_obj.axes.coastlines()
if not getattr(plot_obj.axes, "_gridliners", []):
gridliners = [a for a in plot_obj.axes.artists if isinstance(a, Gridliner)]
if not gridliners:
gl = plot_obj.axes.gridlines(draw_labels=True)
gl.top_labels = gl.right_labels = False

Expand All @@ -250,7 +254,7 @@ def projected_map(
txt = "\n".join(
[
f"{k:>{n_characters}}: {v.squeeze().values:+e}{units}"
for k, v in da_stats.groupby("diagnostic")
for k, v in da_stats.groupby("diagnostic", squeeze=False)
]
)
plt.figtext(
Expand Down Expand Up @@ -454,5 +458,4 @@ def seasonal_boxplot(

da = da.stack(stacked_dim=da.dims)
df = da.to_dataframe()

return df.groupby(by=da[time_dim].dt.season.values).boxplot(**kwargs)
2 changes: 1 addition & 1 deletion environment.yml
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@ dependencies:
- cf-units
- cf_xarray
- dask
- earthkit-data
- fsspec
- geopandas
- joblib
Expand Down Expand Up @@ -43,4 +42,5 @@ dependencies:
- pip:
- cacholote
- cgul
- earthkit-data>=0.7.0
- kaleido
22 changes: 18 additions & 4 deletions tests/test_10_download.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,17 +21,31 @@
)


class MockResult:
def __init__(self, name: str, request: dict[str, Any]) -> None:
self.name = name
self.request = request

@property
def location(self) -> str:
return tempfile.NamedTemporaryFile(suffix=".nc", delete=False).name

def download(self, target: str | pathlib.Path | None = None) -> str | pathlib.Path:
ds = xr.tutorial.open_dataset(self.name).sel(**self.request)
ds.to_netcdf(path := target or self.location)
return path


def mock_retrieve(
self: cdsapi.Client,
name: str,
request: dict[str, Any],
target: str | pathlib.Path | None = None,
) -> fsspec.spec.AbstractBufferedFile:
ds = xr.tutorial.open_dataset(name).sel(**request)
result = MockResult(name, request)
if target is None:
target = tempfile.NamedTemporaryFile(suffix=".nc", delete=False).name
ds.to_netcdf(target)
return target
return result
return result.download(target)


@pytest.mark.parametrize(
Expand Down
2 changes: 1 addition & 1 deletion tests/test_20_diagnostics.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ def test_regrid(self, obj: xr.DataArray | xr.Dataset) -> None:
fs, dirname = cacholote.utils.get_cache_files_fs_dirname()
assert fs.ls(dirname) == [] # cache is empty

for _, obj in obj.isel(time=slice(2)).groupby("time"):
for _, obj in obj.isel(time=slice(2)).groupby("time", squeeze=False):
expected = obj.isel(longitude=slice(10), latitude=slice(10))
actual = diagnostics.regrid(obj, expected, "nearest_s2d")
xr.testing.assert_equal(actual, expected)
Expand Down
12 changes: 4 additions & 8 deletions tests/test_21_time_diagnostics.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,13 +9,11 @@


@overload
def weighted_mean(obj: xr.DataArray) -> xr.DataArray:
...
def weighted_mean(obj: xr.DataArray) -> xr.DataArray: ...


@overload
def weighted_mean(obj: xr.Dataset) -> xr.Dataset:
...
def weighted_mean(obj: xr.Dataset) -> xr.Dataset: ...


def weighted_mean(obj: xr.DataArray | xr.Dataset) -> xr.DataArray | xr.Dataset:
Expand All @@ -25,13 +23,11 @@ def weighted_mean(obj: xr.DataArray | xr.Dataset) -> xr.DataArray | xr.Dataset:


@overload
def weighted_std(obj: xr.DataArray) -> xr.DataArray:
...
def weighted_std(obj: xr.DataArray) -> xr.DataArray: ...


@overload
def weighted_std(obj: xr.Dataset) -> xr.Dataset:
...
def weighted_std(obj: xr.Dataset) -> xr.Dataset: ...


def weighted_std(obj: xr.DataArray | xr.Dataset) -> xr.DataArray | xr.Dataset:
Expand Down
4 changes: 1 addition & 3 deletions tests/test_22_spatial_diagnostics.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,8 +115,6 @@ def test_spatial_weighted_rmse_against_sklearn() -> None:

da1 = da
da2 = da**2
expected = sklearn.metrics.mean_squared_error(
da1, da2, sample_weight=weights, squared=False
)
expected = sklearn.metrics.root_mean_squared_error(da1, da2, sample_weight=weights)
actual = diagnostics.spatial_weighted_rmse(da1, da2)
assert expected == actual.values

0 comments on commit 109c936

Please sign in to comment.