diff --git a/appletree/config.py b/appletree/config.py index 3cecd64..2fc7614 100644 --- a/appletree/config.py +++ b/appletree/config.py @@ -149,6 +149,10 @@ class Map(Config): """ + def __init__(self, method="IDW", **kwargs): + super().__init__(**kwargs) + self.method = method + def build(self, llh_name: Optional[str] = None): """Cache the map to jnp.array.""" @@ -248,9 +252,27 @@ def build_regbin(self, data): if len(self.coordinate_lowers) == 1: setattr(self, "interpolator", interpolation.map_interpolator_regular_binning_1d) elif len(self.coordinate_lowers) == 2: - setattr(self, "interpolator", interpolation.map_interpolator_regular_binning_2d) - elif len(self.coordinate_lowers) == 3: - setattr(self, "interpolator", interpolation.map_interpolator_regular_binning_3d) + if self.method == "IDW": + setattr(self, "interpolator", interpolation.map_interpolator_regular_binning_2d) + elif self.method == "NN": + setattr( + self, + "interpolator", + interpolation.map_interpolator_regular_binning_nearest_neighbor_2d, + ) + else: + raise ValueError(f"Unknown method {self.method} for 2D regular binning.") + elif len(self.coordinate_lowers) == 3 and self.method == "IDW": + if self.method == "IDW": + setattr(self, "interpolator", interpolation.map_interpolator_regular_binning_3d) + elif self.method == "NN": + setattr( + self, + "interpolator", + interpolation.map_interpolator_regular_binning_nearest_neighbor_3d, + ) + else: + raise ValueError(f"Unknown method {self.method} for 3D regular binning.") if self.coordinate_type == "log_regbin": if jnp.any(self.coordinate_lowers <= 0) or jnp.any(self.coordinate_uppers <= 0): raise ValueError( @@ -303,6 +325,10 @@ class SigmaMap(Config): """ + def __init__(self, method="IDW", **kwargs): + super().__init__(**kwargs) + self.method = method + def build(self, llh_name: Optional[str] = None): """Read maps.""" self.llh_name = llh_name @@ -330,7 +356,7 @@ def build(self, llh_name: Optional[str] = None): ) # If only one file is given, then use the same file for all sigmas default = _configs_default - maps[sigma] = Map(name=self.name + f"_{sigma}", default=default) + maps[sigma] = Map(method=self.method, name=self.name + f"_{sigma}", default=default) setattr(self, sigma, maps[sigma]) diff --git a/appletree/interpolation.py b/appletree/interpolation.py index 71db361..d11103c 100644 --- a/appletree/interpolation.py +++ b/appletree/interpolation.py @@ -257,3 +257,66 @@ def map_interpolator_regular_binning_3d(pos, ref_pos_lowers, ref_pos_uppers, ref ) return val + + +@jit +def find_nearest_indices(x, y): + x = x[:, jnp.newaxis] + differences = jnp.abs(x - y) + indices = jnp.argmin(differences, axis=1) + return indices + + +@export +@jit +def map_interpolator_regular_binning_nearest_neighbor_2d( + pos, ref_pos_lowers, ref_pos_uppers, ref_val +): + """Nearest neighbor 2D interpolation. A uniform mesh grid binning is assumed. + + Args: + pos: array with shape (N, 2), positions at which the interp is calculated. + ref_pos_lowers: array with shape (2,), the lower edges of the binning on each dimension. + ref_pos_uppers: array with shape (2,), the upper edges of the binning on each dimension. + ref_val: array with shape (M1, M2), map values. + + """ + n0, n1 = ref_val.shape + + bins0 = jnp.linspace(ref_pos_lowers[0], ref_pos_uppers[0], n0) + ind0 = find_nearest_indices(pos[:, 0], bins0) + + bins1 = jnp.linspace(ref_pos_lowers[1], ref_pos_uppers[1], n1) + ind1 = find_nearest_indices(pos[:, 1], bins1) + + val = ref_val[ind0, ind1] + return val + + +@export +@jit +def map_interpolator_regular_binning_nearest_neighbor_3d( + pos, ref_pos_lowers, ref_pos_uppers, ref_val +): + """Nearest neighbor 3D interpolation. A uniform mesh grid binning is assumed. + + Args: + pos: array with shape (N, 3), positions at which the interp is calculated. + ref_pos_lowers: array with shape (3,), the lower edges of the binning on each dimension. + ref_pos_uppers: array with shape (3,), the upper edges of the binning on each dimension. + ref_val: array with shape (M1, M2, M3), map values. + + """ + n0, n1, n2 = ref_val.shape + + bins0 = jnp.linspace(ref_pos_lowers[0], ref_pos_uppers[0], n0) + ind0 = find_nearest_indices(pos[:, 0], bins0) + + bins1 = jnp.linspace(ref_pos_lowers[1], ref_pos_uppers[1], n1) + ind1 = find_nearest_indices(pos[:, 1], bins1) + + bins2 = jnp.linspace(ref_pos_lowers[2], ref_pos_uppers[2], n2) + ind2 = find_nearest_indices(pos[:, 2], bins2) + + val = ref_val[ind0, ind1, ind2] + return val