diff options
Diffstat (limited to 'trie')
-rw-r--r-- | trie/encoding.go | 114 | ||||
-rw-r--r-- | trie/encoding_test.go | 147 | ||||
-rw-r--r-- | trie/errors.go | 5 | ||||
-rw-r--r-- | trie/hasher.go | 2 | ||||
-rw-r--r-- | trie/iterator.go | 141 | ||||
-rw-r--r-- | trie/iterator_test.go | 69 | ||||
-rw-r--r-- | trie/node.go | 4 | ||||
-rw-r--r-- | trie/proof.go | 4 | ||||
-rw-r--r-- | trie/secure_trie.go | 10 | ||||
-rw-r--r-- | trie/sync_test.go | 2 | ||||
-rw-r--r-- | trie/trie.go | 14 | ||||
-rw-r--r-- | trie/trie_test.go | 2 |
12 files changed, 266 insertions, 248 deletions
diff --git a/trie/encoding.go b/trie/encoding.go index 2037118dd..e96a786e4 100644 --- a/trie/encoding.go +++ b/trie/encoding.go @@ -16,49 +16,54 @@ package trie -func compactEncode(hexSlice []byte) []byte { +// Trie keys are dealt with in three distinct encodings: +// +// KEYBYTES encoding contains the actual key and nothing else. This encoding is the +// input to most API functions. +// +// HEX encoding contains one byte for each nibble of the key and an optional trailing +// 'terminator' byte of value 0x10 which indicates whether or not the node at the key +// contains a value. Hex key encoding is used for nodes loaded in memory because it's +// convenient to access. +// +// COMPACT encoding is defined by the Ethereum Yellow Paper (it's called "hex prefix +// encoding" there) and contains the bytes of the key and a flag. The high nibble of the +// first byte contains the flag; the lowest bit encoding the oddness of the length and +// the second-lowest encoding whether the node at the key is a value node. The low nibble +// of the first byte is zero in the case of an even number of nibbles and the first nibble +// in the case of an odd number. All remaining nibbles (now an even number) fit properly +// into the remaining bytes. Compact encoding is used for nodes stored on disk. + +func hexToCompact(hex []byte) []byte { terminator := byte(0) - if hexSlice[len(hexSlice)-1] == 16 { + if hasTerm(hex) { terminator = 1 - hexSlice = hexSlice[:len(hexSlice)-1] - } - var ( - odd = byte(len(hexSlice) % 2) - buflen = len(hexSlice)/2 + 1 - bi, hi = 0, 0 // indices - hs = byte(0) // shift: flips between 0 and 4 - ) - if odd == 0 { - bi = 1 - hs = 4 + hex = hex[:len(hex)-1] } - buf := make([]byte, buflen) - buf[0] = terminator<<5 | byte(odd)<<4 - for bi < len(buf) && hi < len(hexSlice) { - buf[bi] |= hexSlice[hi] << hs - if hs == 0 { - bi++ - } - hi, hs = hi+1, hs^(1<<2) + buf := make([]byte, len(hex)/2+1) + buf[0] = terminator << 5 // the flag byte + if len(hex)&1 == 1 { + buf[0] |= 1 << 4 // odd flag + buf[0] |= hex[0] // first nibble is contained in the first byte + hex = hex[1:] } + decodeNibbles(hex, buf[1:]) return buf } -func compactDecode(str []byte) []byte { - base := compactHexDecode(str) +func compactToHex(compact []byte) []byte { + base := keybytesToHex(compact) base = base[:len(base)-1] + // apply terminator flag if base[0] >= 2 { base = append(base, 16) } - if base[0]%2 == 1 { - base = base[1:] - } else { - base = base[2:] - } - return base + // apply odd flag + chop := 2 - base[0]&1 + return base[chop:] } -func compactHexDecode(str []byte) []byte { +func keybytesToHex(str []byte) []byte { l := len(str)*2 + 1 var nibbles = make([]byte, l) for i, b := range str { @@ -69,35 +74,24 @@ func compactHexDecode(str []byte) []byte { return nibbles } -// compactHexEncode encodes a series of nibbles into a byte array -func compactHexEncode(nibbles []byte) []byte { - nl := len(nibbles) - if nl == 0 { - return nil - } - if nibbles[nl-1] == 16 { - nl-- +// hexToKeybytes turns hex nibbles into key bytes. +// This can only be used for keys of even length. +func hexToKeybytes(hex []byte) []byte { + if hasTerm(hex) { + hex = hex[:len(hex)-1] } - l := (nl + 1) / 2 - var str = make([]byte, l) - for i := range str { - b := nibbles[i*2] * 16 - if nl > i*2 { - b += nibbles[i*2+1] - } - str[i] = b + if len(hex)&1 != 0 { + panic("can't convert hex key of odd length") } - return str + key := make([]byte, (len(hex)+1)/2) + decodeNibbles(hex, key) + return key } -func decodeCompact(key []byte) []byte { - l := len(key) / 2 - var res = make([]byte, l) - for i := 0; i < l; i++ { - v1, v0 := key[2*i], key[2*i+1] - res[i] = v1*16 + v0 +func decodeNibbles(nibbles []byte, bytes []byte) { + for bi, ni := 0, 0; ni < len(nibbles); bi, ni = bi+1, ni+2 { + bytes[bi] = nibbles[ni]<<4 | nibbles[ni+1] } - return res } // prefixLen returns the length of the common prefix of a and b. @@ -114,15 +108,7 @@ func prefixLen(a, b []byte) int { return i } +// hasTerm returns whether a hex key has the terminator flag. func hasTerm(s []byte) bool { - return s[len(s)-1] == 16 -} - -func remTerm(s []byte) []byte { - if hasTerm(s) { - b := make([]byte, len(s)-1) - copy(b, s) - return b - } - return s + return len(s) > 0 && s[len(s)-1] == 16 } diff --git a/trie/encoding_test.go b/trie/encoding_test.go index 2f125ef2f..97d8da136 100644 --- a/trie/encoding_test.go +++ b/trie/encoding_test.go @@ -17,113 +17,88 @@ package trie import ( - "encoding/hex" + "bytes" "testing" - - checker "gopkg.in/check.v1" ) -func TestEncoding(t *testing.T) { checker.TestingT(t) } - -type TrieEncodingSuite struct{} - -var _ = checker.Suite(&TrieEncodingSuite{}) - -func (s *TrieEncodingSuite) TestCompactEncode(c *checker.C) { - // even compact encode - test1 := []byte{1, 2, 3, 4, 5} - res1 := compactEncode(test1) - c.Assert(res1, checker.DeepEquals, []byte("\x11\x23\x45")) - - // odd compact encode - test2 := []byte{0, 1, 2, 3, 4, 5} - res2 := compactEncode(test2) - c.Assert(res2, checker.DeepEquals, []byte("\x00\x01\x23\x45")) - - //odd terminated compact encode - test3 := []byte{0, 15, 1, 12, 11, 8 /*term*/, 16} - res3 := compactEncode(test3) - c.Assert(res3, checker.DeepEquals, []byte("\x20\x0f\x1c\xb8")) - - // even terminated compact encode - test4 := []byte{15, 1, 12, 11, 8 /*term*/, 16} - res4 := compactEncode(test4) - c.Assert(res4, checker.DeepEquals, []byte("\x3f\x1c\xb8")) -} - -func (s *TrieEncodingSuite) TestCompactHexDecode(c *checker.C) { - exp := []byte{7, 6, 6, 5, 7, 2, 6, 2, 16} - res := compactHexDecode([]byte("verb")) - c.Assert(res, checker.DeepEquals, exp) -} - -func (s *TrieEncodingSuite) TestCompactHexEncode(c *checker.C) { - exp := []byte("verb") - res := compactHexEncode([]byte{7, 6, 6, 5, 7, 2, 6, 2, 16}) - c.Assert(res, checker.DeepEquals, exp) -} - -func (s *TrieEncodingSuite) TestCompactDecode(c *checker.C) { - // odd compact decode - exp := []byte{1, 2, 3, 4, 5} - res := compactDecode([]byte("\x11\x23\x45")) - c.Assert(res, checker.DeepEquals, exp) - - // even compact decode - exp = []byte{0, 1, 2, 3, 4, 5} - res = compactDecode([]byte("\x00\x01\x23\x45")) - c.Assert(res, checker.DeepEquals, exp) - - // even terminated compact decode - exp = []byte{0, 15, 1, 12, 11, 8 /*term*/, 16} - res = compactDecode([]byte("\x20\x0f\x1c\xb8")) - c.Assert(res, checker.DeepEquals, exp) - - // even terminated compact decode - exp = []byte{15, 1, 12, 11, 8 /*term*/, 16} - res = compactDecode([]byte("\x3f\x1c\xb8")) - c.Assert(res, checker.DeepEquals, exp) +func TestHexCompact(t *testing.T) { + tests := []struct{ hex, compact []byte }{ + // empty keys, with and without terminator. + {hex: []byte{}, compact: []byte{0x00}}, + {hex: []byte{16}, compact: []byte{0x20}}, + // odd length, no terminator + {hex: []byte{1, 2, 3, 4, 5}, compact: []byte{0x11, 0x23, 0x45}}, + // even length, no terminator + {hex: []byte{0, 1, 2, 3, 4, 5}, compact: []byte{0x00, 0x01, 0x23, 0x45}}, + // odd length, terminator + {hex: []byte{15, 1, 12, 11, 8, 16 /*term*/}, compact: []byte{0x3f, 0x1c, 0xb8}}, + // even length, terminator + {hex: []byte{0, 15, 1, 12, 11, 8, 16 /*term*/}, compact: []byte{0x20, 0x0f, 0x1c, 0xb8}}, + } + for _, test := range tests { + if c := hexToCompact(test.hex); !bytes.Equal(c, test.compact) { + t.Errorf("hexToCompact(%x) -> %x, want %x", test.hex, c, test.compact) + } + if h := compactToHex(test.compact); !bytes.Equal(h, test.hex) { + t.Errorf("compactToHex(%x) -> %x, want %x", test.compact, h, test.hex) + } + } } -func (s *TrieEncodingSuite) TestDecodeCompact(c *checker.C) { - exp, _ := hex.DecodeString("012345") - res := decodeCompact([]byte{0, 1, 2, 3, 4, 5}) - c.Assert(res, checker.DeepEquals, exp) - - exp, _ = hex.DecodeString("012345") - res = decodeCompact([]byte{0, 1, 2, 3, 4, 5, 16}) - c.Assert(res, checker.DeepEquals, exp) - - exp, _ = hex.DecodeString("abcdef") - res = decodeCompact([]byte{10, 11, 12, 13, 14, 15}) - c.Assert(res, checker.DeepEquals, exp) +func TestHexKeybytes(t *testing.T) { + tests := []struct{ key, hexIn, hexOut []byte }{ + {key: []byte{}, hexIn: []byte{16}, hexOut: []byte{16}}, + {key: []byte{}, hexIn: []byte{}, hexOut: []byte{16}}, + { + key: []byte{0x12, 0x34, 0x56}, + hexIn: []byte{1, 2, 3, 4, 5, 6, 16}, + hexOut: []byte{1, 2, 3, 4, 5, 6, 16}, + }, + { + key: []byte{0x12, 0x34, 0x5}, + hexIn: []byte{1, 2, 3, 4, 0, 5, 16}, + hexOut: []byte{1, 2, 3, 4, 0, 5, 16}, + }, + { + key: []byte{0x12, 0x34, 0x56}, + hexIn: []byte{1, 2, 3, 4, 5, 6}, + hexOut: []byte{1, 2, 3, 4, 5, 6, 16}, + }, + } + for _, test := range tests { + if h := keybytesToHex(test.key); !bytes.Equal(h, test.hexOut) { + t.Errorf("keybytesToHex(%x) -> %x, want %x", test.key, h, test.hexOut) + } + if k := hexToKeybytes(test.hexIn); !bytes.Equal(k, test.key) { + t.Errorf("hexToKeybytes(%x) -> %x, want %x", test.hexIn, k, test.key) + } + } } -func BenchmarkCompactEncode(b *testing.B) { - - testBytes := []byte{0, 15, 1, 12, 11, 8 /*term*/, 16} +func BenchmarkHexToCompact(b *testing.B) { + testBytes := []byte{0, 15, 1, 12, 11, 8, 16 /*term*/} for i := 0; i < b.N; i++ { - compactEncode(testBytes) + hexToCompact(testBytes) } } -func BenchmarkCompactDecode(b *testing.B) { - testBytes := []byte{0, 15, 1, 12, 11, 8 /*term*/, 16} +func BenchmarkCompactToHex(b *testing.B) { + testBytes := []byte{0, 15, 1, 12, 11, 8, 16 /*term*/} for i := 0; i < b.N; i++ { - compactDecode(testBytes) + compactToHex(testBytes) } } -func BenchmarkCompactHexDecode(b *testing.B) { +func BenchmarkKeybytesToHex(b *testing.B) { testBytes := []byte{7, 6, 6, 5, 7, 2, 6, 2, 16} for i := 0; i < b.N; i++ { - compactHexDecode(testBytes) + keybytesToHex(testBytes) } } -func BenchmarkDecodeCompact(b *testing.B) { +func BenchmarkHexToKeybytes(b *testing.B) { testBytes := []byte{7, 6, 6, 5, 7, 2, 6, 2, 16} for i := 0; i < b.N; i++ { - decodeCompact(testBytes) + hexToKeybytes(testBytes) } } diff --git a/trie/errors.go b/trie/errors.go index 76129a70b..e23f9d563 100644 --- a/trie/errors.go +++ b/trie/errors.go @@ -30,10 +30,6 @@ import ( // // RootHash is the original root of the trie that contains the node // -// Key is a binary-encoded key that contains the prefix that leads to the first -// missing node and optionally a suffix that hints on which further nodes should -// also be retrieved -// // PrefixLen is the nibble length of the key prefix that leads from the root to // the missing node // @@ -42,7 +38,6 @@ import ( // such hints in the error message) type MissingNodeError struct { RootHash, NodeHash common.Hash - Key []byte PrefixLen, SuffixLen int } diff --git a/trie/hasher.go b/trie/hasher.go index 98c309531..85b6b60f5 100644 --- a/trie/hasher.go +++ b/trie/hasher.go @@ -105,7 +105,7 @@ func (h *hasher) hashChildren(original node, db DatabaseWriter) (node, node, err case *shortNode: // Hash the short node's child, caching the newly hashed subtree collapsed, cached := n.copy(), n.copy() - collapsed.Key = compactEncode(n.Key) + collapsed.Key = hexToCompact(n.Key) cached.Key = common.CopyBytes(n.Key) if _, ok := n.Val.(valueNode); !ok { diff --git a/trie/iterator.go b/trie/iterator.go index 42149a7d3..26ae1d5ad 100644 --- a/trie/iterator.go +++ b/trie/iterator.go @@ -19,9 +19,13 @@ package trie import ( "bytes" "container/heap" + "errors" + "github.com/ethereum/go-ethereum/common" ) +var iteratorEnd = errors.New("end of iteration") + // Iterator is a key-value trie iterator that traverses a Trie. type Iterator struct { nodeIt NodeIterator @@ -30,15 +34,8 @@ type Iterator struct { Value []byte // Current data value on which the iterator is positioned on } -// NewIterator creates a new key-value iterator. -func NewIterator(trie *Trie) *Iterator { - return &Iterator{ - nodeIt: NewNodeIterator(trie), - } -} - -// FromNodeIterator creates a new key-value iterator from a node iterator -func NewIteratorFromNodeIterator(it NodeIterator) *Iterator { +// NewIterator creates a new key-value iterator from a node iterator +func NewIterator(it NodeIterator) *Iterator { return &Iterator{ nodeIt: it, } @@ -48,7 +45,7 @@ func NewIteratorFromNodeIterator(it NodeIterator) *Iterator { func (it *Iterator) Next() bool { for it.nodeIt.Next(true) { if it.nodeIt.Leaf() { - it.Key = decodeCompact(it.nodeIt.Path()) + it.Key = hexToKeybytes(it.nodeIt.Path()) it.Value = it.nodeIt.LeafBlob() return true } @@ -85,25 +82,24 @@ type nodeIteratorState struct { hash common.Hash // Hash of the node being iterated (nil if not standalone) node node // Trie node being iterated parent common.Hash // Hash of the first full ancestor node (nil if current is the root) - child int // Child to be processed next + index int // Child to be processed next pathlen int // Length of the path to this node } type nodeIterator struct { trie *Trie // Trie being iterated stack []*nodeIteratorState // Hierarchy of trie nodes persisting the iteration state - - err error // Failure set in case of an internal error in the iterator - - path []byte // Path to the current node + err error // Failure set in case of an internal error in the iterator + path []byte // Path to the current node } -// NewNodeIterator creates an post-order trie iterator. -func NewNodeIterator(trie *Trie) NodeIterator { +func newNodeIterator(trie *Trie, start []byte) NodeIterator { if trie.Hash() == emptyState { return new(nodeIterator) } - return &nodeIterator{trie: trie} + it := &nodeIterator{trie: trie} + it.seek(start) + return it } // Hash returns the hash of the current node @@ -153,6 +149,9 @@ func (it *nodeIterator) Path() []byte { // Error returns the error set in case of an internal error in the iterator func (it *nodeIterator) Error() error { + if it.err == iteratorEnd { + return nil + } return it.err } @@ -161,47 +160,54 @@ func (it *nodeIterator) Error() error { // sets the Error field to the encountered failure. If `descend` is false, // skips iterating over any subnodes of the current node. func (it *nodeIterator) Next(descend bool) bool { - // If the iterator failed previously, don't do anything if it.err != nil { return false } // Otherwise step forward with the iterator and report any errors - if err := it.step(descend); err != nil { + state, parentIndex, path, err := it.peek(descend) + if err != nil { it.err = err return false } - return it.trie != nil + it.push(state, parentIndex, path) + return true } -// step moves the iterator to the next node of the trie. -func (it *nodeIterator) step(descend bool) error { - if it.trie == nil { - // Abort if we reached the end of the iteration - return nil +func (it *nodeIterator) seek(prefix []byte) { + // The path we're looking for is the hex encoded key without terminator. + key := keybytesToHex(prefix) + key = key[:len(key)-1] + // Move forward until we're just before the closest match to key. + for { + state, parentIndex, path, err := it.peek(bytes.HasPrefix(key, it.path)) + if err != nil || bytes.Compare(path, key) >= 0 { + it.err = err + return + } + it.push(state, parentIndex, path) } +} + +// peek creates the next state of the iterator. +func (it *nodeIterator) peek(descend bool) (*nodeIteratorState, *int, []byte, error) { if len(it.stack) == 0 { // Initialize the iterator if we've just started. root := it.trie.Hash() - state := &nodeIteratorState{node: it.trie.root, child: -1} + state := &nodeIteratorState{node: it.trie.root, index: -1} if root != emptyRoot { state.hash = root } - it.stack = append(it.stack, state) - return nil + return state, nil, nil, nil } - if !descend { // If we're skipping children, pop the current node first - it.path = it.path[:it.stack[len(it.stack)-1].pathlen] - it.stack = it.stack[:len(it.stack)-1] + it.pop() } // Continue iteration to the next child -outer: for { if len(it.stack) == 0 { - it.trie = nil - return nil + return nil, nil, nil, iteratorEnd } parent := it.stack[len(it.stack)-1] ancestor := parent.hash @@ -209,63 +215,76 @@ outer: ancestor = parent.parent } if node, ok := parent.node.(*fullNode); ok { - // Full node, iterate over children - for parent.child++; parent.child < len(node.Children); parent.child++ { - child := node.Children[parent.child] + // Full node, move to the first non-nil child. + for i := parent.index + 1; i < len(node.Children); i++ { + child := node.Children[i] if child != nil { hash, _ := child.cache() - it.stack = append(it.stack, &nodeIteratorState{ + state := &nodeIteratorState{ hash: common.BytesToHash(hash), node: child, parent: ancestor, - child: -1, + index: -1, pathlen: len(it.path), - }) - it.path = append(it.path, byte(parent.child)) - break outer + } + path := append(it.path, byte(i)) + parent.index = i - 1 + return state, &parent.index, path, nil } } } else if node, ok := parent.node.(*shortNode); ok { // Short node, return the pointer singleton child - if parent.child < 0 { - parent.child++ + if parent.index < 0 { hash, _ := node.Val.cache() - it.stack = append(it.stack, &nodeIteratorState{ + state := &nodeIteratorState{ hash: common.BytesToHash(hash), node: node.Val, parent: ancestor, - child: -1, + index: -1, pathlen: len(it.path), - }) + } + var path []byte if hasTerm(node.Key) { - it.path = append(it.path, node.Key[:len(node.Key)-1]...) + path = append(it.path, node.Key[:len(node.Key)-1]...) } else { - it.path = append(it.path, node.Key...) + path = append(it.path, node.Key...) } - break + return state, &parent.index, path, nil } } else if hash, ok := parent.node.(hashNode); ok { // Hash node, resolve the hash child from the database - if parent.child < 0 { - parent.child++ + if parent.index < 0 { node, err := it.trie.resolveHash(hash, nil, nil) if err != nil { - return err + return it.stack[len(it.stack)-1], &parent.index, it.path, err } - it.stack = append(it.stack, &nodeIteratorState{ + state := &nodeIteratorState{ hash: common.BytesToHash(hash), node: node, parent: ancestor, - child: -1, + index: -1, pathlen: len(it.path), - }) - break + } + return state, &parent.index, it.path, nil } } - it.path = it.path[:parent.pathlen] - it.stack = it.stack[:len(it.stack)-1] + // No more child nodes, move back up. + it.pop() } - return nil +} + +func (it *nodeIterator) push(state *nodeIteratorState, parentIndex *int, path []byte) { + it.path = path + it.stack = append(it.stack, state) + if parentIndex != nil { + *parentIndex += 1 + } +} + +func (it *nodeIterator) pop() { + parent := it.stack[len(it.stack)-1] + it.path = it.path[:parent.pathlen] + it.stack = it.stack[:len(it.stack)-1] } func compareNodes(a, b NodeIterator) int { diff --git a/trie/iterator_test.go b/trie/iterator_test.go index c101bb7b0..f161fd99d 100644 --- a/trie/iterator_test.go +++ b/trie/iterator_test.go @@ -17,6 +17,8 @@ package trie import ( + "bytes" + "fmt" "testing" "github.com/ethereum/go-ethereum/common" @@ -42,7 +44,7 @@ func TestIterator(t *testing.T) { trie.Commit() found := make(map[string]string) - it := NewIterator(trie) + it := NewIterator(trie.NodeIterator(nil)) for it.Next() { found[string(it.Key)] = string(it.Value) } @@ -72,7 +74,7 @@ func TestIteratorLargeData(t *testing.T) { vals[string(value2.k)] = value2 } - it := NewIterator(trie) + it := NewIterator(trie.NodeIterator(nil)) for it.Next() { vals[string(it.Key)].t = true } @@ -99,7 +101,7 @@ func TestNodeIteratorCoverage(t *testing.T) { // Gather all the node hashes found by the iterator hashes := make(map[common.Hash]struct{}) - for it := NewNodeIterator(trie); it.Next(true); { + for it := trie.NodeIterator(nil); it.Next(true); { if it.Hash() != (common.Hash{}) { hashes[it.Hash()] = struct{}{} } @@ -117,18 +119,20 @@ func TestNodeIteratorCoverage(t *testing.T) { } } -var testdata1 = []struct{ k, v string }{ - {"bar", "b"}, +type kvs struct{ k, v string } + +var testdata1 = []kvs{ {"barb", "ba"}, - {"bars", "bb"}, {"bard", "bc"}, + {"bars", "bb"}, + {"bar", "b"}, {"fab", "z"}, - {"foo", "a"}, {"food", "ab"}, {"foos", "aa"}, + {"foo", "a"}, } -var testdata2 = []struct{ k, v string }{ +var testdata2 = []kvs{ {"aardvark", "c"}, {"bar", "b"}, {"barb", "bd"}, @@ -140,6 +144,47 @@ var testdata2 = []struct{ k, v string }{ {"jars", "d"}, } +func TestIteratorSeek(t *testing.T) { + trie := newEmpty() + for _, val := range testdata1 { + trie.Update([]byte(val.k), []byte(val.v)) + } + + // Seek to the middle. + it := NewIterator(trie.NodeIterator([]byte("fab"))) + if err := checkIteratorOrder(testdata1[4:], it); err != nil { + t.Fatal(err) + } + + // Seek to a non-existent key. + it = NewIterator(trie.NodeIterator([]byte("barc"))) + if err := checkIteratorOrder(testdata1[1:], it); err != nil { + t.Fatal(err) + } + + // Seek beyond the end. + it = NewIterator(trie.NodeIterator([]byte("z"))) + if err := checkIteratorOrder(nil, it); err != nil { + t.Fatal(err) + } +} + +func checkIteratorOrder(want []kvs, it *Iterator) error { + for it.Next() { + if len(want) == 0 { + return fmt.Errorf("didn't expect any more values, got key %q", it.Key) + } + if !bytes.Equal(it.Key, []byte(want[0].k)) { + return fmt.Errorf("wrong key: got %q, want %q", it.Key, want[0].k) + } + want = want[1:] + } + if len(want) > 0 { + return fmt.Errorf("iterator ended early, want key %q", want[0]) + } + return nil +} + func TestDifferenceIterator(t *testing.T) { triea := newEmpty() for _, val := range testdata1 { @@ -154,8 +199,8 @@ func TestDifferenceIterator(t *testing.T) { trieb.Commit() found := make(map[string]string) - di, _ := NewDifferenceIterator(NewNodeIterator(triea), NewNodeIterator(trieb)) - it := NewIteratorFromNodeIterator(di) + di, _ := NewDifferenceIterator(triea.NodeIterator(nil), trieb.NodeIterator(nil)) + it := NewIterator(di) for it.Next() { found[string(it.Key)] = string(it.Value) } @@ -189,8 +234,8 @@ func TestUnionIterator(t *testing.T) { } trieb.Commit() - di, _ := NewUnionIterator([]NodeIterator{NewNodeIterator(triea), NewNodeIterator(trieb)}) - it := NewIteratorFromNodeIterator(di) + di, _ := NewUnionIterator([]NodeIterator{triea.NodeIterator(nil), trieb.NodeIterator(nil)}) + it := NewIterator(di) all := []struct{ k, v string }{ {"aardvark", "c"}, diff --git a/trie/node.go b/trie/node.go index 4aa0cab65..a7697fc0c 100644 --- a/trie/node.go +++ b/trie/node.go @@ -139,8 +139,8 @@ func decodeShort(hash, buf, elems []byte, cachegen uint16) (node, error) { return nil, err } flag := nodeFlag{hash: hash, gen: cachegen} - key := compactDecode(kbuf) - if key[len(key)-1] == 16 { + key := compactToHex(kbuf) + if hasTerm(key) { // value node val, _, err := rlp.SplitString(rest) if err != nil { diff --git a/trie/proof.go b/trie/proof.go index 06cf827ab..fb7734b86 100644 --- a/trie/proof.go +++ b/trie/proof.go @@ -38,7 +38,7 @@ import ( // absence of the key. func (t *Trie) Prove(key []byte) []rlp.RawValue { // Collect all nodes on the path to key. - key = compactHexDecode(key) + key = keybytesToHex(key) nodes := []node{} tn := t.root for len(key) > 0 && tn != nil { @@ -89,7 +89,7 @@ func (t *Trie) Prove(key []byte) []rlp.RawValue { // returns an error if the proof contains invalid trie nodes or the // wrong value. func VerifyProof(rootHash common.Hash, key []byte, proof []rlp.RawValue) (value []byte, err error) { - key = compactHexDecode(key) + key = keybytesToHex(key) sha := sha3.NewKeccak256() wantHash := rootHash.Bytes() for i, buf := range proof { diff --git a/trie/secure_trie.go b/trie/secure_trie.go index 113fb6a1a..37d1d4b09 100644 --- a/trie/secure_trie.go +++ b/trie/secure_trie.go @@ -156,12 +156,10 @@ func (t *SecureTrie) Root() []byte { return t.trie.Root() } -func (t *SecureTrie) Iterator() *Iterator { - return t.trie.Iterator() -} - -func (t *SecureTrie) NodeIterator() NodeIterator { - return NewNodeIterator(&t.trie) +// NodeIterator returns an iterator that returns nodes of the underlying trie. Iteration +// starts at the key after the given start key. +func (t *SecureTrie) NodeIterator(start []byte) NodeIterator { + return t.trie.NodeIterator(start) } // CommitTo writes all nodes and the secure hash pre-images to the given database. diff --git a/trie/sync_test.go b/trie/sync_test.go index acae039cd..1e27cbb67 100644 --- a/trie/sync_test.go +++ b/trie/sync_test.go @@ -80,7 +80,7 @@ func checkTrieConsistency(db Database, root common.Hash) error { if err != nil { return nil // // Consider a non existent state consistent } - it := NewNodeIterator(trie) + it := trie.NodeIterator(nil) for it.Next(true) { } return it.Error() diff --git a/trie/trie.go b/trie/trie.go index 2a6044068..5759f97e3 100644 --- a/trie/trie.go +++ b/trie/trie.go @@ -125,9 +125,10 @@ func New(root common.Hash, db Database) (*Trie, error) { return trie, nil } -// Iterator returns an iterator over all mappings in the trie. -func (t *Trie) Iterator() *Iterator { - return NewIterator(t) +// NodeIterator returns an iterator that returns nodes of the trie. Iteration starts at +// the key after the given start key. +func (t *Trie) NodeIterator(start []byte) NodeIterator { + return newNodeIterator(t, start) } // Get returns the value for key stored in the trie. @@ -144,7 +145,7 @@ func (t *Trie) Get(key []byte) []byte { // The value bytes must not be modified by the caller. // If a node was not found in the database, a MissingNodeError is returned. func (t *Trie) TryGet(key []byte) ([]byte, error) { - key = compactHexDecode(key) + key = keybytesToHex(key) value, newroot, didResolve, err := t.tryGet(t.root, key, 0) if err == nil && didResolve { t.root = newroot @@ -211,7 +212,7 @@ func (t *Trie) Update(key, value []byte) { // // If a node was not found in the database, a MissingNodeError is returned. func (t *Trie) TryUpdate(key, value []byte) error { - k := compactHexDecode(key) + k := keybytesToHex(key) if len(value) != 0 { _, n, err := t.insert(t.root, nil, k, valueNode(value)) if err != nil { @@ -307,7 +308,7 @@ func (t *Trie) Delete(key []byte) { // TryDelete removes any existing value for key from the trie. // If a node was not found in the database, a MissingNodeError is returned. func (t *Trie) TryDelete(key []byte) error { - k := compactHexDecode(key) + k := keybytesToHex(key) _, n, err := t.delete(t.root, nil, k) if err != nil { return err @@ -450,7 +451,6 @@ func (t *Trie) resolveHash(n hashNode, prefix, suffix []byte) (node, error) { return nil, &MissingNodeError{ RootHash: t.originalRoot, NodeHash: common.BytesToHash(n), - Key: compactHexEncode(append(prefix, suffix...)), PrefixLen: len(prefix), SuffixLen: len(suffix), } diff --git a/trie/trie_test.go b/trie/trie_test.go index 01ae3a4e7..61adbba0c 100644 --- a/trie/trie_test.go +++ b/trie/trie_test.go @@ -439,7 +439,7 @@ func runRandTest(rt randTest) bool { tr = newtr case opItercheckhash: checktr, _ := New(common.Hash{}, nil) - it := tr.Iterator() + it := NewIterator(tr.NodeIterator(nil)) for it.Next() { checktr.Update(it.Key, it.Value) } |