diff --git a/HISTORY.rst b/HISTORY.rst index f573727d7..e20954b4e 100644 --- a/HISTORY.rst +++ b/HISTORY.rst @@ -4,6 +4,7 @@ History X.Y.Z (YYYY-MM-DD) ------------------ +* Parallel Numba RIME Implementation (:pr:`186`) * Update classifiers and correct license in setup.py to BSD3 0.2.4 (2020-05-29) @@ -18,6 +19,7 @@ X.Y.Z (YYYY-MM-DD) 0.2.3 (2020-05-14) ------------------ + * Fix incorrect SPI calculation and make predict defaults MeqTree equivalent (:pr:`189`) * Depend on pytest-flake8 >= 1.0.6 (:pr:`187`, :pr:`188`) * MeqTrees Comparison Script Updates (:pr:`160`) diff --git a/africanus/config.py b/africanus/config.py new file mode 100644 index 000000000..729b20275 --- /dev/null +++ b/africanus/config.py @@ -0,0 +1,20 @@ +from donfig import Config + + +class AfricanusConfig(Config): + def numba_parallel(self, key): + value = self.get(key, False) + + if value is False: + return {'parallel': False} + elif value is True: + return {'parallel': True} + elif isinstance(value, dict): + value['parallel'] = True + return value + else: + raise TypeError("key %s (%s) is not a bool or a dict", + key, value) + + +config = AfricanusConfig("africanus") diff --git a/africanus/conftest.py b/africanus/conftest.py index f04d969f5..a1af6abc5 100644 --- a/africanus/conftest.py +++ b/africanus/conftest.py @@ -1,9 +1,38 @@ # -*- coding: utf-8 -*- +import importlib +import pytest from africanus.util.testing import mark_in_pytest +@pytest.fixture +def cfg_parallel(request): + """ Performs parallel configuration setting and module reloading """ + from africanus.config import config + + module, cfg = request.param + + assert isinstance(cfg, dict) and len(cfg) == 1 + + # Get module object, because importlib.reload doesn't take strings + mod = importlib.import_module(module) + + with config.set(cfg): + importlib.reload(mod) + + cfg = cfg.copy().popitem()[1] + + if isinstance(cfg, dict): + yield cfg['parallel'] + elif isinstance(cfg, bool): + yield cfg + else: + raise TypeError("Unhandled cfg type %s" % type(cfg)) + + importlib.reload(mod) + + # content of conftest.py def pytest_configure(config): mark_in_pytest(True) diff --git a/africanus/install/requirements.py b/africanus/install/requirements.py new file mode 100644 index 000000000..fedbe1ab5 --- /dev/null +++ b/africanus/install/requirements.py @@ -0,0 +1,47 @@ +# -*- coding: utf-8 -*- + + +# NOTE(sjperkins) +# Non standard library imports should be avoided, +# or should fail gracefully as functionality +# in these modules is called by setup.py +import os + +# requirements +on_rtd = os.environ.get('READTHEDOCS') == 'True' + +# Basic requirements containing no C extensions. +# This is necessary for building on RTD +requirements = ['appdirs >= 1.4.3', + 'decorator', + 'donfig'] + +if not on_rtd: + requirements += [ + # astropy breaks with numpy 1.15.3 + # https://github.com/astropy/astropy/issues/7943 + 'numpy >= 1.14.0, != 1.15.3', + 'numba >= 0.46.0'] + +extras_require = { + 'cuda': ['cupy >= 5.0.0', 'jinja2 >= 2.10'], + 'dask': ['dask[array] >= 1.1.0'], + 'jax': ['jax == 0.1.27', 'jaxlib == 0.1.14'], + 'scipy': ['scipy >= 1.0.0'], + 'astropy': ['astropy >= 3.0'], + 'python-casacore': ['python-casacore >= 3.2.0'], + 'testing': ['pytest', 'flaky', 'pytest-flake8'] +} + +_non_cuda_extras = [er for n, er in extras_require.items() if n != "cuda"] +_all_extras = extras_require.values() + +extras_require['complete'] = sorted(set(sum(_non_cuda_extras, []))) +extras_require['complete-cuda'] = sorted(set(sum(_all_extras, []))) + +setup_requirements = [] +test_requirements = (extras_require['testing'] + + extras_require['astropy'] + + extras_require['python-casacore'] + + extras_require['dask'] + + extras_require['scipy']) diff --git a/africanus/model/shape/gaussian_shape.py b/africanus/model/shape/gaussian_shape.py index 8a5b09050..af8bb254b 100644 --- a/africanus/model/shape/gaussian_shape.py +++ b/africanus/model/shape/gaussian_shape.py @@ -3,12 +3,18 @@ import numpy as np +from africanus.config import config from africanus.util.docs import DocstringTemplate from africanus.util.numba import generated_jit from africanus.constants import c as lightspeed +cfg = config.numba_parallel('model.shape.gaussian.parallel') +parallel = cfg.get('parallel', False) +axes = cfg.get('axes', set(('source', 'row')) if parallel else ()) -@generated_jit(nopython=True, nogil=True, cache=True) + +@generated_jit(nopython=True, nogil=True, + cache=not parallel, parallel=parallel) def gaussian(uvw, frequency, shape_params): # https://en.wikipedia.org/wiki/Full_width_at_half_maximum fwhm = 2.0 * np.sqrt(2.0 * np.log(2.0)) @@ -18,7 +24,16 @@ def gaussian(uvw, frequency, shape_params): dtype = np.result_type(*(np.dtype(a.dtype.name) for a in (uvw, frequency, shape_params))) + from numba import prange, set_num_threads, get_num_threads + srange = prange if parallel and 'source' in axes else range + rrange = prange if parallel and 'row' in axes else range + threads = cfg.get("threads", None) if parallel else None + def impl(uvw, frequency, shape_params): + if parallel and threads is not None: + prev_threads = get_num_threads() + set_num_threads(threads) + nsrc = shape_params.shape[0] nrow = uvw.shape[0] nchan = frequency.shape[0] @@ -30,7 +45,7 @@ def impl(uvw, frequency, shape_params): for f in range(frequency.shape[0]): scaled_freq[f] = frequency[f] * gauss_scale - for s in range(shape_params.shape[0]): + for s in srange(shape_params.shape[0]): emaj, emin, angle = shape_params[s] # Convert to l-projection, m-projection, ratio @@ -38,7 +53,7 @@ def impl(uvw, frequency, shape_params): em = emaj * np.cos(angle) er = emin / (1.0 if emaj == 0.0 else emaj) - for r in range(uvw.shape[0]): + for r in rrange(uvw.shape[0]): u, v, w = uvw[r] u1 = (u*em - v*el)*er @@ -50,6 +65,9 @@ def impl(uvw, frequency, shape_params): shape[s, r, f] = np.exp(-(fu1*fu1 + fv1*fv1)) + if parallel and threads is not None: + set_num_threads(prev_threads) + return shape return impl diff --git a/africanus/model/shape/tests/test_gaussian_shape.py b/africanus/model/shape/tests/test_gaussian_shape.py index a3e7d0271..a1a40e2c5 100644 --- a/africanus/model/shape/tests/test_gaussian_shape.py +++ b/africanus/model/shape/tests/test_gaussian_shape.py @@ -5,13 +5,22 @@ from numpy.testing import assert_array_almost_equal import pytest -from africanus.model.shape import gaussian as np_gaussian - -def test_gauss_shape(): +@pytest.mark.parametrize("cfg_parallel", [ + ("africanus.model.shape.gaussian_shape", + {"model.shape.gaussian.parallel": True}), + ("africanus.model.shape.gaussian_shape", { + "model.shape.gaussian.parallel": {'threads': 2}}), + ("africanus.model.shape.gaussian_shape", + {"model.shape.gaussian.parallel": False}), + ], ids=["parallel", "parallel-2", "serial"], indirect=True) +def test_gauss_shape(cfg_parallel): + from africanus.model.shape.gaussian_shape import gaussian as np_gaussian row = 10 chan = 16 + assert np_gaussian.targetoptions['parallel'] == cfg_parallel + shape_params = np.array([[.4, .3, .2], [.4, .3, .2]]) uvw = np.random.random((row, 3)) @@ -24,6 +33,7 @@ def test_gauss_shape(): def test_dask_gauss_shape(): da = pytest.importorskip('dask.array') + from africanus.model.shape import gaussian as np_gaussian from africanus.model.shape.dask import gaussian as da_gaussian row_chunks = (5, 5) diff --git a/africanus/model/spectral/spec_model.py b/africanus/model/spectral/spec_model.py index 07e74546f..5770b0655 100644 --- a/africanus/model/spectral/spec_model.py +++ b/africanus/model/spectral/spec_model.py @@ -4,10 +4,16 @@ from numba import types import numpy as np +from africanus.config import config from africanus.util.numba import generated_jit, njit from africanus.util.docs import DocstringTemplate +cfg = config.numba_parallel('model.spectral_model.parallel') +parallel = cfg.get('parallel', False) +axes = cfg.get('axes', set(('source', )) if parallel else ()) + + def numpy_spectral_model(stokes, spi, ref_freq, frequency, base): out_shape = (stokes.shape[0], frequency.shape[0]) + stokes.shape[1:] @@ -57,9 +63,11 @@ def numpy_spectral_model(stokes, spi, ref_freq, frequency, base): def pol_getter_factory(npoldims): if npoldims == 0: + @njit def impl(pol_shape): return 1 else: + @njit def impl(pol_shape): npols = 1 @@ -68,32 +76,37 @@ def impl(pol_shape): return npols - return njit(nogil=True, cache=True)(impl) + return impl def promote_base_factory(is_base_list): if is_base_list: + @njit def impl(base, npol): return base + [base[-1]] * (npol - len(base)) else: + @njit def impl(base, npol): return [base] * npol - return njit(nogil=True, cache=True)(impl) + return impl def add_pol_dim_factory(have_pol_dim): if have_pol_dim: + @njit def impl(array): return array else: + @njit def impl(array): return array.reshape(array.shape + (1,)) - return njit(nogil=True, cache=True)(impl) + return impl -@generated_jit(nopython=True, nogil=True, cache=True) +@generated_jit(nopython=True, nogil=True, + cache=not parallel, parallel=parallel) def spectral_model(stokes, spi, ref_freq, frequency, base=0): arg_dtypes = tuple(np.dtype(a.dtype.name) for a in (stokes, spi, ref_freq, frequency)) @@ -108,31 +121,33 @@ def spectral_model(stokes, spi, ref_freq, frequency, base=0): promote_base = promote_base_factory(is_base_list) if isinstance(base, types.scalars.Integer): + @njit def is_std(base): return base == 0 + @njit def is_log(base): return base == 1 + @njit def is_log10(base): return base == 2 elif isinstance(base, types.misc.UnicodeType): + @njit def is_std(base): return base == "std" + @njit def is_log(base): return base == "log" + @njit def is_log10(base): return base == "log10" else: raise TypeError("base '%s' should be a string or integer" % base) - is_std = njit(nogil=True, cache=True)(is_std) - is_log = njit(nogil=True, cache=True)(is_log) - is_log10 = njit(nogil=True, cache=True)(is_log10) - npoldims = stokes.ndim - 1 pol_get_fn = pol_getter_factory(npoldims) add_pol_dim = add_pol_dim_factory(npoldims > 0) @@ -140,7 +155,16 @@ def is_log10(base): if spi.ndim - 2 != npoldims: raise ValueError("Dimensions on stokes and spi don't agree") + from numba import prange, set_num_threads, get_num_threads + threads = cfg.get('threads', None) if parallel else None + srange = prange if parallel and 'source' in axes else range + crange = prange if parallel and 'chan' in axes else range + def impl(stokes, spi, ref_freq, frequency, base=0): + if parallel and threads is not None: + prev_threads = get_num_threads() + set_num_threads(threads) + nsrc = stokes.shape[0] nchan = frequency.shape[0] nspi = spi.shape[1] @@ -161,12 +185,13 @@ def impl(stokes, spi, ref_freq, frequency, base=0): # TODO(sjperkins) # Polarisation + associated base on the outer loop # The output cache patterns could be improved. - for p, b in enumerate(list_base[:npol]): + for p in range(npol): + b = list_base[p] if is_std(b): - for s in range(nsrc): + for s in srange(nsrc): rf = ref_freq[s] - for f in range(nchan): + for f in crange(nchan): freq_ratio = frequency[f] / rf spec_model = estokes[s, p] @@ -177,10 +202,10 @@ def impl(stokes, spi, ref_freq, frequency, base=0): spectral_model[s, f, p] = spec_model elif is_log(b): - for s in range(nsrc): + for s in srange(nsrc): rf = ref_freq[s] - for f in range(nchan): + for f in crange(nchan): freq_ratio = np.log(frequency[f] / rf) spec_model = np.log(estokes[s, p]) @@ -191,10 +216,10 @@ def impl(stokes, spi, ref_freq, frequency, base=0): spectral_model[s, f, p] = np.exp(spec_model) elif is_log10(b): - for s in range(nsrc): + for s in srange(nsrc): rf = ref_freq[s] - for f in range(nchan): + for f in crange(nchan): freq_ratio = np.log10(frequency[f] / rf) spec_model = np.log10(estokes[s, p]) @@ -207,7 +232,11 @@ def impl(stokes, spi, ref_freq, frequency, base=0): else: raise ValueError("Invalid base") + if parallel and threads is not None: + set_num_threads(prev_threads) + out_shape = (stokes.shape[0], frequency.shape[0]) + stokes.shape[1:] + return spectral_model.reshape(out_shape) return impl diff --git a/africanus/model/spectral/tests/test_spectral_model.py b/africanus/model/spectral/tests/test_spectral_model.py index d46ee3430..8acb10ff9 100644 --- a/africanus/model/spectral/tests/test_spectral_model.py +++ b/africanus/model/spectral/tests/test_spectral_model.py @@ -5,9 +5,6 @@ from numpy.testing import assert_array_almost_equal import pytest -from africanus.model.spectral.spec_model import (spectral_model, - numpy_spectral_model) - @pytest.fixture def flux(): @@ -44,7 +41,22 @@ def impl(shape): @pytest.mark.parametrize("base", [0, 1, 2, "std", "log", "log10", ["log", "std", "std", "std"]]) @pytest.mark.parametrize("npol", [0, 1, 2, 4]) -def test_spectral_model_multiple_spi(flux, ref_freq, frequency, base, npol): +@pytest.mark.parametrize("cfg_parallel", [ + ("africanus.model.spectral.spec_model", + {"model.spectral_model.parallel": True}), + ("africanus.model.spectral.spec_model", { + "model.spectral_model.parallel": {'threads': 2}}), + ("africanus.model.spectral.spec_model", + {"model.spectral_model.parallel": False}), + ], ids=["parallel", "parallel-2", "serial"], indirect=True) +def test_spectral_model_multiple_spi(flux, ref_freq, frequency, + base, npol, cfg_parallel): + + from africanus.model.spectral.spec_model import (spectral_model, + numpy_spectral_model) + + assert spectral_model.targetoptions['parallel'] == cfg_parallel + nsrc = 10 nchan = 16 nspi = 6 diff --git a/africanus/rime/fast_beam_cubes.py b/africanus/rime/fast_beam_cubes.py index e3ceb6b3b..9d887bf1c 100644 --- a/africanus/rime/fast_beam_cubes.py +++ b/africanus/rime/fast_beam_cubes.py @@ -1,10 +1,14 @@ # -*- coding: utf-8 -*- -from functools import reduce - import numpy as np + +from africanus.config import config from africanus.util.docs import DocstringTemplate -from africanus.util.numba import njit +from africanus.util.numba import njit, generated_jit + +cfg = config.numba_parallel("rime.beam_cube_dde.parallel") +parallel = cfg.get('parallel', False) +axes = cfg.get("axes", set(('source', 'row')) if parallel else set()) @njit(nogil=True, cache=True) @@ -55,184 +59,208 @@ def freq_grid_interp(frequency, beam_freq_map): return freq_data -@njit(nogil=True, cache=True) +@generated_jit(nopython=not parallel, nogil=True, + cache=not parallel, parallel=parallel) def beam_cube_dde(beam, beam_lm_extents, beam_freq_map, lm, parallactic_angles, point_errors, antenna_scaling, frequency): - nsrc = lm.shape[0] - ntime, nants = parallactic_angles.shape - nchan = frequency.shape[0] - beam_lw, beam_mh, beam_nud = beam.shape[:3] - corrs = beam.shape[3:] + from numba import prange, set_num_threads, get_num_threads, literal_unroll + + rrange = prange if parallel and 'row' in axes else range + threads = cfg.get('threads', None) if parallel else None + + def impl(beam, beam_lm_extents, beam_freq_map, + lm, parallactic_angles, point_errors, antenna_scaling, + frequency): + + if parallel and threads is not None: + prev_threads = get_num_threads() + set_num_threads(threads) + + nsrc = lm.shape[0] + ntime, nants = parallactic_angles.shape + nchan = frequency.shape[0] + beam_lw, beam_mh, beam_nud = beam.shape[:3] + corrs = beam.shape[3:] + + if beam_lw < 2 or beam_mh < 2 or beam_nud < 2: + raise ValueError("beam_lw, beam_mh and beam_nud " + "must be >= 2") + + # Flatten correlations + ncorrs = 1 + + for c in literal_unroll(corrs): + ncorrs *= c - if beam_lw < 2 or beam_mh < 2 or beam_nud < 2: - raise ValueError("beam_lw, beam_mh and beam_nud must be >= 2") + lower_l, upper_l = beam_lm_extents[0] + lower_m, upper_m = beam_lm_extents[1] - # Flatten correlations - ncorrs = reduce(lambda x, y: x*y, corrs, 1) + ex_dtype = beam_lm_extents.dtype - lower_l, upper_l = beam_lm_extents[0] - lower_m, upper_m = beam_lm_extents[1] + # Maximum l and m indices in float and int + lmaxf = ex_dtype.type(beam_lw - 1) + mmaxf = ex_dtype.type(beam_mh - 1) + lmaxi = beam_lw - 1 + mmaxi = beam_mh - 1 - ex_dtype = beam_lm_extents.dtype + lscale = lmaxf / (upper_l - lower_l) + mscale = mmaxf / (upper_m - lower_m) - # Maximum l and m indices in float and int - lmaxf = ex_dtype.type(beam_lw - 1) - mmaxf = ex_dtype.type(beam_mh - 1) - lmaxi = beam_lw - 1 - mmaxi = beam_mh - 1 + one = ex_dtype.type(1) + zero = ex_dtype.type(0) - lscale = lmaxf / (upper_l - lower_l) - mscale = mmaxf / (upper_m - lower_m) + # Flatten the beam on correlation + fbeam = beam.reshape((beam_lw, beam_mh, beam_nud, ncorrs)) - one = ex_dtype.type(1) - zero = ex_dtype.type(0) + # Allocate output array with correlations flattened + fjones = np.empty( + (nsrc, ntime, nants, nchan, ncorrs), dtype=beam.dtype) - # Flatten the beam on correlation - fbeam = beam.reshape((beam_lw, beam_mh, beam_nud, ncorrs)) + # Compute frequency interpolation stuff + freq_data = freq_grid_interp(frequency, beam_freq_map) - # Allocate output array with correlations flattened - fjones = np.empty((nsrc, ntime, nants, nchan, ncorrs), dtype=beam.dtype) + corr_sum = np.zeros((ncorrs,), dtype=beam.dtype) + absc_sum = np.zeros((ncorrs,), dtype=beam.real.dtype) + beam_scratch = np.zeros((ncorrs,), dtype=beam.dtype) - # Compute frequency interpolation stuff - freq_data = freq_grid_interp(frequency, beam_freq_map) + for t in rrange(ntime): + for a in rrange(nants): + sin_pa = np.sin(parallactic_angles[t, a]) + cos_pa = np.cos(parallactic_angles[t, a]) - corr_sum = np.zeros((ncorrs,), dtype=beam.dtype) - absc_sum = np.zeros((ncorrs,), dtype=beam.real.dtype) - beam_scratch = np.zeros((ncorrs,), dtype=beam.dtype) + for s in range(nsrc): + # Extract lm coordinates + l, m = lm[s] - for t in range(ntime): - for a in range(nants): - sin_pa = np.sin(parallactic_angles[t, a]) - cos_pa = np.cos(parallactic_angles[t, a]) + for f in range(nchan): + # Unpack frequency data + freq_scale = freq_data[f, 0] + # lower and upper frequency weights + nud = freq_data[f, 1] + inv_nud = 1.0 - nud + # lower and upper frequency grid position + gc0 = np.int32(freq_data[f, 2]) + gc1 = gc0 + 1 - for s in range(nsrc): - # Extract lm coordinates - l, m = lm[s] + # Apply any frequency scaling + sl = l * freq_scale + sm = m * freq_scale - for f in range(nchan): - # Unpack frequency data - freq_scale = freq_data[f, 0] - # lower and upper frequency weights - nud = freq_data[f, 1] - inv_nud = 1.0 - nud - # lower and upper frequency grid position - gc0 = np.int32(freq_data[f, 2]) - gc1 = gc0 + 1 + # Add pointing errors + tl = sl + point_errors[t, a, f, 0] + tm = sm + point_errors[t, a, f, 1] - # Apply any frequency scaling - sl = l * freq_scale - sm = m * freq_scale + # Rotate lm coordinate angle + vl = tl*cos_pa - tm*sin_pa + vm = tl*sin_pa + tm*cos_pa - # Add pointing errors - tl = sl + point_errors[t, a, f, 0] - tm = sm + point_errors[t, a, f, 1] + # Scale by antenna scaling + vl *= antenna_scaling[a, f, 0] + vm *= antenna_scaling[a, f, 1] - # Rotate lm coordinate angle - vl = tl*cos_pa - tm*sin_pa - vm = tl*sin_pa + tm*cos_pa + # Shift into the cube coordinate system + vl = lscale*(vl - lower_l) + vm = mscale*(vm - lower_m) - # Scale by antenna scaling - vl *= antenna_scaling[a, f, 0] - vm *= antenna_scaling[a, f, 1] + # Clamp the coordinates to the edges of the cube + vl = max(zero, min(vl, lmaxf)) + vm = max(zero, min(vm, mmaxf)) - # Shift into the cube coordinate system - vl = lscale*(vl - lower_l) - vm = mscale*(vm - lower_m) + # Snap to the lower grid coordinates + gl0 = np.int32(np.floor(vl)) + gm0 = np.int32(np.floor(vm)) - # Clamp the coordinates to the edges of the cube - vl = max(zero, min(vl, lmaxf)) - vm = max(zero, min(vm, mmaxf)) + # Snap to the upper grid coordinates + gl1 = min(gl0 + 1, lmaxi) + gm1 = min(gm0 + 1, mmaxi) - # Snap to the lower grid coordinates - gl0 = np.int32(np.floor(vl)) - gm0 = np.int32(np.floor(vm)) + # Difference between grid and offset coordinates + ld = vl - gl0 + md = vm - gm0 - # Snap to the upper grid coordinates - gl1 = min(gl0 + 1, lmaxi) - gm1 = min(gm0 + 1, mmaxi) + # Zero accumulation arrays + corr_sum[:] = 0 + absc_sum[:] = 0 - # Difference between grid and offset coordinates - ld = vl - gl0 - md = vm - gm0 + # Accumulate lower cube correlations + beam_scratch[:] = fbeam[gl0, gm0, gc0, :] + weight = (one - ld)*(one - md)*nud - # Zero accumulation arrays - corr_sum[:] = 0 - absc_sum[:] = 0 + for c in range(ncorrs): + absc_sum[c] += weight * np.abs(beam_scratch[c]) + corr_sum[c] += weight * beam_scratch[c] - # Accumulate lower cube correlations - beam_scratch[:] = fbeam[gl0, gm0, gc0, :] - weight = (one - ld)*(one - md)*nud + beam_scratch[:] = fbeam[gl1, gm0, gc0, :] + weight = ld*(one - md)*nud - for c in range(ncorrs): - absc_sum[c] += weight * np.abs(beam_scratch[c]) - corr_sum[c] += weight * beam_scratch[c] + for c in range(ncorrs): + absc_sum[c] += weight * np.abs(beam_scratch[c]) + corr_sum[c] += weight * beam_scratch[c] - beam_scratch[:] = fbeam[gl1, gm0, gc0, :] - weight = ld*(one - md)*nud + beam_scratch[:] = fbeam[gl0, gm1, gc0, :] + weight = (one - ld)*md*nud - for c in range(ncorrs): - absc_sum[c] += weight * np.abs(beam_scratch[c]) - corr_sum[c] += weight * beam_scratch[c] + for c in range(ncorrs): + absc_sum[c] += weight * np.abs(beam_scratch[c]) + corr_sum[c] += weight * beam_scratch[c] - beam_scratch[:] = fbeam[gl0, gm1, gc0, :] - weight = (one - ld)*md*nud + beam_scratch[:] = fbeam[gl1, gm1, gc0, :] + weight = ld*md*nud - for c in range(ncorrs): - absc_sum[c] += weight * np.abs(beam_scratch[c]) - corr_sum[c] += weight * beam_scratch[c] + for c in range(ncorrs): + absc_sum[c] += weight * np.abs(beam_scratch[c]) + corr_sum[c] += weight * beam_scratch[c] - beam_scratch[:] = fbeam[gl1, gm1, gc0, :] - weight = ld*md*nud + # Accumulate upper cube correlations + beam_scratch[:] = fbeam[gl0, gm0, gc1, :] + weight = (one - ld)*(one - md)*inv_nud - for c in range(ncorrs): - absc_sum[c] += weight * np.abs(beam_scratch[c]) - corr_sum[c] += weight * beam_scratch[c] + for c in range(ncorrs): + absc_sum[c] += weight * np.abs(beam_scratch[c]) + corr_sum[c] += weight * beam_scratch[c] - # Accumulate upper cube correlations - beam_scratch[:] = fbeam[gl0, gm0, gc1, :] - weight = (one - ld)*(one - md)*inv_nud + beam_scratch[:] = fbeam[gl1, gm0, gc1, :] + weight = ld*(one - md)*inv_nud - for c in range(ncorrs): - absc_sum[c] += weight * np.abs(beam_scratch[c]) - corr_sum[c] += weight * beam_scratch[c] + for c in range(ncorrs): + absc_sum[c] += weight * np.abs(beam_scratch[c]) + corr_sum[c] += weight * beam_scratch[c] - beam_scratch[:] = fbeam[gl1, gm0, gc1, :] - weight = ld*(one - md)*inv_nud + beam_scratch[:] = fbeam[gl0, gm1, gc1, :] + weight = (one - ld)*md*inv_nud - for c in range(ncorrs): - absc_sum[c] += weight * np.abs(beam_scratch[c]) - corr_sum[c] += weight * beam_scratch[c] + for c in range(ncorrs): + absc_sum[c] += weight * np.abs(beam_scratch[c]) + corr_sum[c] += weight * beam_scratch[c] - beam_scratch[:] = fbeam[gl0, gm1, gc1, :] - weight = (one - ld)*md*inv_nud + beam_scratch[:] = fbeam[gl1, gm1, gc1, :] + weight = ld*md*inv_nud - for c in range(ncorrs): - absc_sum[c] += weight * np.abs(beam_scratch[c]) - corr_sum[c] += weight * beam_scratch[c] + for c in range(ncorrs): + absc_sum[c] += weight * np.abs(beam_scratch[c]) + corr_sum[c] += weight * beam_scratch[c] - beam_scratch[:] = fbeam[gl1, gm1, gc1, :] - weight = ld*md*inv_nud + for c in range(ncorrs): + # Added all correlations, normalise + div = np.abs(corr_sum[c]) - for c in range(ncorrs): - absc_sum[c] += weight * np.abs(beam_scratch[c]) - corr_sum[c] += weight * beam_scratch[c] + if div == 0.0: + # This case probably works out to a zero assign + corr_sum[c] *= absc_sum[c] + else: + corr_sum[c] *= absc_sum[c] / div - for c in range(ncorrs): - # Added all correlations, normalise - div = np.abs(corr_sum[c]) + # Assign normalised values + fjones[s, t, a, f, :] = corr_sum - if div == 0.0: - # This case probably works out to a zero assign - corr_sum[c] *= absc_sum[c] - else: - corr_sum[c] *= absc_sum[c] / div + if parallel and threads is not None: + set_num_threads(prev_threads) - # Assign normalised values - fjones[s, t, a, f, :] = corr_sum + return fjones.reshape((nsrc, ntime, nants, nchan) + corrs) - return fjones.reshape((nsrc, ntime, nants, nchan) + corrs) + return impl BEAM_CUBE_DOCS = DocstringTemplate( diff --git a/africanus/rime/feeds.py b/africanus/rime/feeds.py index 8002071f5..833e96b05 100644 --- a/africanus/rime/feeds.py +++ b/africanus/rime/feeds.py @@ -1,75 +1,67 @@ # -*- coding: utf-8 -*- - -from functools import reduce -from operator import mul - import numpy as np +from africanus.config import config from africanus.util.docs import DocstringTemplate -from africanus.util.numba import jit - - -@jit(nopython=True, nogil=True, cache=True) -def _nb_feed_rotation(parallactic_angles, feed_type, feed_rotation): - shape = parallactic_angles.shape - parangles = parallactic_angles.flat - - # Linear feeds - if feed_type == 0: - for i, pa in enumerate(parangles): - pa_cos = np.cos(pa) - pa_sin = np.sin(pa) - - feed_rotation.real[i, 0, 0] = pa_cos - feed_rotation.imag[i, 0, 0] = 0.0 - feed_rotation.real[i, 0, 1] = pa_sin - feed_rotation.imag[i, 0, 1] = 0.0 - feed_rotation.real[i, 1, 0] = -pa_sin - feed_rotation.imag[i, 1, 0] = 0.0 - feed_rotation.real[i, 1, 1] = pa_cos - feed_rotation.imag[i, 1, 1] = 0.0 - - # Circular feeds - elif feed_type == 1: - for i, pa in enumerate(parangles): - pa_cos = np.cos(pa) - pa_sin = np.sin(pa) - - feed_rotation.real[i, 0, 0] = pa_cos - feed_rotation.imag[i, 0, 0] = -pa_sin - feed_rotation[i, 0, 1] = 0.0 + 0.0*1j - feed_rotation[i, 1, 0] = 0.0 + 0.0*1j - feed_rotation.real[i, 1, 1] = pa_cos - feed_rotation.imag[i, 1, 1] = pa_sin - else: - raise ValueError("Invalid feed_type") - - return feed_rotation.reshape(shape + (2, 2)) +from africanus.util.numba import generated_jit + +cfg = config.numba_parallel("rime.feed_rotation.parallel") +parallel = cfg.get('parallel', False) +@generated_jit(nopython=True, nogil=True, + cache=not parallel, parallel=parallel) def feed_rotation(parallactic_angles, feed_type='linear'): - if feed_type == 'linear': - poltype = 0 - elif feed_type == 'circular': - poltype = 1 - else: - raise ValueError("Invalid feed_type '%s'" % feed_type) - - if parallactic_angles.dtype == np.float32: - dtype = np.complex64 - elif parallactic_angles.dtype == np.float64: - dtype = np.complex128 - else: - raise ValueError("parallactic_angles has " - "none-floating point type %s" - % parallactic_angles.dtype) - - # Create result array with flattened parangles - shape = (reduce(mul, parallactic_angles.shape),) + (2, 2) - result = np.empty(shape, dtype=dtype) - - return _nb_feed_rotation(parallactic_angles, poltype, result) + pa_np_dtype = np.dtype(parallactic_angles.dtype.name) + dtype = np.result_type(pa_np_dtype, np.complex64) + + import numba + threads = cfg.get("threads", None) if parallel else None + + def impl(parallactic_angles, feed_type='linear'): + if parallel and threads is not None: + prev_threads = numba.get_num_threads() + numba.set_num_threads(threads) + + parangles = parallactic_angles.ravel() + # Can't prepend shape tuple till the following is fixed + # https://github.com/numba/numba/issues/5439 + # We know parangles.ndim == 1 though + result = np.zeros((parangles.shape[0], 2, 2), dtype=dtype) + + # Linear feeds + if feed_type == 'linear': + for i in numba.prange(parangles.shape[0]): + pa = parangles[i] + pa_cos = np.cos(pa) + pa_sin = np.sin(pa) + + result[i, 0, 0] = pa_cos + 0j + result[i, 0, 1] = pa_sin + 0j + result[i, 1, 0] = -pa_sin + 0j + result[i, 1, 1] = pa_cos + 0j + + # Circular feeds + elif feed_type == 'circular': + for i in numba.prange(parangles.shape[0]): + pa = parangles[i] + pa_cos = np.cos(pa) + pa_sin = np.sin(pa) + + result[i, 0, 0] = pa_cos - pa_sin*1j + result[i, 0, 1] = 0.0 + result[i, 1, 0] = 0.0 + result[i, 1, 1] = pa_cos + pa_sin*1j + else: + raise ValueError("feed_type not in ('linear', 'circular')") + + if parallel and threads is not None: + numba.set_num_threads(prev_threads) + + return result.reshape(parallactic_angles.shape + (2, 2)) + + return impl FEED_ROTATION_DOCS = DocstringTemplate(r""" diff --git a/africanus/rime/phase.py b/africanus/rime/phase.py index 48e9060ba..9eee72fa4 100644 --- a/africanus/rime/phase.py +++ b/africanus/rime/phase.py @@ -2,20 +2,36 @@ import numpy as np +from africanus.config import config from africanus.constants import minus_two_pi_over_c from africanus.util.docs import DocstringTemplate from africanus.util.numba import generated_jit from africanus.util.type_inference import infer_complex_dtype +cfg = config.numba_parallel("rime.phase_delay.parallel") +parallel = cfg.get('parallel', False) +axes = cfg.get('axes', set(('source', 'row')) if parallel else ()) -@generated_jit(nopython=True, nogil=True, cache=True) + +@generated_jit(nopython=True, nogil=True, + cache=not parallel, parallel=parallel) def phase_delay(lm, uvw, frequency, convention='fourier'): # Bake constants in with the correct type one = lm.dtype(1.0) neg_two_pi_over_c = lm.dtype(minus_two_pi_over_c) out_dtype = infer_complex_dtype(lm, uvw, frequency) + from numba import prange, set_num_threads, get_num_threads + + srange = prange if 'source' in axes else range + rrange = prange if 'row' in axes else range + threads = cfg.get('threads', None) if parallel else None + def _phase_delay_impl(lm, uvw, frequency, convention='fourier'): + if parallel and threads is not None: + prev_threads = get_num_threads() + set_num_threads(threads) + if convention == 'fourier': constant = neg_two_pi_over_c elif convention == 'casa': @@ -27,12 +43,12 @@ def _phase_delay_impl(lm, uvw, frequency, convention='fourier'): complex_phase = np.zeros(shape, dtype=out_dtype) # For each source - for source in range(lm.shape[0]): + for source in srange(lm.shape[0]): l, m = lm[source] n = np.sqrt(one - l**2 - m**2) - one # For each uvw coordinate - for row in range(uvw.shape[0]): + for row in rrange(uvw.shape[0]): u, v, w = uvw[row] # e^(-2*pi*(l*u + m*v + n*w)/c) real_phase = constant * (l * u + m * v + n * w) @@ -44,8 +60,10 @@ def _phase_delay_impl(lm, uvw, frequency, convention='fourier'): # Our phase input is purely imaginary # so we can can elide a call to exp # and just compute the cos and sin - complex_phase.real[source, row, chan] = np.cos(p) - complex_phase.imag[source, row, chan] = np.sin(p) + complex_phase[source, row, chan] = np.cos(p) + np.sin(p)*1j + + if parallel and threads is not None: + set_num_threads(prev_threads) return complex_phase diff --git a/africanus/rime/predict.py b/africanus/rime/predict.py index c0f4fc2ec..8369eaa53 100644 --- a/africanus/rime/predict.py +++ b/africanus/rime/predict.py @@ -3,9 +3,13 @@ import numpy as np +from africanus.config import config from africanus.util.docs import DocstringTemplate from africanus.util.numba import is_numba_type_none, generated_jit, njit +cfg = config.numba_parallel("rime.predict_vis.parallel") +parallel = cfg.get('parallel', False) +axes = cfg.get("axes", set(('source', 'row')) if parallel else set()) JONES_NOT_PRESENT = 0 JONES_1_OR_2 = 1 @@ -189,10 +193,15 @@ def sum_coherencies_factory(have_ddes, have_coh, jones_type): """ Factory function generating a function that sums coherencies """ jones_mul = jones_mul_factory(have_ddes, have_coh, jones_type, True) + from numba import prange + + srange = prange if 'source' in axes else range + rrange = prange if 'row' in axes else range + if have_ddes and have_coh: def sum_coh_fn(time, ant1, ant2, a1j, blj, a2j, tmin, out): - for s in range(a1j.shape[0]): - for r in range(time.shape[0]): + for s in srange(a1j.shape[0]): + for r in rrange(time.shape[0]): ti = time[r] - tmin a1 = ant1[r] a2 = ant2[r] @@ -205,8 +214,8 @@ def sum_coh_fn(time, ant1, ant2, a1j, blj, a2j, tmin, out): elif have_ddes and not have_coh: def sum_coh_fn(time, ant1, ant2, a1j, blj, a2j, tmin, out): - for s in range(a1j.shape[0]): - for r in range(time.shape[0]): + for s in srange(a1j.shape[0]): + for r in rrange(time.shape[0]): ti = time[r] - tmin a1 = ant1[r] a2 = ant2[r] @@ -219,16 +228,16 @@ def sum_coh_fn(time, ant1, ant2, a1j, blj, a2j, tmin, out): elif not have_ddes and have_coh: if jones_type == JONES_2X2: def sum_coh_fn(time, ant1, ant2, a1j, blj, a2j, tmin, out): - for s in range(blj.shape[0]): - for r in range(blj.shape[1]): + for s in srange(blj.shape[0]): + for r in rrange(blj.shape[1]): for f in range(blj.shape[2]): for c1 in range(blj.shape[3]): for c2 in range(blj.shape[4]): out[r, f, c1, c2] += blj[s, r, f, c1, c2] else: def sum_coh_fn(time, ant1, ant2, a1j, blj, a2j, tmin, out): - for s in range(blj.shape[0]): - for r in range(blj.shape[1]): + for s in srange(blj.shape[0]): + for r in rrange(blj.shape[1]): for f in range(blj.shape[2]): for c in range(blj.shape[3]): out[r, f, c] += blj[s, r, f, c] @@ -298,16 +307,20 @@ def apply_dies_factory(have_dies, have_bvis, jones_type): Factory function returning a function that applies Direction Independent Effects """ + from numba import prange # We always "have visibilities", (the output array) jones_mul = jones_mul_factory(have_dies, True, jones_type, False) + # time is rowlike + trange = prange if 'row' in axes else range + if have_dies and have_bvis: def apply_dies(time, ant1, ant2, die1_jones, die2_jones, tmin, out): # Iterate over rows - for r in range(time.shape[0]): + for r in trange(time.shape[0]): ti = time[r] - tmin a1 = ant1[r] a2 = ant2[r] @@ -322,7 +335,7 @@ def apply_dies(time, ant1, ant2, die1_jones, die2_jones, tmin, out): # Iterate over rows - for r in range(time.shape[0]): + for r in trange(time.shape[0]): ti = time[r] - tmin a1 = ant1[r] a2 = ant2[r] @@ -426,7 +439,8 @@ def predict_checks(time_index, antenna1, antenna2, have_dies1, have_bvis, have_dies2) -@generated_jit(nopython=True, nogil=True, cache=True) +@generated_jit(nopython=True, nogil=True, + cache=not parallel, parallel=parallel) def predict_vis(time_index, antenna1, antenna2, dde1_jones=None, source_coh=None, dde2_jones=None, die1_jones=None, base_vis=None, die2_jones=None): @@ -474,10 +488,17 @@ def predict_vis(time_index, antenna1, antenna2, apply_dies_fn = apply_dies_factory(have_dies, have_bvis, jones_type) add_coh_fn = add_coh_factory(have_bvis) + from numba import set_num_threads, get_num_threads + threads = cfg.get("threads", None) if parallel else None + def _predict_vis_fn(time_index, antenna1, antenna2, dde1_jones=None, source_coh=None, dde2_jones=None, die1_jones=None, base_vis=None, die2_jones=None): + if parallel and threads is not None: + prev_threads = get_num_threads() + set_num_threads(threads) + # Get the output shape out = out_fn(time_index, dde1_jones, source_coh, dde2_jones, die1_jones, base_vis, die2_jones) @@ -498,6 +519,9 @@ def _predict_vis_fn(time_index, antenna1, antenna2, die1_jones, die2_jones, tmin, out) + if parallel and threads is not None: + set_num_threads(prev_threads) + return out return _predict_vis_fn diff --git a/africanus/rime/tests/conftest.py b/africanus/rime/tests/conftest.py index 576635e7d..83bac2ae5 100644 --- a/africanus/rime/tests/conftest.py +++ b/africanus/rime/tests/conftest.py @@ -3,7 +3,6 @@ """Tests for `codex-africanus` package.""" - import numpy as np import pytest diff --git a/africanus/rime/tests/test_fast_beams.py b/africanus/rime/tests/test_fast_beams.py index 4645bebf0..722842a79 100644 --- a/africanus/rime/tests/test_fast_beams.py +++ b/africanus/rime/tests/test_fast_beams.py @@ -3,9 +3,6 @@ import pytest -from africanus.rime.fast_beam_cubes import beam_cube_dde, freq_grid_interp - - def rf(*a, **kw): return np.random.random(*a, **kw) @@ -40,8 +37,17 @@ def freqs(): return np.array([.4, .5, .6, .7, .8, .9, 1.0, 1.1]) -def test_fast_beam_small(): +@pytest.mark.parametrize("cfg_parallel", [ + ("africanus.rime.fast_beam_cubes", {"rime.beam_cube_dde.parallel": True}), + ("africanus.rime.fast_beam_cubes", { + "rime.beam_cube_dde.parallel": {'threads': 2}}), + ("africanus.rime.fast_beam_cubes", {"rime.beam_cube_dde.parallel": False}), + ], ids=["parallel", "parallel-2", "serial"], indirect=True) +def test_fast_beam_small(cfg_parallel): """ Small beam test, interpolation of one soure at [0.1, 0.1] """ + from africanus.rime.fast_beam_cubes import beam_cube_dde + assert beam_cube_dde.targetoptions['parallel'] == cfg_parallel + np.random.seed(42) # One frequency, to the lower side of the beam frequency map @@ -121,6 +127,8 @@ def test_fast_beam_small(): def test_grid_interpolate(freqs, beam_freq_map): + from africanus.rime.fast_beam_cubes import freq_grid_interp + freq_data = freq_grid_interp(freqs, beam_freq_map) freq_scale = freq_data[:, 0] @@ -153,6 +161,7 @@ def test_grid_interpolate(freqs, beam_freq_map): def test_dask_fast_beams(freqs, beam_freq_map): da = pytest.importorskip("dask.array") + from africanus.rime.fast_beam_cubes import beam_cube_dde from africanus.rime.dask import beam_cube_dde as dask_beam_cube_dde beam_lw = 10 @@ -215,6 +224,7 @@ def test_fast_beams_vs_montblanc(freqs, beam_freq_map_montblanc, dtype): """ Test that the numba beam matches montblanc implementation """ mb_tf_mod = pytest.importorskip("montblanc.impl.rime.tensorflow") tf = pytest.importorskip("tensorflow") + from africanus.rime.fast_beam_cubes import beam_cube_dde freqs = freqs.astype(dtype) beam_freq_map = beam_freq_map_montblanc.astype(dtype) diff --git a/africanus/rime/tests/test_predict.py b/africanus/rime/tests/test_predict.py index 041715162..8a880e026 100644 --- a/africanus/rime/tests/test_predict.py +++ b/africanus/rime/tests/test_predict.py @@ -52,9 +52,14 @@ def rc(*a, **kw): @dde_presence_parametrization @die_presence_parametrization @chunk_parametrization +@pytest.mark.parametrize("cfg_parallel", [ + ("africanus.rime.predict", {"rime.predict_vis.parallel": True}), + ("africanus.rime.predict", {"rime.predict_vis.parallel": {'threads': 2}}), + ("africanus.rime.predict", {"rime.predict_vis.parallel": False}), + ], ids=["parallel", "parallel-2", "serial"], indirect=True) def test_predict_vis(corr_shape, idm, einsum_sig1, einsum_sig2, a1j, blj, a2j, g1j, bvis, g2j, - chunks): + chunks, cfg_parallel): from africanus.rime.predict import predict_vis s = sum(chunks['source']) @@ -77,6 +82,8 @@ def test_predict_vis(corr_shape, idm, einsum_sig1, einsum_sig2, assert ant1.size == r + assert predict_vis.targetoptions['parallel'] == cfg_parallel + model_vis = predict_vis(time_idx, ant1, ant2, a1_jones if a1j else None, bl_jones if blj else None, diff --git a/africanus/rime/tests/test_rime.py b/africanus/rime/tests/test_rime.py index eac29d4f8..73dbe6073 100644 --- a/africanus/rime/tests/test_rime.py +++ b/africanus/rime/tests/test_rime.py @@ -4,7 +4,6 @@ """Tests for `codex-africanus` package.""" import numpy as np - import pytest @@ -20,8 +19,15 @@ def rc(*a, **kw): ('fourier', 1), ('casa', -1) ]) -def test_phase_delay(convention, sign): - from africanus.rime import phase_delay +@pytest.mark.parametrize("cfg_parallel", [ + ("africanus.rime.phase", {"rime.phase_delay.parallel": True}), + ("africanus.rime.phase", {"rime.phase_delay.parallel": {"threads": 2}}), + ("africanus.rime.phase", {"rime.phase_delay.parallel": False}), + ], ids=["parallel", "parallel-2", "serial"], indirect=True) +def test_phase_delay(convention, sign, cfg_parallel): + from africanus.rime.phase import phase_delay + + assert phase_delay.targetoptions['parallel'] == cfg_parallel uvw = np.random.random(size=(100, 3)) lm = np.random.random(size=(10, 2)) @@ -50,9 +56,14 @@ def test_phase_delay(convention, sign): assert np.all(np.exp(1j*phase) == complex_phase[lm_i, uvw_i, freq_i]) -def test_feed_rotation(): - import numpy as np - from africanus.rime import feed_rotation +@pytest.mark.parametrize("cfg_parallel", [ + ("africanus.rime.feeds", {"rime.feed_rotation.parallel": True}), + ("africanus.rime.feeds", {"rime.feed_rotation.parallel": False}), + ], ids=["parallel", "serial"], indirect=True) +def test_feed_rotation(cfg_parallel): + from africanus.rime.feeds import feed_rotation + + assert feed_rotation.targetoptions['parallel'] == cfg_parallel parangles = np.random.random((10, 5)) pa_sin = np.sin(parangles) diff --git a/africanus/util/numba.py b/africanus/util/numba.py index 02e1a27f5..88e2efd9d 100644 --- a/africanus/util/numba.py +++ b/africanus/util/numba.py @@ -22,9 +22,10 @@ def wrapper(*args, **kwargs): jit = _fake_decorator njit = _fake_decorator stencil = _fake_decorator + prange = _fake_decorator else: - from numba import cfunc, jit, njit, generated_jit, stencil # noqa + from numba import cfunc, jit, njit, generated_jit, stencil, prange # noqa def is_numba_type_none(arg): diff --git a/setup.py b/setup.py index 4ac6d5321..03e3d9118 100644 --- a/setup.py +++ b/setup.py @@ -14,7 +14,8 @@ # Basic requirements containing no C extensions. # This is necessary for building on RTD requirements = ['appdirs >= 1.4.3', - 'decorator'] + 'decorator', + 'donfig'] if not on_rtd: requirements += [