diff options
Diffstat (limited to 'trie/iterator.go')
-rw-r--r-- | trie/iterator.go | 152 |
1 files changed, 59 insertions, 93 deletions
diff --git a/trie/iterator.go b/trie/iterator.go index 1114715a6..f0dae28bb 100644 --- a/trie/iterator.go +++ b/trie/iterator.go @@ -1,124 +1,73 @@ package trie -/* -import ( - "bytes" - - "github.com/ethereum/go-ethereum/ethutil" -) - -type NodeType byte - -const ( - EmptyNode NodeType = iota - BranchNode - LeafNode - ExtNode -) - -func getType(node *ethutil.Value) NodeType { - if node.Len() == 0 { - return EmptyNode - } - - if node.Len() == 2 { - k := CompactDecode(node.Get(0).Str()) - if HasTerm(k) { - return LeafNode - } - - return ExtNode - } - - return BranchNode -} +import "bytes" type Iterator struct { - Path [][]byte trie *Trie Key []byte - Value *ethutil.Value + Value []byte } func NewIterator(trie *Trie) *Iterator { - return &Iterator{trie: trie} + return &Iterator{trie: trie, Key: make([]byte, 32)} } -func (self *Iterator) key(node *ethutil.Value, path [][]byte) []byte { - switch getType(node) { - case LeafNode: - k := RemTerm(CompactDecode(node.Get(0).Str())) +func (self *Iterator) Next() bool { + self.trie.mu.Lock() + defer self.trie.mu.Unlock() - self.Path = append(path, k) - self.Value = node.Get(1) + key := RemTerm(CompactHexDecode(string(self.Key))) + k := self.next(self.trie.root, key) - return k - case BranchNode: - if node.Get(16).Len() > 0 { - return []byte{16} - } + self.Key = []byte(DecodeCompact(k)) - for i := byte(0); i < 16; i++ { - o := self.key(self.trie.getNode(node.Get(int(i)).Raw()), append(path, []byte{i})) - if o != nil { - return append([]byte{i}, o...) - } - } - case ExtNode: - currKey := node.Get(0).Bytes() + return len(k) > 0 - return self.key(self.trie.getNode(node.Get(1).Raw()), append(path, currKey)) - } - - return nil } -func (self *Iterator) next(node *ethutil.Value, key []byte, path [][]byte) []byte { - switch typ := getType(node); typ { - case EmptyNode: +func (self *Iterator) next(node Node, key []byte) []byte { + if node == nil { return nil - case BranchNode: - if len(key) > 0 { - subNode := self.trie.getNode(node.Get(int(key[0])).Raw()) + } - o := self.next(subNode, key[1:], append(path, key[:1])) - if o != nil { - return append([]byte{key[0]}, o...) + switch node := node.(type) { + case *FullNode: + if len(key) > 0 { + k := self.next(node.branch(key[0]), key[1:]) + if k != nil { + return append([]byte{key[0]}, k...) } } - var r byte = 0 + var r byte if len(key) > 0 { r = key[0] + 1 } for i := r; i < 16; i++ { - subNode := self.trie.getNode(node.Get(int(i)).Raw()) - o := self.key(subNode, append(path, []byte{i})) - if o != nil { - return append([]byte{i}, o...) + k := self.key(node.branch(byte(i))) + if k != nil { + return append([]byte{i}, k...) } } - case LeafNode, ExtNode: - k := RemTerm(CompactDecode(node.Get(0).Str())) - if typ == LeafNode { - if bytes.Compare([]byte(k), []byte(key)) > 0 { - self.Value = node.Get(1) - self.Path = append(path, k) + case *ShortNode: + k := RemTerm(node.Key()) + if vnode, ok := node.Value().(*ValueNode); ok { + if bytes.Compare([]byte(k), key) > 0 { + self.Value = vnode.Val() return k } } else { - subNode := self.trie.getNode(node.Get(1).Raw()) - subKey := key[len(k):] + cnode := node.Value() + var ret []byte + skey := key[len(k):] if BeginsWith(key, k) { - ret = self.next(subNode, subKey, append(path, k)) + ret = self.next(cnode, skey) } else if bytes.Compare(k, key[:len(k)]) > 0 { - ret = self.key(node, append(path, k)) - } else { - ret = nil + ret = self.key(node) } if ret != nil { @@ -130,16 +79,33 @@ func (self *Iterator) next(node *ethutil.Value, key []byte, path [][]byte) []byt return nil } -// Get the next in keys -func (self *Iterator) Next(key string) []byte { - self.trie.mut.Lock() - defer self.trie.mut.Unlock() +func (self *Iterator) key(node Node) []byte { + switch node := node.(type) { + case *ShortNode: + // Leaf node + if vnode, ok := node.Value().(*ValueNode); ok { + k := RemTerm(node.Key()) + self.Value = vnode.Val() - k := RemTerm(CompactHexDecode(key)) - n := self.next(self.trie.getNode(self.trie.Root), k, nil) + return k + } else { + k := RemTerm(node.Key()) + return append(k, self.key(node.Value())...) + } + case *FullNode: + if node.Value() != nil { + self.Value = node.Value().(*ValueNode).Val() - self.Key = []byte(DecodeCompact(n)) + return []byte{16} + } - return self.Key + for i := 0; i < 16; i++ { + k := self.key(node.branch(byte(i))) + if k != nil { + return append([]byte{byte(i)}, k...) + } + } + } + + return nil } -*/ |