Skip to content

Commit

Permalink
Merge pull request #2057 from CortexFoundation/dev
Browse files Browse the repository at this point in the history
trie: iterate values pre-order and fix seek behavior
  • Loading branch information
ucwong authored Jun 13, 2024
2 parents 9ffb4fb + f901cc5 commit a9f64f3
Show file tree
Hide file tree
Showing 3 changed files with 88 additions and 22 deletions.
79 changes: 63 additions & 16 deletions trie/iterator.go
Original file line number Diff line number Diff line change
Expand Up @@ -115,7 +115,7 @@ type nodeIteratorState struct {
node node // Trie node being iterated
parent common.Hash // Hash of the first full ancestor node (nil if current is the root)
index int // Child to be processed next
pathlen int // Length of the path to this node
pathlen int // Length of the path to the parent node
}

type nodeIterator struct {
Expand Down Expand Up @@ -250,14 +250,15 @@ func (it *nodeIterator) seek(prefix []byte) error {
// The path we're looking for is the hex encoded key without terminator.
key := keybytesToHex(prefix)
key = key[:len(key)-1]

// Move forward until we're just before the closest match to key.
for {
state, parentIndex, path, err := it.peekSeek(key)
if err == errIteratorEnd {
return errIteratorEnd
} else if err != nil {
return seekError{prefix, err}
} else if bytes.Compare(path, key) >= 0 {
} else if reachedPath(path, key) {
return nil
}
it.push(state, parentIndex, path)
Expand Down Expand Up @@ -285,7 +286,6 @@ func (it *nodeIterator) peek(descend bool) (*nodeIteratorState, *int, []byte, er
// If we're skipping children, pop the current node first
it.pop()
}

// Continue iteration to the next child
for len(it.stack) > 0 {
parent := it.stack[len(it.stack)-1]
Expand Down Expand Up @@ -318,7 +318,6 @@ func (it *nodeIterator) peekSeek(seekKey []byte) (*nodeIteratorState, *int, []by
// If we're skipping children, pop the current node first
it.pop()
}

// Continue iteration to the next child
for len(it.stack) > 0 {
parent := it.stack[len(it.stack)-1]
Expand Down Expand Up @@ -378,16 +377,18 @@ func (it *nodeIterator) findChild(n *fullNode, index int, ancestor common.Hash)
state *nodeIteratorState
childPath []byte
)
for ; index < len(n.Children); index++ {
for ; index < len(n.Children); index = nextChildIndex(index) {
if n.Children[index] != nil {
child = n.Children[index]
hash, _ := child.cache()

state = it.getFromPool()
state.hash = common.BytesToHash(hash)
state.node = child
state.parent = ancestor
state.index = -1
state.pathlen = len(path)

childPath = append(childPath, path...)
childPath = append(childPath, byte(index))
return child, state, childPath, index
Expand All @@ -400,8 +401,8 @@ func (it *nodeIterator) nextChild(parent *nodeIteratorState, ancestor common.Has
switch node := parent.node.(type) {
case *fullNode:
// Full node, move to the first non-nil child.
if child, state, path, index := it.findChild(node, parent.index+1, ancestor); child != nil {
parent.index = index - 1
if child, state, path, index := it.findChild(node, nextChildIndex(parent.index), ancestor); child != nil {
parent.index = prevChildIndex(index)
return state, path, true
}
case *shortNode:
Expand All @@ -427,23 +428,23 @@ func (it *nodeIterator) nextChildAt(parent *nodeIteratorState, ancestor common.H
switch n := parent.node.(type) {
case *fullNode:
// Full node, move to the first non-nil child before the desired key position
child, state, path, index := it.findChild(n, parent.index+1, ancestor)
child, state, path, index := it.findChild(n, nextChildIndex(parent.index), ancestor)
if child == nil {
// No more children in this fullnode
return parent, it.path, false
}
// If the child we found is already past the seek position, just return it.
if bytes.Compare(path, key) >= 0 {
parent.index = index - 1
if reachedPath(path, key) {
parent.index = prevChildIndex(index)
return state, path, true
}
// The child is before the seek position. Try advancing
for {
nextChild, nextState, nextPath, nextIndex := it.findChild(n, index+1, ancestor)
nextChild, nextState, nextPath, nextIndex := it.findChild(n, nextChildIndex(index), ancestor)
// If we run out of children, or skipped past the target, return the
// previous one
if nextChild == nil || bytes.Compare(nextPath, key) >= 0 {
parent.index = index - 1
if nextChild == nil || reachedPath(nextPath, key) {
parent.index = prevChildIndex(index)
return state, path, true
}
// We found a better child closer to the target
Expand All @@ -470,7 +471,7 @@ func (it *nodeIterator) push(state *nodeIteratorState, parentIndex *int, path []
it.path = path
it.stack = append(it.stack, state)
if parentIndex != nil {
*parentIndex++
*parentIndex = nextChildIndex(*parentIndex)
}
}

Expand All @@ -479,8 +480,54 @@ func (it *nodeIterator) pop() {
it.path = it.path[:last.pathlen]
it.stack[len(it.stack)-1] = nil
it.stack = it.stack[:len(it.stack)-1]
// last is now unused
it.putInPool(last)

it.putInPool(last) // last is now unused
}

// reachedPath normalizes a path by truncating a terminator if present, and
// returns true if it is greater than or equal to the target. Using this,
// the path of a value node embedded a full node will compare less than the
// full node's children.
func reachedPath(path, target []byte) bool {
if hasTerm(path) {
path = path[:len(path)-1]
}
return bytes.Compare(path, target) >= 0
}

// A value embedded in a full node occupies the last slot (16) of the array of
// children. In order to produce a pre-order traversal when iterating children,
// we jump to this last slot first, then go back iterate the child nodes (and
// skip the last slot at the end):

// prevChildIndex returns the index of a child in a full node which precedes
// the given index when performing a pre-order traversal.
func prevChildIndex(index int) int {
switch index {
case 0: // We jumped back to iterate the children, from the value slot
return 16
case 16: // We jumped to the embedded value slot at the end, from the placeholder index
return -1
case 17: // We skipped the value slot after iterating all the children
return 15
default: // We are iterating the children in sequence
return index - 1
}
}

// nextChildIndex returns the index of a child in a full node which follows
// the given index when performing a pre-order traversal.
func nextChildIndex(index int) int {
switch index {
case -1: // Jump from the placeholder index to the embedded value slot
return 16
case 15: // Skip the value slot after iterating the children
return 17
case 16: // From the embedded value slot, jump back to iterate the children
return 0
default: // Iterate children in sequence
return index + 1
}
}

func compareNodes(a, b NodeIterator) int {
Expand Down
18 changes: 12 additions & 6 deletions trie/iterator_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -154,14 +154,14 @@ func TestNodeIteratorCoverage(t *testing.T) {
type kvs struct{ k, v string }

var testdata1 = []kvs{
{"bar", "b"},
{"barb", "ba"},
{"bard", "bc"},
{"bars", "bb"},
{"bar", "b"},
{"fab", "z"},
{"foo", "a"},
{"food", "ab"},
{"foos", "aa"},
{"foo", "a"},
}

var testdata2 = []kvs{
Expand Down Expand Up @@ -190,7 +190,7 @@ func TestIteratorSeek(t *testing.T) {

// Seek to a non-existent key.
it = NewIterator(trie.NodeIterator([]byte("barc")))
if err := checkIteratorOrder(testdata1[1:], it); err != nil {
if err := checkIteratorOrder(testdata1[2:], it); err != nil {
t.Fatal(err)
}

Expand All @@ -199,6 +199,12 @@ func TestIteratorSeek(t *testing.T) {
if err := checkIteratorOrder(nil, it); err != nil {
t.Fatal(err)
}

// Seek to a key for which a prefixing key exists.
it = NewIterator(trie.MustNodeIterator([]byte("food")))
if err := checkIteratorOrder(testdata1[6:], it); err != nil {
t.Fatal(err)
}
}

func checkIteratorOrder(want []kvs, it *Iterator) error {
Expand Down Expand Up @@ -271,16 +277,16 @@ func TestUnionIterator(t *testing.T) {

all := []struct{ k, v string }{
{"aardvark", "c"},
{"bar", "b"},
{"barb", "ba"},
{"barb", "bd"},
{"bard", "bc"},
{"bars", "bb"},
{"bars", "be"},
{"bar", "b"},
{"fab", "z"},
{"foo", "a"},
{"food", "ab"},
{"foos", "aa"},
{"foo", "a"},
{"jars", "d"},
}

Expand Down Expand Up @@ -444,7 +450,7 @@ func testIteratorContinueAfterSeekError(t *testing.T, memonly bool) {
diskdb.Put(barNodeHash[:], barNodeBlob)
}
// Check that iteration produces the right set of values.
if err := checkIteratorOrder(testdata1[2:], NewIterator(it)); err != nil {
if err := checkIteratorOrder(testdata1[3:], NewIterator(it)); err != nil {
t.Fatal(err)
}
}
Expand Down
13 changes: 13 additions & 0 deletions trie/trie.go
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,19 @@ func newWithRootNode(root node) *Trie {
}
}

// NewEmpty is a shortcut to create empty tree. It's mostly used in tests.
func NewEmpty(db *Database) *Trie {
tr, _ := New(TrieID(types.EmptyRootHash), db)
return tr
}

// MustNodeIterator is a wrapper of NodeIterator and will omit any encountered
// error but just print out an error message.
func (t *Trie) MustNodeIterator(start []byte) NodeIterator {
it := t.NodeIterator(start)
return it
}

// NodeIterator returns an iterator that returns nodes of the trie. Iteration starts at
// the key after the given start key.
func (t *Trie) NodeIterator(start []byte) NodeIterator {
Expand Down

0 comments on commit a9f64f3

Please sign in to comment.