Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

chore: code clean up #117

Merged
merged 2 commits into from
Sep 13, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion LICENSE
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
Copyright (c) 2012-2022 by the GalSim developers team on GitHub
Copyright (c) 2012-2024 by the GalSim developers team on GitHub
https://github.com/GalSim-developers

Redistribution and use in source and binary forms, with or without
Expand Down
10 changes: 3 additions & 7 deletions jax_galsim/angle.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
# original source license:
#
# Copyright (c) 2013-2017 LSST Dark Energy Science Collaboration (DESC)
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
Expand Down Expand Up @@ -176,21 +178,15 @@ def __sub__(self, other):
return _Angle(self._rad - other._rad)

def __mul__(self, other):
# if other != float(other):
# raise TypeError("Cannot multiply Angle by %s of type %s" % (other, type(other)))
return _Angle(self._rad * other)

__rmul__ = __mul__

def __div__(self, other):
if isinstance(other, AngleUnit):
return self._rad / other.value
elif other == float(other):
return _Angle(self._rad / other)
else:
raise TypeError(
"Cannot divide Angle by %s of type %s" % (other, type(other))
)
return _Angle(self._rad / other)

__truediv__ = __div__

Expand Down
152 changes: 152 additions & 0 deletions jax_galsim/bessel.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,3 +111,155 @@ def kv(nu, x):
nu = 1.0 * nu
x = 1.0 * x
return _tfp_bessel_kve(nu, x) / jnp.exp(jnp.abs(x))


@jax.jit
def _R(z, num, denom):
return jnp.polyval(num, z) / jnp.polyval(denom, z)


@jax.jit
def _evaluate_rational(z, num, denom):
return _R(z, num[::-1], denom[::-1])


# jitted & vectorized version
_v_rational = jax.jit(jax.vmap(_evaluate_rational, in_axes=(0, None, None)))


@implements(
_galsim.bessel.j0,
lax_description="""\
The JAX-GalSim implementation of ``j0`` is a vectorized version of the Boost C++
algorith for the Bessel function of the first kind J0(x).""",
)
@jax.jit
def j0(x):
orig_shape = x.shape

x = jnp.atleast_1d(x)

P1 = jnp.array(
[
-4.1298668500990866786e11,
2.7282507878605942706e10,
-6.2140700423540120665e08,
6.6302997904833794242e06,
-3.6629814655107086448e04,
1.0344222815443188943e02,
-1.2117036164593528341e-01,
]
)
Q1 = jnp.array(
[
2.3883787996332290397e12,
2.6328198300859648632e10,
1.3985097372263433271e08,
4.5612696224219938200e05,
9.3614022392337710626e02,
1.0,
0.0,
]
)

P2 = jnp.array(
[
-1.8319397969392084011e03,
-1.2254078161378989535e04,
-7.2879702464464618998e03,
1.0341910641583726701e04,
1.1725046279757103576e04,
4.4176707025325087628e03,
7.4321196680624245801e02,
4.8591703355916499363e01,
]
)
Q2 = jnp.array(
[
-3.5783478026152301072e05,
2.4599102262586308984e05,
-8.4055062591169562211e04,
1.8680990008359188352e04,
-2.9458766545509337327e03,
3.3307310774649071172e02,
-2.5258076240801555057e01,
1.0,
]
)

PC = jnp.array(
[
2.2779090197304684302e04,
4.1345386639580765797e04,
2.1170523380864944322e04,
3.4806486443249270347e03,
1.5376201909008354296e02,
8.8961548424210455236e-01,
]
)
QC = jnp.array(
[
2.2779090197304684318e04,
4.1370412495510416640e04,
2.1215350561880115730e04,
3.5028735138235608207e03,
1.5711159858080893649e02,
1.0,
]
)

PS = jnp.array(
[
-8.9226600200800094098e01,
-1.8591953644342993800e02,
-1.1183429920482737611e02,
-2.2300261666214198472e01,
-1.2441026745835638459e00,
-8.8033303048680751817e-03,
]
)
QS = jnp.array(
[
5.7105024128512061905e03,
1.1951131543434613647e04,
7.2642780169211018836e03,
1.4887231232283756582e03,
9.0593769594993125859e01,
1.0,
]
)

