From 570ae79588ecc791ae332605396c21de92702848 Mon Sep 17 00:00:00 2001 From: Simon Humpohl Date: Thu, 14 Mar 2024 16:36:46 +0100 Subject: [PATCH] Add average_windows function for use in new DAQ drivers --- qupulse/utils/performance.py | 90 ++++++++++++++++++++++++++++++++ tests/utils/performance_tests.py | 29 +++++++++- 2 files changed, 118 insertions(+), 1 deletion(-) diff --git a/qupulse/utils/performance.py b/qupulse/utils/performance.py index 4076b664..12abfa80 100644 --- a/qupulse/utils/performance.py +++ b/qupulse/utils/performance.py @@ -81,4 +81,94 @@ def time_windows_to_samples(begins: np.ndarray, lengths: np.ndarray, is_monotonic = _is_monotonic_numba +@njit +def _average_windows_numba(time: np.ndarray, values: np.ndarray, + begins: np.ndarray, ends: np.ndarray) -> np.ndarray: + n_samples, = time.shape + n_windows, = begins.shape + + assert len(begins) == len(ends) + assert values.shape[0] == n_samples + + result = np.zeros(begins.shape + values.shape[1:], dtype=float) + count = np.zeros(n_windows, dtype=np.uint64) + + start = 0 + for i in range(n_samples): + t = time[i] + v = values[i, ...] + + while start < n_windows and ends[start] <= t: + n = count[start] + if n == 0: + result[start] = np.nan + else: + result[start] /= n + start += 1 + + idx = start + while idx < n_windows and begins[idx] <= t: + result[idx] += v + count[idx] += 1 + idx += 1 + + for idx in range(start, n_windows): + n = count[idx] + if n == 0: + result[idx] = np.nan + else: + result[idx] /= count[idx] + + return result + + +def _average_windows_numpy(time: np.ndarray, values: np.ndarray, + begins: np.ndarray, ends: np.ndarray) -> np.ndarray: + start = np.searchsorted(time, begins) + end = np.searchsorted(time, ends) + + val_shape = values.shape[1:] + + count = end - start + val_mask = result_mask = start < end + + result = np.zeros(begins.shape + val_shape, dtype=float) + while np.any(val_mask): + result[val_mask, ...] += values[start[val_mask], ...] + start[val_mask] += 1 + val_mask = start < end + + result[~result_mask, ...] = np.nan + if result.ndim == 1: + result[result_mask, ...] /= count[result_mask] + else: + result[result_mask, ...] /= count[result_mask, None] + + return result + + +def average_windows(time: np.ndarray, values: np.ndarray, begins: np.ndarray, ends: np.ndarray): + """This function calculates the average over all windows that are defined by begins and ends. + The function assumes that the given time array is monotonically increasing and might produce + nonsensical results if not. + Args: + time: Time associated with the values of shape (n_samples,) + values: Values to average of shape (n_samples,) or (n_samples, n_channels) + begins: Beginning time stamps of the windows of shape (n_windows,) + ends: Ending time stamps of the windows of shape (n_windows,) + + Returns: + Averaged values for each window of shape (n_windows,) or (n_windows, n_channels). + Windows without samples are NaN. + """ + n_samples, = time.shape + n_windows, = begins.shape + + assert n_windows == len(ends) + assert values.shape[0] == n_samples + + if numba is None: + return _average_windows_numpy(time, values, begins, ends) + else: + return _average_windows_numba(time, values, begins, ends) diff --git a/tests/utils/performance_tests.py b/tests/utils/performance_tests.py index d158dce5..b9023751 100644 --- a/tests/utils/performance_tests.py +++ b/tests/utils/performance_tests.py @@ -2,7 +2,8 @@ import numpy as np -from qupulse.utils.performance import _time_windows_to_samples_numba, _time_windows_to_samples_numpy +from qupulse.utils.performance import (_time_windows_to_samples_numba, _time_windows_to_samples_numpy, + _average_windows_numba, _average_windows_numpy, average_windows) class TimeWindowsToSamplesTest(unittest.TestCase): @@ -28,3 +29,29 @@ def test_unsorted(self): self.assert_implementations_equal(begins, lengths, sr) +class WindowAverageTest(unittest.TestCase): + @staticmethod + def assert_implementations_equal(time, values, begins, ends): + numpy_result = _average_windows_numpy(time, values, begins, ends) + numba_result = _average_windows_numba(time, values, begins, ends) + np.testing.assert_allclose(numpy_result, numba_result) + + def setUp(self): + self.begins = np.array([1., 2., 3.] + [4.] + [6., 7., 8., 9., 10.]) + self.ends = self.begins + np.array([1., 1., 1.] + [3.] + [2., 2., 2., 2., 2.]) + self.time = np.arange(10).astype(float) + self.values = np.asarray([ + np.sin(self.time), + np.cos(self.time), + ]).T + + def test_dispatch(self): + _ = average_windows(self.time, self.values, self.begins, self.ends) + _ = average_windows(self.time, self.values[..., 0], self.begins, self.ends) + + def test_single_channel(self): + self.assert_implementations_equal(self.time, self.values[..., 0], self.begins, self.ends) + self.assert_implementations_equal(self.time, self.values[..., :1], self.begins, self.ends) + + def test_dual_channel(self): + self.assert_implementations_equal(self.time, self.values, self.begins, self.ends)