From f68b09cd5d03d621184bb21f3087a67e84581443 Mon Sep 17 00:00:00 2001 From: tugtugtug Date: Wed, 13 Nov 2024 23:40:46 -0500 Subject: [PATCH 1/2] feat: Add retain_with_break to HashSet/Table/Map With the removal of the raw table, it is hard to implement an efficient loop to conditionally remove/keep certain fields up to a limit. i.e. a loop that can be aborted and does not require rehash the key for removal of the entry. --- src/map.rs | 71 ++++++++++++++++++++++++++++++++++++++++ src/set.rs | 55 +++++++++++++++++++++++++++++++ src/table.rs | 92 ++++++++++++++++++++++++++++++++++++++++++++++++++++ 3 files changed, 218 insertions(+) diff --git a/src/map.rs b/src/map.rs index c373d5958..7430ea42e 100644 --- a/src/map.rs +++ b/src/map.rs @@ -929,6 +929,53 @@ impl HashMap { } } + /// Retains only the elements specified by the predicate and breaks the iteration when + /// the predicate fails. Keeps the allocated memory for reuse. + /// + /// In other words, remove all pairs `(k, v)` such that `f(&k, &mut v)` returns `Some(false)` until + /// `f(&k, &mut v)` returns `None` + /// The elements are visited in unsorted (and unspecified) order. + /// + /// # Examples + /// + /// ``` + /// use hashbrown::HashMap; + /// + /// let mut map: HashMap = (0..8).map(|x|(x, x*10)).collect(); + /// assert_eq!(map.len(), 8); + /// let mut removed = 0; + /// map.retain_with_break(|&k, _| if removed < 3 { + /// if k % 2 == 0 { + /// Some(true) + /// } else { + /// removed += 1; + /// Some(false) + /// } + /// } else { + /// None + /// }); + /// + /// // We can see, that the number of elements inside map is changed and the + /// // length matches when we have aborted the retain with the return of `None` + /// assert_eq!(map.len(), 5); + /// ``` + pub fn retain_with_break(&mut self, mut f: F) + where + F: FnMut(&K, &mut V) -> Option, + { + // Here we only use `iter` as a temporary, preventing use-after-free + unsafe { + for item in self.table.iter() { + let &mut (ref key, ref mut value) = item.as_mut(); + match f(key, value) { + Some(false) => self.table.erase(item), + Some(true) => continue, + None => break, + } + } + } + } + /// Drains elements which are true under the given predicate, /// and returns an iterator over the removed items. /// @@ -5909,6 +5956,30 @@ mod test_map { assert_eq!(map[&6], 60); } + #[test] + fn test_retain_with_break() { + let mut map: HashMap = (0..100).map(|x| (x, x * 10)).collect(); + // looping and removing any key > 50, but stop after 40 iterations + let mut removed = 0; + map.retain_with_break(|&k, _| { + if removed < 40 { + if k > 50 { + removed += 1; + Some(false) + } else { + Some(true) + } + } else { + None + } + }); + assert_eq!(map.len(), 60); + // check nothing up to 50 is removed + for k in 0..=50 { + assert_eq!(map[&k], k * 10); + } + } + #[test] fn test_extract_if() { { diff --git a/src/set.rs b/src/set.rs index d57390f67..e7abe0230 100644 --- a/src/set.rs +++ b/src/set.rs @@ -372,6 +372,37 @@ impl HashSet { self.map.retain(|k, _| f(k)); } + /// Retains only the elements specified by the predicate until the predicate returns `None`. + /// + /// In other words, remove all elements `e` such that `f(&e)` returns `Some(false)` until + /// `f(&e)` returns `None`. + /// + /// # Examples + /// + /// ``` + /// use hashbrown::HashSet; + /// + /// let xs = [1,2,3,4,5,6]; + /// let mut set: HashSet = xs.into_iter().collect(); + /// let mut count = 0; + /// set.retain_with_break(|&k| if count < 2 { + /// if k % 2 == 0 { + /// Some(true) + /// } else { + /// Some(false) + /// } + /// } else { + /// None + /// }); + /// assert_eq!(set.len(), 3); + /// ``` + pub fn retain_with_break(&mut self, mut f: F) + where + F: FnMut(&T) -> Option, + { + self.map.retain_with_break(|k, _| f(k)); + } + /// Drains elements which are true under the given predicate, /// and returns an iterator over the removed items. /// @@ -2980,6 +3011,30 @@ mod test_set { assert!(set.contains(&6)); } + #[test] + fn test_retain_with_break() { + let mut set: HashSet = (0..100).collect(); + // looping and removing any key > 50, but stop after 40 iterations + let mut removed = 0; + set.retain_with_break(|&k| { + if removed < 40 { + if k > 50 { + removed += 1; + Some(false) + } else { + Some(true) + } + } else { + None + } + }); + assert_eq!(set.len(), 60); + // check nothing up to 50 is removed + for k in 0..=50 { + assert!(set.contains(&k)); + } + } + #[test] fn test_extract_if() { { diff --git a/src/table.rs b/src/table.rs index 7f665b75a..7d5aff9b4 100644 --- a/src/table.rs +++ b/src/table.rs @@ -870,6 +870,61 @@ where } } + /// Retains only the elements specified by the predicate until the predicate returns `None`. + /// + /// In other words, remove all elements `e` such that `f(&e)` returns `Ok(false)` until + /// `f(&e)` returns `None`. + /// + /// # Examples + /// + /// ``` + /// # #[cfg(feature = "nightly")] + /// # fn test() { + /// use hashbrown::{HashTable, DefaultHashBuilder}; + /// use std::hash::BuildHasher; + /// + /// let mut table = HashTable::new(); + /// let hasher = DefaultHashBuilder::default(); + /// let hasher = |val: &_| { + /// use core::hash::Hasher; + /// let mut state = hasher.build_hasher(); + /// core::hash::Hash::hash(&val, &mut state); + /// state.finish() + /// }; + /// let mut removed = 0; + /// for x in 1..=8 { + /// table.insert_unique(hasher(&x), x, hasher); + /// } + /// table.retain_with_break(|&mut v| if removed < 3 { + /// if v % 2 == 0 { + /// Some(true) + /// } else { + /// removed += 1; + /// Some(false) + /// } + /// } else { + /// None + /// }); + /// assert_eq!(table.len(), 5); + /// # } + /// # fn main() { + /// # #[cfg(feature = "nightly")] + /// # test() + /// # } + /// ``` + pub fn retain_with_break(&mut self, mut f: impl FnMut(&mut T) -> Option) { + // Here we only use `iter` as a temporary, preventing use-after-free + unsafe { + for item in self.raw.iter() { + match f(item.as_mut()) { + Some(false) => self.raw.erase(item), + Some(true) => continue, + None => break, + } + } + } + } + /// Clears the set, returning all elements in an iterator. /// /// # Examples @@ -2372,12 +2427,49 @@ impl FusedIterator for ExtractIf<'_, T, F, A> where F: FnMut #[cfg(test)] mod tests { + use crate::DefaultHashBuilder; + use super::HashTable; + use core::hash::BuildHasher; #[test] fn test_allocation_info() { assert_eq!(HashTable::<()>::new().allocation_size(), 0); assert_eq!(HashTable::::new().allocation_size(), 0); assert!(HashTable::::with_capacity(1).allocation_size() > core::mem::size_of::()); } + + #[test] + fn test_retain_with_break() { + let mut table = HashTable::new(); + let hasher = DefaultHashBuilder::default(); + let hasher = |val: &_| { + use core::hash::Hasher; + let mut state = hasher.build_hasher(); + core::hash::Hash::hash(&val, &mut state); + state.finish() + }; + for x in 0..100 { + table.insert_unique(hasher(&x), x, hasher); + } + // looping and removing any value > 50, but stop after 40 iterations + let mut removed = 0; + table.retain_with_break(|&mut v| { + if removed < 40 { + if v > 50 { + removed += 1; + Some(false) + } else { + Some(true) + } + } else { + None + } + }); + assert_eq!(table.len(), 60); + // check nothing up to 50 is removed + for v in 0..=50 { + assert_eq!(table.find(hasher(&v), |&val| val == v), Some(&v)); + } + } } From 35b44790e197f032d4c34528ea8448b0e1151f41 Mon Sep 17 00:00:00 2001 From: tugtugtug Date: Thu, 14 Nov 2024 15:58:07 -0500 Subject: [PATCH 2/2] feat: Add filter to HashSet/Table/Map Removing the raw table complicates implementing efficient loops that conditionally remove or keep elements and control iteration flow, particularly when aborting iteration without rehashing keys for removal. The existing `extract_if` method is cumbersome for flow control, limiting flexibility. The proposed addition addresses these shortcomings, by enabling proper flow control to allow efficient iteration and removal of filtered elements. --- src/map.rs | 63 +++++++++++++++++++++++++++++++++------------------- src/set.rs | 48 +++++++++++++++++++++++---------------- src/table.rs | 52 +++++++++++++++++++++++++++++-------------- 3 files changed, 104 insertions(+), 59 deletions(-) diff --git a/src/map.rs b/src/map.rs index 7430ea42e..497201925 100644 --- a/src/map.rs +++ b/src/map.rs @@ -929,48 +929,64 @@ impl HashMap { } } - /// Retains only the elements specified by the predicate and breaks the iteration when - /// the predicate fails. Keeps the allocated memory for reuse. + /// Iterates over elements, applying the specified `ControlFlow` predicate to each + /// + /// ### Element Fate + /// - Kept if `f(&k, &mut v)` returns `ControlFlow::(true)` + /// - Removed if `f(&k, &mut v)` returns `ControlFlow::(false)` + /// + /// ### Iteration Control + /// - Continue iterating if `f(&k, &mut v)` returns `ControlFlow::Continue` + /// - Abort iteration immediately (after applying the element fate) if `f(&k, &mut v)` + /// returns `ControlFlow::Break` /// - /// In other words, remove all pairs `(k, v)` such that `f(&k, &mut v)` returns `Some(false)` until - /// `f(&k, &mut v)` returns `None` /// The elements are visited in unsorted (and unspecified) order. /// /// # Examples /// /// ``` /// use hashbrown::HashMap; + /// use core::ops::ControlFlow; /// /// let mut map: HashMap = (0..8).map(|x|(x, x*10)).collect(); /// assert_eq!(map.len(), 8); /// let mut removed = 0; - /// map.retain_with_break(|&k, _| if removed < 3 { + /// map.filter(|&k, _| if removed < 3 { /// if k % 2 == 0 { - /// Some(true) + /// ControlFlow::Continue(true) /// } else { /// removed += 1; - /// Some(false) + /// ControlFlow::Continue(false) /// } /// } else { - /// None + /// // keep this item and break + /// ControlFlow::Break(true) /// }); /// /// // We can see, that the number of elements inside map is changed and the - /// // length matches when we have aborted the retain with the return of `None` + /// // length matches when we have aborted the iteration with the return of `ControlFlow::Break` /// assert_eq!(map.len(), 5); /// ``` - pub fn retain_with_break(&mut self, mut f: F) + pub fn filter(&mut self, mut f: F) where - F: FnMut(&K, &mut V) -> Option, + F: FnMut(&K, &mut V) -> core::ops::ControlFlow, { // Here we only use `iter` as a temporary, preventing use-after-free unsafe { for item in self.table.iter() { let &mut (ref key, ref mut value) = item.as_mut(); match f(key, value) { - Some(false) => self.table.erase(item), - Some(true) => continue, - None => break, + core::ops::ControlFlow::Continue(kept) => { + if !kept { + self.table.erase(item); + } + } + core::ops::ControlFlow::Break(kept) => { + if !kept { + self.table.erase(item); + } + break; + } } } } @@ -5957,20 +5973,21 @@ mod test_map { } #[test] - fn test_retain_with_break() { + fn test_filter() { let mut map: HashMap = (0..100).map(|x| (x, x * 10)).collect(); - // looping and removing any key > 50, but stop after 40 iterations + // looping and removing any key > 50, but stop after 40 removed let mut removed = 0; - map.retain_with_break(|&k, _| { - if removed < 40 { - if k > 50 { - removed += 1; - Some(false) + map.filter(|&k, _| { + if k > 50 { + removed += 1; + if removed < 40 { + core::ops::ControlFlow::Continue(false) } else { - Some(true) + // remove this item and break + core::ops::ControlFlow::Break(false) } } else { - None + core::ops::ControlFlow::Continue(true) } }); assert_eq!(map.len(), 60); diff --git a/src/set.rs b/src/set.rs index e7abe0230..39fec8575 100644 --- a/src/set.rs +++ b/src/set.rs @@ -372,35 +372,45 @@ impl HashSet { self.map.retain(|k, _| f(k)); } - /// Retains only the elements specified by the predicate until the predicate returns `None`. + /// Iterates over elements, applying the specified `ControlFlow` predicate to each /// - /// In other words, remove all elements `e` such that `f(&e)` returns `Some(false)` until - /// `f(&e)` returns `None`. + /// ### Element Fate + /// - Kept if `f(&e)` returns `ControlFlow::(true)` + /// - Removed if `f(&e)` returns `ControlFlow::(false)` + /// + /// ### Iteration Control + /// - Continue iterating if `f(&e)` returns `ControlFlow::Continue` + /// - Abort iteration immediately (after applying the element fate) if `f(&e)` + /// returns `ControlFlow::Break` + /// + /// The elements are visited in unsorted (and unspecified) order. /// /// # Examples /// /// ``` /// use hashbrown::HashSet; + /// use core::ops::ControlFlow; /// /// let xs = [1,2,3,4,5,6]; /// let mut set: HashSet = xs.into_iter().collect(); /// let mut count = 0; - /// set.retain_with_break(|&k| if count < 2 { + /// set.filter(|&k| if count < 2 { /// if k % 2 == 0 { - /// Some(true) + /// ControlFlow::Continue(true) /// } else { - /// Some(false) + /// ControlFlow::Continue(false) /// } /// } else { - /// None + /// // keep this item and break + /// ControlFlow::Break(true) /// }); /// assert_eq!(set.len(), 3); /// ``` - pub fn retain_with_break(&mut self, mut f: F) + pub fn filter(&mut self, mut f: F) where - F: FnMut(&T) -> Option, + F: FnMut(&T) -> core::ops::ControlFlow, { - self.map.retain_with_break(|k, _| f(k)); + self.map.filter(|k, _| f(k)); } /// Drains elements which are true under the given predicate, @@ -3012,20 +3022,20 @@ mod test_set { } #[test] - fn test_retain_with_break() { + fn test_filter() { let mut set: HashSet = (0..100).collect(); - // looping and removing any key > 50, but stop after 40 iterations + // looping and removing any element > 50, but stop after 40 removals let mut removed = 0; - set.retain_with_break(|&k| { - if removed < 40 { - if k > 50 { - removed += 1; - Some(false) + set.filter(|&k| { + if k > 50 { + removed += 1; + if removed < 40 { + core::ops::ControlFlow::Continue(false) } else { - Some(true) + core::ops::ControlFlow::Break(false) } } else { - None + core::ops::ControlFlow::Continue(true) } }); assert_eq!(set.len(), 60); diff --git a/src/table.rs b/src/table.rs index 7d5aff9b4..8a95ad98f 100644 --- a/src/table.rs +++ b/src/table.rs @@ -870,10 +870,18 @@ where } } - /// Retains only the elements specified by the predicate until the predicate returns `None`. + /// Iterates over elements, applying the specified `ControlFlow` predicate to each /// - /// In other words, remove all elements `e` such that `f(&e)` returns `Ok(false)` until - /// `f(&e)` returns `None`. + /// ### Element Fate + /// - Kept if `f(&e)` returns `ControlFlow::(true)` + /// - Removed if `f(&e)` returns `ControlFlow::(false)` + /// + /// ### Iteration Control + /// - Continue iterating if `f(&e)` returns `ControlFlow::Continue` + /// - Abort iteration immediately (after applying the element fate) if `f(&e)` + /// returns `ControlFlow::Break` + /// + /// The elements are visited in unsorted (and unspecified) order. /// /// # Examples /// @@ -882,6 +890,7 @@ where /// # fn test() { /// use hashbrown::{HashTable, DefaultHashBuilder}; /// use std::hash::BuildHasher; + /// use core::ops::ControlFlow; /// /// let mut table = HashTable::new(); /// let hasher = DefaultHashBuilder::default(); @@ -895,15 +904,16 @@ where /// for x in 1..=8 { /// table.insert_unique(hasher(&x), x, hasher); /// } - /// table.retain_with_break(|&mut v| if removed < 3 { + /// table.filter(|&mut v| if removed < 3 { /// if v % 2 == 0 { - /// Some(true) + /// ControlFlow::Continue(true) /// } else { /// removed += 1; - /// Some(false) + /// ControlFlow::Continue(false) /// } /// } else { - /// None + /// // keep this item and break + /// ControlFlow::Break(true) /// }); /// assert_eq!(table.len(), 5); /// # } @@ -912,14 +922,22 @@ where /// # test() /// # } /// ``` - pub fn retain_with_break(&mut self, mut f: impl FnMut(&mut T) -> Option) { + pub fn filter(&mut self, mut f: impl FnMut(&mut T) -> core::ops::ControlFlow) { // Here we only use `iter` as a temporary, preventing use-after-free unsafe { for item in self.raw.iter() { match f(item.as_mut()) { - Some(false) => self.raw.erase(item), - Some(true) => continue, - None => break, + core::ops::ControlFlow::Continue(kept) => { + if !kept { + self.raw.erase(item); + } + } + core::ops::ControlFlow::Break(kept) => { + if !kept { + self.raw.erase(item); + } + break; + } } } } @@ -2440,7 +2458,7 @@ mod tests { } #[test] - fn test_retain_with_break() { + fn test_filter() { let mut table = HashTable::new(); let hasher = DefaultHashBuilder::default(); let hasher = |val: &_| { @@ -2452,18 +2470,18 @@ mod tests { for x in 0..100 { table.insert_unique(hasher(&x), x, hasher); } - // looping and removing any value > 50, but stop after 40 iterations + // looping and removing any value > 50, but stop after 40 removals let mut removed = 0; - table.retain_with_break(|&mut v| { + table.filter(|&mut v| { if removed < 40 { if v > 50 { removed += 1; - Some(false) + core::ops::ControlFlow::Continue(false) } else { - Some(true) + core::ops::ControlFlow::Continue(true) } } else { - None + core::ops::ControlFlow::Break(true) } }); assert_eq!(table.len(), 60);