diff options
Diffstat (limited to 'trie')
-rw-r--r-- | trie/database.go | 355 | ||||
-rw-r--r-- | trie/hasher.go | 61 | ||||
-rw-r--r-- | trie/iterator_test.go | 125 | ||||
-rw-r--r-- | trie/proof.go | 47 | ||||
-rw-r--r-- | trie/secure_trie.go | 62 | ||||
-rw-r--r-- | trie/secure_trie_test.go | 20 | ||||
-rw-r--r-- | trie/sync.go | 14 | ||||
-rw-r--r-- | trie/sync_test.go | 103 | ||||
-rw-r--r-- | trie/trie.go | 90 | ||||
-rw-r--r-- | trie/trie_test.go | 104 |
10 files changed, 709 insertions, 272 deletions
diff --git a/trie/database.go b/trie/database.go new file mode 100644 index 000000000..d79120813 --- /dev/null +++ b/trie/database.go @@ -0,0 +1,355 @@ +// Copyright 2017 The go-ethereum Authors +// This file is part of the go-ethereum library. +// +// The go-ethereum library is free software: you can redistribute it and/or modify +// it under the terms of the GNU Lesser General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// The go-ethereum library is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Lesser General Public License for more details. +// +// You should have received a copy of the GNU Lesser General Public License +// along with the go-ethereum library. If not, see <http://www.gnu.org/licenses/>. + +package trie + +import ( + "sync" + "time" + + "github.com/ethereum/go-ethereum/common" + "github.com/ethereum/go-ethereum/ethdb" + "github.com/ethereum/go-ethereum/log" +) + +// secureKeyPrefix is the database key prefix used to store trie node preimages. +var secureKeyPrefix = []byte("secure-key-") + +// secureKeyLength is the length of the above prefix + 32byte hash. +const secureKeyLength = 11 + 32 + +// DatabaseReader wraps the Get and Has method of a backing store for the trie. +type DatabaseReader interface { + // Get retrieves the value associated with key form the database. + Get(key []byte) (value []byte, err error) + + // Has retrieves whether a key is present in the database. + Has(key []byte) (bool, error) +} + +// Database is an intermediate write layer between the trie data structures and +// the disk database. The aim is to accumulate trie writes in-memory and only +// periodically flush a couple tries to disk, garbage collecting the remainder. +type Database struct { + diskdb ethdb.Database // Persistent storage for matured trie nodes + + nodes map[common.Hash]*cachedNode // Data and references relationships of a node + preimages map[common.Hash][]byte // Preimages of nodes from the secure trie + seckeybuf [secureKeyLength]byte // Ephemeral buffer for calculating preimage keys + + gctime time.Duration // Time spent on garbage collection since last commit + gcnodes uint64 // Nodes garbage collected since last commit + gcsize common.StorageSize // Data storage garbage collected since last commit + + nodesSize common.StorageSize // Storage size of the nodes cache + preimagesSize common.StorageSize // Storage size of the preimages cache + + lock sync.RWMutex +} + +// cachedNode is all the information we know about a single cached node in the +// memory database write layer. +type cachedNode struct { + blob []byte // Cached data block of the trie node + parents int // Number of live nodes referencing this one + children map[common.Hash]int // Children referenced by this nodes +} + +// NewDatabase creates a new trie database to store ephemeral trie content before +// its written out to disk or garbage collected. +func NewDatabase(diskdb ethdb.Database) *Database { + return &Database{ + diskdb: diskdb, + nodes: map[common.Hash]*cachedNode{ + {}: {children: make(map[common.Hash]int)}, + }, + preimages: make(map[common.Hash][]byte), + } +} + +// DiskDB retrieves the persistent storage backing the trie database. +func (db *Database) DiskDB() DatabaseReader { + return db.diskdb +} + +// Insert writes a new trie node to the memory database if it's yet unknown. The +// method will make a copy of the slice. +func (db *Database) Insert(hash common.Hash, blob []byte) { + db.lock.Lock() + defer db.lock.Unlock() + + db.insert(hash, blob) +} + +// insert is the private locked version of Insert. +func (db *Database) insert(hash common.Hash, blob []byte) { + if _, ok := db.nodes[hash]; ok { + return + } + db.nodes[hash] = &cachedNode{ + blob: common.CopyBytes(blob), + children: make(map[common.Hash]int), + } + db.nodesSize += common.StorageSize(common.HashLength + len(blob)) +} + +// insertPreimage writes a new trie node pre-image to the memory database if it's +// yet unknown. The method will make a copy of the slice. +// +// Note, this method assumes that the database's lock is held! +func (db *Database) insertPreimage(hash common.Hash, preimage []byte) { + if _, ok := db.preimages[hash]; ok { + return + } + db.preimages[hash] = common.CopyBytes(preimage) + db.preimagesSize += common.StorageSize(common.HashLength + len(preimage)) +} + +// Node retrieves a cached trie node from memory. If it cannot be found cached, +// the method queries the persistent database for the content. +func (db *Database) Node(hash common.Hash) ([]byte, error) { + // Retrieve the node from cache if available + db.lock.RLock() + node := db.nodes[hash] + db.lock.RUnlock() + + if node != nil { + return node.blob, nil + } + // Content unavailable in memory, attempt to retrieve from disk + return db.diskdb.Get(hash[:]) +} + +// preimage retrieves a cached trie node pre-image from memory. If it cannot be +// found cached, the method queries the persistent database for the content. +func (db *Database) preimage(hash common.Hash) ([]byte, error) { + // Retrieve the node from cache if available + db.lock.RLock() + preimage := db.preimages[hash] + db.lock.RUnlock() + + if preimage != nil { + return preimage, nil + } + // Content unavailable in memory, attempt to retrieve from disk + return db.diskdb.Get(db.secureKey(hash[:])) +} + +// secureKey returns the database key for the preimage of key, as an ephemeral +// buffer. The caller must not hold onto the return value because it will become +// invalid on the next call. +func (db *Database) secureKey(key []byte) []byte { + buf := append(db.seckeybuf[:0], secureKeyPrefix...) + buf = append(buf, key...) + return buf +} + +// Nodes retrieves the hashes of all the nodes cached within the memory database. +// This method is extremely expensive and should only be used to validate internal +// states in test code. +func (db *Database) Nodes() []common.Hash { + db.lock.RLock() + defer db.lock.RUnlock() + + var hashes = make([]common.Hash, 0, len(db.nodes)) + for hash := range db.nodes { + if hash != (common.Hash{}) { // Special case for "root" references/nodes + hashes = append(hashes, hash) + } + } + return hashes +} + +// Reference adds a new reference from a parent node to a child node. +func (db *Database) Reference(child common.Hash, parent common.Hash) { + db.lock.RLock() + defer db.lock.RUnlock() + + db.reference(child, parent) +} + +// reference is the private locked version of Reference. +func (db *Database) reference(child common.Hash, parent common.Hash) { + // If the node does not exist, it's a node pulled from disk, skip + node, ok := db.nodes[child] + if !ok { + return + } + // If the reference already exists, only duplicate for roots + if _, ok = db.nodes[parent].children[child]; ok && parent != (common.Hash{}) { + return + } + node.parents++ + db.nodes[parent].children[child]++ +} + +// Dereference removes an existing reference from a parent node to a child node. +func (db *Database) Dereference(child common.Hash, parent common.Hash) { + db.lock.Lock() + defer db.lock.Unlock() + + nodes, storage, start := len(db.nodes), db.nodesSize, time.Now() + db.dereference(child, parent) + + db.gcnodes += uint64(nodes - len(db.nodes)) + db.gcsize += storage - db.nodesSize + db.gctime += time.Since(start) + + log.Debug("Dereferenced trie from memory database", "nodes", nodes-len(db.nodes), "size", storage-db.nodesSize, "time", time.Since(start), + "gcnodes", db.gcnodes, "gcsize", db.gcsize, "gctime", db.gctime, "livenodes", len(db.nodes), "livesize", db.nodesSize) +} + +// dereference is the private locked version of Dereference. +func (db *Database) dereference(child common.Hash, parent common.Hash) { + // Dereference the parent-child + node := db.nodes[parent] + + node.children[child]-- + if node.children[child] == 0 { + delete(node.children, child) + } + // If the node does not exist, it's a previously committed node. + node, ok := db.nodes[child] + if !ok { + return + } + // If there are no more references to the child, delete it and cascade + node.parents-- + if node.parents == 0 { + for hash := range node.children { + db.dereference(hash, child) + } + delete(db.nodes, child) + db.nodesSize -= common.StorageSize(common.HashLength + len(node.blob)) + } +} + +// Commit iterates over all the children of a particular node, writes them out +// to disk, forcefully tearing down all references in both directions. +// +// As a side effect, all pre-images accumulated up to this point are also written. +func (db *Database) Commit(node common.Hash, report bool) error { + // Create a database batch to flush persistent data out. It is important that + // outside code doesn't see an inconsistent state (referenced data removed from + // memory cache during commit but not yet in persistent storage). This is ensured + // by only uncaching existing data when the database write finalizes. + db.lock.RLock() + + start := time.Now() + batch := db.diskdb.NewBatch() + + // Move all of the accumulated preimages into a write batch + for hash, preimage := range db.preimages { + if err := batch.Put(db.secureKey(hash[:]), preimage); err != nil { + log.Error("Failed to commit preimage from trie database", "err", err) + db.lock.RUnlock() + return err + } + if batch.ValueSize() > ethdb.IdealBatchSize { + if err := batch.Write(); err != nil { + return err + } + batch.Reset() + } + } + // Move the trie itself into the batch, flushing if enough data is accumulated + nodes, storage := len(db.nodes), db.nodesSize+db.preimagesSize + if err := db.commit(node, batch); err != nil { + log.Error("Failed to commit trie from trie database", "err", err) + db.lock.RUnlock() + return err + } + // Write batch ready, unlock for readers during persistence + if err := batch.Write(); err != nil { + log.Error("Failed to write trie to disk", "err", err) + db.lock.RUnlock() + return err + } + db.lock.RUnlock() + + // Write successful, clear out the flushed data + db.lock.Lock() + defer db.lock.Unlock() + + db.preimages = make(map[common.Hash][]byte) + db.preimagesSize = 0 + + db.uncache(node) + + logger := log.Info + if !report { + logger = log.Debug + } + logger("Persisted trie from memory database", "nodes", nodes-len(db.nodes), "size", storage-db.nodesSize, "time", time.Since(start), + "gcnodes", db.gcnodes, "gcsize", db.gcsize, "gctime", db.gctime, "livenodes", len(db.nodes), "livesize", db.nodesSize) + + // Reset the garbage collection statistics + db.gcnodes, db.gcsize, db.gctime = 0, 0, 0 + + return nil +} + +// commit is the private locked version of Commit. +func (db *Database) commit(hash common.Hash, batch ethdb.Batch) error { + // If the node does not exist, it's a previously committed node + node, ok := db.nodes[hash] + if !ok { + return nil + } + for child := range node.children { + if err := db.commit(child, batch); err != nil { + return err + } + } + if err := batch.Put(hash[:], node.blob); err != nil { + return err + } + // If we've reached an optimal match size, commit and start over + if batch.ValueSize() >= ethdb.IdealBatchSize { + if err := batch.Write(); err != nil { + return err + } + batch.Reset() + } + return nil +} + +// uncache is the post-processing step of a commit operation where the already +// persisted trie is removed from the cache. The reason behind the two-phase +// commit is to ensure consistent data availability while moving from memory +// to disk. +func (db *Database) uncache(hash common.Hash) { + // If the node does not exist, we're done on this path + node, ok := db.nodes[hash] + if !ok { + return + } + // Otherwise uncache the node's subtries and remove the node itself too + for child := range node.children { + db.uncache(child) + } + delete(db.nodes, hash) + db.nodesSize -= common.StorageSize(common.HashLength + len(node.blob)) +} + +// Size returns the current storage size of the memory cache in front of the +// persistent database layer. +func (db *Database) Size() common.StorageSize { + db.lock.RLock() + defer db.lock.RUnlock() + + return db.nodesSize + db.preimagesSize +} diff --git a/trie/hasher.go b/trie/hasher.go index 4719aabf6..2fc44787a 100644 --- a/trie/hasher.go +++ b/trie/hasher.go @@ -27,21 +27,23 @@ import ( ) type hasher struct { - tmp *bytes.Buffer - sha hash.Hash - cachegen, cachelimit uint16 + tmp *bytes.Buffer + sha hash.Hash + cachegen uint16 + cachelimit uint16 + onleaf LeafCallback } -// hashers live in a global pool. +// hashers live in a global db. var hasherPool = sync.Pool{ New: func() interface{} { return &hasher{tmp: new(bytes.Buffer), sha: sha3.NewKeccak256()} }, } -func newHasher(cachegen, cachelimit uint16) *hasher { +func newHasher(cachegen, cachelimit uint16, onleaf LeafCallback) *hasher { h := hasherPool.Get().(*hasher) - h.cachegen, h.cachelimit = cachegen, cachelimit + h.cachegen, h.cachelimit, h.onleaf = cachegen, cachelimit, onleaf return h } @@ -51,7 +53,7 @@ func returnHasherToPool(h *hasher) { // hash collapses a node down into a hash node, also returning a copy of the // original node initialized with the computed hash to replace the original one. -func (h *hasher) hash(n node, db DatabaseWriter, force bool) (node, node, error) { +func (h *hasher) hash(n node, db *Database, force bool) (node, node, error) { // If we're not storing the node, just hashing, use available cached data if hash, dirty := n.cache(); hash != nil { if db == nil { @@ -98,7 +100,7 @@ func (h *hasher) hash(n node, db DatabaseWriter, force bool) (node, node, error) // hashChildren replaces the children of a node with their hashes if the encoded // size of the child is larger than a hash, returning the collapsed node as well // as a replacement for the original node with the child hashes cached in. -func (h *hasher) hashChildren(original node, db DatabaseWriter) (node, node, error) { +func (h *hasher) hashChildren(original node, db *Database) (node, node, error) { var err error switch n := original.(type) { @@ -145,7 +147,10 @@ func (h *hasher) hashChildren(original node, db DatabaseWriter) (node, node, err } } -func (h *hasher) store(n node, db DatabaseWriter, force bool) (node, error) { +// store hashes the node n and if we have a storage layer specified, it writes +// the key/value pair to it and tracks any node->child references as well as any +// node->external trie references. +func (h *hasher) store(n node, db *Database, force bool) (node, error) { // Don't store hashes or empty nodes. if _, isHash := n.(hashNode); n == nil || isHash { return n, nil @@ -155,7 +160,6 @@ func (h *hasher) store(n node, db DatabaseWriter, force bool) (node, error) { if err := rlp.Encode(h.tmp, n); err != nil { panic("encode error: " + err.Error()) } - if h.tmp.Len() < 32 && !force { return n, nil // Nodes smaller than 32 bytes are stored inside their parent } @@ -167,7 +171,42 @@ func (h *hasher) store(n node, db DatabaseWriter, force bool) (node, error) { hash = hashNode(h.sha.Sum(nil)) } if db != nil { - return hash, db.Put(hash, h.tmp.Bytes()) + // We are pooling the trie nodes into an intermediate memory cache + db.lock.Lock() + + hash := common.BytesToHash(hash) + db.insert(hash, h.tmp.Bytes()) + + // Track all direct parent->child node references + switch n := n.(type) { + case *shortNode: + if child, ok := n.Val.(hashNode); ok { + db.reference(common.BytesToHash(child), hash) + } + case *fullNode: + for i := 0; i < 16; i++ { + if child, ok := n.Children[i].(hashNode); ok { + db.reference(common.BytesToHash(child), hash) + } + } + } + db.lock.Unlock() + + // Track external references from account->storage trie + if h.onleaf != nil { + switch n := n.(type) { + case *shortNode: + if child, ok := n.Val.(valueNode); ok { + h.onleaf(child, hash) + } + case *fullNode: + for i := 0; i < 16; i++ { + if child, ok := n.Children[i].(valueNode); ok { + h.onleaf(child, hash) + } + } + } + } } return hash, nil } diff --git a/trie/iterator_test.go b/trie/iterator_test.go index 4808d8b0c..dce1c78b5 100644 --- a/trie/iterator_test.go +++ b/trie/iterator_test.go @@ -42,7 +42,7 @@ func TestIterator(t *testing.T) { all[val.k] = val.v trie.Update([]byte(val.k), []byte(val.v)) } - trie.Commit() + trie.Commit(nil) found := make(map[string]string) it := NewIterator(trie.NodeIterator(nil)) @@ -109,11 +109,18 @@ func TestNodeIteratorCoverage(t *testing.T) { } // Cross check the hashes and the database itself for hash := range hashes { - if _, err := db.Get(hash.Bytes()); err != nil { + if _, err := db.Node(hash); err != nil { t.Errorf("failed to retrieve reported node %x: %v", hash, err) } } - for _, key := range db.(*ethdb.MemDatabase).Keys() { + for hash, obj := range db.nodes { + if obj != nil && hash != (common.Hash{}) { + if _, ok := hashes[hash]; !ok { + t.Errorf("state entry not reported %x", hash) + } + } + } + for _, key := range db.diskdb.(*ethdb.MemDatabase).Keys() { if _, ok := hashes[common.BytesToHash(key)]; !ok { t.Errorf("state entry not reported %x", key) } @@ -191,13 +198,13 @@ func TestDifferenceIterator(t *testing.T) { for _, val := range testdata1 { triea.Update([]byte(val.k), []byte(val.v)) } - triea.Commit() + triea.Commit(nil) trieb := newEmpty() for _, val := range testdata2 { trieb.Update([]byte(val.k), []byte(val.v)) } - trieb.Commit() + trieb.Commit(nil) found := make(map[string]string) di, _ := NewDifferenceIterator(triea.NodeIterator(nil), trieb.NodeIterator(nil)) @@ -227,13 +234,13 @@ func TestUnionIterator(t *testing.T) { for _, val := range testdata1 { triea.Update([]byte(val.k), []byte(val.v)) } - triea.Commit() + triea.Commit(nil) trieb := newEmpty() for _, val := range testdata2 { trieb.Update([]byte(val.k), []byte(val.v)) } - trieb.Commit() + trieb.Commit(nil) di, _ := NewUnionIterator([]NodeIterator{triea.NodeIterator(nil), trieb.NodeIterator(nil)}) it := NewIterator(di) @@ -278,43 +285,75 @@ func TestIteratorNoDups(t *testing.T) { } // This test checks that nodeIterator.Next can be retried after inserting missing trie nodes. -func TestIteratorContinueAfterError(t *testing.T) { - db, _ := ethdb.NewMemDatabase() - tr, _ := New(common.Hash{}, db) +func TestIteratorContinueAfterErrorDisk(t *testing.T) { testIteratorContinueAfterError(t, false) } +func TestIteratorContinueAfterErrorMemonly(t *testing.T) { testIteratorContinueAfterError(t, true) } + +func testIteratorContinueAfterError(t *testing.T, memonly bool) { + diskdb, _ := ethdb.NewMemDatabase() + triedb := NewDatabase(diskdb) + + tr, _ := New(common.Hash{}, triedb) for _, val := range testdata1 { tr.Update([]byte(val.k), []byte(val.v)) } - tr.Commit() + tr.Commit(nil) + if !memonly { + triedb.Commit(tr.Hash(), true) + } wantNodeCount := checkIteratorNoDups(t, tr.NodeIterator(nil), nil) - keys := db.Keys() - t.Log("node count", wantNodeCount) + var ( + diskKeys [][]byte + memKeys []common.Hash + ) + if memonly { + memKeys = triedb.Nodes() + } else { + diskKeys = diskdb.Keys() + } for i := 0; i < 20; i++ { // Create trie that will load all nodes from DB. - tr, _ := New(tr.Hash(), db) + tr, _ := New(tr.Hash(), triedb) // Remove a random node from the database. It can't be the root node // because that one is already loaded. - var rkey []byte + var ( + rkey common.Hash + rval []byte + robj *cachedNode + ) for { - if rkey = keys[rand.Intn(len(keys))]; !bytes.Equal(rkey, tr.Hash().Bytes()) { + if memonly { + rkey = memKeys[rand.Intn(len(memKeys))] + } else { + copy(rkey[:], diskKeys[rand.Intn(len(diskKeys))]) + } + if rkey != tr.Hash() { break } } - rval, _ := db.Get(rkey) - db.Delete(rkey) - + if memonly { + robj = triedb.nodes[rkey] + delete(triedb.nodes, rkey) + } else { + rval, _ = diskdb.Get(rkey[:]) + diskdb.Delete(rkey[:]) + } // Iterate until the error is hit. seen := make(map[string]bool) it := tr.NodeIterator(nil) checkIteratorNoDups(t, it, seen) missing, ok := it.Error().(*MissingNodeError) - if !ok || !bytes.Equal(missing.NodeHash[:], rkey) { + if !ok || missing.NodeHash != rkey { t.Fatal("didn't hit missing node, got", it.Error()) } // Add the node back and continue iteration. - db.Put(rkey, rval) + if memonly { + triedb.nodes[rkey] = robj + } else { + diskdb.Put(rkey[:], rval) + } checkIteratorNoDups(t, it, seen) if it.Error() != nil { t.Fatal("unexpected error", it.Error()) @@ -328,21 +367,41 @@ func TestIteratorContinueAfterError(t *testing.T) { // Similar to the test above, this one checks that failure to create nodeIterator at a // certain key prefix behaves correctly when Next is called. The expectation is that Next // should retry seeking before returning true for the first time. -func TestIteratorContinueAfterSeekError(t *testing.T) { +func TestIteratorContinueAfterSeekErrorDisk(t *testing.T) { + testIteratorContinueAfterSeekError(t, false) +} +func TestIteratorContinueAfterSeekErrorMemonly(t *testing.T) { + testIteratorContinueAfterSeekError(t, true) +} + +func testIteratorContinueAfterSeekError(t *testing.T, memonly bool) { // Commit test trie to db, then remove the node containing "bars". - db, _ := ethdb.NewMemDatabase() - ctr, _ := New(common.Hash{}, db) + diskdb, _ := ethdb.NewMemDatabase() + triedb := NewDatabase(diskdb) + + ctr, _ := New(common.Hash{}, triedb) for _, val := range testdata1 { ctr.Update([]byte(val.k), []byte(val.v)) } - root, _ := ctr.Commit() + root, _ := ctr.Commit(nil) + if !memonly { + triedb.Commit(root, true) + } barNodeHash := common.HexToHash("05041990364eb72fcb1127652ce40d8bab765f2bfe53225b1170d276cc101c2e") - barNode, _ := db.Get(barNodeHash[:]) - db.Delete(barNodeHash[:]) - + var ( + barNodeBlob []byte + barNodeObj *cachedNode + ) + if memonly { + barNodeObj = triedb.nodes[barNodeHash] + delete(triedb.nodes, barNodeHash) + } else { + barNodeBlob, _ = diskdb.Get(barNodeHash[:]) + diskdb.Delete(barNodeHash[:]) + } // Create a new iterator that seeks to "bars". Seeking can't proceed because // the node is missing. - tr, _ := New(root, db) + tr, _ := New(root, triedb) it := tr.NodeIterator([]byte("bars")) missing, ok := it.Error().(*MissingNodeError) if !ok { @@ -350,10 +409,12 @@ func TestIteratorContinueAfterSeekError(t *testing.T) { } else if missing.NodeHash != barNodeHash { t.Fatal("wrong node missing") } - // Reinsert the missing node. - db.Put(barNodeHash[:], barNode[:]) - + if memonly { + triedb.nodes[barNodeHash] = barNodeObj + } else { + diskdb.Put(barNodeHash[:], barNodeBlob) + } // Check that iteration produces the right set of values. if err := checkIteratorOrder(testdata1[2:], NewIterator(it)); err != nil { t.Fatal(err) diff --git a/trie/proof.go b/trie/proof.go index 5e886a259..508e4a6cf 100644 --- a/trie/proof.go +++ b/trie/proof.go @@ -22,20 +22,19 @@ import ( "github.com/ethereum/go-ethereum/common" "github.com/ethereum/go-ethereum/crypto" + "github.com/ethereum/go-ethereum/ethdb" "github.com/ethereum/go-ethereum/log" "github.com/ethereum/go-ethereum/rlp" ) -// Prove constructs a merkle proof for key. The result contains all -// encoded nodes on the path to the value at key. The value itself is -// also included in the last node and can be retrieved by verifying -// the proof. +// Prove constructs a merkle proof for key. The result contains all encoded nodes +// on the path to the value at key. The value itself is also included in the last +// node and can be retrieved by verifying the proof. // -// If the trie does not contain a value for key, the returned proof -// contains all nodes of the longest existing prefix of the key -// (at least the root node), ending with the node that proves the -// absence of the key. -func (t *Trie) Prove(key []byte, fromLevel uint, proofDb DatabaseWriter) error { +// If the trie does not contain a value for key, the returned proof contains all +// nodes of the longest existing prefix of the key (at least the root node), ending +// with the node that proves the absence of the key. +func (t *Trie) Prove(key []byte, fromLevel uint, proofDb ethdb.Putter) error { // Collect all nodes on the path to key. key = keybytesToHex(key) nodes := []node{} @@ -66,7 +65,7 @@ func (t *Trie) Prove(key []byte, fromLevel uint, proofDb DatabaseWriter) error { panic(fmt.Sprintf("%T: invalid node: %v", tn, tn)) } } - hasher := newHasher(0, 0) + hasher := newHasher(0, 0, nil) for i, n := range nodes { // Don't bother checking for errors here since hasher panics // if encoding doesn't work and we're not writing to any database. @@ -89,19 +88,29 @@ func (t *Trie) Prove(key []byte, fromLevel uint, proofDb DatabaseWriter) error { return nil } -// VerifyProof checks merkle proofs. The given proof must contain the -// value for key in a trie with the given root hash. VerifyProof -// returns an error if the proof contains invalid trie nodes or the -// wrong value. +// Prove constructs a merkle proof for key. The result contains all encoded nodes +// on the path to the value at key. The value itself is also included in the last +// node and can be retrieved by verifying the proof. +// +// If the trie does not contain a value for key, the returned proof contains all +// nodes of the longest existing prefix of the key (at least the root node), ending +// with the node that proves the absence of the key. +func (t *SecureTrie) Prove(key []byte, fromLevel uint, proofDb ethdb.Putter) error { + return t.trie.Prove(key, fromLevel, proofDb) +} + +// VerifyProof checks merkle proofs. The given proof must contain the value for +// key in a trie with the given root hash. VerifyProof returns an error if the +// proof contains invalid trie nodes or the wrong value. func VerifyProof(rootHash common.Hash, key []byte, proofDb DatabaseReader) (value []byte, err error, nodes int) { key = keybytesToHex(key) - wantHash := rootHash[:] + wantHash := rootHash for i := 0; ; i++ { - buf, _ := proofDb.Get(wantHash) + buf, _ := proofDb.Get(wantHash[:]) if buf == nil { - return nil, fmt.Errorf("proof node %d (hash %064x) missing", i, wantHash[:]), i + return nil, fmt.Errorf("proof node %d (hash %064x) missing", i, wantHash), i } - n, err := decodeNode(wantHash, buf, 0) + n, err := decodeNode(wantHash[:], buf, 0) if err != nil { return nil, fmt.Errorf("bad proof node %d: %v", i, err), i } @@ -112,7 +121,7 @@ func VerifyProof(rootHash common.Hash, key []byte, proofDb DatabaseReader) (valu return nil, nil, i case hashNode: key = keyrest - wantHash = cld + copy(wantHash[:], cld) case valueNode: return cld, nil, i + 1 } diff --git a/trie/secure_trie.go b/trie/secure_trie.go index 20c303f31..3881ee18a 100644 --- a/trie/secure_trie.go +++ b/trie/secure_trie.go @@ -23,10 +23,6 @@ import ( "github.com/ethereum/go-ethereum/log" ) -var secureKeyPrefix = []byte("secure-key-") - -const secureKeyLength = 11 + 32 // Length of the above prefix + 32byte hash - // SecureTrie wraps a trie with key hashing. In a secure trie, all // access operations hash the key using keccak256. This prevents // calling code from creating long chains of nodes that @@ -39,25 +35,25 @@ const secureKeyLength = 11 + 32 // Length of the above prefix + 32byte hash // SecureTrie is not safe for concurrent use. type SecureTrie struct { trie Trie - hashKeyBuf [secureKeyLength]byte - secKeyBuf [200]byte + hashKeyBuf [common.HashLength]byte secKeyCache map[string][]byte secKeyCacheOwner *SecureTrie // Pointer to self, replace the key cache on mismatch } -// NewSecure creates a trie with an existing root node from db. +// NewSecure creates a trie with an existing root node from a backing database +// and optional intermediate in-memory node pool. // // If root is the zero hash or the sha3 hash of an empty string, the // trie is initially empty. Otherwise, New will panic if db is nil // and returns MissingNodeError if the root node cannot be found. // -// Accessing the trie loads nodes from db on demand. +// Accessing the trie loads nodes from the database or node pool on demand. // Loaded nodes are kept around until their 'cache generation' expires. // A new cache generation is created by each call to Commit. // cachelimit sets the number of past cache generations to keep. -func NewSecure(root common.Hash, db Database, cachelimit uint16) (*SecureTrie, error) { +func NewSecure(root common.Hash, db *Database, cachelimit uint16) (*SecureTrie, error) { if db == nil { - panic("NewSecure called with nil database") + panic("trie.NewSecure called without a database") } trie, err := New(root, db) if err != nil { @@ -135,7 +131,7 @@ func (t *SecureTrie) GetKey(shaKey []byte) []byte { if key, ok := t.getSecKeyCache()[string(shaKey)]; ok { return key } - key, _ := t.trie.db.Get(t.secKey(shaKey)) + key, _ := t.trie.db.preimage(common.BytesToHash(shaKey)) return key } @@ -144,8 +140,19 @@ func (t *SecureTrie) GetKey(shaKey []byte) []byte { // // Committing flushes nodes from memory. Subsequent Get calls will load nodes // from the database. -func (t *SecureTrie) Commit() (root common.Hash, err error) { - return t.CommitTo(t.trie.db) +func (t *SecureTrie) Commit(onleaf LeafCallback) (root common.Hash, err error) { + // Write all the pre-images to the actual disk database + if len(t.getSecKeyCache()) > 0 { + t.trie.db.lock.Lock() + for hk, key := range t.secKeyCache { + t.trie.db.insertPreimage(common.BytesToHash([]byte(hk)), key) + } + t.trie.db.lock.Unlock() + + t.secKeyCache = make(map[string][]byte) + } + // Commit the trie to its intermediate node database + return t.trie.Commit(onleaf) } func (t *SecureTrie) Hash() common.Hash { @@ -167,38 +174,11 @@ func (t *SecureTrie) NodeIterator(start []byte) NodeIterator { return t.trie.NodeIterator(start) } -// CommitTo writes all nodes and the secure hash pre-images to the given database. -// Nodes are stored with their sha3 hash as the key. -// -// Committing flushes nodes from memory. Subsequent Get calls will load nodes from -// the trie's database. Calling code must ensure that the changes made to db are -// written back to the trie's attached database before using the trie. -func (t *SecureTrie) CommitTo(db DatabaseWriter) (root common.Hash, err error) { - if len(t.getSecKeyCache()) > 0 { - for hk, key := range t.secKeyCache { - if err := db.Put(t.secKey([]byte(hk)), key); err != nil { - return common.Hash{}, err - } - } - t.secKeyCache = make(map[string][]byte) - } - return t.trie.CommitTo(db) -} - -// secKey returns the database key for the preimage of key, as an ephemeral buffer. -// The caller must not hold onto the return value because it will become -// invalid on the next call to hashKey or secKey. -func (t *SecureTrie) secKey(key []byte) []byte { - buf := append(t.secKeyBuf[:0], secureKeyPrefix...) - buf = append(buf, key...) - return buf -} - // hashKey returns the hash of key as an ephemeral buffer. // The caller must not hold onto the return value because it will become // invalid on the next call to hashKey or secKey. func (t *SecureTrie) hashKey(key []byte) []byte { - h := newHasher(0, 0) + h := newHasher(0, 0, nil) h.sha.Reset() h.sha.Write(key) buf := h.sha.Sum(t.hashKeyBuf[:0]) diff --git a/trie/secure_trie_test.go b/trie/secure_trie_test.go index d74102e2a..aedf5a1cd 100644 --- a/trie/secure_trie_test.go +++ b/trie/secure_trie_test.go @@ -28,16 +28,20 @@ import ( ) func newEmptySecure() *SecureTrie { - db, _ := ethdb.NewMemDatabase() - trie, _ := NewSecure(common.Hash{}, db, 0) + diskdb, _ := ethdb.NewMemDatabase() + triedb := NewDatabase(diskdb) + + trie, _ := NewSecure(common.Hash{}, triedb, 0) return trie } // makeTestSecureTrie creates a large enough secure trie for testing. -func makeTestSecureTrie() (ethdb.Database, *SecureTrie, map[string][]byte) { +func makeTestSecureTrie() (*Database, *SecureTrie, map[string][]byte) { // Create an empty trie - db, _ := ethdb.NewMemDatabase() - trie, _ := NewSecure(common.Hash{}, db, 0) + diskdb, _ := ethdb.NewMemDatabase() + triedb := NewDatabase(diskdb) + + trie, _ := NewSecure(common.Hash{}, triedb, 0) // Fill it with some arbitrary data content := make(map[string][]byte) @@ -58,10 +62,10 @@ func makeTestSecureTrie() (ethdb.Database, *SecureTrie, map[string][]byte) { trie.Update(key, val) } } - trie.Commit() + trie.Commit(nil) // Return the generated trie - return db, trie, content + return triedb, trie, content } func TestSecureDelete(t *testing.T) { @@ -137,7 +141,7 @@ func TestSecureTrieConcurrency(t *testing.T) { tries[index].Update(key, val) } } - tries[index].Commit() + tries[index].Commit(nil) }(i) } // Wait for all threads to finish diff --git a/trie/sync.go b/trie/sync.go index fea10051f..b573a9f73 100644 --- a/trie/sync.go +++ b/trie/sync.go @@ -21,6 +21,7 @@ import ( "fmt" "github.com/ethereum/go-ethereum/common" + "github.com/ethereum/go-ethereum/ethdb" "gopkg.in/karalabe/cookiejar.v2/collections/prque" ) @@ -42,7 +43,7 @@ type request struct { depth int // Depth level within the trie the node is located to prioritise DFS deps int // Number of dependencies before allowed to commit this node - callback TrieSyncLeafCallback // Callback to invoke if a leaf node it reached on this branch + callback LeafCallback // Callback to invoke if a leaf node it reached on this branch } // SyncResult is a simple list to return missing nodes along with their request @@ -67,11 +68,6 @@ func newSyncMemBatch() *syncMemBatch { } } -// TrieSyncLeafCallback is a callback type invoked when a trie sync reaches a -// leaf node. It's used by state syncing to check if the leaf node requires some -// further data syncing. -type TrieSyncLeafCallback func(leaf []byte, parent common.Hash) error - // TrieSync is the main state trie synchronisation scheduler, which provides yet // unknown trie hashes to retrieve, accepts node data associated with said hashes // and reconstructs the trie step by step until all is done. @@ -83,7 +79,7 @@ type TrieSync struct { } // NewTrieSync creates a new trie data download scheduler. -func NewTrieSync(root common.Hash, database DatabaseReader, callback TrieSyncLeafCallback) *TrieSync { +func NewTrieSync(root common.Hash, database DatabaseReader, callback LeafCallback) *TrieSync { ts := &TrieSync{ database: database, membatch: newSyncMemBatch(), @@ -95,7 +91,7 @@ func NewTrieSync(root common.Hash, database DatabaseReader, callback TrieSyncLea } // AddSubTrie registers a new trie to the sync code, rooted at the designated parent. -func (s *TrieSync) AddSubTrie(root common.Hash, depth int, parent common.Hash, callback TrieSyncLeafCallback) { +func (s *TrieSync) AddSubTrie(root common.Hash, depth int, parent common.Hash, callback LeafCallback) { // Short circuit if the trie is empty or already known if root == emptyRoot { return @@ -217,7 +213,7 @@ func (s *TrieSync) Process(results []SyncResult) (bool, int, error) { // Commit flushes the data stored in the internal membatch out to persistent // storage, returning th enumber of items written and any occurred error. -func (s *TrieSync) Commit(dbw DatabaseWriter) (int, error) { +func (s *TrieSync) Commit(dbw ethdb.Putter) (int, error) { // Dump the membatch into a database dbw for i, key := range s.membatch.order { if err := dbw.Put(key[:], s.membatch.batch[key]); err != nil { diff --git a/trie/sync_test.go b/trie/sync_test.go index ec16a25bd..4a720612b 100644 --- a/trie/sync_test.go +++ b/trie/sync_test.go @@ -25,10 +25,11 @@ import ( ) // makeTestTrie create a sample test trie to test node-wise reconstruction. -func makeTestTrie() (ethdb.Database, *Trie, map[string][]byte) { +func makeTestTrie() (*Database, *Trie, map[string][]byte) { // Create an empty trie - db, _ := ethdb.NewMemDatabase() - trie, _ := New(common.Hash{}, db) + diskdb, _ := ethdb.NewMemDatabase() + triedb := NewDatabase(diskdb) + trie, _ := New(common.Hash{}, triedb) // Fill it with some arbitrary data content := make(map[string][]byte) @@ -49,15 +50,15 @@ func makeTestTrie() (ethdb.Database, *Trie, map[string][]byte) { trie.Update(key, val) } } - trie.Commit() + trie.Commit(nil) // Return the generated trie - return db, trie, content + return triedb, trie, content } // checkTrieContents cross references a reconstructed trie with an expected data // content map. -func checkTrieContents(t *testing.T, db Database, root []byte, content map[string][]byte) { +func checkTrieContents(t *testing.T, db *Database, root []byte, content map[string][]byte) { // Check root availability and trie contents trie, err := New(common.BytesToHash(root), db) if err != nil { @@ -74,7 +75,7 @@ func checkTrieContents(t *testing.T, db Database, root []byte, content map[strin } // checkTrieConsistency checks that all nodes in a trie are indeed present. -func checkTrieConsistency(db Database, root common.Hash) error { +func checkTrieConsistency(db *Database, root common.Hash) error { // Create and iterate a trie rooted in a subnode trie, err := New(root, db) if err != nil { @@ -88,12 +89,18 @@ func checkTrieConsistency(db Database, root common.Hash) error { // Tests that an empty trie is not scheduled for syncing. func TestEmptyTrieSync(t *testing.T) { - emptyA, _ := New(common.Hash{}, nil) - emptyB, _ := New(emptyRoot, nil) + diskdbA, _ := ethdb.NewMemDatabase() + triedbA := NewDatabase(diskdbA) + + diskdbB, _ := ethdb.NewMemDatabase() + triedbB := NewDatabase(diskdbB) + + emptyA, _ := New(common.Hash{}, triedbA) + emptyB, _ := New(emptyRoot, triedbB) for i, trie := range []*Trie{emptyA, emptyB} { - db, _ := ethdb.NewMemDatabase() - if req := NewTrieSync(common.BytesToHash(trie.Root()), db, nil).Missing(1); len(req) != 0 { + diskdb, _ := ethdb.NewMemDatabase() + if req := NewTrieSync(trie.Hash(), diskdb, nil).Missing(1); len(req) != 0 { t.Errorf("test %d: content requested for empty trie: %v", i, req) } } @@ -109,14 +116,15 @@ func testIterativeTrieSync(t *testing.T, batch int) { srcDb, srcTrie, srcData := makeTestTrie() // Create a destination trie and sync with the scheduler - dstDb, _ := ethdb.NewMemDatabase() - sched := NewTrieSync(common.BytesToHash(srcTrie.Root()), dstDb, nil) + diskdb, _ := ethdb.NewMemDatabase() + triedb := NewDatabase(diskdb) + sched := NewTrieSync(srcTrie.Hash(), diskdb, nil) queue := append([]common.Hash{}, sched.Missing(batch)...) for len(queue) > 0 { results := make([]SyncResult, len(queue)) for i, hash := range queue { - data, err := srcDb.Get(hash.Bytes()) + data, err := srcDb.Node(hash) if err != nil { t.Fatalf("failed to retrieve node data for %x: %v", hash, err) } @@ -125,13 +133,13 @@ func testIterativeTrieSync(t *testing.T, batch int) { if _, index, err := sched.Process(results); err != nil { t.Fatalf("failed to process result #%d: %v", index, err) } - if index, err := sched.Commit(dstDb); err != nil { + if index, err := sched.Commit(diskdb); err != nil { t.Fatalf("failed to commit data #%d: %v", index, err) } queue = append(queue[:0], sched.Missing(batch)...) } // Cross check that the two tries are in sync - checkTrieContents(t, dstDb, srcTrie.Root(), srcData) + checkTrieContents(t, triedb, srcTrie.Root(), srcData) } // Tests that the trie scheduler can correctly reconstruct the state even if only @@ -141,15 +149,16 @@ func TestIterativeDelayedTrieSync(t *testing.T) { srcDb, srcTrie, srcData := makeTestTrie() // Create a destination trie and sync with the scheduler - dstDb, _ := ethdb.NewMemDatabase() - sched := NewTrieSync(common.BytesToHash(srcTrie.Root()), dstDb, nil) + diskdb, _ := ethdb.NewMemDatabase() + triedb := NewDatabase(diskdb) + sched := NewTrieSync(srcTrie.Hash(), diskdb, nil) queue := append([]common.Hash{}, sched.Missing(10000)...) for len(queue) > 0 { // Sync only half of the scheduled nodes results := make([]SyncResult, len(queue)/2+1) for i, hash := range queue[:len(results)] { - data, err := srcDb.Get(hash.Bytes()) + data, err := srcDb.Node(hash) if err != nil { t.Fatalf("failed to retrieve node data for %x: %v", hash, err) } @@ -158,13 +167,13 @@ func TestIterativeDelayedTrieSync(t *testing.T) { if _, index, err := sched.Process(results); err != nil { t.Fatalf("failed to process result #%d: %v", index, err) } - if index, err := sched.Commit(dstDb); err != nil { + if index, err := sched.Commit(diskdb); err != nil { t.Fatalf("failed to commit data #%d: %v", index, err) } queue = append(queue[len(results):], sched.Missing(10000)...) } // Cross check that the two tries are in sync - checkTrieContents(t, dstDb, srcTrie.Root(), srcData) + checkTrieContents(t, triedb, srcTrie.Root(), srcData) } // Tests that given a root hash, a trie can sync iteratively on a single thread, @@ -178,8 +187,9 @@ func testIterativeRandomTrieSync(t *testing.T, batch int) { srcDb, srcTrie, srcData := makeTestTrie() // Create a destination trie and sync with the scheduler - dstDb, _ := ethdb.NewMemDatabase() - sched := NewTrieSync(common.BytesToHash(srcTrie.Root()), dstDb, nil) + diskdb, _ := ethdb.NewMemDatabase() + triedb := NewDatabase(diskdb) + sched := NewTrieSync(srcTrie.Hash(), diskdb, nil) queue := make(map[common.Hash]struct{}) for _, hash := range sched.Missing(batch) { @@ -189,7 +199,7 @@ func testIterativeRandomTrieSync(t *testing.T, batch int) { // Fetch all the queued nodes in a random order results := make([]SyncResult, 0, len(queue)) for hash := range queue { - data, err := srcDb.Get(hash.Bytes()) + data, err := srcDb.Node(hash) if err != nil { t.Fatalf("failed to retrieve node data for %x: %v", hash, err) } @@ -199,7 +209,7 @@ func testIterativeRandomTrieSync(t *testing.T, batch int) { if _, index, err := sched.Process(results); err != nil { t.Fatalf("failed to process result #%d: %v", index, err) } - if index, err := sched.Commit(dstDb); err != nil { + if index, err := sched.Commit(diskdb); err != nil { t.Fatalf("failed to commit data #%d: %v", index, err) } queue = make(map[common.Hash]struct{}) @@ -208,7 +218,7 @@ func testIterativeRandomTrieSync(t *testing.T, batch int) { } } // Cross check that the two tries are in sync - checkTrieContents(t, dstDb, srcTrie.Root(), srcData) + checkTrieContents(t, triedb, srcTrie.Root(), srcData) } // Tests that the trie scheduler can correctly reconstruct the state even if only @@ -218,8 +228,9 @@ func TestIterativeRandomDelayedTrieSync(t *testing.T) { srcDb, srcTrie, srcData := makeTestTrie() // Create a destination trie and sync with the scheduler - dstDb, _ := ethdb.NewMemDatabase() - sched := NewTrieSync(common.BytesToHash(srcTrie.Root()), dstDb, nil) + diskdb, _ := ethdb.NewMemDatabase() + triedb := NewDatabase(diskdb) + sched := NewTrieSync(srcTrie.Hash(), diskdb, nil) queue := make(map[common.Hash]struct{}) for _, hash := range sched.Missing(10000) { @@ -229,7 +240,7 @@ func TestIterativeRandomDelayedTrieSync(t *testing.T) { // Sync only half of the scheduled nodes, even those in random order results := make([]SyncResult, 0, len(queue)/2+1) for hash := range queue { - data, err := srcDb.Get(hash.Bytes()) + data, err := srcDb.Node(hash) if err != nil { t.Fatalf("failed to retrieve node data for %x: %v", hash, err) } @@ -243,7 +254,7 @@ func TestIterativeRandomDelayedTrieSync(t *testing.T) { if _, index, err := sched.Process(results); err != nil { t.Fatalf("failed to process result #%d: %v", index, err) } - if index, err := sched.Commit(dstDb); err != nil { + if index, err := sched.Commit(diskdb); err != nil { t.Fatalf("failed to commit data #%d: %v", index, err) } for _, result := range results { @@ -254,7 +265,7 @@ func TestIterativeRandomDelayedTrieSync(t *testing.T) { } } // Cross check that the two tries are in sync - checkTrieContents(t, dstDb, srcTrie.Root(), srcData) + checkTrieContents(t, triedb, srcTrie.Root(), srcData) } // Tests that a trie sync will not request nodes multiple times, even if they @@ -264,8 +275,9 @@ func TestDuplicateAvoidanceTrieSync(t *testing.T) { srcDb, srcTrie, srcData := makeTestTrie() // Create a destination trie and sync with the scheduler - dstDb, _ := ethdb.NewMemDatabase() - sched := NewTrieSync(common.BytesToHash(srcTrie.Root()), dstDb, nil) + diskdb, _ := ethdb.NewMemDatabase() + triedb := NewDatabase(diskdb) + sched := NewTrieSync(srcTrie.Hash(), diskdb, nil) queue := append([]common.Hash{}, sched.Missing(0)...) requested := make(map[common.Hash]struct{}) @@ -273,7 +285,7 @@ func TestDuplicateAvoidanceTrieSync(t *testing.T) { for len(queue) > 0 { results := make([]SyncResult, len(queue)) for i, hash := range queue { - data, err := srcDb.Get(hash.Bytes()) + data, err := srcDb.Node(hash) if err != nil { t.Fatalf("failed to retrieve node data for %x: %v", hash, err) } @@ -287,13 +299,13 @@ func TestDuplicateAvoidanceTrieSync(t *testing.T) { if _, index, err := sched.Process(results); err != nil { t.Fatalf("failed to process result #%d: %v", index, err) } - if index, err := sched.Commit(dstDb); err != nil { + if index, err := sched.Commit(diskdb); err != nil { t.Fatalf("failed to commit data #%d: %v", index, err) } queue = append(queue[:0], sched.Missing(0)...) } // Cross check that the two tries are in sync - checkTrieContents(t, dstDb, srcTrie.Root(), srcData) + checkTrieContents(t, triedb, srcTrie.Root(), srcData) } // Tests that at any point in time during a sync, only complete sub-tries are in @@ -303,8 +315,9 @@ func TestIncompleteTrieSync(t *testing.T) { srcDb, srcTrie, _ := makeTestTrie() // Create a destination trie and sync with the scheduler - dstDb, _ := ethdb.NewMemDatabase() - sched := NewTrieSync(common.BytesToHash(srcTrie.Root()), dstDb, nil) + diskdb, _ := ethdb.NewMemDatabase() + triedb := NewDatabase(diskdb) + sched := NewTrieSync(srcTrie.Hash(), diskdb, nil) added := []common.Hash{} queue := append([]common.Hash{}, sched.Missing(1)...) @@ -312,7 +325,7 @@ func TestIncompleteTrieSync(t *testing.T) { // Fetch a batch of trie nodes results := make([]SyncResult, len(queue)) for i, hash := range queue { - data, err := srcDb.Get(hash.Bytes()) + data, err := srcDb.Node(hash) if err != nil { t.Fatalf("failed to retrieve node data for %x: %v", hash, err) } @@ -322,7 +335,7 @@ func TestIncompleteTrieSync(t *testing.T) { if _, index, err := sched.Process(results); err != nil { t.Fatalf("failed to process result #%d: %v", index, err) } - if index, err := sched.Commit(dstDb); err != nil { + if index, err := sched.Commit(diskdb); err != nil { t.Fatalf("failed to commit data #%d: %v", index, err) } for _, result := range results { @@ -330,7 +343,7 @@ func TestIncompleteTrieSync(t *testing.T) { } // Check that all known sub-tries in the synced trie are complete for _, root := range added { - if err := checkTrieConsistency(dstDb, root); err != nil { + if err := checkTrieConsistency(triedb, root); err != nil { t.Fatalf("trie inconsistent: %v", err) } } @@ -340,12 +353,12 @@ func TestIncompleteTrieSync(t *testing.T) { // Sanity check that removing any node from the database is detected for _, node := range added[1:] { key := node.Bytes() - value, _ := dstDb.Get(key) + value, _ := diskdb.Get(key) - dstDb.Delete(key) - if err := checkTrieConsistency(dstDb, added[0]); err == nil { + diskdb.Delete(key) + if err := checkTrieConsistency(triedb, added[0]); err == nil { t.Fatalf("trie inconsistency not caught, missing: %x", key) } - dstDb.Put(key, value) + diskdb.Put(key, value) } } diff --git a/trie/trie.go b/trie/trie.go index 8fe98d835..e37a1ae10 100644 --- a/trie/trie.go +++ b/trie/trie.go @@ -22,16 +22,17 @@ import ( "fmt" "github.com/ethereum/go-ethereum/common" - "github.com/ethereum/go-ethereum/crypto/sha3" + "github.com/ethereum/go-ethereum/crypto" "github.com/ethereum/go-ethereum/log" "github.com/rcrowley/go-metrics" ) var ( - // This is the known root hash of an empty trie. + // emptyRoot is the known root hash of an empty trie. emptyRoot = common.HexToHash("56e81f171bcc55a6ff8345e692c0f86e5b48e01b996cadc001622fb5e363b421") - // This is the known hash of an empty state trie entry. - emptyState common.Hash + + // emptyState is the known hash of an empty state trie entry. + emptyState = crypto.Keccak256Hash(nil) ) var ( @@ -53,29 +54,10 @@ func CacheUnloads() int64 { return cacheUnloadCounter.Count() } -func init() { - sha3.NewKeccak256().Sum(emptyState[:0]) -} - -// Database must be implemented by backing stores for the trie. -type Database interface { - DatabaseReader - DatabaseWriter -} - -// DatabaseReader wraps the Get method of a backing store for the trie. -type DatabaseReader interface { - Get(key []byte) (value []byte, err error) - Has(key []byte) (bool, error) -} - -// DatabaseWriter wraps the Put method of a backing store for the trie. -type DatabaseWriter interface { - // Put stores the mapping key->value in the database. - // Implementations must not hold onto the value bytes, the trie - // will reuse the slice across calls to Put. - Put(key, value []byte) error -} +// LeafCallback is a callback type invoked when a trie operation reaches a leaf +// node. It's used by state sync and commit to allow handling external references +// between account and storage tries. +type LeafCallback func(leaf []byte, parent common.Hash) error // Trie is a Merkle Patricia Trie. // The zero value is an empty trie with no database. @@ -83,8 +65,8 @@ type DatabaseWriter interface { // // Trie is not safe for concurrent use. type Trie struct { + db *Database root node - db Database originalRoot common.Hash // Cache generation values. @@ -111,12 +93,15 @@ func (t *Trie) newFlag() nodeFlag { // trie is initially empty and does not require a database. Otherwise, // New will panic if db is nil and returns a MissingNodeError if root does // not exist in the database. Accessing the trie loads nodes from db on demand. -func New(root common.Hash, db Database) (*Trie, error) { - trie := &Trie{db: db, originalRoot: root} +func New(root common.Hash, db *Database) (*Trie, error) { + if db == nil { + panic("trie.New called without a database") + } + trie := &Trie{ + db: db, + originalRoot: root, + } if (root != common.Hash{}) && root != emptyRoot { - if db == nil { - panic("trie.New: cannot use existing root without a database") - } rootnode, err := trie.resolveHash(root[:], nil) if err != nil { return nil, err @@ -447,12 +432,13 @@ func (t *Trie) resolve(n node, prefix []byte) (node, error) { func (t *Trie) resolveHash(n hashNode, prefix []byte) (node, error) { cacheMissCounter.Inc(1) - enc, err := t.db.Get(n) + hash := common.BytesToHash(n) + + enc, err := t.db.Node(hash) if err != nil || enc == nil { - return nil, &MissingNodeError{NodeHash: common.BytesToHash(n), Path: prefix} + return nil, &MissingNodeError{NodeHash: hash, Path: prefix} } - dec := mustDecodeNode(n, enc, t.cachegen) - return dec, nil + return mustDecodeNode(n, enc, t.cachegen), nil } // Root returns the root hash of the trie. @@ -462,32 +448,18 @@ func (t *Trie) Root() []byte { return t.Hash().Bytes() } // Hash returns the root hash of the trie. It does not write to the // database and can be used even if the trie doesn't have one. func (t *Trie) Hash() common.Hash { - hash, cached, _ := t.hashRoot(nil) + hash, cached, _ := t.hashRoot(nil, nil) t.root = cached return common.BytesToHash(hash.(hashNode)) } -// Commit writes all nodes to the trie's database. -// Nodes are stored with their sha3 hash as the key. -// -// Committing flushes nodes from memory. -// Subsequent Get calls will load nodes from the database. -func (t *Trie) Commit() (root common.Hash, err error) { +// Commit writes all nodes to the trie's memory database, tracking the internal +// and external (for account tries) references. +func (t *Trie) Commit(onleaf LeafCallback) (root common.Hash, err error) { if t.db == nil { - panic("Commit called on trie with nil database") + panic("commit called on trie with nil database") } - return t.CommitTo(t.db) -} - -// CommitTo writes all nodes to the given database. -// Nodes are stored with their sha3 hash as the key. -// -// Committing flushes nodes from memory. Subsequent Get calls will -// load nodes from the trie's database. Calling code must ensure that -// the changes made to db are written back to the trie's attached -// database before using the trie. -func (t *Trie) CommitTo(db DatabaseWriter) (root common.Hash, err error) { - hash, cached, err := t.hashRoot(db) + hash, cached, err := t.hashRoot(t.db, onleaf) if err != nil { return common.Hash{}, err } @@ -496,11 +468,11 @@ func (t *Trie) CommitTo(db DatabaseWriter) (root common.Hash, err error) { return common.BytesToHash(hash.(hashNode)), nil } -func (t *Trie) hashRoot(db DatabaseWriter) (node, node, error) { +func (t *Trie) hashRoot(db *Database, onleaf LeafCallback) (node, node, error) { if t.root == nil { return hashNode(emptyRoot.Bytes()), nil, nil } - h := newHasher(t.cachegen, t.cachelimit) + h := newHasher(t.cachegen, t.cachelimit, onleaf) defer returnHasherToPool(h) return h.hash(t.root, db, true) } diff --git a/trie/trie_test.go b/trie/trie_test.go index 1e28c3bc4..997222628 100644 --- a/trie/trie_test.go +++ b/trie/trie_test.go @@ -43,8 +43,8 @@ func init() { // Used for testing func newEmpty() *Trie { - db, _ := ethdb.NewMemDatabase() - trie, _ := New(common.Hash{}, db) + diskdb, _ := ethdb.NewMemDatabase() + trie, _ := New(common.Hash{}, NewDatabase(diskdb)) return trie } @@ -68,8 +68,8 @@ func TestNull(t *testing.T) { } func TestMissingRoot(t *testing.T) { - db, _ := ethdb.NewMemDatabase() - trie, err := New(common.HexToHash("0beec7b5ea3f0fdbc95d0dd47f3c5bc275da8a33"), db) + diskdb, _ := ethdb.NewMemDatabase() + trie, err := New(common.HexToHash("0beec7b5ea3f0fdbc95d0dd47f3c5bc275da8a33"), NewDatabase(diskdb)) if trie != nil { t.Error("New returned non-nil trie for invalid root") } @@ -78,70 +78,75 @@ func TestMissingRoot(t *testing.T) { } } -func TestMissingNode(t *testing.T) { - db, _ := ethdb.NewMemDatabase() - trie, _ := New(common.Hash{}, db) +func TestMissingNodeDisk(t *testing.T) { testMissingNode(t, false) } +func TestMissingNodeMemonly(t *testing.T) { testMissingNode(t, true) } + +func testMissingNode(t *testing.T, memonly bool) { + diskdb, _ := ethdb.NewMemDatabase() + triedb := NewDatabase(diskdb) + + trie, _ := New(common.Hash{}, triedb) updateString(trie, "120000", "qwerqwerqwerqwerqwerqwerqwerqwer") updateString(trie, "123456", "asdfasdfasdfasdfasdfasdfasdfasdf") - root, _ := trie.Commit() + root, _ := trie.Commit(nil) + if !memonly { + triedb.Commit(root, true) + } - trie, _ = New(root, db) + trie, _ = New(root, triedb) _, err := trie.TryGet([]byte("120000")) if err != nil { t.Errorf("Unexpected error: %v", err) } - - trie, _ = New(root, db) + trie, _ = New(root, triedb) _, err = trie.TryGet([]byte("120099")) if err != nil { t.Errorf("Unexpected error: %v", err) } - - trie, _ = New(root, db) + trie, _ = New(root, triedb) _, err = trie.TryGet([]byte("123456")) if err != nil { t.Errorf("Unexpected error: %v", err) } - - trie, _ = New(root, db) + trie, _ = New(root, triedb) err = trie.TryUpdate([]byte("120099"), []byte("zxcvzxcvzxcvzxcvzxcvzxcvzxcvzxcv")) if err != nil { t.Errorf("Unexpected error: %v", err) } - - trie, _ = New(root, db) + trie, _ = New(root, triedb) err = trie.TryDelete([]byte("123456")) if err != nil { t.Errorf("Unexpected error: %v", err) } - db.Delete(common.FromHex("e1d943cc8f061a0c0b98162830b970395ac9315654824bf21b73b891365262f9")) + hash := common.HexToHash("0xe1d943cc8f061a0c0b98162830b970395ac9315654824bf21b73b891365262f9") + if memonly { + delete(triedb.nodes, hash) + } else { + diskdb.Delete(hash[:]) + } - trie, _ = New(root, db) + trie, _ = New(root, triedb) _, err = trie.TryGet([]byte("120000")) if _, ok := err.(*MissingNodeError); !ok { t.Errorf("Wrong error: %v", err) } - - trie, _ = New(root, db) + trie, _ = New(root, triedb) _, err = trie.TryGet([]byte("120099")) if _, ok := err.(*MissingNodeError); !ok { t.Errorf("Wrong error: %v", err) } - - trie, _ = New(root, db) + trie, _ = New(root, triedb) _, err = trie.TryGet([]byte("123456")) if err != nil { t.Errorf("Unexpected error: %v", err) } - - trie, _ = New(root, db) + trie, _ = New(root, triedb) err = trie.TryUpdate([]byte("120099"), []byte("zxcv")) if _, ok := err.(*MissingNodeError); !ok { t.Errorf("Wrong error: %v", err) } - - trie, _ = New(root, db) + trie, _ = New(root, triedb) err = trie.TryDelete([]byte("123456")) if _, ok := err.(*MissingNodeError); !ok { t.Errorf("Wrong error: %v", err) @@ -165,7 +170,7 @@ func TestInsert(t *testing.T) { updateString(trie, "A", "aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa") exp = common.HexToHash("d23786fb4a010da3ce639d66d5e904a11dbc02746d1ce25029e53290cabf28ab") - root, err := trie.Commit() + root, err := trie.Commit(nil) if err != nil { t.Fatalf("commit error: %v", err) } @@ -194,7 +199,7 @@ func TestGet(t *testing.T) { if i == 1 { return } - trie.Commit() + trie.Commit(nil) } } @@ -263,7 +268,7 @@ func TestReplication(t *testing.T) { for _, val := range vals { updateString(trie, val.k, val.v) } - exp, err := trie.Commit() + exp, err := trie.Commit(nil) if err != nil { t.Fatalf("commit error: %v", err) } @@ -278,7 +283,7 @@ func TestReplication(t *testing.T) { t.Errorf("trie2 doesn't have %q => %q", kv.k, kv.v) } } - hash, err := trie2.Commit() + hash, err := trie2.Commit(nil) if err != nil { t.Fatalf("commit error: %v", err) } @@ -314,7 +319,7 @@ func TestLargeValue(t *testing.T) { } type countingDB struct { - Database + ethdb.Database gets map[string]int } @@ -332,19 +337,20 @@ func TestCacheUnload(t *testing.T) { key2 := "---some other branch" updateString(trie, key1, "this is the branch of key1.") updateString(trie, key2, "this is the branch of key2.") - root, _ := trie.Commit() + + root, _ := trie.Commit(nil) + trie.db.Commit(root, true) // Commit the trie repeatedly and access key1. // The branch containing it is loaded from DB exactly two times: // in the 0th and 6th iteration. - db := &countingDB{Database: trie.db, gets: make(map[string]int)} - trie, _ = New(root, db) + db := &countingDB{Database: trie.db.diskdb, gets: make(map[string]int)} + trie, _ = New(root, NewDatabase(db)) trie.SetCacheLimit(5) for i := 0; i < 12; i++ { getString(trie, key1) - trie.Commit() + trie.Commit(nil) } - // Check that it got loaded two times. for dbkey, count := range db.gets { if count != 2 { @@ -407,8 +413,10 @@ func (randTest) Generate(r *rand.Rand, size int) reflect.Value { } func runRandTest(rt randTest) bool { - db, _ := ethdb.NewMemDatabase() - tr, _ := New(common.Hash{}, db) + diskdb, _ := ethdb.NewMemDatabase() + triedb := NewDatabase(diskdb) + + tr, _ := New(common.Hash{}, triedb) values := make(map[string]string) // tracks content of the trie for i, step := range rt { @@ -426,23 +434,23 @@ func runRandTest(rt randTest) bool { rt[i].err = fmt.Errorf("mismatch for key 0x%x, got 0x%x want 0x%x", step.key, v, want) } case opCommit: - _, rt[i].err = tr.Commit() + _, rt[i].err = tr.Commit(nil) case opHash: tr.Hash() case opReset: - hash, err := tr.Commit() + hash, err := tr.Commit(nil) if err != nil { rt[i].err = err return false } - newtr, err := New(hash, db) + newtr, err := New(hash, triedb) if err != nil { rt[i].err = err return false } tr = newtr case opItercheckhash: - checktr, _ := New(common.Hash{}, nil) + checktr, _ := New(common.Hash{}, triedb) it := NewIterator(tr.NodeIterator(nil)) for it.Next() { checktr.Update(it.Key, it.Value) @@ -524,7 +532,7 @@ func benchGet(b *testing.B, commit bool) { } binary.LittleEndian.PutUint64(k, benchElemCount/2) if commit { - trie.Commit() + trie.Commit(nil) } b.ResetTimer() @@ -534,7 +542,7 @@ func benchGet(b *testing.B, commit bool) { b.StopTimer() if commit { - ldb := trie.db.(*ethdb.LDBDatabase) + ldb := trie.db.diskdb.(*ethdb.LDBDatabase) ldb.Close() os.RemoveAll(ldb.Path()) } @@ -585,16 +593,16 @@ func BenchmarkHash(b *testing.B) { trie.Hash() } -func tempDB() (string, Database) { +func tempDB() (string, *Database) { dir, err := ioutil.TempDir("", "trie-bench") if err != nil { panic(fmt.Sprintf("can't create temporary directory: %v", err)) } - db, err := ethdb.NewLDBDatabase(dir, 256, 0) + diskdb, err := ethdb.NewLDBDatabase(dir, 256, 0) if err != nil { panic(fmt.Sprintf("can't create temporary database: %v", err)) } - return dir, db + return dir, NewDatabase(diskdb) } func getString(trie *Trie, k string) []byte { |