diff --git a/src/core/src/index/mod.rs b/src/core/src/index/mod.rs index 0bd9d9fec8..57ff5b656c 100644 --- a/src/core/src/index/mod.rs +++ b/src/core/src/index/mod.rs @@ -25,7 +25,7 @@ use crate::index::search::{search_minhashes, search_minhashes_containment}; use crate::prelude::*; use crate::selection::Selection; use crate::signature::SigsTrait; -use crate::sketch::minhash::KmerMinHash; +use crate::sketch::minhash::KmerMinHashBTree; use crate::storage::SigStore; use crate::Error::CannotUpsampleScaled; use crate::Result; @@ -208,8 +208,8 @@ where #[allow(clippy::too_many_arguments)] pub fn calculate_gather_stats( - orig_query: &KmerMinHash, - remaining_query: KmerMinHash, + orig_query: &KmerMinHashBTree, + remaining_query: KmerMinHashBTree, match_sig: SigStore, match_size: usize, gather_result_rank: u32, @@ -219,6 +219,8 @@ pub fn calculate_gather_stats( calc_ani_ci: bool, confidence: Option, ) -> Result<(GatherResult, (Vec, u64))> { + use crate::sketch::minhash::Intersection; + // get match_mh let match_mh = match_sig.minhash().expect("cannot retrieve sketch"); @@ -234,10 +236,18 @@ pub fn calculate_gather_stats( .expect("cannot downsample match"); // calculate intersection - let isect = match_mh - .intersection(&remaining_query) - .expect("could not do intersection"); - let isect_size = isect.0.len(); + // Using Intersection directly here has a pretty big requirement: + // the sketches MUST BE COMPATIBLE + // (as in: same ksize, max_hash, hash_function, seed) + // this should be covered by the call to downsample_scaled above, + // but important to keep in mind in the future if code changes + let isect_values: Vec<_> = Intersection::new(match_mh.iter_mins(), remaining_query.iter_mins()) + .copied() + .collect(); + + let isect_size = isect_values.len(); + let isect = (isect_values, isect_size as u64); + trace!("isect_size: {}", isect_size); trace!("query.size: {}", remaining_query.size()); @@ -246,7 +256,14 @@ pub fn calculate_gather_stats( (remaining_query.size() - isect_size) as u64 * remaining_query.scaled() as u64; // stats for this match vs original query - let (intersect_orig, _) = match_mh.intersection_size(orig_query).unwrap(); + // Using Intersection directly here has a pretty big requirement: + // the sketches MUST BE COMPATIBLE + // (as in: same ksize, max_hash, hash_function, seed) + // this should be covered by the call to downsample_scaled above, + // but important to keep in mind in the future if code changes + let intersect_orig = + Intersection::new(match_mh.iter_mins(), orig_query.iter_mins()).count() as u64; + let intersect_bp = match_mh.scaled() as u64 * intersect_orig; let f_orig_query = intersect_orig as f64 / orig_query.size() as f64; let f_match_orig = intersect_orig as f64 / match_mh.size() as f64; @@ -303,12 +320,8 @@ pub fn calculate_gather_stats( // If abundance, calculate abund-related metrics (vs current query) if calc_abund_stats { // take abunds from subtracted query - let (abunds, unique_weighted_found) = match match_mh.inflated_abundances(&remaining_query) { - Ok((abunds, unique_weighted_found)) => (abunds, unique_weighted_found), - Err(e) => { - return Err(e); - } - }; + let (abunds, unique_weighted_found) = match_mh + .inflated_abundances(remaining_query.iter_mins(), remaining_query.iter_abunds())?; n_unique_weighted_found = unique_weighted_found; sum_total_weighted_found = sum_weighted_found + n_unique_weighted_found; @@ -399,6 +412,7 @@ mod test_calculate_gather_stats { orig_query.add_hash_with_abundance(8, 1); orig_query.add_hash_with_abundance(10, 1); // Non-matching hash + let orig_query: KmerMinHashBTree = orig_query.into(); let query = orig_query.clone(); let total_weighted_hashes = orig_query.sum_abunds(); diff --git a/src/core/src/index/revindex/disk_revindex.rs b/src/core/src/index/revindex/disk_revindex.rs index 46552c2c67..6f4b2de127 100644 --- a/src/core/src/index/revindex/disk_revindex.rs +++ b/src/core/src/index/revindex/disk_revindex.rs @@ -365,12 +365,13 @@ impl RevIndexOps for RevIndex { query_colors: QueryColors, hash_to_color: HashToColor, threshold: usize, - orig_query: &KmerMinHash, + orig_query: KmerMinHash, selection: Option, ) -> Result> { let mut match_size = usize::MAX; let mut matches = vec![]; - let mut query = KmerMinHashBTree::from(orig_query.clone()); + let orig_query: KmerMinHashBTree = orig_query.into(); + let mut query = orig_query.clone(); let mut sum_weighted_found = 0; let _selection = selection.unwrap_or_else(|| self.collection.selection()); let total_weighted_hashes = orig_query.sum_abunds(); @@ -405,10 +406,10 @@ impl RevIndexOps for RevIndex { // repeatedly downsample query, then extract to KmerMinHash // => calculate_gather_stats - query = query + let query_mh = query + .clone() .downsample_scaled(max_scaled) .expect("cannot downsample query"); - let query_mh = KmerMinHash::from(query.clone()); // just calculate essentials here let gather_result_rank = matches.len() as u32; @@ -416,7 +417,7 @@ impl RevIndexOps for RevIndex { // grab the specific intersection: // Calculate stats let (gather_result, isect) = calculate_gather_stats( - orig_query, + &orig_query, query_mh, match_sig, match_size, diff --git a/src/core/src/index/revindex/mod.rs b/src/core/src/index/revindex/mod.rs index f1248be714..3e95e0635c 100644 --- a/src/core/src/index/revindex/mod.rs +++ b/src/core/src/index/revindex/mod.rs @@ -75,7 +75,7 @@ pub trait RevIndexOps { query_colors: QueryColors, hash_to_color: HashToColor, threshold: usize, - query: &KmerMinHash, + query: KmerMinHash, selection: Option, ) -> Result>; @@ -553,7 +553,7 @@ mod test { query_colors, hash_to_color, 0, - &query, + query, Some(selection), )?; @@ -620,7 +620,7 @@ mod test { query_colors, hash_to_color, 5, // 50kb threshold - &query, + query, Some(selection), )?; @@ -770,7 +770,7 @@ mod test { query_colors, hash_to_color, 0, - &query, + query, Some(selection), )?; @@ -909,7 +909,7 @@ mod test { query_colors, hash_to_color, 0, - &query, + query.clone(), Some(selection.clone()), ) .expect("failed to gather!"); @@ -927,7 +927,7 @@ mod test { query_colors, hash_to_color, 0, - &query, + query.clone(), Some(selection.clone()), )?; assert_eq!(matches_external, matches_internal); @@ -944,7 +944,7 @@ mod test { query_colors, hash_to_color, 0, - &query, + query, Some(selection.clone()), )?; assert_eq!(matches_external, matches_moved); diff --git a/src/core/src/sketch/minhash.rs b/src/core/src/sketch/minhash.rs index 438294e098..d6ec701248 100644 --- a/src/core/src/sketch/minhash.rs +++ b/src/core/src/sketch/minhash.rs @@ -717,6 +717,10 @@ impl KmerMinHash { self.mins.iter() } + pub fn iter_abunds(&self) -> Option> { + self.abunds.as_ref().map(|abunds| abunds.iter()) + } + pub fn abunds(&self) -> Option> { self.abunds.clone() } @@ -828,33 +832,21 @@ impl KmerMinHash { Ok(()) } - pub fn inflated_abundances(&self, abunds_from: &KmerMinHash) -> Result<(Vec, u64), Error> { - self.check_compatible(abunds_from)?; + pub fn inflated_abundances<'a, 'b, M: Iterator, A: Iterator>( + &'b self, + mins_from: M, + abunds_from: Option, + ) -> Result<(Vec, u64), Error> { + //self.check_compatible(abunds_from)?; + // check that abunds_from has abundances - if abunds_from.abunds.is_none() { + if abunds_from.is_none() { return Err(Error::NeedsAbundanceTracking); } let self_iter = self.mins.iter(); - let abunds_iter = abunds_from.abunds.as_ref().unwrap().iter(); - let abunds_from_iter = abunds_from.mins.iter().zip(abunds_iter); - - let (abundances, total_abundance): (Vec, u64) = self_iter - .merge_join_by(abunds_from_iter, |&self_val, &(other_val, _)| { - self_val.cmp(other_val) - }) - .filter_map(|either| match either { - itertools::EitherOrBoth::Both(_self_val, (_other_val, other_abund)) => { - Some(*other_abund) - } - _ => None, - }) - .fold((Vec::new(), 0u64), |(mut acc_vec, acc_sum), abund| { - acc_vec.push(abund); - (acc_vec, acc_sum + abund) - }); - Ok((abundances, total_abundance)) + inflated_abundances(self_iter, mins_from, abunds_from) } } @@ -912,13 +904,13 @@ impl SigsTrait for KmerMinHash { } } -struct Intersection> { +pub(crate) struct Intersection, J: Iterator> { iter: Peekable, - other: Peekable, + other: Peekable, } -impl> Intersection { - pub fn new(left: I, right: I) -> Self { +impl, J: Iterator> Intersection { + pub fn new(left: I, right: J) -> Self { Intersection { iter: left.peekable(), other: right.peekable(), @@ -926,7 +918,7 @@ impl> Intersection { } } -impl> Iterator for Intersection { +impl, J: Iterator> Iterator for Intersection { type Item = T; fn next(&mut self) -> Option { @@ -1534,6 +1526,10 @@ impl KmerMinHashBTree { self.mins.iter() } + pub fn iter_abunds(&self) -> Option> { + self.abunds.as_ref().map(|abunds| abunds.values()) + } + pub fn abunds(&self) -> Option> { self.abunds .as_ref() @@ -1551,6 +1547,13 @@ impl KmerMinHashBTree { } } + // Approximate total number of kmers + // this could be improved by generating an HLL estimate while sketching instead + // (for scaled minhashes) + pub fn n_unique_kmers(&self) -> u64 { + self.size() as u64 * self.scaled() as u64 // + (self.ksize - 1) for bp estimation + } + // create a downsampled copy of self pub fn downsample_scaled(self, scaled: ScaledType) -> Result { if self.scaled() == scaled || self.scaled() == 0 { @@ -1594,6 +1597,53 @@ impl KmerMinHashBTree { self.size() as u64 } } + + pub fn inflated_abundances<'a, 'b, M: Iterator, A: Iterator>( + &'b self, + mins_from: M, + abunds_from: Option, + ) -> Result<(Vec, u64), Error> { + // check that abunds_from has abundances + if abunds_from.is_none() { + return Err(Error::NeedsAbundanceTracking); + } + + let self_iter = self.mins.iter(); + + inflated_abundances(self_iter, mins_from, abunds_from) + } +} + +fn inflated_abundances< + 'a, + 'b, + M: Iterator, + N: Iterator, + A: Iterator, +>( + self_iter: M, + mins_from: N, + abunds_from: Option, +) -> Result<(Vec, u64), Error> { + let abunds_iter = abunds_from.unwrap(); + let abunds_from_iter = mins_from.zip(abunds_iter); + + let (abundances, total_abundance): (Vec, u64) = self_iter + .merge_join_by(abunds_from_iter, |&self_val, &(other_val, _)| { + self_val.cmp(other_val) + }) + .filter_map(|either| match either { + itertools::EitherOrBoth::Both(_self_val, (_other_val, other_abund)) => { + Some(*other_abund) + } + _ => None, + }) + .fold((Vec::new(), 0u64), |(mut acc_vec, acc_sum), abund| { + acc_vec.push(abund); + (acc_vec, acc_sum + abund) + }); + + Ok((abundances, total_abundance)) } impl SigsTrait for KmerMinHashBTree { diff --git a/src/core/tests/minhash.rs b/src/core/tests/minhash.rs index bdbba0cc20..1b7a2234a2 100644 --- a/src/core/tests/minhash.rs +++ b/src/core/tests/minhash.rs @@ -831,7 +831,7 @@ fn test_inflated_abundances() { // Attempt to inflate minhash_a using minhash_b's abundances assert!(a.inflate(&b).is_ok()); - let (abunds, total_abund) = a.inflated_abundances(&b).unwrap(); + let (abunds, total_abund) = a.inflated_abundances(b.iter_mins(), b.iter_abunds()).unwrap(); assert_eq!(abunds, vec![2, 4]); assert_eq!(total_abund, 6); } @@ -858,7 +858,7 @@ fn test_inflated_abunds_noabund() { a.add_hash(10); a.add_hash(20); a.add_hash(30); - let result = a.inflated_abundances(&a); + let result = a.inflated_abundances(a.iter_mins(), a.iter_abunds()); assert!(matches!( result, Err(sourmash::Error::NeedsAbundanceTracking)