Skip to content

Commit

Permalink
Add some tests for interpolation methods
Browse files Browse the repository at this point in the history
  • Loading branch information
VeckoTheGecko committed Jan 8, 2025
1 parent 838f8cf commit 0ab5bd3
Showing 1 changed file with 57 additions and 19 deletions.
76 changes: 57 additions & 19 deletions tests/test_interpolation.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,31 @@
import parcels._interpolation as interpolation




@pytest.fixture
def tmp_interpolator_registry():
"""Resets the interpolator registry after the test. Vital when testing manipulating the registry."""
old_2d = interpolation.interpolator_registry_2d.copy()
old_3d = interpolation.interpolator_registry_3d.copy()
yield
interpolation.interpolator_registry_2d = old_2d
interpolation.interpolator_registry_3d = old_3d


def test_interpolation_registry(tmp_interpolator_registry):
@interpolation.register_3d_interpolator("test")
@interpolation.register_2d_interpolator("test")
def some_function():
return "test"

assert "test" in interpolation.interpolator_registry_2d
assert "test" in interpolation.interpolator_registry_3d

f = interpolation.interpolator_registry_2d["test"]
g = interpolation.interpolator_registry_3d["test"]
assert f() == g() == "test"

def create_interpolation_data():
"""Reference data used for testing interpolation.
Expand Down Expand Up @@ -35,25 +60,38 @@ def data_3d():
return create_interpolation_data().values


@pytest.fixture
def tmp_interpolator_registry():
"""Resets the interpolator registry after the test. Vital when testing manipulating the registry."""
old_2d = interpolation.interpolator_registry_2d.copy()
old_3d = interpolation.interpolator_registry_3d.copy()
yield
interpolation.interpolator_registry_2d = old_2d
interpolation.interpolator_registry_3d = old_3d


def test_interpolation_registry(tmp_interpolator_registry):
@interpolation.register_3d_interpolator("test")
@interpolation.register_2d_interpolator("test")
def some_function():
return "test"
class TestInterpolationMethods:
ti = 0
zi, yi, xi = 1, 1, 1

assert "test" in interpolation.interpolator_registry_2d
assert "test" in interpolation.interpolator_registry_3d
@pytest.mark.parametrize(
"func, eta, xi, expected",
[
pytest.param(interpolation._nearest_2d, 0.49, 0.49, 3.0, id="nearest_2d-1"),
pytest.param(interpolation._nearest_2d, 0.49, 0.51, 4.0, id="nearest_2d-2"),
pytest.param(interpolation._nearest_2d, 0.51, 0.49, 5.0, id="nearest_2d-3"),
pytest.param(interpolation._nearest_2d, 0.51, 0.51, 6.0, id="nearest_2d-4"),
pytest.param(interpolation._tracer_2d, None, None, 6.0, id="tracer_2d"),
# pytest.param(interpolation._linear_2d, ...),
# pytest.param(interpolation._linear_invdist_land_tracer_2d, ...),
],
)
def test_2d(self, data_2d, func, eta, xi, expected):
ctx = interpolation.InterpolationContext2D(data_2d, eta, xi, self.ti, self.yi, self.xi)
assert func(ctx) == expected

f = interpolation.interpolator_registry_2d["test"]
g = interpolation.interpolator_registry_3d["test"]
assert f() == g() == "test"
@pytest.mark.parametrize(
"func, eta, xi, expected",
[
# pytest.param(interpolation._nearest_3d, ...),
# pytest.param(interpolation._cgrid_velocity_3d, ...),
# pytest.param(interpolation._linear_invdist_land_tracer_3d, ...),
# pytest.param(interpolation._linear_3d, ...),
# pytest.param(interpolation._tracer_3d, ...),
],
)
def test_3d(self, data_3d, func, zeta, eta, xi, expected):
ctx = interpolation.InterpolationContext3D(data_2d, zeta, eta, xi, self.ti, self.zi, self.yi, self.xi)
assert func(ctx) == expected

0 comments on commit 0ab5bd3

Please sign in to comment.