Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[skip ci] Interpolate missing nominal values during Averaging #246

Open
wants to merge 12 commits into
base: master
Choose a base branch
from
1 change: 1 addition & 0 deletions HISTORY.rst
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ History

0.2.11 (YYYY-MM-DD)
-------------------
* Interpolate missing nominal values during Averaging (:pr:`246`)
* Baseline-Dependent Time-and-Channel Averaging (:pr:`173`, :pr:`243`)

0.2.10 (2021-02-09)
Expand Down
33 changes: 19 additions & 14 deletions africanus/averaging/bda_mapping.py
Original file line number Diff line number Diff line change
Expand Up @@ -168,8 +168,6 @@ def start_bin(self, row, time, interval, flag_row):
self.rs = row
self.re = row
self.bin_count = 1
self.time_sum = time[row]
self.interval_sum = interval[row]
self.bin_flag_count = (1 if flag_row is not None and flag_row[row] != 0
else 0)

Expand All @@ -196,8 +194,6 @@ def add_row(self, row, auto_corr, time, interval, uvw, flag_row):
self.re = row
self.bin_half_Δψ = self.decorrelation
self.bin_count += 1
self.time_sum += time[row]
self.interval_sum += interval[row]

if flag_row is not None and flag_row[row] != 0:
self.bin_flag_count += 1
Expand Down Expand Up @@ -233,8 +229,6 @@ def add_row(self, row, auto_corr, time, interval, uvw, flag_row):
self.re = row
self.bin_half_Δψ = half_𝞓𝞇
self.bin_count += 1
self.time_sum += time[row]
self.interval_sum += interval[row]

if flag_row is not None and flag_row[row] != 0:
self.bin_flag_count += 1
Expand All @@ -245,7 +239,9 @@ def add_row(self, row, auto_corr, time, interval, uvw, flag_row):
def empty(self):
return self.bin_count == 0

