diff --git a/africanus/averaging/bda_mapping.py b/africanus/averaging/bda_mapping.py index e5504c4f..92db5f97 100644 --- a/africanus/averaging/bda_mapping.py +++ b/africanus/averaging/bda_mapping.py @@ -5,7 +5,7 @@ import numpy as np import numba from numba.experimental import jitclass -import numba.types +from numba import types from africanus.constants import c as lightspeed from africanus.util.numba import ( @@ -64,7 +64,7 @@ def max_chan_width(ref_freq, fractional_bandwidth): "nchan", "flag"]) -class Binner(object): +class Binner: def __init__(self, row_start, row_end, max_lm, decorrelation, time_bin_secs, max_chan_freq): @@ -305,7 +305,7 @@ def bda_mapper_impl(time, interval, ant1, ant2, uvw, min_nchan=1): return NotImplementedError -@overload(bda_mapper_impl, jit_options=JIT_OPTIONS) +@overload(bda_mapper_impl, jit_options={"nogil": True}) def nb_bda_mapper(time, interval, ant1, ant2, uvw, chan_width, chan_freq, max_uvw_dist, @@ -316,7 +316,7 @@ def nb_bda_mapper(time, interval, ant1, ant2, uvw, min_nchan=1): have_time_bin_secs = not is_numba_type_none(time_bin_secs) - Omitted = numba.types.misc.Omitted + Omitted = types.misc.Omitted decorr_type = (numba.typeof(decorrelation.value) if isinstance(decorrelation, Omitted) @@ -347,7 +347,7 @@ def nb_bda_mapper(time, interval, ant1, ant2, uvw, ('max_chan_freq', chan_freq.dtype), ('max_uvw_dist', max_uvw_dist)] - JitBinner = st(spec)(Binner) + JitBinner = jitclass(spec)(Binner) def impl(time, interval, ant1, ant2, uvw, chan_width, chan_freq, diff --git a/africanus/averaging/tests/test_bda_mapping.py b/africanus/averaging/tests/test_bda_mapping.py index b0ed2e65..58b1b8a5 100644 --- a/africanus/averaging/tests/test_bda_mapping.py +++ b/africanus/averaging/tests/test_bda_mapping.py @@ -5,7 +5,7 @@ import pytest from africanus.averaging.bda_mapping import bda_mapper, Binner - +from africanus.util.numba import njit @pytest.fixture(scope="session", params=[4096]) def nchan(request): @@ -172,43 +172,43 @@ def synthesized_uvw(ants, time, phase_dir, auto_correlations): return ant1, ant2, uvw -# @pytest.mark.parametrize("decorrelation", [0.95]) -# @pytest.mark.parametrize("min_nchan", [1]) -# def test_bda_mapper(time, synthesized_uvw, interval, -# chan_freq, chan_width, -# decorrelation, min_nchan): -# time = np.unique(time) -# ant1, ant2, uvw = synthesized_uvw - -# nbl = ant1.shape[0] -# ntime = time.shape[0] - -# time = np.repeat(time, nbl) -# interval = np.repeat(interval, nbl) -# ant1 = np.tile(ant1, ntime) -# ant2 = np.tile(ant2, ntime) -# flag_row = np.zeros(time.shape[0], dtype=np.int8) - -# max_uvw_dist = np.sqrt(np.sum(uvw**2, axis=1)).max() - -# row_meta = bda_mapper(time, interval, ant1, ant2, uvw, # noqa :F841 -# chan_width, chan_freq, -# max_uvw_dist, -# flag_row=flag_row, -# max_fov=3.0, -# decorrelation=decorrelation, -# min_nchan=min_nchan) - -# offsets = np.unique(row_meta.map[np.arange(time.shape[0]), 0]) -# assert_array_equal(offsets, row_meta.offsets[:-1]) -# assert row_meta.map.max() + 1 == row_meta.offsets[-1] - -# # NUM_CHAN divides number of channels exactly -# num_chan = np.diff(row_meta.offsets) -# _, remainder = np.divmod(chan_width.shape[0], num_chan) -# assert np.all(remainder == 0) -# decorr_cw = chan_width.sum() / num_chan -# assert_array_equal(decorr_cw, row_meta.decorr_chan_width) +@pytest.mark.parametrize("decorrelation", [0.95]) +@pytest.mark.parametrize("min_nchan", [1]) +def test_bda_mapper(time, synthesized_uvw, interval, + chan_freq, chan_width, + decorrelation, min_nchan): + time = np.unique(time) + ant1, ant2, uvw = synthesized_uvw + + nbl = ant1.shape[0] + ntime = time.shape[0] + + time = np.repeat(time, nbl) + interval = np.repeat(interval, nbl) + ant1 = np.tile(ant1, ntime) + ant2 = np.tile(ant2, ntime) + flag_row = np.zeros(time.shape[0], dtype=np.int8) + + max_uvw_dist = np.sqrt(np.sum(uvw**2, axis=1)).max() + + row_meta = bda_mapper(time, interval, ant1, ant2, uvw, # noqa :F841 + chan_width, chan_freq, + max_uvw_dist, + flag_row=flag_row, + max_fov=3.0, + decorrelation=decorrelation, + min_nchan=min_nchan) + + offsets = np.unique(row_meta.map[np.arange(time.shape[0]), 0]) + assert_array_equal(offsets, row_meta.offsets[:-1]) + assert row_meta.map.max() + 1 == row_meta.offsets[-1] + + # NUM_CHAN divides number of channels exactly + num_chan = np.diff(row_meta.offsets) + _, remainder = np.divmod(chan_width.shape[0], num_chan) + assert np.all(remainder == 0) + decorr_cw = chan_width.sum() / num_chan + assert_array_equal(decorr_cw, row_meta.decorr_chan_width) def test_bda_binner(time, interval, synthesized_uvw,