From 691999b59e9ca7c8fecec0f041c7b1049e4730d1 Mon Sep 17 00:00:00 2001 From: Philip Chmielowiec <67855069+philipc2@users.noreply.github.com> Date: Thu, 10 Oct 2024 14:54:26 -0500 Subject: [PATCH] Selection for GeoDataFrame `engine` in plotting routines (#987) * add engine, fix project * add test case and comments * initial updates to poly and lc * fix comment and update tests * update to_geodataframe docstring * update edge plot * add default clabel for edge plot * update docstrings in geometry functions * update to_geodataframe docstring to warn about split polygon projections * remove unused parameter * update call after removed unused argument * remove commented out bit --- test/test_plot.py | 42 ++++++++++++++++++++------------------- uxarray/core/dataarray.py | 9 +++++---- uxarray/grid/geometry.py | 35 ++++++++++++++++++-------------- uxarray/grid/grid.py | 17 +++++++--------- uxarray/plot/accessor.py | 36 ++++++++++++++++++++++++++------- 5 files changed, 83 insertions(+), 56 deletions(-) diff --git a/test/test_plot.py b/test/test_plot.py index dc5d8a39b..c04fea9b1 100644 --- a/test/test_plot.py +++ b/test/test_plot.py @@ -1,5 +1,7 @@ import os import uxarray as ux +import holoviews as hv + from unittest import TestCase from pathlib import Path @@ -44,15 +46,9 @@ def test_face_centered_data(self): uxds = ux.open_dataset(gridfile_mpas, gridfile_mpas) for backend in ['matplotlib', 'bokeh']: - - uxds['bottomDepth'].plot(backend=backend) - - uxds['bottomDepth'].plot.polygons(backend=backend) - - uxds['bottomDepth'].plot.points(backend=backend) - - uxds['bottomDepth'].plot.rasterize(method='polygon', - backend=backend) + assert(isinstance(uxds['bottomDepth'].plot(backend=backend), hv.DynamicMap)) + assert(isinstance(uxds['bottomDepth'].plot.polygons(backend=backend), hv.DynamicMap)) + assert(isinstance(uxds['bottomDepth'].plot.points(backend=backend), hv.Points)) def test_face_centered_remapped_dim(self): """Tests execution of plotting method on a data variable whose @@ -60,14 +56,10 @@ def test_face_centered_remapped_dim(self): uxds = ux.open_dataset(gridfile_ne30, datafile_ne30) for backend in ['matplotlib', 'bokeh']: + assert(isinstance(uxds['psi'].plot(backend=backend), hv.DynamicMap)) + assert(isinstance(uxds['psi'].plot.polygons(backend=backend), hv.DynamicMap)) + assert(isinstance(uxds['psi'].plot.points(backend=backend), hv.Points)) - uxds['psi'].plot(backend=backend) - - uxds['psi'].plot.polygons(backend=backend) - - uxds['psi'].plot.points(backend=backend) - - uxds['psi'].plot.rasterize(method='polygon', backend=backend) def test_node_centered_data(self): """Tests execution of plotting methods on node-centered data.""" @@ -75,11 +67,12 @@ def test_node_centered_data(self): uxds = ux.open_dataset(gridfile_geoflow, datafile_geoflow) for backend in ['matplotlib', 'bokeh']: - uxds['v1'][0][0].plot(backend=backend) + assert(isinstance(uxds['v1'][0][0].plot(backend=backend), hv.Points)) - uxds['v1'][0][0].plot.points(backend=backend) + assert(isinstance(uxds['v1'][0][0].plot.points(backend=backend), hv.Points)) + + assert(isinstance(uxds['v1'][0][0].topological_mean(destination='face').plot.polygons(backend=backend), hv.DynamicMap)) - uxds['v1'][0][0].topological_mean(destination='face').plot.polygons(backend=backend) def test_clabel(self): @@ -88,9 +81,18 @@ def test_clabel(self): uxds = ux.open_dataset(gridfile_geoflow, datafile_geoflow) raster_no_clabel = uxds['v1'][0][0].plot.rasterize(method='point') - raster_with_clabel = uxds['v1'][0][0].plot.rasterize(method='point', clabel='Foo') + def test_engine(self): + uxds = ux.open_dataset(gridfile_mpas, gridfile_mpas) + _plot_sp = uxds['bottomDepth'].plot.polygons(rasterize=True, engine='spatialpandas') + _plot_gp = uxds['bottomDepth'].plot.polygons(rasterize=True, engine='geopandas') + + assert isinstance(_plot_sp, hv.DynamicMap) + assert isinstance(_plot_gp, hv.DynamicMap) + + + class TestXarrayMethods(TestCase): def test_dataset(self): diff --git a/uxarray/core/dataarray.py b/uxarray/core/dataarray.py index a038ca5a0..1cb63b070 100644 --- a/uxarray/core/dataarray.py +++ b/uxarray/core/dataarray.py @@ -154,11 +154,11 @@ def to_geodataframe( self, periodic_elements: Optional[str] = "exclude", projection: Optional[ccrs.Projection] = None, - project: Optional[bool] = False, cache: Optional[bool] = True, override: Optional[bool] = False, engine: Optional[str] = "spatialpandas", exclude_antimeridian: Optional[bool] = None, + **kwargs, ): """Constructs a ``GeoDataFrame`` consisting of polygons representing the faces of the current ``Grid`` with a face-centered data variable @@ -178,7 +178,8 @@ def to_geodataframe( - 'split': Periodic elements will be identified and split using the ``antimeridian`` package - 'ignore': No processing will be applied to periodic elements. projection: ccrs.Projection, optional - Geographic projection used to transform polygons + Geographic projection used to transform polygons. Only supported when periodic_elements is set to + 'ignore' or 'exclude' cache: bool, optional Flag used to select whether to cache the computed GeoDataFrame override: bool, optional @@ -191,7 +192,7 @@ def to_geodataframe( Returns ------- - gdf : spatialpandas.GeoDataFrame + gdf : spatialpandas.GeoDataFrame or geopandas.GeoDataFrame The output ``GeoDataFrame`` with a filled out "geometry" column of polygons and a data column with the same name as the ``UxDataArray`` (or named ``var`` if no name exists) """ @@ -207,7 +208,7 @@ def to_geodataframe( gdf, non_nan_polygon_indices = self.uxgrid.to_geodataframe( periodic_elements=periodic_elements, projection=projection, - project=project, + project=kwargs.get("project", True), cache=cache, override=override, exclude_antimeridian=exclude_antimeridian, diff --git a/uxarray/grid/geometry.py b/uxarray/grid/geometry.py index bb2f68f43..09f5e50b0 100644 --- a/uxarray/grid/geometry.py +++ b/uxarray/grid/geometry.py @@ -162,7 +162,7 @@ def _correct_central_longitude(node_lon, node_lat, projection): def _grid_to_polygon_geodataframe(grid, periodic_elements, projection, project, engine): """Converts the faces of a ``Grid`` into a ``spatialpandas.GeoDataFrame`` - with a geometry column of polygons.""" + or ``geopandas.GeoDataFrame`` with a geometry column of polygons.""" node_lon, node_lat, central_longitude = _correct_central_longitude( grid.node_lon.values, grid.node_lat.values, projection @@ -214,9 +214,8 @@ def _grid_to_polygon_geodataframe(grid, periodic_elements, projection, project, gdf = _build_geodataframe_with_antimeridian( polygon_shells, projected_polygon_shells, - projection, antimeridian_face_indices, - engine=geopandas, + engine=engine, ) elif periodic_elements == "ignore": if engine == "geopandas": @@ -248,8 +247,9 @@ def _grid_to_polygon_geodataframe(grid, periodic_elements, projection, project, def _build_geodataframe_without_antimeridian( polygon_shells, projected_polygon_shells, antimeridian_face_indices, engine ): - """Builds a ``spatialpandas.GeoDataFrame`` excluding any faces that cross - the antimeridian.""" + """Builds a ``spatialpandas.GeoDataFrame`` or + ``geopandas.GeoDataFrame``excluding any faces that cross the + antimeridian.""" if projected_polygon_shells is not None: # use projected shells if a projection is applied shells_without_antimeridian = np.delete( @@ -276,12 +276,11 @@ def _build_geodataframe_without_antimeridian( def _build_geodataframe_with_antimeridian( polygon_shells, projected_polygon_shells, - projection, antimeridian_face_indices, engine, ): - """Builds a ``spatialpandas.GeoDataFrame`` including any faces that cross - the antimeridian.""" + """Builds a ``spatialpandas.GeoDataFrame`` or ``geopandas.GeoDataFrame`` + including any faces that cross the antimeridian.""" polygons = _build_corrected_shapely_polygons( polygon_shells, projected_polygon_shells, antimeridian_face_indices ) @@ -425,7 +424,8 @@ def _grid_to_matplotlib_polycollection( # Handle unsupported configuration: splitting periodic elements with projection if periodic_elements == "split" and projection is not None: raise ValueError( - "Projections are not supported when splitting periodic elements.'" + "Explicitly projecting lines is not supported. Please pass in your projection" + "using the 'transform' parameter" ) # Correct the central longitude and build polygon shells @@ -533,7 +533,7 @@ def _grid_to_matplotlib_polycollection( return PolyCollection(polygon_shells, **kwargs), [] -def _get_polygons(grid, periodic_elements, projection=None): +def _get_polygons(grid, periodic_elements, projection=None, apply_projection=True): # Correct the central longitude if projection is provided node_lon, node_lat, central_longitude = _correct_central_longitude( grid.node_lon.values, grid.node_lat.values, projection @@ -552,7 +552,7 @@ def _get_polygons(grid, periodic_elements, projection=None): ) # If projection is provided, create the projected polygon shells - if projection: + if projection and apply_projection: projected_polygon_shells = _build_polygon_shells( node_lon, node_lat, @@ -625,8 +625,14 @@ def _grid_to_matplotlib_linecollection( ): """Constructs and returns a ``matplotlib.collections.LineCollection``""" + if periodic_elements == "split" and projection is not None: + apply_projection = False + else: + apply_projection = True + + # do not explicitly project when splitting elements polygons, central_longitude, _, _ = _get_polygons( - grid, periodic_elements, projection + grid, periodic_elements, projection, apply_projection ) # Convert polygons to line segments for the LineCollection @@ -639,14 +645,13 @@ def _grid_to_matplotlib_linecollection( else: lines.append(np.array(boundary.coords)) - # Set default transform if not provided if "transform" not in kwargs: - if projection is None: + # Set default transform if one is not provided not provided + if projection is None or not apply_projection: kwargs["transform"] = ccrs.PlateCarree(central_longitude=central_longitude) else: kwargs["transform"] = projection - # Return a LineCollection of the line segments return LineCollection(lines, **kwargs) diff --git a/uxarray/grid/grid.py b/uxarray/grid/grid.py index 211aa7949..c90b97786 100644 --- a/uxarray/grid/grid.py +++ b/uxarray/grid/grid.py @@ -1635,13 +1635,13 @@ def to_geodataframe( self, periodic_elements: Optional[str] = "exclude", projection: Optional[ccrs.Projection] = None, - project: Optional[bool] = False, cache: Optional[bool] = True, override: Optional[bool] = False, engine: Optional[str] = "spatialpandas", exclude_antimeridian: Optional[bool] = None, return_non_nan_polygon_indices: Optional[bool] = False, exclude_nan_polygons: Optional[bool] = True, + **kwargs, ): """Constructs a ``GeoDataFrame`` consisting of polygons representing the faces of the current ``Grid`` @@ -1661,7 +1661,8 @@ def to_geodataframe( - 'split': Periodic elements will be identified and split using the ``antimeridian`` package - 'ignore': No processing will be applied to periodic elements. projection: ccrs.Projection, optional - Geographic projection used to transform polygons + Geographic projection used to transform polygons. Only supported when periodic_elements is set to + 'ignore' or 'exclude' cache: bool, optional Flag used to select whether to cache the computed GeoDataFrame override: bool, optional @@ -1679,7 +1680,7 @@ def to_geodataframe( Returns ------- - gdf : spatialpandas.GeoDataFrame + gdf : spatialpandas.GeoDataFrame or geopandas.GeoDataFrame The output ``GeoDataFrame`` with a filled out "geometry" column of polygons. """ @@ -1688,6 +1689,9 @@ def to_geodataframe( f"Invalid engine. Expected one of ['spatialpandas', 'geopandas'] but received {engine}" ) + # if project is false, projection is only used for determining central coordinates + project = kwargs.get("project", True) + if projection and project: if periodic_elements == "split": raise ValueError( @@ -1871,13 +1875,6 @@ def to_linecollection( f"Invalid value for 'periodic_elements'. Expected one of ['ignore', 'exclude', 'split'] but received: {periodic_elements}" ) - if projection is not None: - if periodic_elements == "split": - raise ValueError( - "Setting ``periodic_elements='split'`` is not supported when a " - "projection is provided." - ) - if self._line_collection_cached_parameters["line_collection"] is not None: if ( self._line_collection_cached_parameters["periodic_elements"] diff --git a/uxarray/plot/accessor.py b/uxarray/plot/accessor.py index bfaa3c16e..8f4938f5a 100644 --- a/uxarray/plot/accessor.py +++ b/uxarray/plot/accessor.py @@ -163,7 +163,13 @@ def face_centers(self, backend=None, **kwargs): face_centers.__doc__ = face_coords.__doc__ - def edges(self, periodic_elements="exclude", backend=None, **kwargs): + def edges( + self, + periodic_elements="exclude", + backend=None, + engine="spatialpandas", + **kwargs, + ): """Plots the edges of a Grid. This function plots the edges of the grid as geographical paths using `hvplot`. @@ -182,6 +188,8 @@ def edges(self, periodic_elements="exclude", backend=None, **kwargs): - "split": Split periodic elements. backend : str or None, optional Plotting backend to use. One of ['matplotlib', 'bokeh']. Equivalent to running holoviews.extension(backend) + engine: str, optional + Engine to use for GeoDataFrame construction. One of ['spatialpandas', 'geopandas'] **kwargs : dict Additional keyword arguments passed to `hvplot.paths`. These can include: - "rasterize" (bool): Whether to rasterize the plot (default: False), @@ -195,7 +203,6 @@ def edges(self, periodic_elements="exclude", backend=None, **kwargs): gdf.hvplot.paths : hvplot.paths A paths plot of the edges of the unstructured grid """ - uxarray.plot.utils.backend.assign(backend) if "rasterize" not in kwargs: @@ -212,8 +219,11 @@ def edges(self, periodic_elements="exclude", backend=None, **kwargs): kwargs["crs"] = ccrs.PlateCarree(central_longitude=central_longitude) gdf = self._uxgrid.to_geodataframe( - periodic_elements=periodic_elements, projection=kwargs.get("projection") - )[["geometry"]] + periodic_elements=periodic_elements, + projection=kwargs.get("projection"), + engine=engine, + project=False, + ) return gdf.hvplot.paths(geo=True, **kwargs) @@ -260,8 +270,15 @@ def __getattr__(self, name: str) -> Any: else: raise AttributeError(f"Unsupported Plotting Method: '{name}'") - def polygons(self, periodic_elements="exclude", backend=None, *args, **kwargs): - """Generate a shaded polygon plot of a face-centered data variable. + def polygons( + self, + periodic_elements="exclude", + backend=None, + engine="spatialpandas", + *args, + **kwargs, + ): + """Generated a shaded polygon plot. This function plots the faces of an unstructured grid shaded with a face-centered data variable using hvplot. It allows for rasterization, projection settings, and labeling of the data variable to be @@ -278,6 +295,8 @@ def polygons(self, periodic_elements="exclude", backend=None, *args, **kwargs): - "ignore": Include periodic elements without any corrections backend : str or None, optional Plotting backend to use. One of ['matplotlib', 'bokeh']. Equivalent to running holoviews.extension(backend) + engine: str, optional + Engine to use for GeoDataFrame construction. One of ['spatialpandas', 'geopandas'] *args : tuple Additional positional arguments to be passed to `hvplot.polygons`. **kwargs : dict @@ -309,7 +328,10 @@ def polygons(self, periodic_elements="exclude", backend=None, *args, **kwargs): kwargs["crs"] = ccrs.PlateCarree(central_longitude=central_longitude) gdf = self._uxda.to_geodataframe( - periodic_elements=periodic_elements, projection=kwargs.get("projection") + periodic_elements=periodic_elements, + projection=kwargs.get("projection"), + engine=engine, + project=False, ) return gdf.hvplot.polygons(