Skip to content

Commit

Permalink
Merge pull request #518 from UXARRAY/zedwick/kdtree
Browse files Browse the repository at this point in the history
KDTree Data Structure for Center and Corner Nodes
  • Loading branch information
philipc2 authored Nov 2, 2023
2 parents c541e82 + 816ccce commit d5f61da
Show file tree
Hide file tree
Showing 4 changed files with 320 additions and 12 deletions.
10 changes: 10 additions & 0 deletions docs/user_api/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -165,6 +165,7 @@ Methods
Grid.compute_face_areas
Grid.encode_as
Grid.get_ball_tree
Grid.get_kd_tree
Grid.copy


Expand Down Expand Up @@ -229,6 +230,15 @@ UxDataArray Plotting Methods
Nearest Neighbor Data Structures
================================

KDTree
------
.. autosummary::
:toctree: _autosummary

grid.neighbors.KDTree
grid.neighbors.KDTree.query
grid.neighbors.KDTree.query_radius

BallTree
--------
.. autosummary::
Expand Down
43 changes: 33 additions & 10 deletions test/test_grid.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,7 @@

from uxarray.grid.connectivity import _populate_edge_node_connectivity, _populate_face_edge_connectivity, _build_edge_face_connectivity

from uxarray.grid.coordinates import _populate_cartesian_xyz_coord, _populate_lonlat_coord

from uxarray.grid.neighbors import BallTree
from uxarray.grid.coordinates import _populate_lonlat_coord

from uxarray.constants import INT_FILL_VALUE

Expand All @@ -38,7 +36,6 @@


class TestGrid(TestCase):

grid_CSne30 = ux.open_grid(gridfile_CSne30)
grid_RLL1deg = ux.open_grid(gridfile_RLL1deg)
grid_RLL10deg_CSne4 = ux.open_grid(gridfile_RLL10deg_CSne4)
Expand Down Expand Up @@ -282,7 +279,6 @@ def test_ne(self):


class TestFaceAreas(TestCase):

grid_CSne30 = ux.open_grid(gridfile_CSne30)

def test_calculate_total_face_area_triangle(self):
Expand All @@ -294,7 +290,7 @@ def test_calculate_total_face_area_triangle(self):

grid_verts = ux.open_grid(verts, latlon=False)

#calculate area
# calculate area
area_gaussian = grid_verts.calculate_total_face_area(
quadrature_rule="gaussian", order=5)
nt.assert_almost_equal(area_gaussian, constants.TRI_AREA, decimal=3)
Expand Down Expand Up @@ -406,7 +402,7 @@ def test_populate_cartesian_xyz_coord(self):
verts_degree = np.stack((lon_deg, lat_deg), axis=1)

vgrid = ux.open_grid(verts_degree, latlon=True)
#_populate_cartesian_xyz_coord(vgrid)
# _populate_cartesian_xyz_coord(vgrid)

for i in range(0, vgrid.nMesh2_node):
nt.assert_almost_equal(vgrid.Mesh2_node_cart_x.values[i],
Expand Down Expand Up @@ -879,14 +875,12 @@ def test_edge_face_connectivity_sample(self):


class TestClassMethods(TestCase):

gridfile_ugrid = current_path / "meshfiles" / "ugrid" / "geoflow-small" / "grid.nc"
gridfile_mpas = current_path / "meshfiles" / "mpas" / "QU" / "mesh.QU.1920km.151026.nc"
gridfile_exodus = current_path / "meshfiles" / "exodus" / "outCSne8" / "outCSne8.g"
gridfile_scrip = current_path / "meshfiles" / "scrip" / "outCSne8" / "outCSne8.nc"

def test_from_dataset(self):

# UGRID
xrds = xr.open_dataset(self.gridfile_ugrid)
uxgrid = ux.Grid.from_dataset(xrds)
Expand Down Expand Up @@ -918,7 +912,6 @@ def test_from_face_vertices(self):


class TestBallTree(TestCase):

corner_grid_files = [gridfile_CSne30, gridfile_mpas]
center_grid_files = [gridfile_mpas]

Expand Down Expand Up @@ -986,3 +979,33 @@ def test_antimeridian_distance_nodes(self):
def test_antimeridian_distance_face_centers(self):
"""TODO: Write addition tests once construction and representation of face centers is implemented."""
pass


class TestKDTree:
corner_grid_files = [gridfile_CSne30, gridfile_mpas]
center_grid_file = gridfile_mpas

def test_construction_from_nodes(self):
"""Test the KDTree creation and query function using the grids
nodes."""

for grid_file in self.corner_grid_files:
uxgrid = ux.open_grid(grid_file)
d, ind = uxgrid.get_kd_tree(tree_type="nodes").query(
[0.0, 0.0, 1.0])

def test_construction_from_face_centers(self):
"""Test the KDTree creation and query function using the grids face
centers."""

uxgrid = ux.open_grid(self.center_grid_file)
d, ind = uxgrid.get_kd_tree(tree_type="face centers").query(
[1.0, 0.0, 0.0], k=5)

def test_query_radius(self):
"""Test the KDTree creation and query_radius function using the grids
face centers."""

uxgrid = ux.open_grid(self.center_grid_file)
d, ind = uxgrid.get_kd_tree(tree_type="face centers").query_radius(
[0.0, 0.0, 1.0], r=5)
31 changes: 30 additions & 1 deletion uxarray/grid/grid.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@
_grid_to_matplotlib_linecollection,
_grid_to_polygons)

from uxarray.grid.neighbors import BallTree
from uxarray.grid.neighbors import BallTree, KDTree

from uxarray.plot.accessor import GridPlotAccessor

Expand Down Expand Up @@ -124,6 +124,7 @@ def __init__(self,

# initialize cached data structures (nearest neighbor operations)
self._ball_tree = None
self._kd_tree = None

self._mesh2_warning_raised = False

Expand Down Expand Up @@ -757,6 +758,34 @@ def get_ball_tree(self, tree_type: Optional[str] = "nodes"):

return self._ball_tree

def get_kd_tree(self, tree_type: Optional[str] = "nodes"):
"""Get the KDTree data structure of this Grid that allows for nearest
neighbor queries (k nearest or within some radius) on either the nodes
(``Mesh2_node_cart_x``, ``Mesh2_node_cart_y``, ``Mesh2_node_cart_z``)
or face centers (``Mesh2_face_cart_x``, ``Mesh2_face_cart_y``,
``Mesh2_face_cart_z``).
Parameters
----------
tree_type : str, default="nodes"
Selects which tree to query, with "nodes" selecting the Corner Nodes and "face centers" selecting the Face
Centers of each face
Returns
-------
self._kd_tree : grid.Neighbors.KDTree
KDTree instance
"""
if self._kd_tree is None:
self._kd_tree = KDTree(self,
tree_type=tree_type,
distance_metric='minkowski')
else:
if tree_type != self._kd_tree._tree_type:
self._kd_tree.tree_type = tree_type

return self._kd_tree

def copy(self):
"""Returns a deep copy of this grid."""

Expand Down
Loading

0 comments on commit d5f61da

Please sign in to comment.