Skip to content

Commit

Permalink
Add average_windows function for use in new DAQ drivers
Browse files Browse the repository at this point in the history
  • Loading branch information
terrorfisch committed Mar 14, 2024
1 parent f0a957d commit 570ae79
Show file tree
Hide file tree
Showing 2 changed files with 118 additions and 1 deletion.
90 changes: 90 additions & 0 deletions qupulse/utils/performance.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,4 +81,94 @@ def time_windows_to_samples(begins: np.ndarray, lengths: np.ndarray,
is_monotonic = _is_monotonic_numba


@njit
def _average_windows_numba(time: np.ndarray, values: np.ndarray,
begins: np.ndarray, ends: np.ndarray) -> np.ndarray:
n_samples, = time.shape
n_windows, = begins.shape

assert len(begins) == len(ends)
assert values.shape[0] == n_samples

result = np.zeros(begins.shape + values.shape[1:], dtype=float)
count = np.zeros(n_windows, dtype=np.uint64)

start = 0
for i in range(n_samples):
t = time[i]
v = values[i, ...]

while start < n_windows and ends[start] <= t:
n = count[start]
if n == 0:
result[start] = np.nan
else:
result[start] /= n
start += 1

idx = start
while idx < n_windows and begins[idx] <= t:
result[idx] += v
count[idx] += 1
idx += 1

for idx in range(start, n_windows):
n = count[idx]
if n == 0:
result[idx] = np.nan
else:
result[idx] /= count[idx]

return result


def _average_windows_numpy(time: np.ndarray, values: np.ndarray,
begins: np.ndarray, ends: np.ndarray) -> np.ndarray:
start = np.searchsorted(time, begins)
end = np.searchsorted(time, ends)

val_shape = values.shape[1:]

count = end - start
val_mask = result_mask = start < end

result = np.zeros(begins.shape + val_shape, dtype=float)
while np.any(val_mask):
result[val_mask, ...] += values[start[val_mask], ...]
start[val_mask] += 1
val_mask = start < end

result[~result_mask, ...] = np.nan
if result.ndim == 1:
result[result_mask, ...] /= count[result_mask]
else:
result[result_mask, ...] /= count[result_mask, None]

return result


def average_windows(time: np.ndarray, values: np.ndarray, begins: np.ndarray, ends: np.ndarray):
"""This function calculates the average over all windows that are defined by begins and ends.
The function assumes that the given time array is monotonically increasing and might produce
nonsensical results if not.
Args:
time: Time associated with the values of shape (n_samples,)
values: Values to average of shape (n_samples,) or (n_samples, n_channels)
begins: Beginning time stamps of the windows of shape (n_windows,)
ends: Ending time stamps of the windows of shape (n_windows,)
Returns:
Averaged values for each window of shape (n_windows,) or (n_windows, n_channels).
Windows without samples are NaN.
"""
n_samples, = time.shape
n_windows, = begins.shape

assert n_windows == len(ends)
assert values.shape[0] == n_samples

if numba is None:
return _average_windows_numpy(time, values, begins, ends)
else:
return _average_windows_numba(time, values, begins, ends)
29 changes: 28 additions & 1 deletion tests/utils/performance_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,8 @@

import numpy as np

from qupulse.utils.performance import _time_windows_to_samples_numba, _time_windows_to_samples_numpy
from qupulse.utils.performance import (_time_windows_to_samples_numba, _time_windows_to_samples_numpy,
_average_windows_numba, _average_windows_numpy, average_windows)


class TimeWindowsToSamplesTest(unittest.TestCase):
Expand All @@ -28,3 +29,29 @@ def test_unsorted(self):
self.assert_implementations_equal(begins, lengths, sr)


class WindowAverageTest(unittest.TestCase):
@staticmethod
def assert_implementations_equal(time, values, begins, ends):
numpy_result = _average_windows_numpy(time, values, begins, ends)
numba_result = _average_windows_numba(time, values, begins, ends)
np.testing.assert_allclose(numpy_result, numba_result)

def setUp(self):
self.begins = np.array([1., 2., 3.] + [4.] + [6., 7., 8., 9., 10.])
self.ends = self.begins + np.array([1., 1., 1.] + [3.] + [2., 2., 2., 2., 2.])
self.time = np.arange(10).astype(float)
self.values = np.asarray([
np.sin(self.time),
np.cos(self.time),
]).T

def test_dispatch(self):
_ = average_windows(self.time, self.values, self.begins, self.ends)
_ = average_windows(self.time, self.values[..., 0], self.begins, self.ends)

def test_single_channel(self):
self.assert_implementations_equal(self.time, self.values[..., 0], self.begins, self.ends)
self.assert_implementations_equal(self.time, self.values[..., :1], self.begins, self.ends)

def test_dual_channel(self):
self.assert_implementations_equal(self.time, self.values, self.begins, self.ends)

0 comments on commit 570ae79

Please sign in to comment.