Skip to content

Commit

Permalink
Merge pull request #85 from rom-py/84-remove-dependency-on-np1darray-…
Browse files Browse the repository at this point in the history
…np2darray

Changed basegrid object to not take x, y and inputs and instead infer…
  • Loading branch information
rafa-guedes authored Aug 13, 2024
2 parents d1d5c79 + c3eaecd commit 69653c1
Show file tree
Hide file tree
Showing 6 changed files with 78 additions and 59 deletions.
3 changes: 1 addition & 2 deletions rompy/core/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,8 @@

import logging
import os
from abc import ABC, abstractmethod
from datetime import timedelta
from functools import cached_property
from abc import ABC, abstractmethod
from pathlib import Path
from shutil import copytree
from typing import Literal, Optional, Union
Expand Down
36 changes: 18 additions & 18 deletions rompy/core/grid.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@
import matplotlib.pyplot as plt
import numpy as np
from pydantic import Field, model_validator
from pydantic_numpy.typing import Np1DArray, Np2DArray
from shapely.geometry import MultiPoint, Polygon

from rompy.core.types import Bbox, RompyBaseModel
Expand All @@ -24,14 +23,16 @@ class BaseGrid(RompyBaseModel):
"""

x: Optional[Union[Np1DArray, Np2DArray]] = Field(
default=None, description="The x coordinates"
)
y: Optional[Union[Np1DArray, Np2DArray]] = Field(
default=None, description="The y coordinates"
)
grid_type: Literal["base"] = "base"

@property
def x(self) -> np.ndarray:
raise NotImplementedError

@property
def y(self) -> np.ndarray:
raise NotImplementedError

@property
def minx(self) -> float:
return np.nanmin(self.x)
Expand Down Expand Up @@ -221,21 +222,20 @@ class RegularGrid(BaseGrid):
def generate(self) -> "RegularGrid":
"""Generate the grid from the provided parameters."""
keys = ["x0", "y0", "dx", "dy", "nx", "ny"]
if self.x is not None and self.y is not None:
for key in keys:
if getattr(self, key) is not None:
logger.warning(f"x, y provided explicitly, can't process {key}")
self._attrs_from_xy()
elif None in [getattr(self, key) for key in keys]:
if None in [getattr(self, key) for key in keys]:
raise ValueError(f"All of {','.join(keys)} must be provided for REG grid")
# Ensure x, y 2D coordinates are defined
self._regen_grid()
return self

def _regen_grid(self):
_x, _y = self._gen_reg_cgrid()
self.x = _x
self.y = _y
@property
def x(self) -> np.ndarray:
x, y = self._gen_reg_cgrid()
return x

@property
def y(self) -> np.ndarray:
x, y = self._gen_reg_cgrid()
return y

def _attrs_from_xy(self):
"""Generate regular grid attributes from x, y coordinates."""
Expand Down
14 changes: 7 additions & 7 deletions rompy/schism/grid.py
Original file line number Diff line number Diff line change
Expand Up @@ -315,13 +315,13 @@ def vgrid_validator(cls, v, values):
v = VgridGenerator()
return v

@model_validator(mode="after")
def set_xy(cls, v):
if v.hgrid is not None:
v._pyschism_hgrid = Hgrid.open(v.hgrid._copied or v.hgrid.source, crs=v.crs)
v.x = v._pyschism_hgrid.x
v.y = v._pyschism_hgrid.y
return v
@property
def x(self) -> np.ndarray:
return self.pyschism_hgrid.x

@property
def y(self) -> np.ndarray:
return self.pyschism_hgrid.y

@property
def pyschism_hgrid(self):
Expand Down
39 changes: 28 additions & 11 deletions tests/test_basegrid.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,13 +5,31 @@
from rompy.core import BaseGrid, RegularGrid


class CustomGrid(BaseGrid):
xsize: int = 10
ysize: int = 10

def meshgrid(self):
x = np.arange(self.xsize)
y = np.arange(self.ysize)
xx, yy = np.meshgrid(x, y)
return xx, yy

@property
def x(self):
xx, yy = self.meshgrid()
return xx

@property
def y(self):
xx, yy = self.meshgrid()
return yy


# test class based on pytest fixtures
@pytest.fixture
def grid():
x = np.arange(10)
y = np.arange(10)
xx, yy = np.meshgrid(x, y)
return BaseGrid(x=xx, y=yy)
return CustomGrid()


@pytest.fixture
Expand All @@ -23,13 +41,12 @@ def regulargrid():


def test_regulargrid_xy(regulargrid):
grid = RegularGrid(x=regulargrid.x, y=regulargrid.y)
assert np.array_equal(grid.x, regulargrid.x)
assert np.array_equal(grid.y, regulargrid.y)
assert grid.nx == regulargrid.nx
assert grid.ny == regulargrid.ny
for attr in ["x0", "y0", "dx", "dy", "rot"]:
assert getattr(grid, attr) == pytest.approx(getattr(regulargrid, attr))
xx, yy = np.meshgrid(np.arange(10), np.arange(10))
assert np.array_equal(regulargrid.x, xx)
assert np.array_equal(regulargrid.y, yy)
assert regulargrid.nx == 10
assert regulargrid.ny == 10


def test_bbox(grid):
assert grid.bbox() == [0.0, 0.0, 9.0, 9.0]
Expand Down
38 changes: 22 additions & 16 deletions tests/test_data.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,18 @@
import os
import intake
from pathlib import Path
from pydantic import ValidationError

import intake
import numpy as np
import pandas as pd
import pytest
import xarray as xr
from pydantic import ValidationError

from rompy.core import DataBlob, DataGrid, RegularGrid, TimeRange
from rompy.core.data import (SourceDatamesh, SourceDataset, SourceFile,
SourceIntake)
from rompy.core.filters import Filter
from rompy.core.types import DatasetCoords
from rompy.core.data import SourceDataset, SourceFile, SourceIntake, SourceDatamesh
from rompy.core import BaseGrid, DataBlob, DataGrid, TimeRange


HERE = Path(__file__).parent
DATAMESH_TOKEN = os.environ.get("DATAMESH_TOKEN")
Expand All @@ -27,6 +28,11 @@ def txt_data_source(tmp_path):
return DataBlob(id="test", source=source)


@pytest.fixture
def grid():
return RegularGrid(x0=2, y0=3, dx=1, dy=1, nx=5, ny=4)


@pytest.fixture
def grid_data_source():
return DataGrid(
Expand Down Expand Up @@ -98,27 +104,24 @@ def test_netcdf_grid(nc_data_source):
assert data.ds.longitude.min() == 0


def test_grid_filter(nc_data_source):
grid = BaseGrid(x=np.arange(2, 7), y=np.arange(3, 7))
def test_grid_filter(nc_data_source, grid):
nc_data_source._filter_grid(grid)
assert nc_data_source.ds.latitude.max() == 6
assert nc_data_source.ds.latitude.min() == 3
assert nc_data_source.ds.longitude.max() == 6
assert nc_data_source.ds.longitude.min() == 2


def test_grid_filter_buffer(nc_data_source):
grid = BaseGrid(x=np.arange(3, 7), y=np.arange(3, 7))
def test_grid_filter_buffer(nc_data_source, grid):
nc_data_source.buffer = 1.0
nc_data_source._filter_grid(grid)
assert nc_data_source.ds.latitude.max() == 7
assert nc_data_source.ds.latitude.min() == 2
assert nc_data_source.ds.longitude.max() == 7
assert nc_data_source.ds.longitude.min() == 2
assert nc_data_source.ds.longitude.min() == 1


def test_time_filter(nc_data_source):
grid = BaseGrid(x=np.arange(3, 7), y=np.arange(3, 7))
def test_time_filter(nc_data_source, grid):
nc_data_source._filter_time(TimeRange(start="2000-01-02", end="2000-01-03"))
assert nc_data_source.ds.time.max() == np.datetime64("2000-01-03")
assert nc_data_source.ds.time.min() == np.datetime64("2000-01-02")
Expand Down Expand Up @@ -167,7 +170,7 @@ def test_source_intake_uri_or_yaml():

def test_intake_grid_plot(grid_data_source):
data = grid_data_source
data.plot(param='u10', isel={'time': 0})
data.plot(param="u10", isel={"time": 0})


@pytest.mark.skip(reason="This won't work with pydantic<2, fix once migrated")
Expand All @@ -179,6 +182,9 @@ def test_source_datamesh():
filters.crop.update(
dict(longitude=slice(115.5, 116.0), latitude=slice(-33.0, -32.5))
)
dset = dataset.open(variables=["u10"], filters=filters, coords=DatasetCoords(x="longitude", y="latitude"))
assert(isinstance(dset, xr.Dataset))

dset = dataset.open(
variables=["u10"],
filters=filters,
coords=DatasetCoords(x="longitude", y="latitude"),
)
assert isinstance(dset, xr.Dataset)
7 changes: 2 additions & 5 deletions tests/test_swangrid.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,7 @@
# test class based on pytest fixtures
@pytest.fixture
def grid():
x = np.arange(10)
y = np.arange(10)
xx, yy = np.meshgrid(x, y)
return SwanGrid(x=xx, y=yy)
return SwanGrid(x0=0, y0=0, nx=10, ny=10, dx=1, dy=1)


@pytest.fixture
Expand Down Expand Up @@ -80,4 +77,4 @@ def test_grid_from_component(grid):
my=grid.ny - 1,
)
grid2 = SwanGrid.from_component(regular_grid_component)
assert grid == grid2
assert grid == grid2

0 comments on commit 69653c1

Please sign in to comment.