From d0ab3a6fdf490f745f815249cdb0c384aa4e41d1 Mon Sep 17 00:00:00 2001 From: Simon Perkins Date: Thu, 23 May 2024 11:40:48 +0200 Subject: [PATCH 1/2] Fix bda overload to return an implementation --- africanus/averaging/bda_avg.py | 177 ++++++++++++------ .../averaging/tests/test_bda_averaging.py | 16 +- 2 files changed, 133 insertions(+), 60 deletions(-) diff --git a/africanus/averaging/bda_avg.py b/africanus/averaging/bda_avg.py index 1589c204..5c32478e 100644 --- a/africanus/averaging/bda_avg.py +++ b/africanus/averaging/bda_avg.py @@ -3,6 +3,7 @@ from collections import namedtuple import numpy as np +from numba import types from africanus.averaging.bda_mapping import bda_mapper, RowMapOutput from africanus.averaging.shared import chan_corrs, merge_flags, vis_output_arrays @@ -757,72 +758,132 @@ def nb_bda_impl( time_bin_secs=None, min_nchan=1, ): - # Merge flag_row and flag arrays - flag_row = merge_flags(flag_row, flag) + if is_numba_type_none(chan_width): + return TypeError(f"chan_width must be provided") - meta = bda_mapper( + if is_numba_type_none(chan_freq): + return TypeError(f"chan_freq must be provided") + + if is_numba_type_none(uvw): + raise TypeError(f"uvw must be provided") + + valid_types = ( + types.misc.NoneType, + types.misc.Omitted, + types.scalars.Float, + types.scalars.Integer, + ) + + if not isinstance(max_uvw_dist, valid_types): + raise TypeError(f"max_uvw_dist ({max_uvw_dist}) must be a scalar float") + + if not isinstance(max_fov, valid_types): + raise TypeError(f"max_fov ({max_fov}) must be a scalar float") + + if not isinstance(decorrelation, valid_types): + raise TypeError(f"decorrelation ({decorrelation}) must be a scalar float") + + if not isinstance(time_bin_secs, valid_types): + raise TypeError(f"time_bin_secs ({time_bin_secs}) must be a scalar float") + + valid_types = (types.misc.NoneType, types.misc.Omitted, types.scalars.Integer) + + if not isinstance(min_nchan, valid_types): + raise TypeError(f"min_nchan ({min_nchan}) must be an integer") + + def impl( time, interval, antenna1, antenna2, - uvw, - chan_width, - chan_freq, - max_uvw_dist, - flag_row=flag_row, - max_fov=max_fov, - decorrelation=decorrelation, - time_bin_secs=time_bin_secs, - min_nchan=min_nchan, - ) + time_centroid=None, + exposure=None, + flag_row=None, + uvw=None, + weight=None, + sigma=None, + chan_freq=None, + chan_width=None, + effective_bw=None, + resolution=None, + visibilities=None, + flag=None, + weight_spectrum=None, + sigma_spectrum=None, + max_uvw_dist=None, + max_fov=3.0, + decorrelation=0.98, + time_bin_secs=None, + min_nchan=1, + ): + # Merge flag_row and flag arrays + flag_row = merge_flags(flag_row, flag) + + meta = bda_mapper( + time, + interval, + antenna1, + antenna2, + uvw, + chan_width, + chan_freq, + max_uvw_dist, + flag_row=flag_row, + max_fov=max_fov, + decorrelation=decorrelation, + time_bin_secs=time_bin_secs, + min_nchan=min_nchan, + ) - row_avg = row_average( - meta, - antenna1, - antenna2, - flag_row, # noqa: F841 - time_centroid, - exposure, - uvw, - weight=weight, - sigma=sigma, - ) + row_avg = row_average( + meta, + antenna1, + antenna2, + flag_row, # noqa: F841 + time_centroid, + exposure, + uvw, + weight=weight, + sigma=sigma, + ) - row_chan_avg = row_chan_average( - meta, # noqa: F841 - flag_row=flag_row, - visibilities=visibilities, - flag=flag, - weight_spectrum=weight_spectrum, - sigma_spectrum=sigma_spectrum, - ) + row_chan_avg = row_chan_average( + meta, # noqa: F841 + flag_row=flag_row, + visibilities=visibilities, + flag=flag, + weight_spectrum=weight_spectrum, + sigma_spectrum=sigma_spectrum, + ) - # Have to explicitly write it out because numba tuples - # are highly constrained types - return AverageOutput( - meta.map, - meta.offsets, - meta.decorr_chan_width, - meta.time, - meta.interval, - meta.chan_width, - meta.flag_row, - row_avg.antenna1, - row_avg.antenna2, - row_avg.time_centroid, - row_avg.exposure, - row_avg.uvw, - row_avg.weight, - row_avg.sigma, - # None, # chan_data.chan_freq, - # None, # chan_data.chan_width, - # None, # chan_data.effective_bw, - # None, # chan_data.resolution, - row_chan_avg.visibilities, - row_chan_avg.flag, - row_chan_avg.weight_spectrum, - row_chan_avg.sigma_spectrum, - ) + # Have to explicitly write it out because numba tuples + # are highly constrained types + return AverageOutput( + meta.map, + meta.offsets, + meta.decorr_chan_width, + meta.time, + meta.interval, + meta.chan_width, + meta.flag_row, + row_avg.antenna1, + row_avg.antenna2, + row_avg.time_centroid, + row_avg.exposure, + row_avg.uvw, + row_avg.weight, + row_avg.sigma, + # None, # chan_data.chan_freq, + # None, # chan_data.chan_width, + # None, # chan_data.effective_bw, + # None, # chan_data.resolution, + row_chan_avg.visibilities, + row_chan_avg.flag, + row_chan_avg.weight_spectrum, + row_chan_avg.sigma_spectrum, + ) + + return impl BDA_DOCS = DocstringTemplate( diff --git a/africanus/averaging/tests/test_bda_averaging.py b/africanus/averaging/tests/test_bda_averaging.py index 4ee72352..fb5a0fcd 100644 --- a/africanus/averaging/tests/test_bda_averaging.py +++ b/africanus/averaging/tests/test_bda_averaging.py @@ -7,7 +7,7 @@ import pytest from africanus.averaging.bda_mapping import RowMapOutput -from africanus.averaging.bda_avg import row_average, row_chan_average +from africanus.averaging.bda_avg import bda as bda_avg, row_average, row_chan_average from africanus.averaging.dask import bda as dask_bda @@ -94,7 +94,7 @@ def _calc_sigma(weight, sigma, rows): return np.sqrt(numerator / denominator) -def test_bda_avg(bda_test_map, inv_bda_test_map, flags): +def test_bda_avg_in_parts(bda_test_map, inv_bda_test_map, flags): rs = np.random.RandomState(42) # Derive flag_row from flags @@ -119,6 +119,7 @@ def test_bda_avg(bda_test_map, inv_bda_test_map, flags): weight = rs.normal(size=(in_row, in_corr)) sigma = rs.normal(size=(in_row, in_corr)) chan_width = np.repeat(0.856e9 / out_chan, out_chan) + chan_freq = np.linspace(0.856e9, 2 * 0.856e9, chan_width.size) # Aggregate time and interval, in_row => out_row # first channel in the map. We're only averaging over @@ -236,6 +237,17 @@ def test_bda_avg(bda_test_map, inv_bda_test_map, flags): assert_array_almost_equal(row_chan_avg.weight_spectrum, out_ws) assert_array_almost_equal(row_chan_avg.sigma_spectrum, out_ss) + result = bda_avg( + time=time, + interval=interval, + antenna1=ant1, + antenna2=ant2, + visibilities=vis, + uvw=uvw, + chan_width=chan_width, + chan_freq=chan_freq, + ) + @pytest.mark.parametrize("vis_format", ["ragged", "flat"]) def test_dask_bda_avg(vis_format): From 2b6b41f7622a7a0e00fedbcba9de0cd77b8a9ab1 Mon Sep 17 00:00:00 2001 From: Simon Perkins Date: Thu, 23 May 2024 11:42:32 +0200 Subject: [PATCH 2/2] [skip ci] Update HISTORY.rst --- HISTORY.rst | 1 + 1 file changed, 1 insertion(+) diff --git a/HISTORY.rst b/HISTORY.rst index 87bf23fc..b78b09a9 100644 --- a/HISTORY.rst +++ b/HISTORY.rst @@ -4,6 +4,7 @@ History X.Y.Z (YYYY-MM-DD) ------------------ +* Fix bda overload to return an implementation (:pr:`307`) * Upgrade obsolete readthedocs configuration (:pr:`304`) 0.3.6 (2024-04-15)