x1 = 2.4048255576957727686e00
x2 = 5.5200781102863106496e00
x11 = 6.160e02
x12 = -1.42444230422723137837e-03
x21 = 1.4130e03
x22 = 5.46860286310649596604e-04
one_div_root_pi = 5.641895835477562869480794515607725858e-01

def t1(x): # x<=4
y = x * x
r = _v_rational(y, P1, Q1)
factor = (x + x1) * ((x - x11 / 256) - x12)
return factor * r

def t2(x): # x<=8
y = 1 - (x * x) / 64
r = _v_rational(y, P2, Q2)
factor = (x + x2) * ((x - x21 / 256) - x22)
return factor * r

def t3(x): # x>8
y = 8 / x
y2 = y * y
rc = _v_rational(y2, PC, QC)
rs = _v_rational(y2, PS, QS)
factor = one_div_root_pi / jnp.sqrt(x)
sx = jnp.sin(x)
cx = jnp.cos(x)
return factor * (rc * (cx + sx) - y * rs * (sx - cx))

x = jnp.abs(x)
return jnp.select(
[x == 0, x <= 4, x <= 8, x > 8], [1, t1(x), t2(x), t3(x)], default=x
).reshape(orig_shape)
1 change: 0 additions & 1 deletion jax_galsim/bounds.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@
"""


# The reason for avoid these tests is that they are not easy to do for jitted code.
@implements(_galsim.Bounds, lax_description=BOUNDS_LAX_DESCR)
@register_pytree_node_class
class Bounds(_galsim.Bounds):
Expand Down
2 changes: 1 addition & 1 deletion jax_galsim/box.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,6 +115,7 @@ def tree_unflatten(cls, aux_data, children):
**aux_data,
)

@implements(_galsim.Box._shoot)
def _shoot(self, photons, rng):
ud = UniformDeviate(rng)

Expand All @@ -135,7 +136,6 @@ def __init__(self, scale, flux=1.0, gsparams=None):
@property
@implements(_galsim.Pixel.scale)
def scale(self):
"""The linear scale size of the `Pixel`."""
return self.width

def __repr__(self):
Expand Down
7 changes: 2 additions & 5 deletions jax_galsim/celestial.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
# original source license:
#
# Copyright (c) 2013-2017 LSST Dark Energy Science Collaboration (DESC)
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
Expand Down Expand Up @@ -71,8 +73,6 @@ def __init__(self, ra, dec=None):
raise TypeError("ra must be a galsim.Angle")
elif not isinstance(dec, Angle):
raise TypeError("dec must be a galsim.Angle")
# elif dec/degrees > 90. or dec/degrees < -90.:
# raise ValueError("dec must be between -90 deg and +90 deg.")
else:
# Normal case
self._ra = ra
Expand Down Expand Up @@ -130,9 +130,6 @@ def get_xyz(self):
)
def from_xyz(x, y, z):
norm = jnp.sqrt(x * x + y * y + z * z)
# JAX cannot check this condition
# if norm == 0.:
# raise ValueError("CelestialCoord for position (0,0,0) is undefined.")
ret = CelestialCoord.__new__(CelestialCoord)
ret._x = x / norm
ret._y = y / norm
Expand Down
5 changes: 1 addition & 4 deletions jax_galsim/convolve.py
Original file line number Diff line number Diff line change
Expand Up @@ -300,6 +300,7 @@ def _kValue(self, kpos):
def _drawReal(self, image, jac=None, offset=(0.0, 0.0), flux_scaling=1.0):
raise NotImplementedError("Real-space convolutions are not implemented")

@implements(_galsim.Convolution._shoot)
def _shoot(self, photons, rng):
self.obj_list[0]._shoot(photons, rng)
# It may be necessary to shuffle when convolving because we do not have a
Expand Down Expand Up @@ -342,10 +343,6 @@ def tree_unflatten(cls, aux_data, children):
lax_description="Does not support ChromaticDeconvolution",
)
def Deconvolve(obj, gsparams=None, propagate_gsparams=True):
# from .chromatic import ChromaticDeconvolution
# if isinstance(obj, ChromaticObject):
# return ChromaticDeconvolution(obj, gsparams=gsparams, propagate_gsparams=propagate_gsparams)
# elif isinstance(obj, GSObject):
if isinstance(obj, GSObject):
return Deconvolution(
obj, gsparams=gsparams, propagate_gsparams=propagate_gsparams
Expand Down
Loading
Loading