def finalise_bin(self, auto_corr, uvw, nchan_factors,
def finalise_bin(self, auto_corr,
time, interval,
uvw, nchan_factors,
chan_width, chan_freq):
""" Finalise the contents of this bin """
if self.bin_count == 0:
Expand Down Expand Up @@ -301,10 +297,20 @@ def finalise_bin(self, auto_corr, uvw, nchan_factors,
s = np.searchsorted(nchan_factors, nchan, side='left')
nchan = nchan_factors[min(nchan_factors.shape[0] - 1, s)]

if rs == re:
# single value in the bin, re-use time and interval
bin_time = time[rs]
bin_interval = interval[rs]
else:
# take the midpoint
dt = time[re] - time[rs]
bin_time = 0.5*(time[re] + time[rs])
bin_interval = 0.5*interval[re] + 0.5*interval[rs] + dt

# Finalise bin values for return
out = FinaliseOutput(self.tbin,
self.time_sum / self.bin_count,
self.interval_sum,
bin_time,
bin_interval,
nchan,
self.bin_count == self.bin_flag_count)

Expand Down Expand Up @@ -487,8 +493,8 @@ def update_lookups(finalised, bl):
elif not binner.add_row(r, auto_corr,
time, interval,
uvw, flag_row):
f = binner.finalise_bin(auto_corr, uvw,
nchan_factors,
f = binner.finalise_bin(auto_corr, time, interval,
uvw, nchan_factors,
chan_width, chan_freq)
update_lookups(f, bl)
# Post-finalisation, the bin is empty, start a new bin
Expand All @@ -499,9 +505,8 @@ def update_lookups(finalised, bl):

# Finalise any remaining data in the bin
if not binner.empty:
f = binner.finalise_bin(auto_corr, uvw,
nchan_factors,
chan_width, chan_freq)
f = binner.finalise_bin(auto_corr, time, interval, uvw,
nchan_factors, chan_width, chan_freq)
update_lookups(f, bl)

nr_of_time_bins += binner.tbin
Expand Down
97 changes: 94 additions & 3 deletions africanus/averaging/tests/test_mapping.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,8 @@ def time():


@pytest.fixture
def interval():
data = np.asarray([1.9, 2.0, 2.1, 1.85, 1.95, 2.0, 2.05, 2.1, 2.05, 1.9])
return data*0.1
def interval(time):
return np.full_like(time, 1.0)


@pytest.fixture
Expand Down Expand Up @@ -108,6 +107,68 @@ def test_row_mapper(time, interval, ant1, ant2,
assert_array_almost_equal(new_exp, new_exp2)


@pytest.mark.parametrize("time_bin_secs", [3.0])
@pytest.mark.parametrize("keep", [
[0, 1, 3, 4, 5, 7, 8, 9],
[0, 1, 2, 3, 4, 5, 7, 8, 9],
])
def test_interpolation(time_bin_secs, keep):
time = np.linspace(1.0, 10.0, 10)
interval = np.full_like(time, 1.0, time.dtype)

ant1 = np.full_like(time, 0, np.int32)
ant2 = np.full_like(time, 1, np.int32)
flag_row = np.full_like(time, 0, np.uint8)

full = row_mapper(time, interval, ant1, ant2, flag_row, time_bin_secs)

holes = row_mapper(time[keep], interval[keep],
ant1[keep], ant2[keep],
flag_row[keep], time_bin_secs)

assert_array_almost_equal(full.time, holes.time)
assert_array_almost_equal(full.interval, holes.interval)


@pytest.mark.parametrize("time_bin_secs", [3.0])
def test_interpolation_edge(time_bin_secs):
time = np.linspace(1.0, 10.0, 10)
interval = np.full_like(time, 1.0, time.dtype)

ant1 = np.full_like(time, 0, np.int32)
ant2 = np.full_like(time, 1, np.int32)
flag_row = np.full_like(time, 0, np.uint8)

# First and last time centroids removed
keep = [1, 2, 3, 4, 5, 6, 7, 8],
holes = row_mapper(time[keep], interval[keep],
ant1[keep], ant2[keep],
flag_row[keep], time_bin_secs)

assert_array_almost_equal(holes.time, [3, 6, 8.5])
assert_array_almost_equal(holes.interval, [3, 3, 2])

# First and last time centroids removed as well
# as an interval value
keep = [1, 2, 3, 4, 5, 6, 8],
holes = row_mapper(time[keep], interval[keep],
ant1[keep], ant2[keep],
flag_row[keep], time_bin_secs)

assert_array_almost_equal(holes.time, [3, 6, 9])
assert_array_almost_equal(holes.interval, [3, 3, 1])

# First and last time centroids removed as well
# as an internal value
keep = [1, 3, 4, 5, 6, 7, 8],
holes = row_mapper(time[keep], interval[keep],
ant1[keep], ant2[keep],
flag_row[keep], time_bin_secs)

assert_array_almost_equal(holes.time, [3, 6, 8.5])
assert_array_almost_equal(holes.interval, [3, 3, 2])


def test_channel_mapper():
chan_map, out_chans = channel_mapper(64, 17)

Expand All @@ -122,3 +183,33 @@ def test_channel_mapper():
assert_array_equal(counts, [17, 17, 17, 13])

assert out_chans == 4


@pytest.mark.parametrize("time_bin_secs", [3])
def test_row_mapper2(time_bin_secs):
time = np.linspace(1.0, 10.0, 10)
interval = np.full_like(time, 1.0)

min_time_i = time.argmin()
max_time_i = time.argmax()

time_min = time[min_time_i] - interval[min_time_i] / 2
time_max = time[max_time_i] + interval[max_time_i] / 2
grid = [time_min]
next = time_min + time_bin_secs

while next < time_max:
grid.append(next)
next += time_bin_secs

grid.append(time_max)
grid = np.asarray(grid)
print(grid, np.diff(grid))

for j, (t, i) in enumerate(zip(time, interval)):
half_i = i / 2
l = np.searchsorted(grid, t - half_i, side="left") # noqa
u = np.searchsorted(grid, t + half_i, side="left")
vals = ([((t - half_i, t + half_i), (l, u), (grid[l], grid[u]))] +
[time[k] for k in range(l, u)])
print(*vals, sep=", ", end="\n")
70 changes: 51 additions & 19 deletions africanus/averaging/tests/test_time_and_channel_averaging.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,9 +46,8 @@ def uvw():


@pytest.fixture
def interval():
data = np.asarray([1.9, 2.0, 2.1, 1.85, 1.95, 2.0, 2.05, 2.1, 2.05, 1.9])
return 0.1 * data
def interval(time):
return np.full_like(time, 1.0)


@pytest.fixture
Expand Down Expand Up @@ -140,6 +139,37 @@ def _gen_testing_lookup(time, interval, ant1, ant2, flag_row, time_bin_secs,
# data
# 2. Nominal row bin, which includes both flagged and unflagged rows

def _can_add_row(high, low):
dt = ((time[high] + 0.5*interval[high]) -
(time[low] - 0.5*interval[low]))

if dt > time_bin_secs:
return False

return True

def _time_avg(nominal_rows):
if len(nominal_rows) == 0:
raise ValueError("nominal_rows == 0")
elif len(nominal_rows) == 1:
return time[nominal_rows[0]]
else:
low = nominal_rows[0]
high = nominal_rows[-1]
return 0.5*(time[high] + time[low])

def _int_sum(nominal_rows):
if len(nominal_rows) == 0:
raise ValueError("nominal_rows == 0")
elif len(nominal_rows) == 1:
return interval[nominal_rows[0]]
else:
low = nominal_rows[0]
high = nominal_rows[-1]

return (0.5*interval[high] + 0.5*interval[low] +
(time[high] - time[low]))

for bl, (a1, a2) in enumerate(ubl):
bl_row_idx = bl_time_lookup[bl, :]

Expand All @@ -153,13 +183,13 @@ def _gen_testing_lookup(time, interval, ant1, ant2, flag_row, time_bin_secs,
if ri == -1:
continue

half_int = 0.5 * interval[ri]

# We're starting a new bin
if len(nominal_map) == 0:
bin_low = time[ri] - half_int
rs = ri
effective_map = []
nominal_map = []
# Reached passed the endpoint of the bin, start a new one
elif time[ri] + half_int - bin_low > time_bin_secs:
elif not _can_add_row(ri, rs):
if len(effective_map) > 0:
effective_bin_map.append(effective_map)
nominal_bin_map.append(nominal_map)
Expand All @@ -170,6 +200,7 @@ def _gen_testing_lookup(time, interval, ant1, ant2, flag_row, time_bin_secs,
else:
raise ValueError("Zero-filled bin")

rs = ri
effective_map = []
nominal_map = []

Expand All @@ -190,13 +221,15 @@ def _gen_testing_lookup(time, interval, ant1, ant2, flag_row, time_bin_secs,
effective_bin_map.append(nominal_map)
nominal_bin_map.append(nominal_map)

# Produce a (avg_time, bl, effective_rows, nominal_rows) tuple
time_bl_row_map.extend((time[nrows].mean(), (a1, a2), erows, nrows)
# Produce a tuple of the form
# (avg_time, bl, interval, effective_rows, nominal_rows)
time_bl_row_map.extend((_time_avg(nrows), (a1, a2),
_int_sum(nrows), erows, nrows)
for erows, nrows
in zip(effective_bin_map, nominal_bin_map))

# Sort lookup sorted on averaged times
return sorted(time_bl_row_map, key=lambda tup: tup[0])
# Sort lookup sorted on averaged times and baselines
return sorted(time_bl_row_map, key=lambda tup: tup[:2])


def _calc_sigma(sigma, weight, idx):
Expand Down Expand Up @@ -239,19 +272,22 @@ def test_averager(time, ant1, ant2, flagged_rows,
row_meta = row_mapper(time, interval, ant1, ant2, flag_row, time_bin_secs)
chan_map, chan_bins = channel_mapper(nchan, chan_bin_size)

time_bl_row_map = _gen_testing_lookup(time_centroid, exposure, ant1, ant2,
time_bl_row_map = _gen_testing_lookup(time, interval, ant1, ant2,
flag_row, time_bin_secs,
row_meta)

# Effective and Nominal rows associated with each output row
eff_idx, nom_idx = zip(*[(nrows, erows) for _, _, nrows, erows
eff_idx, nom_idx = zip(*[(nrows, erows) for _, _, _, nrows, erows
in time_bl_row_map])

eff_idx = [ei for ei in eff_idx if len(ei) > 0]

# Check that the averaged times from the test and accelerated lookup match
assert_array_equal([t for t, _, _, _ in time_bl_row_map],
# Check that the times and intervals from the test lookup
# match those of the accelerated lookup
assert_array_equal([t for t, _, _, _, _ in time_bl_row_map],
row_meta.time)
assert_array_equal([i for _, _, i, _, _ in time_bl_row_map],
row_meta.interval)

avg = time_and_channel(time, interval, ant1, ant2,
flag_row=flag_row,
Expand All @@ -266,20 +302,16 @@ def test_averager(time, ant1, ant2, flagged_rows,

# Take mean time, but first ant1 and ant2
expected_time_centroids = [time_centroid[i].mean(axis=0) for i in eff_idx]
expected_times = [time[i].mean(axis=0) for i in nom_idx]
expected_ant1 = [ant1[i[0]] for i in nom_idx]
expected_ant2 = [ant2[i[0]] for i in nom_idx]
expected_flag_row = [flag_row[i].any(axis=0) for i in eff_idx]

# Take mean average, but sum of interval and exposure
expected_uvw = [uvw[i].mean(axis=0) for i in eff_idx]
expected_interval = [interval[i].sum(axis=0) for i in nom_idx]
expected_exposure = [exposure[i].sum(axis=0) for i in eff_idx]
expected_weight = [weight[i].sum(axis=0) for i in eff_idx]
expected_sigma = [_calc_sigma(sigma, weight, i) for i in eff_idx]

assert_array_equal(row_meta.time, expected_times)
assert_array_equal(row_meta.interval, expected_interval)
assert_array_equal(row_meta.flag_row, expected_flag_row)
assert_array_equal(avg.antenna1, expected_ant1)
assert_array_equal(avg.antenna2, expected_ant2)
Expand Down
Loading