Skip to content

Commit

Permalink
Implement range iterators for inner nodes
Browse files Browse the repository at this point in the history
**Description**
 - Add a `range` iterator constructor to the `InnerNode` trait
 - Implement the `range` iterator for each inner node type and
   add tests

**Motivation**
I think this is a prerequisite for a tree-level range iterator

**Testing Done**
`./scripts/full-test.sh nightly`
  • Loading branch information
declanvk committed Jul 8, 2024
1 parent 140ac74 commit f883bd4
Show file tree
Hide file tree
Showing 4 changed files with 489 additions and 48 deletions.
13 changes: 11 additions & 2 deletions src/nodes/representation.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ use std::{
iter::FusedIterator,
marker::PhantomData,
mem::{self, ManuallyDrop},
ops::Range,
ops::{Range, RangeBounds},
ptr::{self, NonNull},
};

Expand Down Expand Up @@ -662,10 +662,19 @@ pub trait InnerNode<const PREFIX_LEN: usize>: Node<PREFIX_LEN> + Sized {
self.header().num_children() >= Self::TYPE.upper_capacity()
}

/// Create an iterator over all (key bytes, child pointers) in this inner
/// Create an iterator over all `(key bytes, child pointers)` in this inner
/// node.
fn iter(&self) -> Self::Iter<'_>;

/// Create an iterator over a subset of `(key bytes, child pointers)`, using
/// the given `bound` as a restriction on the set of key bytes.
fn range(
&self,
bound: impl RangeBounds<u8>,
) -> impl Iterator<Item = (u8, OpaqueNodePtr<Self::Key, Self::Value, PREFIX_LEN>)>
+ DoubleEndedIterator
+ FusedIterator;

/// Compares the compressed path of a node with the key and returns the
/// number of equal bytes.
///
Expand Down
135 changes: 117 additions & 18 deletions src/nodes/representation/inner_node_256.rs
Original file line number Diff line number Diff line change
Expand Up @@ -141,6 +141,28 @@ impl<K: AsBytes, V, const PREFIX_LEN: usize> InnerNode<PREFIX_LEN>
}
}

fn range(
&self,
bound: impl std::ops::RangeBounds<u8>,
) -> impl Iterator<Item = (u8, OpaqueNodePtr<Self::Key, Self::Value, PREFIX_LEN>)>
+ DoubleEndedIterator
+ FusedIterator {
let start = bound.start_bound().map(|val| usize::from(*val));
let key_offset = match bound.start_bound() {
std::ops::Bound::Included(val) => *val,
std::ops::Bound::Excluded(val) => val.saturating_add(1),
std::ops::Bound::Unbounded => 0,
};
let end = bound.end_bound().map(|val| usize::from(*val));

(&self.child_pointers[(start, end)])
.iter()
.enumerate()
.filter_map(move |(key, child)| {
child.map(|child| ((key as u8).saturating_add(key_offset), child))
})
}

