Skip to content

Commit

Permalink
Deprecate @generated_jit and remove upper bound on numba version (#289)
Browse files Browse the repository at this point in the history
  • Loading branch information
sjperkins authored Jan 30, 2024
1 parent 8620280 commit 88fa423
Show file tree
Hide file tree
Showing 29 changed files with 633 additions and 214 deletions.
1 change: 1 addition & 0 deletions HISTORY.rst
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ History

X.Y.Z (YYYY-MM-DD)
------------------
* Deprecate use of @generated_jit. Remove upper bound on numba. (:pr:`289`)
* Remove unnecessary new_axes in calibration utils after upstream fix in dask (:pr:`288`)
* Check that ncorr is never larger than 2 in calibration utils (:pr:`287`)
* Optionally check NRT allocations (:pr:`286`)
Expand Down
171 changes: 118 additions & 53 deletions africanus/averaging/bda_avg.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,9 @@
merge_flags,
vis_output_arrays)
from africanus.util.docs import DocstringTemplate
from africanus.util.numba import (generated_jit,
from africanus.util.numba import (njit,
overload,
JIT_OPTIONS,
intrinsic,
is_numba_type_none)

Expand All @@ -20,11 +22,25 @@
RowAverageOutput = namedtuple("RowAverageOutput", _row_output_fields)


@generated_jit(nopython=True, nogil=True, cache=True)
@njit(**JIT_OPTIONS)
def row_average(meta, ant1, ant2, flag_row=None,
time_centroid=None, exposure=None, uvw=None,
weight=None, sigma=None):
return row_average_impl(meta, ant1, ant2, flag_row=flag_row,
time_centroid=time_centroid, exposure=exposure,
uvw=uvw, weight=weight, sigma=sigma)


def row_average_impl(meta, ant1, ant2, flag_row=None,
time_centroid=None, exposure=None, uvw=None,
weight=None, sigma=None):
return NotImplementedError


@overload(row_average_impl, jit_options=JIT_OPTIONS)
def nb_row_average_impl(meta, ant1, ant2, flag_row=None,
time_centroid=None, exposure=None, uvw=None,
weight=None, sigma=None):
have_flag_row = not is_numba_type_none(flag_row)
have_time_centroid = not is_numba_type_none(time_centroid)
have_exposure = not is_numba_type_none(exposure)
Expand Down Expand Up @@ -310,13 +326,35 @@ def codegen(context, builder, signature, args):
return sig, codegen


@generated_jit(nopython=True, nogil=True, cache=True)
@njit(**JIT_OPTIONS)
def row_chan_average(meta, flag_row=None, weight=None,
visibilities=None,
flag=None,
weight_spectrum=None,
sigma_spectrum=None):

return row_chan_average_impl(meta, flag_row=flag_row, weight=weight,
visibilities=visibilities, flag=flag,
weight_spectrum=weight_spectrum,
sigma_spectrum=sigma_spectrum)


def row_chan_average_impl(meta, flag_row=None, weight=None,
visibilities=None,
flag=None,
weight_spectrum=None,
sigma_spectrum=None):

return NotImplementedError


@overload(row_chan_average_impl, jit_options=JIT_OPTIONS)
def nb_row_chan_average(meta, flag_row=None, weight=None,
visibilities=None,
flag=None,
weight_spectrum=None,
sigma_spectrum=None):

have_vis = not is_numba_type_none(visibilities)
have_flag = not is_numba_type_none(flag)
have_flag_row = not is_numba_type_none(flag_row)
Expand Down Expand Up @@ -523,7 +561,7 @@ def impl(meta, flag_row=None, weight=None,
_rowchan_output_fields)


@generated_jit(nopython=True, nogil=True, cache=True)
@njit(**JIT_OPTIONS)
def bda(time, interval, antenna1, antenna2,
time_centroid=None, exposure=None, flag_row=None,
uvw=None, weight=None, sigma=None,
Expand All @@ -535,7 +573,21 @@ def bda(time, interval, antenna1, antenna2,
decorrelation=0.98,
time_bin_secs=None,
min_nchan=1):
def impl(time, interval, antenna1, antenna2,

return bda_impl(time, interval, antenna1, antenna2,
time_centroid=time_centroid, exposure=exposure,
flag_row=flag_row, uvw=uvw, weight=weight, sigma=sigma,
chan_freq=chan_freq, chan_width=chan_width,
effective_bw=effective_bw, resolution=resolution,
visibilities=visibilities, flag=flag,
weight_spectrum=weight_spectrum,
sigma_spectrum=sigma_spectrum,
max_uvw_dist=max_uvw_dist, max_fov=max_fov,
decorrelation=decorrelation,
time_bin_secs=time_bin_secs, min_nchan=min_nchan)


def bda_impl(time, interval, antenna1, antenna2,
time_centroid=None, exposure=None, flag_row=None,
uvw=None, weight=None, sigma=None,
chan_freq=None, chan_width=None,
Expand All @@ -546,54 +598,67 @@ def impl(time, interval, antenna1, antenna2,
decorrelation=0.98,
time_bin_secs=None,
min_nchan=1):
# Merge flag_row and flag arrays
flag_row = merge_flags(flag_row, flag)

meta = bda_mapper(time, interval, antenna1, antenna2, uvw,
chan_width, chan_freq,
max_uvw_dist,
flag_row=flag_row,
max_fov=max_fov,
decorrelation=decorrelation,
time_bin_secs=time_bin_secs,
min_nchan=min_nchan)

row_avg = row_average(meta, antenna1, antenna2, flag_row, # noqa: F841
time_centroid, exposure, uvw,
weight=weight, sigma=sigma)

row_chan_avg = row_chan_average(meta, # noqa: F841
flag_row=flag_row,
visibilities=visibilities, flag=flag,
weight_spectrum=weight_spectrum,
sigma_spectrum=sigma_spectrum)

# Have to explicitly write it out because numba tuples
# are highly constrained types
return AverageOutput(meta.map,
meta.offsets,
meta.decorr_chan_width,
meta.time,
meta.interval,
meta.chan_width,
meta.flag_row,
row_avg.antenna1,
row_avg.antenna2,
row_avg.time_centroid,
row_avg.exposure,
row_avg.uvw,
row_avg.weight,
row_avg.sigma,
# None, # chan_data.chan_freq,
# None, # chan_data.chan_width,
# None, # chan_data.effective_bw,
# None, # chan_data.resolution,
row_chan_avg.visibilities,
row_chan_avg.flag,
row_chan_avg.weight_spectrum,
row_chan_avg.sigma_spectrum)

return impl
return NotImplementedError


@overload(bda_impl, jit_options=JIT_OPTIONS)
def nb_bda_impl(time, interval, antenna1, antenna2,
time_centroid=None, exposure=None, flag_row=None,
uvw=None, weight=None, sigma=None,
chan_freq=None, chan_width=None,
effective_bw=None, resolution=None,
visibilities=None, flag=None,
weight_spectrum=None, sigma_spectrum=None,
max_uvw_dist=None, max_fov=3.0,
decorrelation=0.98,
time_bin_secs=None,
min_nchan=1):
# Merge flag_row and flag arrays
flag_row = merge_flags(flag_row, flag)

meta = bda_mapper(time, interval, antenna1, antenna2, uvw,
chan_width, chan_freq,
max_uvw_dist,
flag_row=flag_row,
max_fov=max_fov,
decorrelation=decorrelation,
time_bin_secs=time_bin_secs,
min_nchan=min_nchan)

row_avg = row_average(meta, antenna1, antenna2, flag_row, # noqa: F841
time_centroid, exposure, uvw,
weight=weight, sigma=sigma)

row_chan_avg = row_chan_average(meta, # noqa: F841
flag_row=flag_row,
visibilities=visibilities, flag=flag,
weight_spectrum=weight_spectrum,
sigma_spectrum=sigma_spectrum)

# Have to explicitly write it out because numba tuples
# are highly constrained types
return AverageOutput(meta.map,
meta.offsets,
meta.decorr_chan_width,
meta.time,
meta.interval,
meta.chan_width,
meta.flag_row,
row_avg.antenna1,
row_avg.antenna2,
row_avg.time_centroid,
row_avg.exposure,
row_avg.uvw,
row_avg.weight,
row_avg.sigma,
# None, # chan_data.chan_freq,
# None, # chan_data.chan_width,
# None, # chan_data.effective_bw,
# None, # chan_data.resolution,
row_chan_avg.visibilities,
row_chan_avg.flag,
row_chan_avg.weight_spectrum,
row_chan_avg.sigma_spectrum)


BDA_DOCS = DocstringTemplate("""
Expand Down
112 changes: 39 additions & 73 deletions africanus/averaging/bda_mapping.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,84 +5,21 @@
import numpy as np
import numba
from numba.experimental import jitclass
import numba.types
from numba import types

from africanus.constants import c as lightspeed
from africanus.util.numba import generated_jit, njit, is_numba_type_none
from africanus.util.numba import (
JIT_OPTIONS,
overload,
njit,
is_numba_type_none)
from africanus.averaging.support import unique_time, unique_baselines


class RowMapperError(Exception):
pass


@njit(nogil=True, cache=True)
def erf26(x):
"""Implements 7.1.26 erf approximation from Abramowitz and
Stegun (1972), pg. 299. Accurate for abs(eps(x)) <= 1.5e-7."""

# Constants
p = 0.3275911
a1 = 0.254829592
a2 = -0.284496736
a3 = 1.421413741
a4 = -1.453152027
a5 = 1.061405429
e = 2.718281828

# t
t = 1.0/(1.0 + (p * x))

# Erf calculation
erf = 1.0 - (((((a5 * t + a4) * t + a3) * t + a2) * t + a1) * t)
erf *= e ** -(x ** 2)

return -round(erf, 9) if x < 0 else round(erf, 0)


@njit(nogil=True, cache=True)
def time_decorrelation(u, v, w, max_lm, time_bin_secs, min_wavelength):
sidereal_rotation_rate = 7.292118516e-5
diffraction_limit = min_wavelength / np.sqrt(u**2 + v**2 + w**2)
term = max_lm * time_bin_secs * sidereal_rotation_rate / diffraction_limit
return 1.0 - 1.0645 * erf26(0.8326*term) / term


_SERIES_COEFFS = (1./40, 107./67200, 3197./24192000, 49513./3973939200)


@njit(nogil=True, cache=True, inline='always')
def inv_sinc(sinc_x, tol=1e-12):
# Invalid input
if sinc_x > 1.0:
raise ValueError("sinc_x > 1.0")

# Initial guess from reversion of Taylor series
# https://math.stackexchange.com/questions/3189307/inverse-of-frac-sinxx
x = t_pow = np.sqrt(6*np.abs((1 - sinc_x)))
t_squared = t_pow*t_pow

for coeff in numba.literal_unroll(_SERIES_COEFFS):
t_pow *= t_squared
x += coeff * t_pow

# Use Newton Raphson to go the rest of the way
# https://www.wolframalpha.com/input/?i=simplify+%28sinc%5Bx%5D+-+c%29+%2F+D%5Bsinc%5Bx%5D%2Cx%5D
while True:
# evaluate delta between this iteration sinc(x) and original
sinx = np.sin(x)
𝞓sinc_x = (1.0 if x == 0.0 else sinx/x) - sinc_x

# Stop if converged
if np.abs(𝞓sinc_x) < tol:
break

# Next iteration
x -= (x*x * 𝞓sinc_x) / (x*np.cos(x) - sinx)

return x


@njit(nogil=True, cache=True, inline='always')
def factors(n):
assert n >= 1
Expand Down Expand Up @@ -126,7 +63,7 @@ def max_chan_width(ref_freq, fractional_bandwidth):
"nchan", "flag"])


class Binner(object):
class Binner:
def __init__(self, row_start, row_end,
max_lm, decorrelation, time_bin_secs,
max_chan_freq):
Expand Down Expand Up @@ -338,7 +275,7 @@ def finalise_bin(self, auto_corr, uvw, time, interval,
"time", "interval", "chan_width", "flag_row"])


@generated_jit(nopython=True, nogil=True, cache=True)
@njit(**JIT_OPTIONS)
def bda_mapper(time, interval, ant1, ant2, uvw,
chan_width, chan_freq,
max_uvw_dist,
Expand All @@ -347,10 +284,39 @@ def bda_mapper(time, interval, ant1, ant2, uvw,
decorrelation=0.98,
time_bin_secs=None,
min_nchan=1):

return bda_mapper_impl(time, interval, ant1, ant2, uvw,
chan_width, chan_freq,
max_uvw_dist,
flag_row=flag_row,
max_fov=max_fov,
decorrelation=decorrelation,
time_bin_secs=time_bin_secs,
min_nchan=min_nchan)


def bda_mapper_impl(time, interval, ant1, ant2, uvw,
chan_width, chan_freq,
max_uvw_dist,
flag_row=None,
max_fov=3.0,
decorrelation=0.98,
time_bin_secs=None,
min_nchan=1):
return NotImplementedError


@overload(bda_mapper_impl, jit_options={"nogil": True})
def nb_bda_mapper(time, interval, ant1, ant2, uvw,
chan_width, chan_freq,
max_uvw_dist,
flag_row=None,
max_fov=3.0,
decorrelation=0.98,
time_bin_secs=None,
min_nchan=1):
have_time_bin_secs = not is_numba_type_none(time_bin_secs)

Omitted = numba.types.misc.Omitted
Omitted = types.misc.Omitted

decorr_type = (numba.typeof(decorrelation.value)
if isinstance(decorrelation, Omitted)
Expand Down
Loading

0 comments on commit 88fa423

Please sign in to comment.