From f4aca368868ba0ce2811380ea063a9ed2722923f Mon Sep 17 00:00:00 2001 From: Philip Chmielowiec Date: Fri, 22 Sep 2023 09:07:16 -0500 Subject: [PATCH] to_dataset --- docs/user_api/index.rst | 1 + test/test_dataarray.py | 13 +++++++++++++ uxarray/core/dataarray.py | 13 ++++++++++++- 3 files changed, 26 insertions(+), 1 deletion(-) diff --git a/docs/user_api/index.rst b/docs/user_api/index.rst index 623cf148a..ef6a5fc94 100644 --- a/docs/user_api/index.rst +++ b/docs/user_api/index.rst @@ -71,6 +71,7 @@ IO .. autosummary:: :toctree: _autosummary + UxDataArray.to_dataset UxDataArray.to_geodataframe UxDataArray.to_polycollection diff --git a/test/test_dataarray.py b/test/test_dataarray.py index 9fc482ed4..49b2fe08f 100644 --- a/test/test_dataarray.py +++ b/test/test_dataarray.py @@ -7,6 +7,8 @@ from uxarray.grid.geometry import _build_polygon_shells, _build_corrected_polygon_shells +from uxarray.core.dataset import UxDataset + current_path = Path(os.path.dirname(os.path.realpath(__file__))) gridfile_ne30 = current_path / "meshfiles" / "ugrid" / "outCSne30" / "outCSne30.ug" @@ -19,6 +21,17 @@ dsfile_v1_geoflow = current_path / "meshfiles" / "ugrid" / "geoflow-small" / "v1.nc" +class TestDataArray(TestCase): + + def test_to_dataset(self): + """Tests the conversion of UxDataArrays to a UXDataset.""" + uxds = ux.open_dataset(gridfile_ne30, dsfile_var2_ne30) + uxds_converted = uxds['psi'].to_dataset() + + assert isinstance(uxds_converted, UxDataset) + assert uxds_converted.uxgrid == uxds.uxgrid + + class TestGeometryConversions(TestCase): def test_to_geodataframe(self): diff --git a/uxarray/core/dataarray.py b/uxarray/core/dataarray.py index 9629d8d25..03b3532ea 100644 --- a/uxarray/core/dataarray.py +++ b/uxarray/core/dataarray.py @@ -1,9 +1,15 @@ +from __future__ import annotations import xarray as xr import numpy as np -from typing import Optional +from typing import Optional, TYPE_CHECKING from uxarray.grid import Grid +import uxarray.core.dataset + +if TYPE_CHECKING: + from uxarray.core.dataarray import UxDataArray + from uxarray.core.dataset import UxDataset class UxDataArray(xr.DataArray): @@ -97,6 +103,11 @@ def uxgrid(self): def uxgrid(self, ugrid_obj): self._uxgrid = ugrid_obj + def to_dataset(self) -> UxDataset: + """Convert a UxDataArray to a UxDataset.""" + xrds = super().to_dataset() + return uxarray.core.dataset.UxDataset(xrds, uxgrid=self.uxgrid) + def to_geodataframe(self, override=False, cache=True,