From 35b44790e197f032d4c34528ea8448b0e1151f41 Mon Sep 17 00:00:00 2001 From: tugtugtug Date: Thu, 14 Nov 2024 15:58:07 -0500 Subject: [PATCH] feat: Add filter to HashSet/Table/Map Removing the raw table complicates implementing efficient loops that conditionally remove or keep elements and control iteration flow, particularly when aborting iteration without rehashing keys for removal. The existing `extract_if` method is cumbersome for flow control, limiting flexibility. The proposed addition addresses these shortcomings, by enabling proper flow control to allow efficient iteration and removal of filtered elements. --- src/map.rs | 63 +++++++++++++++++++++++++++++++++------------------- src/set.rs | 48 +++++++++++++++++++++++---------------- src/table.rs | 52 +++++++++++++++++++++++++++++-------------- 3 files changed, 104 insertions(+), 59 deletions(-) diff --git a/src/map.rs b/src/map.rs index 7430ea42e..497201925 100644 --- a/src/map.rs +++ b/src/map.rs @@ -929,48 +929,64 @@ impl HashMap { } } - /// Retains only the elements specified by the predicate and breaks the iteration when - /// the predicate fails. Keeps the allocated memory for reuse. + /// Iterates over elements, applying the specified `ControlFlow` predicate to each + /// + /// ### Element Fate + /// - Kept if `f(&k, &mut v)` returns `ControlFlow::(true)` + /// - Removed if `f(&k, &mut v)` returns `ControlFlow::(false)` + /// + /// ### Iteration Control + /// - Continue iterating if `f(&k, &mut v)` returns `ControlFlow::Continue` + /// - Abort iteration immediately (after applying the element fate) if `f(&k, &mut v)` + /// returns `ControlFlow::Break` /// - /// In other words, remove all pairs `(k, v)` such that `f(&k, &mut v)` returns `Some(false)` until - /// `f(&k, &mut v)` returns `None` /// The elements are visited in unsorted (and unspecified) order. /// /// # Examples /// /// ``` /// use hashbrown::HashMap; + /// use core::ops::ControlFlow; /// /// let mut map: HashMap = (0..8).map(|x|(x, x*10)).collect(); /// assert_eq!(map.len(), 8); /// let mut removed = 0; - /// map.retain_with_break(|&k, _| if removed < 3 { + /// map.filter(|&k, _| if removed < 3 { /// if k % 2 == 0 { - /// Some(true) + /// ControlFlow::Continue(true) /// } else { /// removed += 1; - /// Some(false) + /// ControlFlow::Continue(false) /// } /// } else { - /// None + /// // keep this item and break + /// ControlFlow::Break(true) /// }); /// /// // We can see, that the number of elements inside map is changed and the - /// // length matches when we have aborted the retain with the return of `None` + /// // length matches when we have aborted the iteration with the return of `ControlFlow::Break` /// assert_eq!(map.len(), 5); /// ``` - pub fn retain_with_break(&mut self, mut f: F) + pub fn filter(&mut self, mut f: F) where - F: FnMut(&K, &mut V) -> Option, + F: FnMut(&K, &mut V) -> core::ops::ControlFlow, { // Here we only use `iter` as a temporary, preventing use-after-free unsafe { for item in self.table.iter() { let &mut (ref key, ref mut value) = item.as_mut(); match f(key, value) { - Some(false) => self.table.erase(item), - Some(true) => continue, - None => break, + core::ops::ControlFlow::Continue(kept) => { + if !kept { + self.table.erase(item); + } + } + core::ops::ControlFlow::Break(kept) => { + if !kept { + self.table.erase(item); + } + break; + } } } } @@ -5957,20 +5973,21 @@ mod test_map { } #[test] - fn test_retain_with_break() { + fn test_filter() { let mut map: HashMap = (0..100).map(|x| (x, x * 10)).collect(); - // looping and removing any key > 50, but stop after 40 iterations + // looping and removing any key > 50, but stop after 40 removed let mut removed = 0; - map.retain_with_break(|&k, _| { - if removed < 40 { - if k > 50 { - removed += 1; - Some(false) + map.filter(|&k, _| { + if k > 50 { + removed += 1; + if removed < 40 { + core::ops::ControlFlow::Continue(false) } else { - Some(true) + // remove this item and break + core::ops::ControlFlow::Break(false) } } else { - None + core::ops::ControlFlow::Continue(true) } }); assert_eq!(map.len(), 60); diff --git a/src/set.rs b/src/set.rs index e7abe0230..39fec8575 100644 --- a/src/set.rs +++ b/src/set.rs @@ -372,35 +372,45 @@ impl HashSet { self.map.retain(|k, _| f(k)); } - /// Retains only the elements specified by the predicate until the predicate returns `None`. + /// Iterates over elements, applying the specified `ControlFlow` predicate to each /// - /// In other words, remove all elements `e` such that `f(&e)` returns `Some(false)` until - /// `f(&e)` returns `None`. + /// ### Element Fate + /// - Kept if `f(&e)` returns `ControlFlow::(true)` + /// - Removed if `f(&e)` returns `ControlFlow::(false)` + /// + /// ### Iteration Control + /// - Continue iterating if `f(&e)` returns `ControlFlow::Continue` + /// - Abort iteration immediately (after applying the element fate) if `f(&e)` + /// returns `ControlFlow::Break` + /// + /// The elements are visited in unsorted (and unspecified) order. /// /// # Examples /// /// ``` /// use hashbrown::HashSet; + /// use core::ops::ControlFlow; /// /// let xs = [1,2,3,4,5,6]; /// let mut set: HashSet = xs.into_iter().collect(); /// let mut count = 0; - /// set.retain_with_break(|&k| if count < 2 { + /// set.filter(|&k| if count < 2 { /// if k % 2 == 0 { - /// Some(true) + /// ControlFlow::Continue(true) /// } else { - /// Some(false) + /// ControlFlow::Continue(false) /// } /// } else { - /// None + /// // keep this item and break + /// ControlFlow::Break(true) /// }); /// assert_eq!(set.len(), 3); /// ``` - pub fn retain_with_break(&mut self, mut f: F) + pub fn filter(&mut self, mut f: F) where - F: FnMut(&T) -> Option, + F: FnMut(&T) -> core::ops::ControlFlow, { - self.map.retain_with_break(|k, _| f(k)); + self.map.filter(|k, _| f(k)); } /// Drains elements which are true under the given predicate, @@ -3012,20 +3022,20 @@ mod test_set { } #[test] - fn test_retain_with_break() { + fn test_filter() { let mut set: HashSet = (0..100).collect(); - // looping and removing any key > 50, but stop after 40 iterations + // looping and removing any element > 50, but stop after 40 removals let mut removed = 0; - set.retain_with_break(|&k| { - if removed < 40 { - if k > 50 { - removed += 1; - Some(false) + set.filter(|&k| { + if k > 50 { + removed += 1; + if removed < 40 { + core::ops::ControlFlow::Continue(false) } else { - Some(true) + core::ops::ControlFlow::Break(false) } } else { - None + core::ops::ControlFlow::Continue(true) } }); assert_eq!(set.len(), 60); diff --git a/src/table.rs b/src/table.rs index 7d5aff9b4..8a95ad98f 100644 --- a/src/table.rs +++ b/src/table.rs @@ -870,10 +870,18 @@ where } } - /// Retains only the elements specified by the predicate until the predicate returns `None`. + /// Iterates over elements, applying the specified `ControlFlow` predicate to each /// - /// In other words, remove all elements `e` such that `f(&e)` returns `Ok(false)` until - /// `f(&e)` returns `None`. + /// ### Element Fate + /// - Kept if `f(&e)` returns `ControlFlow::(true)` + /// - Removed if `f(&e)` returns `ControlFlow::(false)` + /// + /// ### Iteration Control + /// - Continue iterating if `f(&e)` returns `ControlFlow::Continue` + /// - Abort iteration immediately (after applying the element fate) if `f(&e)` + /// returns `ControlFlow::Break` + /// + /// The elements are visited in unsorted (and unspecified) order. /// /// # Examples /// @@ -882,6 +890,7 @@ where /// # fn test() { /// use hashbrown::{HashTable, DefaultHashBuilder}; /// use std::hash::BuildHasher; + /// use core::ops::ControlFlow; /// /// let mut table = HashTable::new(); /// let hasher = DefaultHashBuilder::default(); @@ -895,15 +904,16 @@ where /// for x in 1..=8 { /// table.insert_unique(hasher(&x), x, hasher); /// } - /// table.retain_with_break(|&mut v| if removed < 3 { + /// table.filter(|&mut v| if removed < 3 { /// if v % 2 == 0 { - /// Some(true) + /// ControlFlow::Continue(true) /// } else { /// removed += 1; - /// Some(false) + /// ControlFlow::Continue(false) /// } /// } else { - /// None + /// // keep this item and break + /// ControlFlow::Break(true) /// }); /// assert_eq!(table.len(), 5); /// # } @@ -912,14 +922,22 @@ where /// # test() /// # } /// ``` - pub fn retain_with_break(&mut self, mut f: impl FnMut(&mut T) -> Option) { + pub fn filter(&mut self, mut f: impl FnMut(&mut T) -> core::ops::ControlFlow) { // Here we only use `iter` as a temporary, preventing use-after-free unsafe { for item in self.raw.iter() { match f(item.as_mut()) { - Some(false) => self.raw.erase(item), - Some(true) => continue, - None => break, + core::ops::ControlFlow::Continue(kept) => { + if !kept { + self.raw.erase(item); + } + } + core::ops::ControlFlow::Break(kept) => { + if !kept { + self.raw.erase(item); + } + break; + } } } } @@ -2440,7 +2458,7 @@ mod tests { } #[test] - fn test_retain_with_break() { + fn test_filter() { let mut table = HashTable::new(); let hasher = DefaultHashBuilder::default(); let hasher = |val: &_| { @@ -2452,18 +2470,18 @@ mod tests { for x in 0..100 { table.insert_unique(hasher(&x), x, hasher); } - // looping and removing any value > 50, but stop after 40 iterations + // looping and removing any value > 50, but stop after 40 removals let mut removed = 0; - table.retain_with_break(|&mut v| { + table.filter(|&mut v| { if removed < 40 { if v > 50 { removed += 1; - Some(false) + core::ops::ControlFlow::Continue(false) } else { - Some(true) + core::ops::ControlFlow::Continue(true) } } else { - None + core::ops::ControlFlow::Break(true) } }); assert_eq!(table.len(), 60);