Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add 1D nearest neighbour and linear interpolator #170

Merged
merged 2 commits into from
Jun 18, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 17 additions & 2 deletions appletree/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -213,7 +213,22 @@ def build_point(self, data):
self.coordinate_system = jnp.asarray(data["coordinate_system"], dtype=float)
self.map = jnp.asarray(data["map"], dtype=float)

setattr(self, "interpolator", interpolation.curve_interpolator)
if self.method == "IDW":
setattr(self, "interpolator", interpolation.curve_interpolator)
elif self.method == "NN":
setattr(
self,
"interpolator",
interpolation.map_interpolator_nearest_neighbor_1d,
)
elif self.method == "LERP":
setattr(
self,
"interpolator",
interpolation.map_interpolator_linear_1d,
)
else:
raise ValueError(f"Unknown method {self.method} for 1D regular binning.")
if self.coordinate_type == "log_point":
if jnp.any(self.coordinate_system <= 0):
raise ValueError(
Expand Down Expand Up @@ -262,7 +277,7 @@ def build_regbin(self, data):
)
else:
raise ValueError(f"Unknown method {self.method} for 2D regular binning.")
elif len(self.coordinate_lowers) == 3 and self.method == "IDW":
elif len(self.coordinate_lowers) == 3:
if self.method == "IDW":
setattr(self, "interpolator", interpolation.map_interpolator_regular_binning_3d)
elif self.method == "NN":
Expand Down
31 changes: 31 additions & 0 deletions appletree/interpolation.py
Original file line number Diff line number Diff line change
Expand Up @@ -267,6 +267,37 @@ def find_nearest_indices(x, y):
return indices


@export
@jit
def map_interpolator_linear_1d(pos, ref_pos, ref_val):
"""Linear 1D interpolation. Copied to prevent misuse of other arguments of jnp.interp.

Args:
pos: array with shape (N,), as the points to be interpolated.
ref_pos: array with shape (M,), as the reference points.
ref_val: array with shape (M,), as the reference values.

"""
return jnp.interp(pos, ref_pos, ref_val)


@export
@jit
def map_interpolator_nearest_neighbor_1d(pos, ref_pos, ref_val):
"""Nearest neighbor 1D interpolation.

Args:
pos: array with shape (N,), as the points to be interpolated.
ref_pos: array with shape (M,), as the reference points.
ref_val: array with shape (M,), as the reference values.

"""
ind = find_nearest_indices(pos, ref_pos)

val = ref_val[ind]
return val


@export
@jit
def map_interpolator_regular_binning_nearest_neighbor_2d(
Expand Down
7 changes: 6 additions & 1 deletion appletree/plugins/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,12 @@ def simulate(self, key, parameters, batch_size):

@export
@takes_config(
Map(name="energy_spectrum", default="_nr_spectrum.json", help="Recoil energy spectrum"),
Map(
name="energy_spectrum",
method="LERP",
default="_nr_spectrum.json",
help="Recoil energy spectrum",
),
)
class FixedEnergySpectra(Plugin):
depends_on = ["batch_size"]
Expand Down
7 changes: 6 additions & 1 deletion appletree/plugins/detector.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,12 @@ def simulate(self, key, parameters, num_s1_phd):

@export
@takes_config(
Map(name="elife", default="_elife.json", help="Electron lifetime correction"),
Map(
name="elife",
method="LERP",
default="_elife.json",
help="Electron lifetime correction",
),
)
class DriftLoss(Plugin):
depends_on = ["z"]
Expand Down
3 changes: 3 additions & 0 deletions appletree/plugins/efficiency.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ def simulate(self, key, parameters, s2_area):
@takes_config(
SigmaMap(
name="s1_eff_3f",
method="NN",
default="_3fold_recon_eff.json",
help="3fold S1 reconstruction efficiency",
),
Expand All @@ -44,6 +45,7 @@ def simulate(self, key, parameters, num_s1_phd):
@takes_config(
SigmaMap(
name="s1_cut_acc",
method="LERP",
default=["_s1_cut_acc.json", "_s1_cut_acc.json", "_s1_cut_acc.json"],
help="S1 cut acceptance",
),
Expand All @@ -64,6 +66,7 @@ def simulate(self, key, parameters, s1_area):
@takes_config(
SigmaMap(
name="s2_cut_acc",
method="LERP",
default=["_s2_cut_acc.json", "_s2_cut_acc.json", "_s2_cut_acc.json", "s2_cut_acc_sigma"],
help="S2 cut acceptance",
),
Expand Down
14 changes: 12 additions & 2 deletions appletree/plugins/lyqy.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,12 @@

@export
@takes_config(
Map(name="ly_median", default="_nr_ly.json", help="Light yield curve from NESTv2"),
Map(
name="ly_median",
method="LERP",
default="_nr_ly.json",
help="Light yield curve from NESTv2",
),
)
class LightYield(Plugin):
depends_on = ["energy"]
Expand All @@ -39,7 +44,12 @@ def simulate(self, key, parameters, energy, light_yield):

@export
@takes_config(
Map(name="qy_median", default="_nr_qy.json", help="Charge yield curve from NESTv2"),
Map(
name="qy_median",
method="LERP",
default="_nr_qy.json",
help="Charge yield curve from NESTv2",
),
)
class ChargeYield(Plugin):
depends_on = ["energy"]
Expand Down
35 changes: 30 additions & 5 deletions appletree/plugins/reconstruction.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,12 @@

@export
@takes_config(
Map(name="posrec_reso", default="_posrec_reso.json", help="Position reconstruction resolution"),
Map(
name="posrec_reso",
method="LERP",
default="_posrec_reso.json",
help="Position reconstruction resolution",
),
)
class PositionRecon(Plugin):
depends_on = ["x", "y", "z", "num_electron_drifted"]
Expand All @@ -35,8 +40,18 @@ def simulate(self, key, parameters, x, y, z, num_electron_drifted):

@export
@takes_config(
Map(name="s1_bias_3f", default="_s1_bias.json", help="3fold S1 reconstruction bias"),
Map(name="s1_smear_3f", default="_s1_smearing.json", help="3fold S1 reconstruction smearing"),
Map(
name="s1_bias_3f",
method="LERP",
default="_s1_bias.json",
help="3fold S1 reconstruction bias",
),
Map(
name="s1_smear_3f",
method="LERP",
default="_s1_smearing.json",
help="3fold S1 reconstruction smearing",
),
)
class S1(Plugin):
depends_on = ["num_s1_phd", "num_s1_pe"]
Expand All @@ -53,8 +68,18 @@ def simulate(self, key, parameters, num_s1_phd, num_s1_pe):

@export
@takes_config(
Map(name="s2_bias", default="_s2_bias.json", help="S2 reconstruction bias"),
Map(name="s2_smear", default="_s2_smearing.json", help="S2 reconstruction smearing"),
Map(
name="s2_bias",
method="LERP",
default="_s2_bias.json",
help="S2 reconstruction bias",
),
Map(
name="s2_smear",
method="LERP",
default="_s2_smearing.json",
help="S2 reconstruction smearing",
),
)
class S2(Plugin):
depends_on = ["num_s2_pe"]
Expand Down
Loading