From 38f077c317bcaad598e265efa0391d889852ad80 Mon Sep 17 00:00:00 2001 From: Brent Pedersen Date: Mon, 21 Aug 2023 11:28:47 +0200 Subject: [PATCH] more efficient identify_trim_point for oscillations --- src/lib/base_quality.rs | 57 +++++++++++++++++++++-------------------- 1 file changed, 29 insertions(+), 28 deletions(-) diff --git a/src/lib/base_quality.rs b/src/lib/base_quality.rs index 6afd974..7a412a0 100644 --- a/src/lib/base_quality.rs +++ b/src/lib/base_quality.rs @@ -17,34 +17,24 @@ pub fn identify_trim_point( window_size: usize, max_oscillations: usize, ) -> Option { - // Compute a vector of booleans to represent whether each position is immediately - // after an oscillation - let mut is_osc = vec![false; quals.len()]; - - for i in 1..quals.len() { - is_osc[i] = (quals[i] as i32 - quals[i - 1] as i32).abs() >= osc_delta; - } - - // TODO [brent]: this is ~O(n^2/2) and could be ~O(n) - // instead, convert is_osc into a vector of indices of oscillations. then can use a - // window of `max_oscillations` and check if there are at least `window_size` oscillations in the window. - for i in 1..is_osc.len() { - if !is_osc[i] { - continue; - } - let mut n = 1; - #[allow(clippy::needless_range_loop)] // will refactor this later anyway - for j in (i + 1)..std::cmp::min(i + window_size, quals.len()) { - if is_osc[j] { - n += 1; - if n > max_oscillations { - return Some(i); - } - } - } - } - - None + // collect indices of oscillations where quals[i] - quals[i-1] >= osc_delta + let osc = quals + .windows(2) + .enumerate() + .filter(|(_, w)| (w[1] as i32 - w[0] as i32).abs() >= osc_delta) + .map(|(i, _)| i + 1) + .collect::>(); + // NOTE [performance]: we can do this without allocating but the .windows() method used below is only + // available on a slice and this only allocates for as many oscillations as are found. + + // here we have indices of oscillations. e.g.: [50, 52, 53, 54, 55, 57, 58, 59, 60, 62, 63, 64, 65, 67, 68, 69, 70, 72, 73, 74] + + // use a window of `max_oscilations` and check if there are at least `max_oscillations` oscillations in the window. + osc.windows(max_oscillations).find(|w| + // given a e.g. [50, 52, 55] means we found 3 oscillations in 5 bases (55 - 50) + // and we had window_size to find that many. so if the last - first < window_size, we found an osc window. + w[w.len() - 1] - w[0] < window_size + ).map(|w| w[0]) } // Indicates which tail(s) to clip. @@ -159,4 +149,15 @@ mod tests { let trim_point = identify_trim_point(&quals, 10, 20, 3); assert_eq!(trim_point, Some(50)); } + #[test] + fn test_identify_trim_point_finds_point_within() { + let left = vec![20u8; 50]; + let right = [15u8, 22, 35, 20, 32]; + let quals = + [&left[..], &right.into_iter().cycle().take(right.len() * 5).collect::>()] + .concat(); + + let trim_point = identify_trim_point(&quals, 10, 20, 3); + assert_eq!(trim_point, Some(52)); + } }