From f68b09cd5d03d621184bb21f3087a67e84581443 Mon Sep 17 00:00:00 2001 From: tugtugtug Date: Wed, 13 Nov 2024 23:40:46 -0500 Subject: [PATCH] feat: Add retain_with_break to HashSet/Table/Map With the removal of the raw table, it is hard to implement an efficient loop to conditionally remove/keep certain fields up to a limit. i.e. a loop that can be aborted and does not require rehash the key for removal of the entry. --- src/map.rs | 71 ++++++++++++++++++++++++++++++++++++++++ src/set.rs | 55 +++++++++++++++++++++++++++++++ src/table.rs | 92 ++++++++++++++++++++++++++++++++++++++++++++++++++++ 3 files changed, 218 insertions(+) diff --git a/src/map.rs b/src/map.rs index c373d5958..7430ea42e 100644 --- a/src/map.rs +++ b/src/map.rs @@ -929,6 +929,53 @@ impl HashMap { } } + /// Retains only the elements specified by the predicate and breaks the iteration when + /// the predicate fails. Keeps the allocated memory for reuse. + /// + /// In other words, remove all pairs `(k, v)` such that `f(&k, &mut v)` returns `Some(false)` until + /// `f(&k, &mut v)` returns `None` + /// The elements are visited in unsorted (and unspecified) order. + /// + /// # Examples + /// + /// ``` + /// use hashbrown::HashMap; + /// + /// let mut map: HashMap = (0..8).map(|x|(x, x*10)).collect(); + /// assert_eq!(map.len(), 8); + /// let mut removed = 0; + /// map.retain_with_break(|&k, _| if removed < 3 { + /// if k % 2 == 0 { + /// Some(true) + /// } else { + /// removed += 1; + /// Some(false) + /// } + /// } else { + /// None + /// }); + /// + /// // We can see, that the number of elements inside map is changed and the + /// // length matches when we have aborted the retain with the return of `None` + /// assert_eq!(map.len(), 5); + /// ``` + pub fn retain_with_break(&mut self, mut f: F) + where + F: FnMut(&K, &mut V) -> Option, + { + // Here we only use `iter` as a temporary, preventing use-after-free + unsafe { + for item in self.table.iter() { + let &mut (ref key, ref mut value) = item.as_mut(); + match f(key, value) { + Some(false) => self.table.erase(item), + Some(true) => continue, + None => break, + } + } + } + } + /// Drains elements which are true under the given predicate, /// and returns an iterator over the removed items. /// @@ -5909,6 +5956,30 @@ mod test_map { assert_eq!(map[&6], 60); } + #[test] + fn test_retain_with_break() { + let mut map: HashMap = (0..100).map(|x| (x, x * 10)).collect(); + // looping and removing any key > 50, but stop after 40 iterations + let mut removed = 0; + map.retain_with_break(|&k, _| { + if removed < 40 { + if k > 50 { + removed += 1; + Some(false) + } else { + Some(true) + } + } else { + None + } + }); + assert_eq!(map.len(), 60); + // check nothing up to 50 is removed + for k in 0..=50 { + assert_eq!(map[&k], k * 10); + } + } + #[test] fn test_extract_if() { { diff --git a/src/set.rs b/src/set.rs index d57390f67..e7abe0230 100644 --- a/src/set.rs +++ b/src/set.rs @@ -372,6 +372,37 @@ impl HashSet { self.map.retain(|k, _| f(k)); } + /// Retains only the elements specified by the predicate until the predicate returns `None`. + /// + /// In other words, remove all elements `e` such that `f(&e)` returns `Some(false)` until + /// `f(&e)` returns `None`. + /// + /// # Examples + /// + /// ``` + /// use hashbrown::HashSet; + /// + /// let xs = [1,2,3,4,5,6]; + /// let mut set: HashSet = xs.into_iter().collect(); + /// let mut count = 0; + /// set.retain_with_break(|&k| if count < 2 { + /// if k % 2 == 0 { + /// Some(true) + /// } else { + /// Some(false) + /// } + /// } else { + /// None + /// }); + /// assert_eq!(set.len(), 3); + /// ``` + pub fn retain_with_break(&mut self, mut f: F) + where + F: FnMut(&T) -> Option, + { + self.map.retain_with_break(|k, _| f(k)); + } + /// Drains elements which are true under the given predicate, /// and returns an iterator over the removed items. /// @@ -2980,6 +3011,30 @@ mod test_set { assert!(set.contains(&6)); } + #[test] + fn test_retain_with_break() { + let mut set: HashSet = (0..100).collect(); + // looping and removing any key > 50, but stop after 40 iterations + let mut removed = 0; + set.retain_with_break(|&k| { + if removed < 40 { + if k > 50 { + removed += 1; + Some(false) + } else { + Some(true) + } + } else { + None + } + }); + assert_eq!(set.len(), 60); + // check nothing up to 50 is removed + for k in 0..=50 { + assert!(set.contains(&k)); + } + } + #[test] fn test_extract_if() { { diff --git a/src/table.rs b/src/table.rs index 7f665b75a..7d5aff9b4 100644 --- a/src/table.rs +++ b/src/table.rs @@ -870,6 +870,61 @@ where } } + /// Retains only the elements specified by the predicate until the predicate returns `None`. + /// + /// In other words, remove all elements `e` such that `f(&e)` returns `Ok(false)` until + /// `f(&e)` returns `None`. + /// + /// # Examples + /// + /// ``` + /// # #[cfg(feature = "nightly")] + /// # fn test() { + /// use hashbrown::{HashTable, DefaultHashBuilder}; + /// use std::hash::BuildHasher; + /// + /// let mut table = HashTable::new(); + /// let hasher = DefaultHashBuilder::default(); + /// let hasher = |val: &_| { + /// use core::hash::Hasher; + /// let mut state = hasher.build_hasher(); + /// core::hash::Hash::hash(&val, &mut state); + /// state.finish() + /// }; + /// let mut removed = 0; + /// for x in 1..=8 { + /// table.insert_unique(hasher(&x), x, hasher); + /// } + /// table.retain_with_break(|&mut v| if removed < 3 { + /// if v % 2 == 0 { + /// Some(true) + /// } else { + /// removed += 1; + /// Some(false) + /// } + /// } else { + /// None + /// }); + /// assert_eq!(table.len(), 5); + /// # } + /// # fn main() { + /// # #[cfg(feature = "nightly")] + /// # test() + /// # } + /// ``` + pub fn retain_with_break(&mut self, mut f: impl FnMut(&mut T) -> Option) { + // Here we only use `iter` as a temporary, preventing use-after-free + unsafe { + for item in self.raw.iter() { + match f(item.as_mut()) { + Some(false) => self.raw.erase(item), + Some(true) => continue, + None => break, + } + } + } + } + /// Clears the set, returning all elements in an iterator. /// /// # Examples @@ -2372,12 +2427,49 @@ impl FusedIterator for ExtractIf<'_, T, F, A> where F: FnMut #[cfg(test)] mod tests { + use crate::DefaultHashBuilder; + use super::HashTable; + use core::hash::BuildHasher; #[test] fn test_allocation_info() { assert_eq!(HashTable::<()>::new().allocation_size(), 0); assert_eq!(HashTable::::new().allocation_size(), 0); assert!(HashTable::::with_capacity(1).allocation_size() > core::mem::size_of::()); } + + #[test] + fn test_retain_with_break() { + let mut table = HashTable::new(); + let hasher = DefaultHashBuilder::default(); + let hasher = |val: &_| { + use core::hash::Hasher; + let mut state = hasher.build_hasher(); + core::hash::Hash::hash(&val, &mut state); + state.finish() + }; + for x in 0..100 { + table.insert_unique(hasher(&x), x, hasher); + } + // looping and removing any value > 50, but stop after 40 iterations + let mut removed = 0; + table.retain_with_break(|&mut v| { + if removed < 40 { + if v > 50 { + removed += 1; + Some(false) + } else { + Some(true) + } + } else { + None + } + }); + assert_eq!(table.len(), 60); + // check nothing up to 50 is removed + for v in 0..=50 { + assert_eq!(table.find(hasher(&v), |&val| val == v), Some(&v)); + } + } }