diff --git a/src/map.rs b/src/map.rs index c373d5958..497201925 100644 --- a/src/map.rs +++ b/src/map.rs @@ -929,6 +929,69 @@ impl HashMap { } } + /// Iterates over elements, applying the specified `ControlFlow` predicate to each + /// + /// ### Element Fate + /// - Kept if `f(&k, &mut v)` returns `ControlFlow::(true)` + /// - Removed if `f(&k, &mut v)` returns `ControlFlow::(false)` + /// + /// ### Iteration Control + /// - Continue iterating if `f(&k, &mut v)` returns `ControlFlow::Continue` + /// - Abort iteration immediately (after applying the element fate) if `f(&k, &mut v)` + /// returns `ControlFlow::Break` + /// + /// The elements are visited in unsorted (and unspecified) order. + /// + /// # Examples + /// + /// ``` + /// use hashbrown::HashMap; + /// use core::ops::ControlFlow; + /// + /// let mut map: HashMap = (0..8).map(|x|(x, x*10)).collect(); + /// assert_eq!(map.len(), 8); + /// let mut removed = 0; + /// map.filter(|&k, _| if removed < 3 { + /// if k % 2 == 0 { + /// ControlFlow::Continue(true) + /// } else { + /// removed += 1; + /// ControlFlow::Continue(false) + /// } + /// } else { + /// // keep this item and break + /// ControlFlow::Break(true) + /// }); + /// + /// // We can see, that the number of elements inside map is changed and the + /// // length matches when we have aborted the iteration with the return of `ControlFlow::Break` + /// assert_eq!(map.len(), 5); + /// ``` + pub fn filter(&mut self, mut f: F) + where + F: FnMut(&K, &mut V) -> core::ops::ControlFlow, + { + // Here we only use `iter` as a temporary, preventing use-after-free + unsafe { + for item in self.table.iter() { + let &mut (ref key, ref mut value) = item.as_mut(); + match f(key, value) { + core::ops::ControlFlow::Continue(kept) => { + if !kept { + self.table.erase(item); + } + } + core::ops::ControlFlow::Break(kept) => { + if !kept { + self.table.erase(item); + } + break; + } + } + } + } + } + /// Drains elements which are true under the given predicate, /// and returns an iterator over the removed items. /// @@ -5909,6 +5972,31 @@ mod test_map { assert_eq!(map[&6], 60); } + #[test] + fn test_filter() { + let mut map: HashMap = (0..100).map(|x| (x, x * 10)).collect(); + // looping and removing any key > 50, but stop after 40 removed + let mut removed = 0; + map.filter(|&k, _| { + if k > 50 { + removed += 1; + if removed < 40 { + core::ops::ControlFlow::Continue(false) + } else { + // remove this item and break + core::ops::ControlFlow::Break(false) + } + } else { + core::ops::ControlFlow::Continue(true) + } + }); + assert_eq!(map.len(), 60); + // check nothing up to 50 is removed + for k in 0..=50 { + assert_eq!(map[&k], k * 10); + } + } + #[test] fn test_extract_if() { { diff --git a/src/set.rs b/src/set.rs index d57390f67..39fec8575 100644 --- a/src/set.rs +++ b/src/set.rs @@ -372,6 +372,47 @@ impl HashSet { self.map.retain(|k, _| f(k)); } + /// Iterates over elements, applying the specified `ControlFlow` predicate to each + /// + /// ### Element Fate + /// - Kept if `f(&e)` returns `ControlFlow::(true)` + /// - Removed if `f(&e)` returns `ControlFlow::(false)` + /// + /// ### Iteration Control + /// - Continue iterating if `f(&e)` returns `ControlFlow::Continue` + /// - Abort iteration immediately (after applying the element fate) if `f(&e)` + /// returns `ControlFlow::Break` + /// + /// The elements are visited in unsorted (and unspecified) order. + /// + /// # Examples + /// + /// ``` + /// use hashbrown::HashSet; + /// use core::ops::ControlFlow; + /// + /// let xs = [1,2,3,4,5,6]; + /// let mut set: HashSet = xs.into_iter().collect(); + /// let mut count = 0; + /// set.filter(|&k| if count < 2 { + /// if k % 2 == 0 { + /// ControlFlow::Continue(true) + /// } else { + /// ControlFlow::Continue(false) + /// } + /// } else { + /// // keep this item and break + /// ControlFlow::Break(true) + /// }); + /// assert_eq!(set.len(), 3); + /// ``` + pub fn filter(&mut self, mut f: F) + where + F: FnMut(&T) -> core::ops::ControlFlow, + { + self.map.filter(|k, _| f(k)); + } + /// Drains elements which are true under the given predicate, /// and returns an iterator over the removed items. /// @@ -2980,6 +3021,30 @@ mod test_set { assert!(set.contains(&6)); } + #[test] + fn test_filter() { + let mut set: HashSet = (0..100).collect(); + // looping and removing any element > 50, but stop after 40 removals + let mut removed = 0; + set.filter(|&k| { + if k > 50 { + removed += 1; + if removed < 40 { + core::ops::ControlFlow::Continue(false) + } else { + core::ops::ControlFlow::Break(false) + } + } else { + core::ops::ControlFlow::Continue(true) + } + }); + assert_eq!(set.len(), 60); + // check nothing up to 50 is removed + for k in 0..=50 { + assert!(set.contains(&k)); + } + } + #[test] fn test_extract_if() { { diff --git a/src/table.rs b/src/table.rs index 7f665b75a..8a95ad98f 100644 --- a/src/table.rs +++ b/src/table.rs @@ -870,6 +870,79 @@ where } } + /// Iterates over elements, applying the specified `ControlFlow` predicate to each + /// + /// ### Element Fate + /// - Kept if `f(&e)` returns `ControlFlow::(true)` + /// - Removed if `f(&e)` returns `ControlFlow::(false)` + /// + /// ### Iteration Control + /// - Continue iterating if `f(&e)` returns `ControlFlow::Continue` + /// - Abort iteration immediately (after applying the element fate) if `f(&e)` + /// returns `ControlFlow::Break` + /// + /// The elements are visited in unsorted (and unspecified) order. + /// + /// # Examples + /// + /// ``` + /// # #[cfg(feature = "nightly")] + /// # fn test() { + /// use hashbrown::{HashTable, DefaultHashBuilder}; + /// use std::hash::BuildHasher; + /// use core::ops::ControlFlow; + /// + /// let mut table = HashTable::new(); + /// let hasher = DefaultHashBuilder::default(); + /// let hasher = |val: &_| { + /// use core::hash::Hasher; + /// let mut state = hasher.build_hasher(); + /// core::hash::Hash::hash(&val, &mut state); + /// state.finish() + /// }; + /// let mut removed = 0; + /// for x in 1..=8 { + /// table.insert_unique(hasher(&x), x, hasher); + /// } + /// table.filter(|&mut v| if removed < 3 { + /// if v % 2 == 0 { + /// ControlFlow::Continue(true) + /// } else { + /// removed += 1; + /// ControlFlow::Continue(false) + /// } + /// } else { + /// // keep this item and break + /// ControlFlow::Break(true) + /// }); + /// assert_eq!(table.len(), 5); + /// # } + /// # fn main() { + /// # #[cfg(feature = "nightly")] + /// # test() + /// # } + /// ``` + pub fn filter(&mut self, mut f: impl FnMut(&mut T) -> core::ops::ControlFlow) { + // Here we only use `iter` as a temporary, preventing use-after-free + unsafe { + for item in self.raw.iter() { + match f(item.as_mut()) { + core::ops::ControlFlow::Continue(kept) => { + if !kept { + self.raw.erase(item); + } + } + core::ops::ControlFlow::Break(kept) => { + if !kept { + self.raw.erase(item); + } + break; + } + } + } + } + } + /// Clears the set, returning all elements in an iterator. /// /// # Examples @@ -2372,12 +2445,49 @@ impl FusedIterator for ExtractIf<'_, T, F, A> where F: FnMut #[cfg(test)] mod tests { + use crate::DefaultHashBuilder; + use super::HashTable; + use core::hash::BuildHasher; #[test] fn test_allocation_info() { assert_eq!(HashTable::<()>::new().allocation_size(), 0); assert_eq!(HashTable::::new().allocation_size(), 0); assert!(HashTable::::with_capacity(1).allocation_size() > core::mem::size_of::()); } + + #[test] + fn test_filter() { + let mut table = HashTable::new(); + let hasher = DefaultHashBuilder::default(); + let hasher = |val: &_| { + use core::hash::Hasher; + let mut state = hasher.build_hasher(); + core::hash::Hash::hash(&val, &mut state); + state.finish() + }; + for x in 0..100 { + table.insert_unique(hasher(&x), x, hasher); + } + // looping and removing any value > 50, but stop after 40 removals + let mut removed = 0; + table.filter(|&mut v| { + if removed < 40 { + if v > 50 { + removed += 1; + core::ops::ControlFlow::Continue(false) + } else { + core::ops::ControlFlow::Continue(true) + } + } else { + core::ops::ControlFlow::Break(true) + } + }); + assert_eq!(table.len(), 60); + // check nothing up to 50 is removed + for v in 0..=50 { + assert_eq!(table.find(hasher(&v), |&val| val == v), Some(&v)); + } + } }