#[cfg(feature = "nightly")]
fn min(&self) -> (u8, OpaqueNodePtr<K, V, PREFIX_LEN>) {
use crate::rust_nightly_apis::assume;
Expand Down Expand Up @@ -313,6 +335,8 @@ impl<'a, K: AsBytes, V, const PREFIX_LEN: usize> FusedIterator

#[cfg(test)]
mod tests {
use std::ops::Bound;

use crate::{
nodes::representation::tests::{
inner_node_remove_child_test, inner_node_shrink_test, inner_node_write_child_test,
Expand Down Expand Up @@ -382,7 +406,7 @@ mod tests {
}

fn fixture() -> FixtureReturn<InnerNode256<Box<[u8]>, (), 16>, 4> {
let mut n4 = InnerNode256::empty();
let mut n256 = InnerNode256::empty();
let mut l1 = LeafNode::new(vec![].into(), ());
let mut l2 = LeafNode::new(vec![].into(), ());
let mut l3 = LeafNode::new(vec![].into(), ());
Expand All @@ -392,29 +416,104 @@ mod tests {
let l3_ptr = NodePtr::from(&mut l3).to_opaque();
let l4_ptr = NodePtr::from(&mut l4).to_opaque();

n4.write_child(3, l1_ptr);
n4.write_child(255, l2_ptr);
n4.write_child(0u8, l3_ptr);
n4.write_child(85, l4_ptr);
n256.write_child(3, l1_ptr);
n256.write_child(255, l2_ptr);
n256.write_child(0u8, l3_ptr);
n256.write_child(85, l4_ptr);

(n4, [l1, l2, l3, l4], [l1_ptr, l2_ptr, l3_ptr, l4_ptr])
(n256, [l1, l2, l3, l4], [l1_ptr, l2_ptr, l3_ptr, l4_ptr])
}

#[test]
fn iterate() {
let (node, _, [l1_ptr, l2_ptr, l3_ptr, l4_ptr]) = fixture();

assert!(node
.iter()
.any(|(key_fragment, ptr)| key_fragment == 3 && ptr == l1_ptr));
assert!(node
.iter()
.any(|(key_fragment, ptr)| key_fragment == 255 && ptr == l2_ptr));
assert!(node
.iter()
.any(|(key_fragment, ptr)| key_fragment == 0u8 && ptr == l3_ptr));
assert!(node
.iter()
.any(|(key_fragment, ptr)| key_fragment == 85 && ptr == l4_ptr));
let mut iter = node.iter();

assert_eq!(iter.next().unwrap(), (0u8, l3_ptr));
assert_eq!(iter.next().unwrap(), (3, l1_ptr));
assert_eq!(iter.next().unwrap(), (85, l4_ptr));
assert_eq!(iter.next().unwrap(), (255, l2_ptr));
assert_eq!(iter.next(), None);
}

#[test]
fn iterate_rev() {
let (node, _, [l1_ptr, l2_ptr, l3_ptr, l4_ptr]) = fixture();

let mut iter = node.iter().rev();

assert_eq!(iter.next().unwrap(), (255, l2_ptr));
assert_eq!(iter.next().unwrap(), (85, l4_ptr));
assert_eq!(iter.next().unwrap(), (3, l1_ptr));
assert_eq!(iter.next().unwrap(), (0u8, l3_ptr));
assert_eq!(iter.next(), None);
}

#[test]
fn range_iterate() {
let (node, _, [l1_ptr, l2_ptr, l3_ptr, l4_ptr]) = fixture();

let pairs = node
.range((Bound::Included(0), Bound::Included(3)))
.collect::<Vec<_>>();
assert_eq!(pairs, &[(0u8, l3_ptr), (3, l1_ptr)]);

let pairs = node
.range((Bound::Excluded(0), Bound::Excluded(3)))
.collect::<Vec<_>>();
assert_eq!(pairs, &[]);

let pairs = node
.range((Bound::Included(0), Bound::Included(0)))
.collect::<Vec<_>>();
assert_eq!(pairs, &[(0u8, l3_ptr)]);

let pairs = node
.range((Bound::Included(0), Bound::Included(255)))
.collect::<Vec<_>>();
assert_eq!(
pairs,
&[(0u8, l3_ptr), (3, l1_ptr), (85, l4_ptr), (255, l2_ptr),]
);

let pairs = node
.range((Bound::Included(255), Bound::Included(255)))
.collect::<Vec<_>>();
assert_eq!(pairs, &[(255, l2_ptr),]);

let pairs = node
.range((Bound::Included(255), Bound::Excluded(255)))
.collect::<Vec<_>>();
assert_eq!(pairs, &[]);

let pairs = node
.range((Bound::Excluded(255), Bound::Included(255)))
.collect::<Vec<_>>();
assert_eq!(pairs, &[]);

let pairs = node
.range((Bound::Excluded(0), Bound::Excluded(255)))
.collect::<Vec<_>>();
assert_eq!(pairs, &[(3, l1_ptr), (85, l4_ptr)]);

let pairs = node
.range((Bound::<u8>::Unbounded, Bound::Unbounded))
.collect::<Vec<_>>();
assert_eq!(
pairs,
&[(0u8, l3_ptr), (3, l1_ptr), (85, l4_ptr), (255, l2_ptr),]
);
}

#[test]
#[should_panic]
fn range_iterate_out_of_bounds_panic_both_excluded() {
let (node, _, [_l1_ptr, _l2_ptr, _l3_ptr, _l4_ptr]) = fixture();

let pairs = node
.range((Bound::Excluded(80), Bound::Excluded(80)))
.collect::<Vec<_>>();
assert_eq!(pairs, &[]);
}
}
144 changes: 131 additions & 13 deletions src/nodes/representation/inner_node_48.rs
Original file line number Diff line number Diff line change
Expand Up @@ -163,7 +163,13 @@ impl<K: AsBytes, V, const PREFIX_LEN: usize> InnerNode<PREFIX_LEN>
#[cfg(not(feature = "nightly"))]
type Iter<'a> = Node48Iter<'a, K, V, PREFIX_LEN> where Self: 'a;
#[cfg(feature = "nightly")]
type Iter<'a> = Map<FilterMap<Enumerate<Iter<'a, RestrictedNodeIndex<48>>>, impl FnMut((usize, &'a RestrictedNodeIndex<48>)) -> Option<(u8, usize)>>, impl FnMut((u8, usize)) -> (u8, OpaqueNodePtr<K, V, PREFIX_LEN>)> where Self: 'a;
type Iter<'a> = Map<
FilterMap<
Enumerate<Iter<'a, RestrictedNodeIndex<48>>>,
impl FnMut((usize, &'a RestrictedNodeIndex<48>)) -> Option<(u8, usize)>,
>,
impl FnMut((u8, usize)) -> (u8, OpaqueNodePtr<K, V, PREFIX_LEN>),
> where Self: 'a;
type ShrunkNode = InnerNode16<K, V, PREFIX_LEN>;

fn header(&self) -> &Header<PREFIX_LEN> {
Expand Down Expand Up @@ -369,6 +375,40 @@ impl<K: AsBytes, V, const PREFIX_LEN: usize> InnerNode<PREFIX_LEN>
}
}

fn range(
&self,
bound: impl std::ops::RangeBounds<u8>,
) -> impl Iterator<Item = (u8, OpaqueNodePtr<Self::Key, Self::Value, PREFIX_LEN>)>
+ DoubleEndedIterator
+ FusedIterator {
let child_pointers = self.initialized_child_pointers();

let start = bound.start_bound().map(|val| usize::from(*val));
let key_offset = match bound.start_bound() {
std::ops::Bound::Included(val) => *val,
std::ops::Bound::Excluded(val) => val.saturating_add(1),
std::ops::Bound::Unbounded => 0,
};
let end = bound.end_bound().map(|val| usize::from(*val));

(&self.child_indices[(start, end)])
.iter()
.enumerate()
.filter_map(|(key, idx)| {
// normally `enumerate()` has elements of (idx, val), but we're using the index
// as the key since it ranges from [0, 256)
(!idx.is_empty()).then_some((key as u8, usize::from(*idx)))
})
// SAFETY: By the construction of `Self` idx, must always
// be inbounds
.map(move |(key, idx)| unsafe {
(
key.saturating_add(key_offset),
*child_pointers.get_unchecked(idx),
)
})
}

#[cfg(feature = "nightly")]
fn min(&self) -> (u8, OpaqueNodePtr<K, V, PREFIX_LEN>) {
// SAFETY: Since `RestrictedNodeIndex` is
Expand Down Expand Up @@ -576,6 +616,8 @@ impl<'a, K: AsBytes, V, const PREFIX_LEN: usize> FusedIterator

#[cfg(test)]
mod tests {
use std::ops::Bound;

use crate::{
nodes::representation::tests::{
inner_node_remove_child_test, inner_node_shrink_test, inner_node_write_child_test,
Expand Down Expand Up @@ -623,6 +665,7 @@ mod tests {
inner_node_remove_child_test(InnerNode48::<_, _, 16>::empty(), 48)
}

// TODO
// #[test]
// #[should_panic]
// fn write_child_full_panic() {
Expand Down Expand Up @@ -685,17 +728,92 @@ mod tests {
fn iterate() {
let (node, _, [l1_ptr, l2_ptr, l3_ptr, l4_ptr]) = fixture();

assert!(node
.iter()
.any(|(key_fragment, ptr)| key_fragment == 3 && ptr == l1_ptr));
assert!(node
.iter()
.any(|(key_fragment, ptr)| key_fragment == 255 && ptr == l2_ptr));
assert!(node
.iter()
.any(|(key_fragment, ptr)| key_fragment == 0u8 && ptr == l3_ptr));
assert!(node
.iter()
.any(|(key_fragment, ptr)| key_fragment == 85 && ptr == l4_ptr));
let mut iter = node.iter();

assert_eq!(iter.next().unwrap(), (0u8, l3_ptr));
assert_eq!(iter.next().unwrap(), (3, l1_ptr));
assert_eq!(iter.next().unwrap(), (85, l4_ptr));
assert_eq!(iter.next().unwrap(), (255, l2_ptr));
assert_eq!(iter.next(), None);
}

#[test]
fn iterate_rev() {
let (node, _, [l1_ptr, l2_ptr, l3_ptr, l4_ptr]) = fixture();

let mut iter = node.iter().rev();

assert_eq!(iter.next().unwrap(), (255, l2_ptr));
assert_eq!(iter.next().unwrap(), (85, l4_ptr));
assert_eq!(iter.next().unwrap(), (3, l1_ptr));
assert_eq!(iter.next().unwrap(), (0u8, l3_ptr));
assert_eq!(iter.next(), None);
}

#[test]
fn range_iterate() {
let (node, _, [l1_ptr, l2_ptr, l3_ptr, l4_ptr]) = fixture();

let pairs = node
.range((Bound::Included(0), Bound::Included(3)))
.collect::<Vec<_>>();
assert_eq!(pairs, &[(0u8, l3_ptr), (3, l1_ptr)]);

let pairs = node
.range((Bound::Excluded(0), Bound::Excluded(3)))
.collect::<Vec<_>>();
assert_eq!(pairs, &[]);

let pairs = node
.range((Bound::Included(0), Bound::Included(0)))
.collect::<Vec<_>>();
assert_eq!(pairs, &[(0u8, l3_ptr)]);

let pairs = node
.range((Bound::Included(0), Bound::Included(255)))
.collect::<Vec<_>>();
assert_eq!(
pairs,
&[(0u8, l3_ptr), (3, l1_ptr), (85, l4_ptr), (255, l2_ptr),]
);

let pairs = node
.range((Bound::Included(255), Bound::Included(255)))
.collect::<Vec<_>>();
assert_eq!(pairs, &[(255, l2_ptr),]);

let pairs = node
.range((Bound::Included(255), Bound::Excluded(255)))
.collect::<Vec<_>>();
assert_eq!(pairs, &[]);

let pairs = node
.range((Bound::Excluded(255), Bound::Included(255)))
.collect::<Vec<_>>();
assert_eq!(pairs, &[]);

let pairs = node
.range((Bound::Excluded(0), Bound::Excluded(255)))
.collect::<Vec<_>>();
assert_eq!(pairs, &[(3, l1_ptr), (85, l4_ptr)]);

let pairs = node
.range((Bound::<u8>::Unbounded, Bound::Unbounded))
.collect::<Vec<_>>();
assert_eq!(
pairs,
&[(0u8, l3_ptr), (3, l1_ptr), (85, l4_ptr), (255, l2_ptr),]
);
}

#[test]
#[should_panic]
fn range_iterate_out_of_bounds_panic_both_excluded() {
let (node, _, [_l1_ptr, _l2_ptr, _l3_ptr, _l4_ptr]) = fixture();

let pairs = node
.range((Bound::Excluded(80), Bound::Excluded(80)))
.collect::<Vec<_>>();
assert_eq!(pairs, &[]);
}
}
Loading

0 comments on commit f883bd4

Please sign in to comment.