From 464a1a2befac78aa917b22f7c7ba292c8d59df3c Mon Sep 17 00:00:00 2001 From: beckermr Date: Thu, 12 Sep 2024 13:15:52 -0500 Subject: [PATCH 1/2] chore: code clean up --- LICENSE | 2 +- jax_galsim/angle.py | 8 +- jax_galsim/bessel.py | 152 ++++++++++++++++++++++ jax_galsim/bounds.py | 1 - jax_galsim/box.py | 2 +- jax_galsim/celestial.py | 7 +- jax_galsim/convolve.py | 5 +- jax_galsim/core/bessel.py | 165 ------------------------ jax_galsim/core/draw.py | 1 - jax_galsim/deltafunction.py | 1 + jax_galsim/moffat.py | 3 +- jax_galsim/random.py | 242 ------------------------------------ jax_galsim/sum.py | 1 + jax_galsim/transform.py | 1 + jax_galsim/wcs.py | 12 +- 15 files changed, 172 insertions(+), 431 deletions(-) delete mode 100644 jax_galsim/core/bessel.py diff --git a/LICENSE b/LICENSE index 0c455eaf..344c4a4e 100644 --- a/LICENSE +++ b/LICENSE @@ -1,4 +1,4 @@ -Copyright (c) 2012-2022 by the GalSim developers team on GitHub +Copyright (c) 2012-2024 by the GalSim developers team on GitHub https://github.com/GalSim-developers Redistribution and use in source and binary forms, with or without diff --git a/jax_galsim/angle.py b/jax_galsim/angle.py index 2f45daf1..3c573f1a 100644 --- a/jax_galsim/angle.py +++ b/jax_galsim/angle.py @@ -1,3 +1,5 @@ +# original source license: +# # Copyright (c) 2013-2017 LSST Dark Energy Science Collaboration (DESC) # # Permission is hereby granted, free of charge, to any person obtaining a copy @@ -176,8 +178,10 @@ def __sub__(self, other): return _Angle(self._rad - other._rad) def __mul__(self, other): - # if other != float(other): - # raise TypeError("Cannot multiply Angle by %s of type %s" % (other, type(other))) + if other != float(other): + raise TypeError( + "Cannot multiply Angle by %s of type %s" % (other, type(other)) + ) return _Angle(self._rad * other) __rmul__ = __mul__ diff --git a/jax_galsim/bessel.py b/jax_galsim/bessel.py index 616b4161..0471dcc3 100644 --- a/jax_galsim/bessel.py +++ b/jax_galsim/bessel.py @@ -111,3 +111,155 @@ def kv(nu, x): nu = 1.0 * nu x = 1.0 * x return _tfp_bessel_kve(nu, x) / jnp.exp(jnp.abs(x)) + + +@jax.jit +def _R(z, num, denom): + return jnp.polyval(num, z) / jnp.polyval(denom, z) + + +@jax.jit +def _evaluate_rational(z, num, denom): + return _R(z, num[::-1], denom[::-1]) + + +# jitted & vectorized version +_v_rational = jax.jit(jax.vmap(_evaluate_rational, in_axes=(0, None, None))) + + +@implements( + _galsim.bessel.j0, + lax_description="""\ +The JAX-GalSim implementation of ``j0`` is a vectorized version of the Boost C++ +algorith for the Bessel function of the first kind J0(x).""", +) +@jax.jit +def j0(x): + orig_shape = x.shape + + x = jnp.atleast_1d(x) + + P1 = jnp.array( + [ + -4.1298668500990866786e11, + 2.7282507878605942706e10, + -6.2140700423540120665e08, + 6.6302997904833794242e06, + -3.6629814655107086448e04, + 1.0344222815443188943e02, + -1.2117036164593528341e-01, + ] + ) + Q1 = jnp.array( + [ + 2.3883787996332290397e12, + 2.6328198300859648632e10, + 1.3985097372263433271e08, + 4.5612696224219938200e05, + 9.3614022392337710626e02, + 1.0, + 0.0, + ] + ) + + P2 = jnp.array( + [ + -1.8319397969392084011e03, + -1.2254078161378989535e04, + -7.2879702464464618998e03, + 1.0341910641583726701e04, + 1.1725046279757103576e04, + 4.4176707025325087628e03, + 7.4321196680624245801e02, + 4.8591703355916499363e01, + ] + ) + Q2 = jnp.array( + [ + -3.5783478026152301072e05, + 2.4599102262586308984e05, + -8.4055062591169562211e04, + 1.8680990008359188352e04, + -2.9458766545509337327e03, + 3.3307310774649071172e02, + -2.5258076240801555057e01, + 1.0, + ] + ) + + PC = jnp.array( + [ + 2.2779090197304684302e04, + 4.1345386639580765797e04, + 2.1170523380864944322e04, + 3.4806486443249270347e03, + 1.5376201909008354296e02, + 8.8961548424210455236e-01, + ] + ) + QC = jnp.array( + [ + 2.2779090197304684318e04, + 4.1370412495510416640e04, + 2.1215350561880115730e04, + 3.5028735138235608207e03, + 1.5711159858080893649e02, + 1.0, + ] + ) + + PS = jnp.array( + [ + -8.9226600200800094098e01, + -1.8591953644342993800e02, + -1.1183429920482737611e02, + -2.2300261666214198472e01, + -1.2441026745835638459e00, + -8.8033303048680751817e-03, + ] + ) + QS = jnp.array( + [ + 5.7105024128512061905e03, + 1.1951131543434613647e04, + 7.2642780169211018836e03, + 1.4887231232283756582e03, + 9.0593769594993125859e01, + 1.0, + ] + ) + + x1 = 2.4048255576957727686e00 + x2 = 5.5200781102863106496e00 + x11 = 6.160e02 + x12 = -1.42444230422723137837e-03 + x21 = 1.4130e03 + x22 = 5.46860286310649596604e-04 + one_div_root_pi = 5.641895835477562869480794515607725858e-01 + + def t1(x): # x<=4 + y = x * x + r = _v_rational(y, P1, Q1) + factor = (x + x1) * ((x - x11 / 256) - x12) + return factor * r + + def t2(x): # x<=8 + y = 1 - (x * x) / 64 + r = _v_rational(y, P2, Q2) + factor = (x + x2) * ((x - x21 / 256) - x22) + return factor * r + + def t3(x): # x>8 + y = 8 / x + y2 = y * y + rc = _v_rational(y2, PC, QC) + rs = _v_rational(y2, PS, QS) + factor = one_div_root_pi / jnp.sqrt(x) + sx = jnp.sin(x) + cx = jnp.cos(x) + return factor * (rc * (cx + sx) - y * rs * (sx - cx)) + + x = jnp.abs(x) + return jnp.select( + [x == 0, x <= 4, x <= 8, x > 8], [1, t1(x), t2(x), t3(x)], default=x + ).reshape(orig_shape) diff --git a/jax_galsim/bounds.py b/jax_galsim/bounds.py index d7dc7101..19f11841 100644 --- a/jax_galsim/bounds.py +++ b/jax_galsim/bounds.py @@ -19,7 +19,6 @@ """ -# The reason for avoid these tests is that they are not easy to do for jitted code. @implements(_galsim.Bounds, lax_description=BOUNDS_LAX_DESCR) @register_pytree_node_class class Bounds(_galsim.Bounds): diff --git a/jax_galsim/box.py b/jax_galsim/box.py index a2475019..95b3d373 100644 --- a/jax_galsim/box.py +++ b/jax_galsim/box.py @@ -115,6 +115,7 @@ def tree_unflatten(cls, aux_data, children): **aux_data, ) + @implements(_galsim.Box._shoot) def _shoot(self, photons, rng): ud = UniformDeviate(rng) @@ -135,7 +136,6 @@ def __init__(self, scale, flux=1.0, gsparams=None): @property @implements(_galsim.Pixel.scale) def scale(self): - """The linear scale size of the `Pixel`.""" return self.width def __repr__(self): diff --git a/jax_galsim/celestial.py b/jax_galsim/celestial.py index 25e5d702..1e4f346f 100644 --- a/jax_galsim/celestial.py +++ b/jax_galsim/celestial.py @@ -1,3 +1,5 @@ +# original source license: +# # Copyright (c) 2013-2017 LSST Dark Energy Science Collaboration (DESC) # # Permission is hereby granted, free of charge, to any person obtaining a copy @@ -71,8 +73,6 @@ def __init__(self, ra, dec=None): raise TypeError("ra must be a galsim.Angle") elif not isinstance(dec, Angle): raise TypeError("dec must be a galsim.Angle") - # elif dec/degrees > 90. or dec/degrees < -90.: - # raise ValueError("dec must be between -90 deg and +90 deg.") else: # Normal case self._ra = ra @@ -130,9 +130,6 @@ def get_xyz(self): ) def from_xyz(x, y, z): norm = jnp.sqrt(x * x + y * y + z * z) - # JAX cannot check this condition - # if norm == 0.: - # raise ValueError("CelestialCoord for position (0,0,0) is undefined.") ret = CelestialCoord.__new__(CelestialCoord) ret._x = x / norm ret._y = y / norm diff --git a/jax_galsim/convolve.py b/jax_galsim/convolve.py index 62409d65..6961ad39 100644 --- a/jax_galsim/convolve.py +++ b/jax_galsim/convolve.py @@ -300,6 +300,7 @@ def _kValue(self, kpos): def _drawReal(self, image, jac=None, offset=(0.0, 0.0), flux_scaling=1.0): raise NotImplementedError("Real-space convolutions are not implemented") + @implements(_galsim.Convolution._shoot) def _shoot(self, photons, rng): self.obj_list[0]._shoot(photons, rng) # It may be necessary to shuffle when convolving because we do not have a @@ -342,10 +343,6 @@ def tree_unflatten(cls, aux_data, children): lax_description="Does not support ChromaticDeconvolution", ) def Deconvolve(obj, gsparams=None, propagate_gsparams=True): - # from .chromatic import ChromaticDeconvolution - # if isinstance(obj, ChromaticObject): - # return ChromaticDeconvolution(obj, gsparams=gsparams, propagate_gsparams=propagate_gsparams) - # elif isinstance(obj, GSObject): if isinstance(obj, GSObject): return Deconvolution( obj, gsparams=gsparams, propagate_gsparams=propagate_gsparams diff --git a/jax_galsim/core/bessel.py b/jax_galsim/core/bessel.py deleted file mode 100644 index 7d65b3bd..00000000 --- a/jax_galsim/core/bessel.py +++ /dev/null @@ -1,165 +0,0 @@ -import jax -import jax.numpy as jnp - - -@jax.jit -def R(z, num, denom): - return jnp.polyval(num, z) / jnp.polyval(denom, z) - - -@jax.jit -def _evaluate_rational(z, num, denom): - return R(z, num[::-1], denom[::-1]) - - -# jitted & vectorized version -v_rational = jax.jit(jax.vmap(_evaluate_rational, in_axes=(0, None, None))) - - -@jax.jit -def j0(x): - """Bessel function of the first kind J0(x) (similar to scipy.special.j0) - code from Boost C++ implementation - boost/math/special_functions/detail/bessel_j0.hpp - - Examples:: - - >>> x = jnp.linspace(0,300,10_000) - >>> plt.plot(x,j0(x)) - >>> plt.plot(x,jax.vmap(jax.jacfwd(j0))(x)) - - Inputs: - x: scalar/array of real(s) - - Outputs: - j0(x) with same shape as x - - """ - orig_shape = x.shape - - x = jnp.atleast_1d(x) - - P1 = jnp.array( - [ - -4.1298668500990866786e11, - 2.7282507878605942706e10, - -6.2140700423540120665e08, - 6.6302997904833794242e06, - -3.6629814655107086448e04, - 1.0344222815443188943e02, - -1.2117036164593528341e-01, - ] - ) - Q1 = jnp.array( - [ - 2.3883787996332290397e12, - 2.6328198300859648632e10, - 1.3985097372263433271e08, - 4.5612696224219938200e05, - 9.3614022392337710626e02, - 1.0, - 0.0, - ] - ) - - P2 = jnp.array( - [ - -1.8319397969392084011e03, - -1.2254078161378989535e04, - -7.2879702464464618998e03, - 1.0341910641583726701e04, - 1.1725046279757103576e04, - 4.4176707025325087628e03, - 7.4321196680624245801e02, - 4.8591703355916499363e01, - ] - ) - Q2 = jnp.array( - [ - -3.5783478026152301072e05, - 2.4599102262586308984e05, - -8.4055062591169562211e04, - 1.8680990008359188352e04, - -2.9458766545509337327e03, - 3.3307310774649071172e02, - -2.5258076240801555057e01, - 1.0, - ] - ) - - PC = jnp.array( - [ - 2.2779090197304684302e04, - 4.1345386639580765797e04, - 2.1170523380864944322e04, - 3.4806486443249270347e03, - 1.5376201909008354296e02, - 8.8961548424210455236e-01, - ] - ) - QC = jnp.array( - [ - 2.2779090197304684318e04, - 4.1370412495510416640e04, - 2.1215350561880115730e04, - 3.5028735138235608207e03, - 1.5711159858080893649e02, - 1.0, - ] - ) - - PS = jnp.array( - [ - -8.9226600200800094098e01, - -1.8591953644342993800e02, - -1.1183429920482737611e02, - -2.2300261666214198472e01, - -1.2441026745835638459e00, - -8.8033303048680751817e-03, - ] - ) - QS = jnp.array( - [ - 5.7105024128512061905e03, - 1.1951131543434613647e04, - 7.2642780169211018836e03, - 1.4887231232283756582e03, - 9.0593769594993125859e01, - 1.0, - ] - ) - - x1 = 2.4048255576957727686e00 - x2 = 5.5200781102863106496e00 - x11 = 6.160e02 - x12 = -1.42444230422723137837e-03 - x21 = 1.4130e03 - x22 = 5.46860286310649596604e-04 - one_div_root_pi = 5.641895835477562869480794515607725858e-01 - - def t1(x): # x<=4 - y = x * x - r = v_rational(y, P1, Q1) - factor = (x + x1) * ((x - x11 / 256) - x12) - return factor * r - - def t2(x): # x<=8 - y = 1 - (x * x) / 64 - r = v_rational(y, P2, Q2) - factor = (x + x2) * ((x - x21 / 256) - x22) - return factor * r - - def t3(x): # x>8 - y = 8 / x - y2 = y * y - rc = v_rational(y2, PC, QC) - rs = v_rational(y2, PS, QS) - factor = one_div_root_pi / jnp.sqrt(x) - sx = jnp.sin(x) - cx = jnp.cos(x) - return factor * (rc * (cx + sx) - y * rs * (sx - cx)) - - x = jnp.abs(x) - return jnp.select( - [x == 0, x <= 4, x <= 8, x > 8], [1, t1(x), t2(x), t3(x)], default=x - ).reshape(orig_shape) diff --git a/jax_galsim/core/draw.py b/jax_galsim/core/draw.py index b17b8d6e..a197d151 100644 --- a/jax_galsim/core/draw.py +++ b/jax_galsim/core/draw.py @@ -69,7 +69,6 @@ def apply_kImage_phases(offset, image, jacobian=jnp.eye(2)): kcoords = jnp.dot(kcoords, jacobian) cenx, ceny = offset.x, offset.y - # # flux Exp(-i (kx cx + kxy cx + kyx cy + ky cy ) ) # NB: seems that tere is no jax.lax.polar equivalent to c++ std::polar function def phase(kpos): diff --git a/jax_galsim/deltafunction.py b/jax_galsim/deltafunction.py index 7f974063..71043e96 100644 --- a/jax_galsim/deltafunction.py +++ b/jax_galsim/deltafunction.py @@ -65,6 +65,7 @@ def _kValue(self, kpos): # to match the input kpos return self.flux + kpos.x * (0.0 + 0.0j) + @implements(_galsim.DeltaFunction._shoot) def _shoot(self, photons, rng): flux_per_photon = self.flux / photons.size() photons.x = 0.0 diff --git a/jax_galsim/moffat.py b/jax_galsim/moffat.py index 186e6f2c..07502dce 100644 --- a/jax_galsim/moffat.py +++ b/jax_galsim/moffat.py @@ -4,8 +4,7 @@ from jax.tree_util import Partial as partial from jax.tree_util import register_pytree_node_class -from jax_galsim.bessel import kv -from jax_galsim.core.bessel import j0 +from jax_galsim.bessel import j0, kv from jax_galsim.core.draw import draw_by_kValue, draw_by_xValue from jax_galsim.core.integrate import ClenshawCurtisQuad, quad_integral from jax_galsim.core.utils import bisect_for_root, ensure_hashable, implements diff --git a/jax_galsim/random.py b/jax_galsim/random.py index 4699739f..db027ed7 100644 --- a/jax_galsim/random.py +++ b/jax_galsim/random.py @@ -700,248 +700,6 @@ def __str__(self): return "galsim.Chi2Deviate(n=%r)" % (ensure_hashable(self.n),) -# class DistDeviate(BaseDeviate): -# """A class to draw random numbers from a user-defined probability distribution. - -# DistDeviate is a `BaseDeviate` class that can be used to draw from an arbitrary probability -# distribution. The probability distribution passed to DistDeviate can be given one of three -# ways: as the name of a file containing a 2d ASCII array of x and P(x), as a `LookupTable` -# mapping x to P(x), or as a callable function. - -# Once given a probability, DistDeviate creates a table of the cumulative probability and draws -# from it using a `UniformDeviate`. The precision of its outputs can be controlled with the -# keyword ``npoints``, which sets the number of points DistDeviate creates for its internal table -# of CDF(x). To prevent errors due to non-monotonicity, the interpolant for this internal table -# is always linear. - -# Two keywords, ``x_min`` and ``x_max``, define the support of the function. They must be passed -# if a callable function is given to DistDeviate, unless the function is a `LookupTable`, which -# has its own defined endpoints. If a filename or `LookupTable` is passed to DistDeviate, the -# use of ``x_min`` or ``x_max`` will result in an error. - -# If given a table in a file, DistDeviate will construct an interpolated `LookupTable` to obtain -# more finely gridded probabilities for generating the cumulative probability table. The default -# ``interpolant`` is linear, but any interpolant understood by `LookupTable` may be used. We -# caution against the use of splines because they can cause non-monotonic behavior. Passing the -# ``interpolant`` keyword next to anything but a table in a file will result in an error. - -# **Examples**: - -# Some sample initialization calls:: - -# >>> d = galsim.DistDeviate(function=f, x_min=x_min, x_max=x_max) - -# Initializes d to be a DistDeviate instance with a distribution given by the callable function -# ``f(x)`` from ``x=x_min`` to ``x=x_max`` and seeds the PRNG using current time:: - -# >>> d = galsim.DistDeviate(1062533, function=file_name, interpolant='floor') - -# Initializes d to be a DistDeviate instance with a distribution given by the data in file -# ``file_name``, which must be a 2-column ASCII table, and seeds the PRNG using the integer -# seed 1062533. It generates probabilities from ``file_name`` using the interpolant 'floor':: - -# >>> d = galsim.DistDeviate(rng, function=galsim.LookupTable(x,p)) - -# Initializes d to be a DistDeviate instance with a distribution given by P(x), defined as two -# arrays ``x`` and ``p`` which are used to make a callable `LookupTable`, and links the -# DistDeviate PRNG to the already-existing random number generator ``rng``. - -# Successive calls to ``d()`` generate pseudo-random values with the given probability -# distribution:: - -# >>> d = galsim.DistDeviate(31415926, function=lambda x: 1-abs(x), x_min=-1, x_max=1) -# >>> d() -# -0.4151921102709466 -# >>> d() -# -0.00909781188974034 - -# Parameters: -# seed: Something that can seed a `BaseDeviate`: an integer seed or another -# `BaseDeviate`. Using 0 means to generate a seed from the system. -# [default: None] -# function: A callable function giving a probability distribution or the name of a -# file containing a probability distribution as a 2-column ASCII table. -# [required] -# x_min: The minimum desired return value (required for non-`LookupTable` -# callable functions; will raise an error if not passed in that case, or if -# passed in any other case) [default: None] -# x_max: The maximum desired return value (required for non-`LookupTable` -# callable functions; will raise an error if not passed in that case, or if -# passed in any other case) [default: None] -# interpolant: Type of interpolation used for interpolating a file (causes an error if -# passed alongside a callable function). Options are given in the -# documentation for `LookupTable`. [default: 'linear'] -# npoints: Number of points DistDeviate should create for its internal interpolation -# tables. [default: 256, unless the function is a non-log `LookupTable`, in -# which case it uses the table's x values] -# """ -# def __init__(self, seed=None, function=None, x_min=None, -# x_max=None, interpolant=None, npoints=None): -# from .table import LookupTable -# from . import utilities -# from . import integ - -# # Set up the PRNG -# self._rng_type = _galsim.UniformDeviateImpl -# self._rng_args = () -# self.reset(seed) - -# # Basic input checking and setups -# if function is None: -# raise TypeError('You must pass a function to DistDeviate!') - -# self._interpolant = interpolant -# self._npoints = npoints -# self._xmin = x_min -# self._xmax = x_max - -# # Figure out if a string is a filename or something we should be using in an eval call -# if isinstance(function, str): -# self._function = function # Save the inputs to be used in repr -# import os.path -# if os.path.isfile(function): -# if interpolant is None: -# interpolant='linear' -# if x_min or x_max: -# raise GalSimIncompatibleValuesError( -# "Cannot pass x_min or x_max with a filename argument", -# function=function, x_min=x_min, x_max=x_max) -# function = LookupTable.from_file(function, interpolant=interpolant) -# x_min = function.x_min -# x_max = function.x_max -# else: -# try: -# function = utilities.math_eval('lambda x : ' + function) -# if x_min is not None: # is not None in case x_min=0. -# function(x_min) -# else: -# # Somebody would be silly to pass a string for evaluation without x_min, -# # but we'd like to throw reasonable errors in that case anyway -# function(0.6) # A value unlikely to be a singular point of a function -# except Exception as e: -# raise GalSimValueError( -# "String function must either be a valid filename or something that " -# "can eval to a function of x.\n" -# "Caught error: {0}".format(e), self._function) -# else: -# # Check that the function is actually a function -# if not hasattr(function, '__call__'): -# raise TypeError('function must be a callable function or a string') -# if interpolant: -# raise GalSimIncompatibleValuesError( -# "Cannot provide an interpolant with a callable function argument", -# interpolant=interpolant, function=function) -# if isinstance(function, LookupTable): -# if x_min or x_max: -# raise GalSimIncompatibleValuesError( -# "Cannot provide x_min or x_max with a LookupTable function", -# function=function, x_min=x_min, x_max=x_max) -# x_min = function.x_min -# x_max = function.x_max -# else: -# if x_min is None or x_max is None: -# raise GalSimIncompatibleValuesError( -# "Must provide x_min and x_max when function argument is a regular " -# "python callable function", -# function=function, x_min=x_min, x_max=x_max) - -# self._function = function # Save the inputs to be used in repr - -# # Compute the probability distribution function, pdf(x) -# if (npoints is None and isinstance(function, LookupTable) and -# not function.x_log and not function.f_log): -# xarray = np.array(function.x, dtype=float) -# pdf = np.array(function.f, dtype=float) -# # Set up pdf, so cumsum basically does a cumulative trapz integral -# # On Python 3.4, doing pdf[1:] += pdf[:-1] the last value gets messed up. -# # Writing it this way works. (Maybe slightly slower though, so if we stop -# # supporting python 3.4, consider switching to the += version.) -# pdf[1:] = pdf[1:] + pdf[:-1] -# pdf[1:] *= np.diff(xarray) -# pdf[0] = 0. -# else: -# if npoints is None: npoints = 256 -# xarray = x_min+(1.*x_max-x_min)/(npoints-1)*np.array(range(npoints),float) -# # Integrate over the range of x in case the function is doing something weird here. -# pdf = [0.] + [integ.int1d(function, xarray[i], xarray[i+1]) -# for i in range(npoints - 1)] -# pdf = np.array(pdf) - -# # Check that the probability is nonnegative -# if not np.all(pdf >= 0.): -# raise GalSimValueError('Negative probability found in DistDeviate.',function) - -# # Compute the cumulative distribution function = int(pdf(x),x) -# cdf = np.cumsum(pdf) - -# # Quietly renormalize the probability if it wasn't already normalized -# totalprobability = cdf[-1] -# cdf /= totalprobability - -# self._inverse_cdf = LookupTable(cdf, xarray, interpolant='linear') -# self.x_min = x_min -# self.x_max = x_max - -# def val(self, p): -# r""" -# Return the value :math:`x` of the input function to `DistDeviate` such that ``p`` = -# :math:`F(x)`, where :math:`F` is the cumulattive probability distribution function: - -# .. math:: - -# F(x) = \int_{-\infty}^x \mathrm{pdf}(t) dt - -# This function is typically called by `__call__`, which generates a random p -# between 0 and 1 and calls ``self.val(p)``. - -# Parameters: -# p: The desired cumulative probabilty p. - -# Returns: -# the corresponding x such that :math:`p = F(x)`. -# """ -# if p<0 or p>1: -# raise GalSimRangeError('Invalid cumulative probability for DistDeviate', p, 0., 1.) -# return self._inverse_cdf(p) - -# def __call__(self): -# """Draw a new random number from the distribution. -# """ -# return self._inverse_cdf(self._rng.generate1()) - -# def generate(self, array): -# """Generate many pseudo-random values, filling in the values of a numpy array. -# """ -# p = np.empty_like(array) -# BaseDeviate.generate(self, p) # Fill with unform deviate values -# np.copyto(array, self._inverse_cdf(p)) # Convert from p -> x - -# def add_generate(self, array): -# """Generate many pseudo-random values, adding them to the values of a numpy array. -# """ -# p = np.empty_like(array) -# BaseDeviate.generate(self, p) -# array += self._inverse_cdf(p) - -# def __repr__(self): -# return ('galsim.DistDeviate(seed=%r, function=%r, x_min=%r, x_max=%r, interpolant=%r, ' -# 'npoints=%r)')%(self._seed_repr(), self._function, self._xmin, self._xmax, -# self._interpolant, self._npoints) -# def __str__(self): -# return 'galsim.DistDeviate(function="%s", x_min=%s, x_max=%s, interpolant=%s, npoints=%s)'%( -# self._function, self._xmin, self._xmax, self._interpolant, self._npoints) - -# def __eq__(self, other): -# return (self is other or -# (isinstance(other, DistDeviate) and -# self.serialize() == other.serialize() and -# self._function == other._function and -# self._xmin == other._xmin and -# self._xmax == other._xmax and -# self._interpolant == other._interpolant and -# self._npoints == other._npoints)) - - @implements( _galsim.random.permute, lax_description="The JAX implementation of this function cannot operate in-place and so returns a new list of arrays.", diff --git a/jax_galsim/sum.py b/jax_galsim/sum.py index 453dc293..958e6bfa 100644 --- a/jax_galsim/sum.py +++ b/jax_galsim/sum.py @@ -188,6 +188,7 @@ def _negative_flux(self): def _flux_per_photon(self): return self._calculate_flux_per_photon() + @implements(_galsim.Sum._shoot) def _shoot(self, photons, rng): tot_flux = self.positive_flux + self.negative_flux fluxes = jnp.array( diff --git a/jax_galsim/transform.py b/jax_galsim/transform.py index 61f9d518..d0df2acf 100644 --- a/jax_galsim/transform.py +++ b/jax_galsim/transform.py @@ -383,6 +383,7 @@ def _drawKImage(self, image, jac=None): image = image * self._flux_scaling return image + @implements(_galsim.Transformation._shoot) def _shoot(self, photons, rng): self._original._shoot(photons, rng) photons.x, photons.y = self._fwd(photons.x, photons.y) diff --git a/jax_galsim/wcs.py b/jax_galsim/wcs.py index 52833939..3a177d0b 100644 --- a/jax_galsim/wcs.py +++ b/jax_galsim/wcs.py @@ -14,6 +14,7 @@ # We inherit from the reference BaseWCS and only redefine the methods that # make references to jax_galsim objects. +@implements(_galsim.BaseWCS) class BaseWCS(_galsim.BaseWCS): @implements(_galsim.BaseWCS.toWorld) def toWorld(self, *args, **kwargs): @@ -281,12 +282,8 @@ def from_galsim(cls, galsim_wcs): ######################################################################################### +@implements(_galsim.wcs.EuclideanWCS) class EuclideanWCS(BaseWCS): - """A EuclideanWCS is a `BaseWCS` whose world coordinates are on a Euclidean plane. - We usually use the notation (u,v) to refer to positions in world coordinates, and - they use the class `PositionD`. - """ - # All EuclideanWCS classes must define origin and world_origin. # Sometimes it is convenient to access x0,y0,u0,v0 directly. @property @@ -474,9 +471,8 @@ def __ne__(self, other): return not self.__eq__(other) +@implements(_galsim.wcs.UniformWCS) class UniformWCS(EuclideanWCS): - """A UniformWCS is a `EuclideanWCS` which has a uniform pixel size and shape.""" - @property def _isUniform(self): return True @@ -522,6 +518,7 @@ def __eq__(self, other): ) +@implements(_galsim.wcs.LocalWCS) class LocalWCS(UniformWCS): """A LocalWCS is a `UniformWCS` in which (0,0) in image coordinates is at the same place as (0,0) in world coordinates @@ -565,6 +562,7 @@ def _local(self, image_pos, color): return self +@implements(_galsim.wcs.CelestialWCS) class CelestialWCS(BaseWCS): """A CelestialWCS is a `BaseWCS` whose world coordinates are on the celestial sphere. We use the `CelestialCoord` class for the world coordinates. From 83469ba1c62eae0f6d1e17659b84f49884c714f6 Mon Sep 17 00:00:00 2001 From: beckermr Date: Thu, 12 Sep 2024 15:44:02 -0500 Subject: [PATCH 2/2] fix: this code was commented for a reason; clean out other instance --- jax_galsim/angle.py | 10 +--------- 1 file changed, 1 insertion(+), 9 deletions(-) diff --git a/jax_galsim/angle.py b/jax_galsim/angle.py index 3c573f1a..a7d1230a 100644 --- a/jax_galsim/angle.py +++ b/jax_galsim/angle.py @@ -178,10 +178,6 @@ def __sub__(self, other): return _Angle(self._rad - other._rad) def __mul__(self, other): - if other != float(other): - raise TypeError( - "Cannot multiply Angle by %s of type %s" % (other, type(other)) - ) return _Angle(self._rad * other) __rmul__ = __mul__ @@ -189,12 +185,8 @@ def __mul__(self, other): def __div__(self, other): if isinstance(other, AngleUnit): return self._rad / other.value - elif other == float(other): - return _Angle(self._rad / other) else: - raise TypeError( - "Cannot divide Angle by %s of type %s" % (other, type(other)) - ) + return _Angle(self._rad / other) __truediv__ = __div__