Skip to content

Commit

Permalink
feat: Add filter to HashSet/Table/Map
Browse files Browse the repository at this point in the history
Removing the raw table complicates implementing efficient loops
that conditionally remove or keep elements and control iteration flow,
particularly when aborting iteration without rehashing keys for
removal.

The existing `extract_if` method is cumbersome for flow control,
limiting flexibility.

The proposed addition addresses these shortcomings, by enabling proper
flow control to allow efficient iteration and removal of filtered
elements.
  • Loading branch information
tugtugtug committed Nov 14, 2024
1 parent f68b09c commit 35b4479
Show file tree
Hide file tree
Showing 3 changed files with 104 additions and 59 deletions.
63 changes: 40 additions & 23 deletions src/map.rs
Original file line number Diff line number Diff line change
Expand Up @@ -929,48 +929,64 @@ impl<K, V, S, A: Allocator> HashMap<K, V, S, A> {
}
}

/// Retains only the elements specified by the predicate and breaks the iteration when
/// the predicate fails. Keeps the allocated memory for reuse.
/// Iterates over elements, applying the specified `ControlFlow` predicate to each
///
/// ### Element Fate
/// - Kept if `f(&k, &mut v)` returns `ControlFlow::<Any>(true)`
/// - Removed if `f(&k, &mut v)` returns `ControlFlow::<Any>(false)`
///
/// ### Iteration Control
/// - Continue iterating if `f(&k, &mut v)` returns `ControlFlow::Continue`
/// - Abort iteration immediately (after applying the element fate) if `f(&k, &mut v)`
/// returns `ControlFlow::Break`
///
/// In other words, remove all pairs `(k, v)` such that `f(&k, &mut v)` returns `Some(false)` until
/// `f(&k, &mut v)` returns `None`
/// The elements are visited in unsorted (and unspecified) order.
///
/// # Examples
///
/// ```
/// use hashbrown::HashMap;
/// use core::ops::ControlFlow;
///
/// let mut map: HashMap<i32, i32> = (0..8).map(|x|(x, x*10)).collect();
/// assert_eq!(map.len(), 8);
/// let mut removed = 0;
/// map.retain_with_break(|&k, _| if removed < 3 {
/// map.filter(|&k, _| if removed < 3 {
/// if k % 2 == 0 {
/// Some(true)
/// ControlFlow::Continue(true)
/// } else {
/// removed += 1;
/// Some(false)
/// ControlFlow::Continue(false)
/// }
/// } else {
/// None
/// // keep this item and break
/// ControlFlow::Break(true)
/// });
///
/// // We can see, that the number of elements inside map is changed and the
/// // length matches when we have aborted the retain with the return of `None`
/// // length matches when we have aborted the iteration with the return of `ControlFlow::Break`
/// assert_eq!(map.len(), 5);
/// ```
pub fn retain_with_break<F>(&mut self, mut f: F)
pub fn filter<F>(&mut self, mut f: F)
where
F: FnMut(&K, &mut V) -> Option<bool>,
F: FnMut(&K, &mut V) -> core::ops::ControlFlow<bool, bool>,
{
// Here we only use `iter` as a temporary, preventing use-after-free
unsafe {
for item in self.table.iter() {
let &mut (ref key, ref mut value) = item.as_mut();
match f(key, value) {
Some(false) => self.table.erase(item),
Some(true) => continue,
None => break,
core::ops::ControlFlow::Continue(kept) => {
if !kept {
self.table.erase(item);
}
}
core::ops::ControlFlow::Break(kept) => {
if !kept {
self.table.erase(item);
}
break;
}
}
}
}
Expand Down Expand Up @@ -5957,20 +5973,21 @@ mod test_map {
}

#[test]
fn test_retain_with_break() {
fn test_filter() {
let mut map: HashMap<i32, i32> = (0..100).map(|x| (x, x * 10)).collect();
// looping and removing any key > 50, but stop after 40 iterations
// looping and removing any key > 50, but stop after 40 removed
let mut removed = 0;
map.retain_with_break(|&k, _| {
if removed < 40 {
if k > 50 {
removed += 1;
Some(false)
map.filter(|&k, _| {
if k > 50 {
removed += 1;
if removed < 40 {
core::ops::ControlFlow::Continue(false)
} else {
Some(true)
// remove this item and break
core::ops::ControlFlow::Break(false)
}
} else {
None
core::ops::ControlFlow::Continue(true)
}
});
assert_eq!(map.len(), 60);
Expand Down
48 changes: 29 additions & 19 deletions src/set.rs
Original file line number Diff line number Diff line change
Expand Up @@ -372,35 +372,45 @@ impl<T, S, A: Allocator> HashSet<T, S, A> {
self.map.retain(|k, _| f(k));
}

/// Retains only the elements specified by the predicate until the predicate returns `None`.
/// Iterates over elements, applying the specified `ControlFlow` predicate to each
///
/// In other words, remove all elements `e` such that `f(&e)` returns `Some(false)` until
/// `f(&e)` returns `None`.
/// ### Element Fate
/// - Kept if `f(&e)` returns `ControlFlow::<Any>(true)`
/// - Removed if `f(&e)` returns `ControlFlow::<Any>(false)`
///
/// ### Iteration Control
/// - Continue iterating if `f(&e)` returns `ControlFlow::Continue`
/// - Abort iteration immediately (after applying the element fate) if `f(&e)`
/// returns `ControlFlow::Break`
///
/// The elements are visited in unsorted (and unspecified) order.
///
/// # Examples
///
/// ```
/// use hashbrown::HashSet;
/// use core::ops::ControlFlow;
///
/// let xs = [1,2,3,4,5,6];
/// let mut set: HashSet<i32> = xs.into_iter().collect();
/// let mut count = 0;
/// set.retain_with_break(|&k| if count < 2 {
/// set.filter(|&k| if count < 2 {
/// if k % 2 == 0 {
/// Some(true)
/// ControlFlow::Continue(true)
/// } else {
/// Some(false)
/// ControlFlow::Continue(false)
/// }
/// } else {
/// None
/// // keep this item and break
/// ControlFlow::Break(true)
/// });
/// assert_eq!(set.len(), 3);
/// ```
pub fn retain_with_break<F>(&mut self, mut f: F)
pub fn filter<F>(&mut self, mut f: F)
where
F: FnMut(&T) -> Option<bool>,
F: FnMut(&T) -> core::ops::ControlFlow<bool, bool>,
{
self.map.retain_with_break(|k, _| f(k));
self.map.filter(|k, _| f(k));
}

/// Drains elements which are true under the given predicate,
Expand Down Expand Up @@ -3012,20 +3022,20 @@ mod test_set {
}

#[test]
fn test_retain_with_break() {
fn test_filter() {
let mut set: HashSet<i32> = (0..100).collect();
// looping and removing any key > 50, but stop after 40 iterations
// looping and removing any element > 50, but stop after 40 removals
let mut removed = 0;
set.retain_with_break(|&k| {
if removed < 40 {
if k > 50 {
removed += 1;
Some(false)
set.filter(|&k| {
if k > 50 {
removed += 1;
if removed < 40 {
core::ops::ControlFlow::Continue(false)
} else {
Some(true)
core::ops::ControlFlow::Break(false)
}
} else {
None
core::ops::ControlFlow::Continue(true)
}
});
assert_eq!(set.len(), 60);
Expand Down
52 changes: 35 additions & 17 deletions src/table.rs
Original file line number Diff line number Diff line change
Expand Up @@ -870,10 +870,18 @@ where
}
}

/// Retains only the elements specified by the predicate until the predicate returns `None`.
/// Iterates over elements, applying the specified `ControlFlow` predicate to each
///
/// In other words, remove all elements `e` such that `f(&e)` returns `Ok(false)` until
/// `f(&e)` returns `None`.
/// ### Element Fate
/// - Kept if `f(&e)` returns `ControlFlow::<Any>(true)`
/// - Removed if `f(&e)` returns `ControlFlow::<Any>(false)`
///
/// ### Iteration Control
/// - Continue iterating if `f(&e)` returns `ControlFlow::Continue`
/// - Abort iteration immediately (after applying the element fate) if `f(&e)`
/// returns `ControlFlow::Break`
///
/// The elements are visited in unsorted (and unspecified) order.
///
/// # Examples
///
Expand All @@ -882,6 +890,7 @@ where
/// # fn test() {
/// use hashbrown::{HashTable, DefaultHashBuilder};
/// use std::hash::BuildHasher;
/// use core::ops::ControlFlow;
///
/// let mut table = HashTable::new();
/// let hasher = DefaultHashBuilder::default();
Expand All @@ -895,15 +904,16 @@ where
/// for x in 1..=8 {
/// table.insert_unique(hasher(&x), x, hasher);
/// }
/// table.retain_with_break(|&mut v| if removed < 3 {
/// table.filter(|&mut v| if removed < 3 {
/// if v % 2 == 0 {
/// Some(true)
/// ControlFlow::Continue(true)
/// } else {
/// removed += 1;
/// Some(false)
/// ControlFlow::Continue(false)
/// }
/// } else {
/// None
/// // keep this item and break
/// ControlFlow::Break(true)
/// });
/// assert_eq!(table.len(), 5);
/// # }
Expand All @@ -912,14 +922,22 @@ where
/// # test()
/// # }
/// ```
pub fn retain_with_break(&mut self, mut f: impl FnMut(&mut T) -> Option<bool>) {
pub fn filter(&mut self, mut f: impl FnMut(&mut T) -> core::ops::ControlFlow<bool, bool>) {
// Here we only use `iter` as a temporary, preventing use-after-free
unsafe {
for item in self.raw.iter() {
match f(item.as_mut()) {
Some(false) => self.raw.erase(item),
Some(true) => continue,
None => break,
core::ops::ControlFlow::Continue(kept) => {
if !kept {
self.raw.erase(item);
}
}
core::ops::ControlFlow::Break(kept) => {
if !kept {
self.raw.erase(item);
}
break;
}
}
}
}
Expand Down Expand Up @@ -2440,7 +2458,7 @@ mod tests {
}

#[test]
fn test_retain_with_break() {
fn test_filter() {
let mut table = HashTable::new();
let hasher = DefaultHashBuilder::default();
let hasher = |val: &_| {
Expand All @@ -2452,18 +2470,18 @@ mod tests {
for x in 0..100 {
table.insert_unique(hasher(&x), x, hasher);
}
// looping and removing any value > 50, but stop after 40 iterations
// looping and removing any value > 50, but stop after 40 removals
let mut removed = 0;
table.retain_with_break(|&mut v| {
table.filter(|&mut v| {
if removed < 40 {
if v > 50 {
removed += 1;
Some(false)
core::ops::ControlFlow::Continue(false)
} else {
Some(true)
core::ops::ControlFlow::Continue(true)
}
} else {
None
core::ops::ControlFlow::Break(true)
}
});
assert_eq!(table.len(), 60);
Expand Down

0 comments on commit 35b4479

Please sign in to comment.