diff options
Diffstat (limited to 'trie/trie.go')
-rw-r--r-- | trie/trie.go | 204 |
1 files changed, 146 insertions, 58 deletions
diff --git a/trie/trie.go b/trie/trie.go index a3a383fb5..717296e27 100644 --- a/trie/trie.go +++ b/trie/trie.go @@ -19,7 +19,6 @@ package trie import ( "bytes" - "errors" "fmt" "hash" @@ -44,7 +43,10 @@ var ( emptyState = crypto.Sha3Hash(nil) ) -var ErrMissingRoot = errors.New("missing root node") +// ClearGlobalCache clears the global trie cache +func ClearGlobalCache() { + globalCache.Clear() +} // Database must be implemented by backing stores for the trie. type Database interface { @@ -67,8 +69,9 @@ type DatabaseWriter interface { // // Trie is not safe for concurrent use. type Trie struct { - root node - db Database + root node + db Database + originalRoot common.Hash *hasher } @@ -76,16 +79,19 @@ type Trie struct { // // If root is the zero hash or the sha3 hash of an empty string, the // trie is initially empty and does not require a database. Otherwise, -// New will panics if db is nil or root does not exist in the -// database. Accessing the trie loads nodes from db on demand. +// New will panic if db is nil and returns a MissingNodeError if root does +// not exist in the database. Accessing the trie loads nodes from db on demand. func New(root common.Hash, db Database) (*Trie, error) { - trie := &Trie{db: db} + trie := &Trie{db: db, originalRoot: root} if (root != common.Hash{}) && root != emptyRoot { if db == nil { panic("trie.New: cannot use existing root without a database") } if v, _ := trie.db.Get(root[:]); len(v) == 0 { - return nil, ErrMissingRoot + return nil, &MissingNodeError{ + RootHash: root, + NodeHash: root, + } } trie.root = hashNode(root.Bytes()) } @@ -100,28 +106,44 @@ func (t *Trie) Iterator() *Iterator { // Get returns the value for key stored in the trie. // The value bytes must not be modified by the caller. func (t *Trie) Get(key []byte) []byte { + res, err := t.TryGet(key) + if err != nil && glog.V(logger.Error) { + glog.Errorf("Unhandled trie error: %v", err) + } + return res +} + +// TryGet returns the value for key stored in the trie. +// The value bytes must not be modified by the caller. +// If a node was not found in the database, a MissingNodeError is returned. +func (t *Trie) TryGet(key []byte) ([]byte, error) { key = compactHexDecode(key) + pos := 0 tn := t.root - for len(key) > 0 { + for pos < len(key) { switch n := tn.(type) { case shortNode: - if len(key) < len(n.Key) || !bytes.Equal(n.Key, key[:len(n.Key)]) { - return nil + if len(key)-pos < len(n.Key) || !bytes.Equal(n.Key, key[pos:pos+len(n.Key)]) { + return nil, nil } tn = n.Val - key = key[len(n.Key):] + pos += len(n.Key) case fullNode: - tn = n[key[0]] - key = key[1:] + tn = n[key[pos]] + pos++ case nil: - return nil + return nil, nil case hashNode: - tn = t.resolveHash(n) + var err error + tn, err = t.resolveHash(n, key[:pos], key[pos:]) + if err != nil { + return nil, err + } default: panic(fmt.Sprintf("%T: invalid node: %v", tn, tn)) } } - return tn.(valueNode) + return tn.(valueNode), nil } // Update associates key with value in the trie. Subsequent calls to @@ -131,17 +153,40 @@ func (t *Trie) Get(key []byte) []byte { // The value bytes must not be modified by the caller while they are // stored in the trie. func (t *Trie) Update(key, value []byte) { + if err := t.TryUpdate(key, value); err != nil && glog.V(logger.Error) { + glog.Errorf("Unhandled trie error: %v", err) + } +} + +// TryUpdate associates key with value in the trie. Subsequent calls to +// Get will return value. If value has length zero, any existing value +// is deleted from the trie and calls to Get will return nil. +// +// The value bytes must not be modified by the caller while they are +// stored in the trie. +// +// If a node was not found in the database, a MissingNodeError is returned. +func (t *Trie) TryUpdate(key, value []byte) error { k := compactHexDecode(key) if len(value) != 0 { - t.root = t.insert(t.root, k, valueNode(value)) + n, err := t.insert(t.root, nil, k, valueNode(value)) + if err != nil { + return err + } + t.root = n } else { - t.root = t.delete(t.root, k) + n, err := t.delete(t.root, nil, k) + if err != nil { + return err + } + t.root = n } + return nil } -func (t *Trie) insert(n node, key []byte, value node) node { +func (t *Trie) insert(n node, prefix, key []byte, value node) (node, error) { if len(key) == 0 { - return value + return value, nil } switch n := n.(type) { case shortNode: @@ -149,25 +194,40 @@ func (t *Trie) insert(n node, key []byte, value node) node { // If the whole key matches, keep this short node as is // and only update the value. if matchlen == len(n.Key) { - return shortNode{n.Key, t.insert(n.Val, key[matchlen:], value)} + nn, err := t.insert(n.Val, append(prefix, key[:matchlen]...), key[matchlen:], value) + if err != nil { + return nil, err + } + return shortNode{n.Key, nn}, nil } // Otherwise branch out at the index where they differ. var branch fullNode - branch[n.Key[matchlen]] = t.insert(nil, n.Key[matchlen+1:], n.Val) - branch[key[matchlen]] = t.insert(nil, key[matchlen+1:], value) + var err error + branch[n.Key[matchlen]], err = t.insert(nil, append(prefix, n.Key[:matchlen+1]...), n.Key[matchlen+1:], n.Val) + if err != nil { + return nil, err + } + branch[key[matchlen]], err = t.insert(nil, append(prefix, key[:matchlen+1]...), key[matchlen+1:], value) + if err != nil { + return nil, err + } // Replace this shortNode with the branch if it occurs at index 0. if matchlen == 0 { - return branch + return branch, nil } // Otherwise, replace it with a short node leading up to the branch. - return shortNode{key[:matchlen], branch} + return shortNode{key[:matchlen], branch}, nil case fullNode: - n[key[0]] = t.insert(n[key[0]], key[1:], value) - return n + nn, err := t.insert(n[key[0]], append(prefix, key[0]), key[1:], value) + if err != nil { + return nil, err + } + n[key[0]] = nn + return n, nil case nil: - return shortNode{key, value} + return shortNode{key, value}, nil case hashNode: // We've hit a part of the trie that isn't loaded yet. Load @@ -176,7 +236,11 @@ func (t *Trie) insert(n node, key []byte, value node) node { // // TODO: track whether insertion changed the value and keep // n as a hash node if it didn't. - return t.insert(t.resolveHash(n), key, value) + rn, err := t.resolveHash(n, prefix, key) + if err != nil { + return nil, err + } + return t.insert(rn, prefix, key, value) default: panic(fmt.Sprintf("%T: invalid node: %v", n, n)) @@ -185,28 +249,44 @@ func (t *Trie) insert(n node, key []byte, value node) node { // Delete removes any existing value for key from the trie. func (t *Trie) Delete(key []byte) { + if err := t.TryDelete(key); err != nil && glog.V(logger.Error) { + glog.Errorf("Unhandled trie error: %v", err) + } +} + +// TryDelete removes any existing value for key from the trie. +// If a node was not found in the database, a MissingNodeError is returned. +func (t *Trie) TryDelete(key []byte) error { k := compactHexDecode(key) - t.root = t.delete(t.root, k) + n, err := t.delete(t.root, nil, k) + if err != nil { + return err + } + t.root = n + return nil } // delete returns the new root of the trie with key deleted. // It reduces the trie to minimal form by simplifying // nodes on the way up after deleting recursively. -func (t *Trie) delete(n node, key []byte) node { +func (t *Trie) delete(n node, prefix, key []byte) (node, error) { switch n := n.(type) { case shortNode: matchlen := prefixLen(key, n.Key) if matchlen < len(n.Key) { - return n // don't replace n on mismatch + return n, nil // don't replace n on mismatch } if matchlen == len(key) { - return nil // remove n entirely for whole matches + return nil, nil // remove n entirely for whole matches } // The key is longer than n.Key. Remove the remaining suffix // from the subtrie. Child can never be nil here since the // subtrie must contain at least two other values with keys // longer than n.Key. - child := t.delete(n.Val, key[len(n.Key):]) + child, err := t.delete(n.Val, append(prefix, key[:len(n.Key)]...), key[len(n.Key):]) + if err != nil { + return nil, err + } switch child := child.(type) { case shortNode: // Deleting from the subtrie reduced it to another @@ -215,13 +295,17 @@ func (t *Trie) delete(n node, key []byte) node { // always creates a new slice) instead of append to // avoid modifying n.Key since it might be shared with // other nodes. - return shortNode{concat(n.Key, child.Key...), child.Val} + return shortNode{concat(n.Key, child.Key...), child.Val}, nil default: - return shortNode{n.Key, child} + return shortNode{n.Key, child}, nil } case fullNode: - n[key[0]] = t.delete(n[key[0]], key[1:]) + nn, err := t.delete(n[key[0]], append(prefix, key[0]), key[1:]) + if err != nil { + return nil, err + } + n[key[0]] = nn // Check how many non-nil entries are left after deleting and // reduce the full node to a short node if only one entry is // left. Since n must've contained at least two children @@ -250,21 +334,24 @@ func (t *Trie) delete(n node, key []byte) node { // shortNode{..., shortNode{...}}. Since the entry // might not be loaded yet, resolve it just for this // check. - cnode := t.resolve(n[pos]) + cnode, err := t.resolve(n[pos], prefix, []byte{byte(pos)}) + if err != nil { + return nil, err + } if cnode, ok := cnode.(shortNode); ok { k := append([]byte{byte(pos)}, cnode.Key...) - return shortNode{k, cnode.Val} + return shortNode{k, cnode.Val}, nil } } // Otherwise, n is replaced by a one-nibble short node // containing the child. - return shortNode{[]byte{byte(pos)}, n[pos]} + return shortNode{[]byte{byte(pos)}, n[pos]}, nil } // n still contains at least two values and cannot be reduced. - return n + return n, nil case nil: - return nil + return nil, nil case hashNode: // We've hit a part of the trie that isn't loaded yet. Load @@ -273,7 +360,11 @@ func (t *Trie) delete(n node, key []byte) node { // // TODO: track whether deletion actually hit a key and keep // n as a hash node if it didn't. - return t.delete(t.resolveHash(n), key) + rn, err := t.resolveHash(n, prefix, key) + if err != nil { + return nil, err + } + return t.delete(rn, prefix, key) default: panic(fmt.Sprintf("%T: invalid node: %v (%v)", n, n, key)) @@ -287,34 +378,31 @@ func concat(s1 []byte, s2 ...byte) []byte { return r } -func (t *Trie) resolve(n node) node { +func (t *Trie) resolve(n node, prefix, suffix []byte) (node, error) { if n, ok := n.(hashNode); ok { - return t.resolveHash(n) + return t.resolveHash(n, prefix, suffix) } - return n + return n, nil } -func (t *Trie) resolveHash(n hashNode) node { +func (t *Trie) resolveHash(n hashNode, prefix, suffix []byte) (node, error) { if v, ok := globalCache.Get(n); ok { - return v + return v, nil } enc, err := t.db.Get(n) if err != nil || enc == nil { - // TODO: This needs to be improved to properly distinguish errors. - // Disk I/O errors shouldn't produce nil (and cause a - // consensus failure or weird crash), but it is unclear how - // they could be handled because the entire stack above the trie isn't - // prepared to cope with missing state nodes. - if glog.V(logger.Error) { - glog.Errorf("Dangling hash node ref %x: %v", n, err) + return nil, &MissingNodeError{ + RootHash: t.originalRoot, + NodeHash: common.BytesToHash(n), + KeyPrefix: prefix, + KeySuffix: suffix, } - return nil } dec := mustDecodeNode(n, enc) if dec != nil { globalCache.Put(n, dec) } - return dec + return dec, nil } // Root returns the root hash of the trie. |