aboutsummaryrefslogtreecommitdiffstats
path: root/trie
diff options
context:
space:
mode:
authorJeffrey Wilcke <jeffrey@ethereum.org>2016-10-19 19:35:49 +0800
committerGitHub <noreply@github.com>2016-10-19 19:35:49 +0800
commit25ac04a444d82f42138fc06e651c1ef9bac935dc (patch)
tree4b4925de834f243a05c73e661e77bc60287aeb9d /trie
parent8e52c2e754cdb343d0eb880a33251e1ba593d327 (diff)
parent8d56bf5ceb74a7ed45c986450848a89e2df61189 (diff)
downloaddexon-25ac04a444d82f42138fc06e651c1ef9bac935dc.tar.gz
dexon-25ac04a444d82f42138fc06e651c1ef9bac935dc.tar.zst
dexon-25ac04a444d82f42138fc06e651c1ef9bac935dc.zip
Merge pull request #3153 from fjl/trie-unload-fix
trie: improve cache unloading mechanism
Diffstat (limited to 'trie')
-rw-r--r--trie/hasher.go33
-rw-r--r--trie/node.go26
-rw-r--r--trie/proof.go2
-rw-r--r--trie/sync.go6
-rw-r--r--trie/trie.go21
-rw-r--r--trie/trie_test.go104
6 files changed, 126 insertions, 66 deletions
diff --git a/trie/hasher.go b/trie/hasher.go
index e395e00d7..b6223bf32 100644
--- a/trie/hasher.go
+++ b/trie/hasher.go
@@ -58,7 +58,7 @@ func (h *hasher) hash(n node, db DatabaseWriter, force bool) (node, node, error)
return hash, n, nil
}
if n.canUnload(h.cachegen, h.cachelimit) {
- // Evict the node from cache. All of its subnodes will have a lower or equal
+ // Unload the node from cache. All of its subnodes will have a lower or equal
// cache generation number.
return hash, hash, nil
}
@@ -75,23 +75,20 @@ func (h *hasher) hash(n node, db DatabaseWriter, force bool) (node, node, error)
if err != nil {
return hashNode{}, n, err
}
- // Cache the hash of the ndoe for later reuse.
- if hash, ok := hashed.(hashNode); ok && !force {
- switch cached := cached.(type) {
- case *shortNode:
- cached = cached.copy()
- cached.flags.hash = hash
- if db != nil {
- cached.flags.dirty = false
- }
- return hashed, cached, nil
- case *fullNode:
- cached = cached.copy()
- cached.flags.hash = hash
- if db != nil {
- cached.flags.dirty = false
- }
- return hashed, cached, nil
+ // Cache the hash of the ndoe for later reuse and remove
+ // the dirty flag in commit mode. It's fine to assign these values directly
+ // without copying the node first because hashChildren copies it.
+ cachedHash, _ := hashed.(hashNode)
+ switch cn := cached.(type) {
+ case *shortNode:
+ cn.flags.hash = cachedHash
+ if db != nil {
+ cn.flags.dirty = false
+ }
+ case *fullNode:
+ cn.flags.hash = cachedHash
+ if db != nil {
+ cn.flags.dirty = false
}
}
return hashed, cached, nil
diff --git a/trie/node.go b/trie/node.go
index de9752c93..4aa0cab65 100644
--- a/trie/node.go
+++ b/trie/node.go
@@ -104,8 +104,8 @@ func (n valueNode) fstring(ind string) string {
return fmt.Sprintf("%x ", []byte(n))
}
-func mustDecodeNode(hash, buf []byte) node {
- n, err := decodeNode(hash, buf)
+func mustDecodeNode(hash, buf []byte, cachegen uint16) node {
+ n, err := decodeNode(hash, buf, cachegen)
if err != nil {
panic(fmt.Sprintf("node %x: %v", hash, err))
}
@@ -113,7 +113,7 @@ func mustDecodeNode(hash, buf []byte) node {
}
// decodeNode parses the RLP encoding of a trie node.
-func decodeNode(hash, buf []byte) (node, error) {
+func decodeNode(hash, buf []byte, cachegen uint16) (node, error) {
if len(buf) == 0 {
return nil, io.ErrUnexpectedEOF
}
@@ -123,22 +123,22 @@ func decodeNode(hash, buf []byte) (node, error) {
}
switch c, _ := rlp.CountValues(elems); c {
case 2:
- n, err := decodeShort(hash, buf, elems)
+ n, err := decodeShort(hash, buf, elems, cachegen)
return n, wrapError(err, "short")
case 17:
- n, err := decodeFull(hash, buf, elems)
+ n, err := decodeFull(hash, buf, elems, cachegen)
return n, wrapError(err, "full")
default:
return nil, fmt.Errorf("invalid number of list elements: %v", c)
}
}
-func decodeShort(hash, buf, elems []byte) (node, error) {
+func decodeShort(hash, buf, elems []byte, cachegen uint16) (node, error) {
kbuf, rest, err := rlp.SplitString(elems)
if err != nil {
return nil, err
}
- flag := nodeFlag{hash: hash}
+ flag := nodeFlag{hash: hash, gen: cachegen}
key := compactDecode(kbuf)
if key[len(key)-1] == 16 {
// value node
@@ -148,17 +148,17 @@ func decodeShort(hash, buf, elems []byte) (node, error) {
}
return &shortNode{key, append(valueNode{}, val...), flag}, nil
}
- r, _, err := decodeRef(rest)
+ r, _, err := decodeRef(rest, cachegen)
if err != nil {
return nil, wrapError(err, "val")
}
return &shortNode{key, r, flag}, nil
}
-func decodeFull(hash, buf, elems []byte) (*fullNode, error) {
- n := &fullNode{flags: nodeFlag{hash: hash}}
+func decodeFull(hash, buf, elems []byte, cachegen uint16) (*fullNode, error) {
+ n := &fullNode{flags: nodeFlag{hash: hash, gen: cachegen}}
for i := 0; i < 16; i++ {
- cld, rest, err := decodeRef(elems)
+ cld, rest, err := decodeRef(elems, cachegen)
if err != nil {
return n, wrapError(err, fmt.Sprintf("[%d]", i))
}
@@ -176,7 +176,7 @@ func decodeFull(hash, buf, elems []byte) (*fullNode, error) {
const hashLen = len(common.Hash{})
-func decodeRef(buf []byte) (node, []byte, error) {
+func decodeRef(buf []byte, cachegen uint16) (node, []byte, error) {
kind, val, rest, err := rlp.Split(buf)
if err != nil {
return nil, buf, err
@@ -189,7 +189,7 @@ func decodeRef(buf []byte) (node, []byte, error) {
err := fmt.Errorf("oversized embedded node (size is %d bytes, want size < %d)", size, hashLen)
return nil, buf, err
}
- n, err := decodeNode(nil, buf)
+ n, err := decodeNode(nil, buf, cachegen)
return n, rest, err
case kind == rlp.String && len(val) == 0:
// empty node
diff --git a/trie/proof.go b/trie/proof.go
index f193b52df..bea5e5c09 100644
--- a/trie/proof.go
+++ b/trie/proof.go
@@ -101,7 +101,7 @@ func VerifyProof(rootHash common.Hash, key []byte, proof []rlp.RawValue) (value
if !bytes.Equal(sha.Sum(nil), wantHash) {
return nil, fmt.Errorf("bad proof node %d: hash mismatch", i)
}
- n, err := decodeNode(wantHash, buf)
+ n, err := decodeNode(wantHash, buf, 0)
if err != nil {
return nil, fmt.Errorf("bad proof node %d: %v", i, err)
}
diff --git a/trie/sync.go b/trie/sync.go
index 400dff903..30caf6980 100644
--- a/trie/sync.go
+++ b/trie/sync.go
@@ -82,7 +82,7 @@ func (s *TrieSync) AddSubTrie(root common.Hash, depth int, parent common.Hash, c
}
key := root.Bytes()
blob, _ := s.database.Get(key)
- if local, err := decodeNode(key, blob); local != nil && err == nil {
+ if local, err := decodeNode(key, blob, 0); local != nil && err == nil {
return
}
// Assemble the new sub-trie sync request
@@ -158,7 +158,7 @@ func (s *TrieSync) Process(results []SyncResult) (int, error) {
continue
}
// Decode the node data content and update the request
- node, err := decodeNode(item.Hash[:], item.Data)
+ node, err := decodeNode(item.Hash[:], item.Data, 0)
if err != nil {
return i, err
}
@@ -246,7 +246,7 @@ func (s *TrieSync) children(req *request) ([]*request, error) {
if node, ok := (*child.node).(hashNode); ok {
// Try to resolve the node from the local database
blob, _ := s.database.Get(node)
- if local, err := decodeNode(node[:], blob); local != nil && err == nil {
+ if local, err := decodeNode(node[:], blob, 0); local != nil && err == nil {
*child.node = local
continue
}
diff --git a/trie/trie.go b/trie/trie.go
index 5a4b6185d..914bf20fa 100644
--- a/trie/trie.go
+++ b/trie/trie.go
@@ -105,13 +105,11 @@ func New(root common.Hash, db Database) (*Trie, error) {
if db == nil {
panic("trie.New: cannot use existing root without a database")
}
- if v, _ := trie.db.Get(root[:]); len(v) == 0 {
- return nil, &MissingNodeError{
- RootHash: root,
- NodeHash: root,
- }
+ rootnode, err := trie.resolveHash(root[:], nil, nil)
+ if err != nil {
+ return nil, err
}
- trie.root = hashNode(root.Bytes())
+ trie.root = rootnode
}
return trie, nil
}
@@ -158,14 +156,15 @@ func (t *Trie) tryGet(origNode node, key []byte, pos int) (value []byte, newnode
if err == nil && didResolve {
n = n.copy()
n.Val = newnode
+ n.flags.gen = t.cachegen
}
return value, n, didResolve, err
case *fullNode:
value, newnode, didResolve, err = t.tryGet(n.Children[key[pos]], key, pos+1)
if err == nil && didResolve {
n = n.copy()
+ n.flags.gen = t.cachegen
n.Children[key[pos]] = newnode
-
}
return value, n, didResolve, err
case hashNode:
@@ -261,7 +260,8 @@ func (t *Trie) insert(n node, prefix, key []byte, value node) (bool, node, error
return false, n, err
}
n = n.copy()
- n.Children[key[0]], n.flags.hash, n.flags.dirty = nn, nil, true
+ n.flags = t.newFlag()
+ n.Children[key[0]] = nn
return true, n, nil
case nil:
@@ -345,7 +345,8 @@ func (t *Trie) delete(n node, prefix, key []byte) (bool, node, error) {
return false, n, err
}
n = n.copy()
- n.Children[key[0]], n.flags.hash, n.flags.dirty = nn, nil, true
+ n.flags = t.newFlag()
+ n.Children[key[0]] = nn
// Check how many non-nil entries are left after deleting and
// reduce the full node to a short node if only one entry is
@@ -443,7 +444,7 @@ func (t *Trie) resolveHash(n hashNode, prefix, suffix []byte) (node, error) {
SuffixLen: len(suffix),
}
}
- dec := mustDecodeNode(n, enc)
+ dec := mustDecodeNode(n, enc, t.cachegen)
return dec, nil
}
diff --git a/trie/trie_test.go b/trie/trie_test.go
index 32fbe6801..14ac5a666 100644
--- a/trie/trie_test.go
+++ b/trie/trie_test.go
@@ -300,25 +300,6 @@ func TestReplication(t *testing.T) {
}
}
-// Not an actual test
-func TestOutput(t *testing.T) {
- t.Skip()
-
- base := "aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa"
- trie := newEmpty()
- for i := 0; i < 50; i++ {
- updateString(trie, fmt.Sprintf("%s%d", base, i), "valueeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeee")
- }
- fmt.Println("############################## FULL ################################")
- fmt.Println(trie.root)
-
- trie.Commit()
- fmt.Println("############################## SMALL ################################")
- trie2, _ := New(trie.Hash(), trie.db)
- getString(trie2, base+"20")
- fmt.Println(trie2.root)
-}
-
func TestLargeValue(t *testing.T) {
trie := newEmpty()
trie.Update([]byte("key1"), []byte{99, 99, 99, 99})
@@ -326,14 +307,56 @@ func TestLargeValue(t *testing.T) {
trie.Hash()
}
+type countingDB struct {
+ Database
+ gets map[string]int
+}
+
+func (db *countingDB) Get(key []byte) ([]byte, error) {
+ db.gets[string(key)]++
+ return db.Database.Get(key)
+}
+
+// TestCacheUnload checks that decoded nodes are unloaded after a
+// certain number of commit operations.
+func TestCacheUnload(t *testing.T) {
+ // Create test trie with two branches.
+ trie := newEmpty()
+ key1 := "---------------------------------"
+ key2 := "---some other branch"
+ updateString(trie, key1, "this is the branch of key1.")
+ updateString(trie, key2, "this is the branch of key2.")
+ root, _ := trie.Commit()
+
+ // Commit the trie repeatedly and access key1.
+ // The branch containing it is loaded from DB exactly two times:
+ // in the 0th and 6th iteration.
+ db := &countingDB{Database: trie.db, gets: make(map[string]int)}
+ trie, _ = New(root, db)
+ trie.SetCacheLimit(5)
+ for i := 0; i < 12; i++ {
+ getString(trie, key1)
+ trie.Commit()
+ }
+
+ // Check that it got loaded two times.
+ for dbkey, count := range db.gets {
+ if count != 2 {
+ t.Errorf("db key %x loaded %d times, want %d times", []byte(dbkey), count, 2)
+ }
+ }
+}
+
+// randTest performs random trie operations.
+// Instances of this test are created by Generate.
+type randTest []randTestStep
+
type randTestStep struct {
op int
key []byte // for opUpdate, opDelete, opGet
value []byte // for opUpdate
}
-type randTest []randTestStep
-
const (
opUpdate = iota
opDelete
@@ -342,6 +365,7 @@ const (
opHash
opReset
opItercheckhash
+ opCheckCacheInvariant
opMax // boundary value, not an actual op
)
@@ -437,6 +461,44 @@ func runRandTest(rt randTest) bool {
fmt.Println("hashes not equal")
return false
}
+ case opCheckCacheInvariant:
+ return checkCacheInvariant(tr.root, nil, tr.cachegen, false, 0)
+ }
+ }
+ return true
+}
+
+func checkCacheInvariant(n, parent node, parentCachegen uint16, parentDirty bool, depth int) bool {
+ var children []node
+ var flag nodeFlag
+ switch n := n.(type) {
+ case *shortNode:
+ flag = n.flags
+ children = []node{n.Val}
+ case *fullNode:
+ flag = n.flags
+ children = n.Children[:]
+ default:
+ return true
+ }
+
+ showerror := func() {
+ fmt.Printf("at depth %d node %s", depth, spew.Sdump(n))
+ fmt.Printf("parent: %s", spew.Sdump(parent))
+ }
+ if flag.gen > parentCachegen {
+ fmt.Printf("cache invariant violation: %d > %d\n", flag.gen, parentCachegen)
+ showerror()
+ return false
+ }
+ if depth > 0 && !parentDirty && flag.dirty {
+ fmt.Printf("cache invariant violation: child is dirty but parent isn't\n")
+ showerror()
+ return false
+ }
+ for _, child := range children {
+ if !checkCacheInvariant(child, n, flag.gen, flag.dirty, depth+1) {
+ return false
}
}
return true