aboutsummaryrefslogtreecommitdiffstats
path: root/trie
diff options
context:
space:
mode:
authorFelix Lange <fjl@twurst.com>2016-10-15 00:04:33 +0800
committerPéter Szilágyi <peterke@gmail.com>2016-10-15 00:04:33 +0800
commit40cdcf1183df235e4b32cfdbf6182a00a0e49f24 (patch)
tree571daa75d1590a47dfe98b67ab1a3350e46ea527 /trie
parentc2ddfb343a22958324c0c26dae789d3937eece4f (diff)
downloadgo-tangerine-40cdcf1183df235e4b32cfdbf6182a00a0e49f24.tar.gz
go-tangerine-40cdcf1183df235e4b32cfdbf6182a00a0e49f24.tar.zst
go-tangerine-40cdcf1183df235e4b32cfdbf6182a00a0e49f24.zip
trie, core/state: improve memory usage and performance (#3135)
* trie: store nodes as pointers This avoids memory copies when unwrapping node interface values. name old time/op new time/op delta Get 388ns ± 8% 215ns ± 2% -44.56% (p=0.000 n=15+15) GetDB 363ns ± 3% 202ns ± 2% -44.21% (p=0.000 n=15+15) UpdateBE 1.57µs ± 2% 1.29µs ± 3% -17.80% (p=0.000 n=13+15) UpdateLE 1.92µs ± 2% 1.61µs ± 2% -16.25% (p=0.000 n=14+14) HashBE 2.16µs ± 6% 2.18µs ± 6% ~ (p=0.436 n=15+15) HashLE 7.43µs ± 3% 7.21µs ± 3% -2.96% (p=0.000 n=15+13) * trie: close temporary databases in GetDB benchmark * trie: don't keep []byte from DB load around Nodes decoded from a DB load kept hashes and values as sub-slices of the DB value. This can be a problem because loading from leveldb often returns []byte with a cap that's larger than necessary, increasing memory usage. * trie: unload old cached nodes * trie, core/state: use cache unloading for account trie * trie: use explicit private flags (fixes Go 1.5 reflection issue). * trie: fixup cachegen overflow at request of nick * core/state: rename journal size constant
Diffstat (limited to 'trie')
-rw-r--r--trie/hasher.go78
-rw-r--r--trie/iterator.go12
-rw-r--r--trie/node.go63
-rw-r--r--trie/node_test.go58
-rw-r--r--trie/proof.go10
-rw-r--r--trie/secure_trie.go13
-rw-r--r--trie/secure_trie_test.go4
-rw-r--r--trie/sync.go4
-rw-r--r--trie/trie.go114
-rw-r--r--trie/trie_test.go10
10 files changed, 235 insertions, 131 deletions
diff --git a/trie/hasher.go b/trie/hasher.go
index 87e02fb85..e395e00d7 100644
--- a/trie/hasher.go
+++ b/trie/hasher.go
@@ -27,8 +27,9 @@ import (
)
type hasher struct {
- tmp *bytes.Buffer
- sha hash.Hash
+ tmp *bytes.Buffer
+ sha hash.Hash
+ cachegen, cachelimit uint16
}
// hashers live in a global pool.
@@ -38,8 +39,10 @@ var hasherPool = sync.Pool{
},
}
-func newHasher() *hasher {
- return hasherPool.Get().(*hasher)
+func newHasher(cachegen, cachelimit uint16) *hasher {
+ h := hasherPool.Get().(*hasher)
+ h.cachegen, h.cachelimit = cachegen, cachelimit
+ return h
}
func returnHasherToPool(h *hasher) {
@@ -50,8 +53,18 @@ func returnHasherToPool(h *hasher) {
// original node initialzied with the computed hash to replace the original one.
func (h *hasher) hash(n node, db DatabaseWriter, force bool) (node, node, error) {
// If we're not storing the node, just hashing, use avaialble cached data
- if hash, dirty := n.cache(); hash != nil && (db == nil || !dirty) {
- return hash, n, nil
+ if hash, dirty := n.cache(); hash != nil {
+ if db == nil {
+ return hash, n, nil
+ }
+ if n.canUnload(h.cachegen, h.cachelimit) {
+ // Evict the node from cache. All of its subnodes will have a lower or equal
+ // cache generation number.
+ return hash, hash, nil
+ }
+ if !dirty {
+ return hash, n, nil
+ }
}
// Trie not processed yet or needs storage, walk the children
collapsed, cached, err := h.hashChildren(n, db)
@@ -62,19 +75,21 @@ func (h *hasher) hash(n node, db DatabaseWriter, force bool) (node, node, error)
if err != nil {
return hashNode{}, n, err
}
- // Cache the hash and RLP blob of the ndoe for later reuse
+ // Cache the hash of the ndoe for later reuse.
if hash, ok := hashed.(hashNode); ok && !force {
switch cached := cached.(type) {
- case shortNode:
- cached.hash = hash
+ case *shortNode:
+ cached = cached.copy()
+ cached.flags.hash = hash
if db != nil {
- cached.dirty = false
+ cached.flags.dirty = false
}
return hashed, cached, nil
- case fullNode:
- cached.hash = hash
+ case *fullNode:
+ cached = cached.copy()
+ cached.flags.hash = hash
if db != nil {
- cached.dirty = false
+ cached.flags.dirty = false
}
return hashed, cached, nil
}
@@ -89,40 +104,42 @@ func (h *hasher) hashChildren(original node, db DatabaseWriter) (node, node, err
var err error
switch n := original.(type) {
- case shortNode:
+ case *shortNode:
// Hash the short node's child, caching the newly hashed subtree
- cached := n
- cached.Key = common.CopyBytes(cached.Key)
+ collapsed, cached := n.copy(), n.copy()
+ collapsed.Key = compactEncode(n.Key)
+ cached.Key = common.CopyBytes(n.Key)
- n.Key = compactEncode(n.Key)
if _, ok := n.Val.(valueNode); !ok {
- if n.Val, cached.Val, err = h.hash(n.Val, db, false); err != nil {
- return n, original, err
+ collapsed.Val, cached.Val, err = h.hash(n.Val, db, false)
+ if err != nil {
+ return original, original, err
}
}
- if n.Val == nil {
- n.Val = valueNode(nil) // Ensure that nil children are encoded as empty strings.
+ if collapsed.Val == nil {
+ collapsed.Val = valueNode(nil) // Ensure that nil children are encoded as empty strings.
}
- return n, cached, nil
+ return collapsed, cached, nil
- case fullNode:
+ case *fullNode:
// Hash the full node's children, caching the newly hashed subtrees
- cached := fullNode{dirty: n.dirty}
+ collapsed, cached := n.copy(), n.copy()
for i := 0; i < 16; i++ {
if n.Children[i] != nil {
- if n.Children[i], cached.Children[i], err = h.hash(n.Children[i], db, false); err != nil {
- return n, original, err
+ collapsed.Children[i], cached.Children[i], err = h.hash(n.Children[i], db, false)
+ if err != nil {
+ return original, original, err
}
} else {
- n.Children[i] = valueNode(nil) // Ensure that nil children are encoded as empty strings.
+ collapsed.Children[i] = valueNode(nil) // Ensure that nil children are encoded as empty strings.
}
}
cached.Children[16] = n.Children[16]
- if n.Children[16] == nil {
- n.Children[16] = valueNode(nil)
+ if collapsed.Children[16] == nil {
+ collapsed.Children[16] = valueNode(nil)
}
- return n, cached, nil
+ return collapsed, cached, nil
default:
// Value and hash nodes don't have children so they're left as were
@@ -140,6 +157,7 @@ func (h *hasher) store(n node, db DatabaseWriter, force bool) (node, error) {
if err := rlp.Encode(h.tmp, n); err != nil {
panic("encode error: " + err.Error())
}
+
if h.tmp.Len() < 32 && !force {
return n, nil // Nodes smaller than 32 bytes are stored inside their parent
}
diff --git a/trie/iterator.go b/trie/iterator.go
index 8cad51aff..afde6e19e 100644
--- a/trie/iterator.go
+++ b/trie/iterator.go
@@ -56,11 +56,11 @@ func (it *Iterator) makeKey() []byte {
key := it.keyBuf[:0]
for _, se := range it.nodeIt.stack {
switch node := se.node.(type) {
- case fullNode:
+ case *fullNode:
if se.child <= 16 {
key = append(key, byte(se.child))
}
- case shortNode:
+ case *shortNode:
if hasTerm(node.Key) {
key = append(key, node.Key[:len(node.Key)-1]...)
} else {
@@ -148,7 +148,7 @@ func (it *NodeIterator) step() error {
if (ancestor == common.Hash{}) {
ancestor = parent.parent
}
- if node, ok := parent.node.(fullNode); ok {
+ if node, ok := parent.node.(*fullNode); ok {
// Full node, traverse all children, then the node itself
if parent.child >= len(node.Children) {
break
@@ -156,7 +156,7 @@ func (it *NodeIterator) step() error {
for parent.child++; parent.child < len(node.Children); parent.child++ {
if current := node.Children[parent.child]; current != nil {
it.stack = append(it.stack, &nodeIteratorState{
- hash: common.BytesToHash(node.hash),
+ hash: common.BytesToHash(node.flags.hash),
node: current,
parent: ancestor,
child: -1,
@@ -164,14 +164,14 @@ func (it *NodeIterator) step() error {
break
}
}
- } else if node, ok := parent.node.(shortNode); ok {
+ } else if node, ok := parent.node.(*shortNode); ok {
// Short node, traverse the pointer singleton child, then the node itself
if parent.child >= 0 {
break
}
parent.child++
it.stack = append(it.stack, &nodeIteratorState{
- hash: common.BytesToHash(node.hash),
+ hash: common.BytesToHash(node.flags.hash),
node: node.Val,
parent: ancestor,
child: -1,
diff --git a/trie/node.go b/trie/node.go
index b97d370be..de9752c93 100644
--- a/trie/node.go
+++ b/trie/node.go
@@ -30,42 +30,60 @@ var indices = []string{"0", "1", "2", "3", "4", "5", "6", "7", "8", "9", "a", "b
type node interface {
fstring(string) string
cache() (hashNode, bool)
+ canUnload(cachegen, cachelimit uint16) bool
}
type (
fullNode struct {
Children [17]node // Actual trie node data to encode/decode (needs custom encoder)
- hash hashNode // Cached hash of the node to prevent rehashing (may be nil)
- dirty bool // Cached flag whether the node's new or already stored
+ flags nodeFlag
}
shortNode struct {
Key []byte
Val node
- hash hashNode // Cached hash of the node to prevent rehashing (may be nil)
- dirty bool // Cached flag whether the node's new or already stored
+ flags nodeFlag
}
hashNode []byte
valueNode []byte
)
// EncodeRLP encodes a full node into the consensus RLP format.
-func (n fullNode) EncodeRLP(w io.Writer) error {
+func (n *fullNode) EncodeRLP(w io.Writer) error {
return rlp.Encode(w, n.Children)
}
-// Cache accessors to retrieve precalculated values (avoid lengthy type switches).
-func (n fullNode) cache() (hashNode, bool) { return n.hash, n.dirty }
-func (n shortNode) cache() (hashNode, bool) { return n.hash, n.dirty }
-func (n hashNode) cache() (hashNode, bool) { return nil, true }
-func (n valueNode) cache() (hashNode, bool) { return nil, true }
+func (n *fullNode) copy() *fullNode { copy := *n; return &copy }
+func (n *shortNode) copy() *shortNode { copy := *n; return &copy }
+
+// nodeFlag contains caching-related metadata about a node.
+type nodeFlag struct {
+ hash hashNode // cached hash of the node (may be nil)
+ gen uint16 // cache generation counter
+ dirty bool // whether the node has changes that must be written to the database
+}
+
+// canUnload tells whether a node can be unloaded.
+func (n *nodeFlag) canUnload(cachegen, cachelimit uint16) bool {
+ return !n.dirty && cachegen-n.gen >= cachelimit
+}
+
+func (n *fullNode) canUnload(gen, limit uint16) bool { return n.flags.canUnload(gen, limit) }
+func (n *shortNode) canUnload(gen, limit uint16) bool { return n.flags.canUnload(gen, limit) }
+func (n hashNode) canUnload(uint16, uint16) bool { return false }
+func (n valueNode) canUnload(uint16, uint16) bool { return false }
+
+func (n *fullNode) cache() (hashNode, bool) { return n.flags.hash, n.flags.dirty }
+func (n *shortNode) cache() (hashNode, bool) { return n.flags.hash, n.flags.dirty }
+func (n hashNode) cache() (hashNode, bool) { return nil, true }
+func (n valueNode) cache() (hashNode, bool) { return nil, true }
// Pretty printing.
-func (n fullNode) String() string { return n.fstring("") }
-func (n shortNode) String() string { return n.fstring("") }
-func (n hashNode) String() string { return n.fstring("") }
-func (n valueNode) String() string { return n.fstring("") }
+func (n *fullNode) String() string { return n.fstring("") }
+func (n *shortNode) String() string { return n.fstring("") }
+func (n hashNode) String() string { return n.fstring("") }
+func (n valueNode) String() string { return n.fstring("") }
-func (n fullNode) fstring(ind string) string {
+func (n *fullNode) fstring(ind string) string {
resp := fmt.Sprintf("[\n%s ", ind)
for i, node := range n.Children {
if node == nil {
@@ -76,7 +94,7 @@ func (n fullNode) fstring(ind string) string {
}
return resp + fmt.Sprintf("\n%s] ", ind)
}
-func (n shortNode) fstring(ind string) string {
+func (n *shortNode) fstring(ind string) string {
return fmt.Sprintf("{%x: %v} ", n.Key, n.Val.fstring(ind+" "))
}
func (n hashNode) fstring(ind string) string {
@@ -120,6 +138,7 @@ func decodeShort(hash, buf, elems []byte) (node, error) {
if err != nil {
return nil, err
}
+ flag := nodeFlag{hash: hash}
key := compactDecode(kbuf)
if key[len(key)-1] == 16 {
// value node
@@ -127,17 +146,17 @@ func decodeShort(hash, buf, elems []byte) (node, error) {
if err != nil {
return nil, fmt.Errorf("invalid value node: %v", err)
}
- return shortNode{key, valueNode(val), hash, false}, nil
+ return &shortNode{key, append(valueNode{}, val...), flag}, nil
}
r, _, err := decodeRef(rest)
if err != nil {
return nil, wrapError(err, "val")
}
- return shortNode{key, r, hash, false}, nil
+ return &shortNode{key, r, flag}, nil
}
-func decodeFull(hash, buf, elems []byte) (fullNode, error) {
- n := fullNode{hash: hash}
+func decodeFull(hash, buf, elems []byte) (*fullNode, error) {
+ n := &fullNode{flags: nodeFlag{hash: hash}}
for i := 0; i < 16; i++ {
cld, rest, err := decodeRef(elems)
if err != nil {
@@ -150,7 +169,7 @@ func decodeFull(hash, buf, elems []byte) (fullNode, error) {
return n, err
}
if len(val) > 0 {
- n.Children[16] = valueNode(val)
+ n.Children[16] = append(valueNode{}, val...)
}
return n, nil
}
@@ -176,7 +195,7 @@ func decodeRef(buf []byte) (node, []byte, error) {
// empty node
return nil, rest, nil
case kind == rlp.String && len(val) == 32:
- return hashNode(val), rest, nil
+ return append(hashNode{}, val...), rest, nil
default:
return nil, nil, fmt.Errorf("invalid RLP string size %d (want 0 or 32)", len(val))
}
diff --git a/trie/node_test.go b/trie/node_test.go
new file mode 100644
index 000000000..7ad1ff9e7
--- /dev/null
+++ b/trie/node_test.go
@@ -0,0 +1,58 @@
+// Copyright 2016 The go-ethereum Authors
+// This file is part of the go-ethereum library.
+//
+// The go-ethereum library is free software: you can redistribute it and/or modify
+// it under the terms of the GNU Lesser General Public License as published by
+// the Free Software Foundation, either version 3 of the License, or
+// (at your option) any later version.
+//
+// The go-ethereum library is distributed in the hope that it will be useful,
+// but WITHOUT ANY WARRANTY; without even the implied warranty of
+// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+// GNU Lesser General Public License for more details.
+//
+// You should have received a copy of the GNU Lesser General Public License
+// along with the go-ethereum library. If not, see <http://www.gnu.org/licenses/>.
+
+package trie
+
+import "testing"
+
+func TestCanUnload(t *testing.T) {
+ tests := []struct {
+ flag nodeFlag
+ cachegen, cachelimit uint16
+ want bool
+ }{
+ {
+ flag: nodeFlag{dirty: true, gen: 0},
+ want: false,
+ },
+ {
+ flag: nodeFlag{dirty: false, gen: 0},
+ cachegen: 0, cachelimit: 0,
+ want: true,
+ },
+ {
+ flag: nodeFlag{dirty: false, gen: 65534},
+ cachegen: 65535, cachelimit: 1,
+ want: true,
+ },
+ {
+ flag: nodeFlag{dirty: false, gen: 65534},
+ cachegen: 0, cachelimit: 1,
+ want: true,
+ },
+ {
+ flag: nodeFlag{dirty: false, gen: 1},
+ cachegen: 65535, cachelimit: 1,
+ want: true,
+ },
+ }
+
+ for _, test := range tests {
+ if got := test.flag.canUnload(test.cachegen, test.cachelimit); got != test.want {
+ t.Errorf("%+v\n got %t, want %t", test, got, test.want)
+ }
+ }
+}
diff --git a/trie/proof.go b/trie/proof.go
index 116c13a1b..f193b52df 100644
--- a/trie/proof.go
+++ b/trie/proof.go
@@ -44,7 +44,7 @@ func (t *Trie) Prove(key []byte) []rlp.RawValue {
tn := t.root
for len(key) > 0 && tn != nil {
switch n := tn.(type) {
- case shortNode:
+ case *shortNode:
if len(key) < len(n.Key) || !bytes.Equal(n.Key, key[:len(n.Key)]) {
// The trie doesn't contain the key.
tn = nil
@@ -53,7 +53,7 @@ func (t *Trie) Prove(key []byte) []rlp.RawValue {
key = key[len(n.Key):]
}
nodes = append(nodes, n)
- case fullNode:
+ case *fullNode:
tn = n.Children[key[0]]
key = key[1:]
nodes = append(nodes, n)
@@ -70,7 +70,7 @@ func (t *Trie) Prove(key []byte) []rlp.RawValue {
panic(fmt.Sprintf("%T: invalid node: %v", tn, tn))
}
}
- hasher := newHasher()
+ hasher := newHasher(0, 0)
proof := make([]rlp.RawValue, 0, len(nodes))
for i, n := range nodes {
// Don't bother checking for errors here since hasher panics
@@ -130,13 +130,13 @@ func VerifyProof(rootHash common.Hash, key []byte, proof []rlp.RawValue) (value
func get(tn node, key []byte) ([]byte, node) {
for len(key) > 0 {
switch n := tn.(type) {
- case shortNode:
+ case *shortNode:
if len(key) < len(n.Key) || !bytes.Equal(n.Key, key[:len(n.Key)]) {
return nil, nil
}
tn = n.Val
key = key[len(n.Key):]
- case fullNode:
+ case *fullNode:
tn = n.Children[key[0]]
key = key[1:]
case hashNode:
diff --git a/trie/secure_trie.go b/trie/secure_trie.go
index 2a8b57214..4d9ebe4d3 100644
--- a/trie/secure_trie.go
+++ b/trie/secure_trie.go
@@ -49,8 +49,12 @@ type SecureTrie struct {
// If root is the zero hash or the sha3 hash of an empty string, the
// trie is initially empty. Otherwise, New will panic if db is nil
// and returns MissingNodeError if the root node cannot be found.
+//
// Accessing the trie loads nodes from db on demand.
-func NewSecure(root common.Hash, db Database) (*SecureTrie, error) {
+// Loaded nodes are kept around until their 'cache generation' expires.
+// A new cache generation is created by each call to Commit.
+// cachelimit sets the number of past cache generations to keep.
+func NewSecure(root common.Hash, db Database, cachelimit uint16) (*SecureTrie, error) {
if db == nil {
panic("NewSecure called with nil database")
}
@@ -58,9 +62,8 @@ func NewSecure(root common.Hash, db Database) (*SecureTrie, error) {
if err != nil {
return nil, err
}
- return &SecureTrie{
- trie: *trie,
- }, nil
+ trie.SetCacheLimit(cachelimit)
+ return &SecureTrie{trie: *trie}, nil
}
// Get returns the value for key stored in the trie.
@@ -191,7 +194,7 @@ func (t *SecureTrie) secKey(key []byte) []byte {
// The caller must not hold onto the return value because it will become
// invalid on the next call to hashKey or secKey.
func (t *SecureTrie) hashKey(key []byte) []byte {
- h := newHasher()
+ h := newHasher(0, 0)
h.sha.Reset()
h.sha.Write(key)
buf := h.sha.Sum(t.hashKeyBuf[:0])
diff --git a/trie/secure_trie_test.go b/trie/secure_trie_test.go
index 3171b8c31..159640fda 100644
--- a/trie/secure_trie_test.go
+++ b/trie/secure_trie_test.go
@@ -29,7 +29,7 @@ import (
func newEmptySecure() *SecureTrie {
db, _ := ethdb.NewMemDatabase()
- trie, _ := NewSecure(common.Hash{}, db)
+ trie, _ := NewSecure(common.Hash{}, db, 0)
return trie
}
@@ -37,7 +37,7 @@ func newEmptySecure() *SecureTrie {
func makeTestSecureTrie() (ethdb.Database, *SecureTrie, map[string][]byte) {
// Create an empty trie
db, _ := ethdb.NewMemDatabase()
- trie, _ := NewSecure(common.Hash{}, db)
+ trie, _ := NewSecure(common.Hash{}, db, 0)
// Fill it with some arbitrary data
content := make(map[string][]byte)
diff --git a/trie/sync.go b/trie/sync.go
index 6e9e029b9..3de758536 100644
--- a/trie/sync.go
+++ b/trie/sync.go
@@ -212,12 +212,12 @@ func (s *TrieSync) children(req *request) ([]*request, error) {
children := []child{}
switch node := (*req.object).(type) {
- case shortNode:
+ case *shortNode:
children = []child{{
node: &node.Val,
depth: req.depth + len(node.Key),
}}
- case fullNode:
+ case *fullNode:
for i := 0; i < 17; i++ {
if node.Children[i] != nil {
children = append(children, child{
diff --git a/trie/trie.go b/trie/trie.go
index 55598af98..65005bae8 100644
--- a/trie/trie.go
+++ b/trie/trie.go
@@ -62,6 +62,23 @@ type Trie struct {
root node
db Database
originalRoot common.Hash
+
+ // Cache generation values.
+ // cachegen increase by one with each commit operation.
+ // new nodes are tagged with the current generation and unloaded
+ // when their generation is older than than cachegen-cachelimit.
+ cachegen, cachelimit uint16
+}
+
+// SetCacheLimit sets the number of 'cache generations' to keep.
+// A cache generations is created by a call to Commit.
+func (t *Trie) SetCacheLimit(l uint16) {
+ t.cachelimit = l
+}
+
+// newFlag returns the cache flag value for a newly created node.
+func (t *Trie) newFlag() nodeFlag {
+ return nodeFlag{dirty: true, gen: t.cachegen}
}
// New creates a trie with an existing root node from db.
@@ -120,27 +137,25 @@ func (t *Trie) tryGet(origNode node, key []byte, pos int) (value []byte, newnode
return nil, nil, false, nil
case valueNode:
return n, n, false, nil
- case shortNode:
+ case *shortNode:
if len(key)-pos < len(n.Key) || !bytes.Equal(n.Key, key[pos:pos+len(n.Key)]) {
// key not found in trie
return nil, n, false, nil
}
value, newnode, didResolve, err = t.tryGet(n.Val, key, pos+len(n.Key))
if err == nil && didResolve {
+ n = n.copy()
n.Val = newnode
- return value, n, didResolve, err
- } else {
- return value, origNode, didResolve, err
}
- case fullNode:
- child := n.Children[key[pos]]
- value, newnode, didResolve, err = t.tryGet(child, key, pos+1)
+ return value, n, didResolve, err
+ case *fullNode:
+ value, newnode, didResolve, err = t.tryGet(n.Children[key[pos]], key, pos+1)
if err == nil && didResolve {
+ n = n.copy()
n.Children[key[pos]] = newnode
- return value, n, didResolve, err
- } else {
- return value, origNode, didResolve, err
+
}
+ return value, n, didResolve, err
case hashNode:
child, err := t.resolveHash(n, key[:pos], key[pos:])
if err != nil {
@@ -199,22 +214,19 @@ func (t *Trie) insert(n node, prefix, key []byte, value node) (bool, node, error
return true, value, nil
}
switch n := n.(type) {
- case shortNode:
+ case *shortNode:
matchlen := prefixLen(key, n.Key)
// If the whole key matches, keep this short node as is
// and only update the value.
if matchlen == len(n.Key) {
dirty, nn, err := t.insert(n.Val, append(prefix, key[:matchlen]...), key[matchlen:], value)
- if err != nil {
- return false, nil, err
+ if !dirty || err != nil {
+ return false, n, err
}
- if !dirty {
- return false, n, nil
- }
- return true, shortNode{n.Key, nn, nil, true}, nil
+ return true, &shortNode{n.Key, nn, t.newFlag()}, nil
}
// Otherwise branch out at the index where they differ.
- branch := fullNode{dirty: true}
+ branch := &fullNode{flags: t.newFlag()}
var err error
_, branch.Children[n.Key[matchlen]], err = t.insert(nil, append(prefix, n.Key[:matchlen+1]...), n.Key[matchlen+1:], n.Val)
if err != nil {
@@ -229,21 +241,19 @@ func (t *Trie) insert(n node, prefix, key []byte, value node) (bool, node, error
return true, branch, nil
}
// Otherwise, replace it with a short node leading up to the branch.
- return true, shortNode{key[:matchlen], branch, nil, true}, nil
+ return true, &shortNode{key[:matchlen], branch, t.newFlag()}, nil
- case fullNode:
+ case *fullNode:
dirty, nn, err := t.insert(n.Children[key[0]], append(prefix, key[0]), key[1:], value)
- if err != nil {
- return false, nil, err
+ if !dirty || err != nil {
+ return false, n, err
}
- if !dirty {
- return false, n, nil
- }
- n.Children[key[0]], n.hash, n.dirty = nn, nil, true
+ n = n.copy()
+ n.Children[key[0]], n.flags.hash, n.flags.dirty = nn, nil, true
return true, n, nil
case nil:
- return true, shortNode{key, value, nil, true}, nil
+ return true, &shortNode{key, value, t.newFlag()}, nil
case hashNode:
// We've hit a part of the trie that isn't loaded yet. Load
@@ -254,11 +264,8 @@ func (t *Trie) insert(n node, prefix, key []byte, value node) (bool, node, error
return false, nil, err
}
dirty, nn, err := t.insert(rn, prefix, key, value)
- if err != nil {
- return false, nil, err
- }
- if !dirty {
- return false, rn, nil
+ if !dirty || err != nil {
+ return false, rn, err
}
return true, nn, nil
@@ -291,7 +298,7 @@ func (t *Trie) TryDelete(key []byte) error {
// nodes on the way up after deleting recursively.
func (t *Trie) delete(n node, prefix, key []byte) (bool, node, error) {
switch n := n.(type) {
- case shortNode:
+ case *shortNode:
matchlen := prefixLen(key, n.Key)
if matchlen < len(n.Key) {
return false, n, nil // don't replace n on mismatch
@@ -304,34 +311,29 @@ func (t *Trie) delete(n node, prefix, key []byte) (bool, node, error) {
// subtrie must contain at least two other values with keys
// longer than n.Key.
dirty, child, err := t.delete(n.Val, append(prefix, key[:len(n.Key)]...), key[len(n.Key):])
- if err != nil {
- return false, nil, err
- }
- if !dirty {
- return false, n, nil
+ if !dirty || err != nil {
+ return false, n, err
}
switch child := child.(type) {
- case shortNode:
+ case *shortNode:
// Deleting from the subtrie reduced it to another
// short node. Merge the nodes to avoid creating a
// shortNode{..., shortNode{...}}. Use concat (which
// always creates a new slice) instead of append to
// avoid modifying n.Key since it might be shared with
// other nodes.
- return true, shortNode{concat(n.Key, child.Key...), child.Val, nil, true}, nil
+ return true, &shortNode{concat(n.Key, child.Key...), child.Val, t.newFlag()}, nil
default:
- return true, shortNode{n.Key, child, nil, true}, nil
+ return true, &shortNode{n.Key, child, t.newFlag()}, nil
}
- case fullNode:
+ case *fullNode:
dirty, nn, err := t.delete(n.Children[key[0]], append(prefix, key[0]), key[1:])
- if err != nil {
- return false, nil, err
- }
- if !dirty {
- return false, n, nil
+ if !dirty || err != nil {
+ return false, n, err
}
- n.Children[key[0]], n.hash, n.dirty = nn, nil, true
+ n = n.copy()
+ n.Children[key[0]], n.flags.hash, n.flags.dirty = nn, nil, true
// Check how many non-nil entries are left after deleting and
// reduce the full node to a short node if only one entry is
@@ -365,14 +367,14 @@ func (t *Trie) delete(n node, prefix, key []byte) (bool, node, error) {
if err != nil {
return false, nil, err
}
- if cnode, ok := cnode.(shortNode); ok {
+ if cnode, ok := cnode.(*shortNode); ok {
k := append([]byte{byte(pos)}, cnode.Key...)
- return true, shortNode{k, cnode.Val, nil, true}, nil
+ return true, &shortNode{k, cnode.Val, t.newFlag()}, nil
}
}
// Otherwise, n is replaced by a one-nibble short node
// containing the child.
- return true, shortNode{[]byte{byte(pos)}, n.Children[pos], nil, true}, nil
+ return true, &shortNode{[]byte{byte(pos)}, n.Children[pos], t.newFlag()}, nil
}
// n still contains at least two values and cannot be reduced.
return true, n, nil
@@ -392,11 +394,8 @@ func (t *Trie) delete(n node, prefix, key []byte) (bool, node, error) {
return false, nil, err
}
dirty, nn, err := t.delete(rn, prefix, key)
- if err != nil {
- return false, nil, err
- }
- if !dirty {
- return false, rn, nil
+ if !dirty || err != nil {
+ return false, rn, err
}
return true, nn, nil
@@ -471,6 +470,7 @@ func (t *Trie) CommitTo(db DatabaseWriter) (root common.Hash, err error) {
return (common.Hash{}), err
}
t.root = cached
+ t.cachegen++
return common.BytesToHash(hash.(hashNode)), nil
}
@@ -478,7 +478,7 @@ func (t *Trie) hashRoot(db DatabaseWriter) (node, node, error) {
if t.root == nil {
return hashNode(emptyRoot.Bytes()), nil, nil
}
- h := newHasher()
+ h := newHasher(t.cachegen, t.cachelimit)
defer returnHasherToPool(h)
return h.hash(t.root, db, true)
}
diff --git a/trie/trie_test.go b/trie/trie_test.go
index 87a7ec258..32fbe6801 100644
--- a/trie/trie_test.go
+++ b/trie/trie_test.go
@@ -460,8 +460,7 @@ const benchElemCount = 20000
func benchGet(b *testing.B, commit bool) {
trie := new(Trie)
if commit {
- dir, tmpdb := tempDB()
- defer os.RemoveAll(dir)
+ _, tmpdb := tempDB()
trie, _ = New(common.Hash{}, tmpdb)
}
k := make([]byte, 32)
@@ -478,6 +477,13 @@ func benchGet(b *testing.B, commit bool) {
for i := 0; i < b.N; i++ {
trie.Get(k)
}
+ b.StopTimer()
+
+ if commit {
+ ldb := trie.db.(*ethdb.LDBDatabase)
+ ldb.Close()
+ os.RemoveAll(ldb.Path())
+ }
}
func benchUpdate(b *testing.B, e binary.ByteOrder) *Trie {