diff options
Diffstat (limited to 'trie/iterator.go')
-rw-r--r-- | trie/iterator.go | 127 |
1 files changed, 76 insertions, 51 deletions
diff --git a/trie/iterator.go b/trie/iterator.go index fef5b2593..26ae1d5ad 100644 --- a/trie/iterator.go +++ b/trie/iterator.go @@ -19,10 +19,13 @@ package trie import ( "bytes" "container/heap" + "errors" "github.com/ethereum/go-ethereum/common" ) +var iteratorEnd = errors.New("end of iteration") + // Iterator is a key-value trie iterator that traverses a Trie. type Iterator struct { nodeIt NodeIterator @@ -79,25 +82,24 @@ type nodeIteratorState struct { hash common.Hash // Hash of the node being iterated (nil if not standalone) node node // Trie node being iterated parent common.Hash // Hash of the first full ancestor node (nil if current is the root) - child int // Child to be processed next + index int // Child to be processed next pathlen int // Length of the path to this node } type nodeIterator struct { trie *Trie // Trie being iterated stack []*nodeIteratorState // Hierarchy of trie nodes persisting the iteration state - - err error // Failure set in case of an internal error in the iterator - - path []byte // Path to the current node + err error // Failure set in case of an internal error in the iterator + path []byte // Path to the current node } -// newNodeIterator creates an post-order trie iterator. -func newNodeIterator(trie *Trie) NodeIterator { +func newNodeIterator(trie *Trie, start []byte) NodeIterator { if trie.Hash() == emptyState { return new(nodeIterator) } - return &nodeIterator{trie: trie} + it := &nodeIterator{trie: trie} + it.seek(start) + return it } // Hash returns the hash of the current node @@ -147,6 +149,9 @@ func (it *nodeIterator) Path() []byte { // Error returns the error set in case of an internal error in the iterator func (it *nodeIterator) Error() error { + if it.err == iteratorEnd { + return nil + } return it.err } @@ -155,47 +160,54 @@ func (it *nodeIterator) Error() error { // sets the Error field to the encountered failure. If `descend` is false, // skips iterating over any subnodes of the current node. func (it *nodeIterator) Next(descend bool) bool { - // If the iterator failed previously, don't do anything if it.err != nil { return false } // Otherwise step forward with the iterator and report any errors - if err := it.step(descend); err != nil { + state, parentIndex, path, err := it.peek(descend) + if err != nil { it.err = err return false } - return it.trie != nil + it.push(state, parentIndex, path) + return true } -// step moves the iterator to the next node of the trie. -func (it *nodeIterator) step(descend bool) error { - if it.trie == nil { - // Abort if we reached the end of the iteration - return nil +func (it *nodeIterator) seek(prefix []byte) { + // The path we're looking for is the hex encoded key without terminator. + key := keybytesToHex(prefix) + key = key[:len(key)-1] + // Move forward until we're just before the closest match to key. + for { + state, parentIndex, path, err := it.peek(bytes.HasPrefix(key, it.path)) + if err != nil || bytes.Compare(path, key) >= 0 { + it.err = err + return + } + it.push(state, parentIndex, path) } +} + +// peek creates the next state of the iterator. +func (it *nodeIterator) peek(descend bool) (*nodeIteratorState, *int, []byte, error) { if len(it.stack) == 0 { // Initialize the iterator if we've just started. root := it.trie.Hash() - state := &nodeIteratorState{node: it.trie.root, child: -1} + state := &nodeIteratorState{node: it.trie.root, index: -1} if root != emptyRoot { state.hash = root } - it.stack = append(it.stack, state) - return nil + return state, nil, nil, nil } - if !descend { // If we're skipping children, pop the current node first - it.path = it.path[:it.stack[len(it.stack)-1].pathlen] - it.stack = it.stack[:len(it.stack)-1] + it.pop() } // Continue iteration to the next child -outer: for { if len(it.stack) == 0 { - it.trie = nil - return nil + return nil, nil, nil, iteratorEnd } parent := it.stack[len(it.stack)-1] ancestor := parent.hash @@ -203,63 +215,76 @@ outer: ancestor = parent.parent } if node, ok := parent.node.(*fullNode); ok { - // Full node, iterate over children - for parent.child++; parent.child < len(node.Children); parent.child++ { - child := node.Children[parent.child] + // Full node, move to the first non-nil child. + for i := parent.index + 1; i < len(node.Children); i++ { + child := node.Children[i] if child != nil { hash, _ := child.cache() - it.stack = append(it.stack, &nodeIteratorState{ + state := &nodeIteratorState{ hash: common.BytesToHash(hash), node: child, parent: ancestor, - child: -1, + index: -1, pathlen: len(it.path), - }) - it.path = append(it.path, byte(parent.child)) - break outer + } + path := append(it.path, byte(i)) + parent.index = i - 1 + return state, &parent.index, path, nil } } } else if node, ok := parent.node.(*shortNode); ok { // Short node, return the pointer singleton child - if parent.child < 0 { - parent.child++ + if parent.index < 0 { hash, _ := node.Val.cache() - it.stack = append(it.stack, &nodeIteratorState{ + state := &nodeIteratorState{ hash: common.BytesToHash(hash), node: node.Val, parent: ancestor, - child: -1, + index: -1, pathlen: len(it.path), - }) + } + var path []byte if hasTerm(node.Key) { - it.path = append(it.path, node.Key[:len(node.Key)-1]...) + path = append(it.path, node.Key[:len(node.Key)-1]...) } else { - it.path = append(it.path, node.Key...) + path = append(it.path, node.Key...) } - break + return state, &parent.index, path, nil } } else if hash, ok := parent.node.(hashNode); ok { // Hash node, resolve the hash child from the database - if parent.child < 0 { - parent.child++ + if parent.index < 0 { node, err := it.trie.resolveHash(hash, nil, nil) if err != nil { - return err + return it.stack[len(it.stack)-1], &parent.index, it.path, err } - it.stack = append(it.stack, &nodeIteratorState{ + state := &nodeIteratorState{ hash: common.BytesToHash(hash), node: node, parent: ancestor, - child: -1, + index: -1, pathlen: len(it.path), - }) - break + } + return state, &parent.index, it.path, nil } } - it.path = it.path[:parent.pathlen] - it.stack = it.stack[:len(it.stack)-1] + // No more child nodes, move back up. + it.pop() } - return nil +} + +func (it *nodeIterator) push(state *nodeIteratorState, parentIndex *int, path []byte) { + it.path = path + it.stack = append(it.stack, state) + if parentIndex != nil { + *parentIndex += 1 + } +} + +func (it *nodeIterator) pop() { + parent := it.stack[len(it.stack)-1] + it.path = it.path[:parent.pathlen] + it.stack = it.stack[:len(it.stack)-1] } func compareNodes(a, b NodeIterator) int { |