From db4aaedcbdb409e17ea3de161e7b24a80ba0a58c Mon Sep 17 00:00:00 2001 From: obscuren Date: Thu, 8 Jan 2015 11:47:04 +0100 Subject: Moved ptrie => trie. Removed old trie --- trie/cache.go | 42 +++ trie/fullnode.go | 77 ++++++ trie/hashnode.go | 22 ++ trie/iterator.go | 152 ++++------- trie/iterator_test.go | 33 +++ trie/main_test.go | 9 - trie/node.go | 40 +++ trie/shortnode.go | 29 ++ trie/trie.go | 725 +++++++++++++++----------------------------------- trie/trie_test.go | 459 +++++++++++--------------------- trie/valuenode.go | 13 + 11 files changed, 692 insertions(+), 909 deletions(-) create mode 100644 trie/cache.go create mode 100644 trie/fullnode.go create mode 100644 trie/hashnode.go create mode 100644 trie/iterator_test.go delete mode 100644 trie/main_test.go create mode 100644 trie/node.go create mode 100644 trie/shortnode.go create mode 100644 trie/valuenode.go (limited to 'trie') diff --git a/trie/cache.go b/trie/cache.go new file mode 100644 index 000000000..e03702b25 --- /dev/null +++ b/trie/cache.go @@ -0,0 +1,42 @@ +package trie + +type Backend interface { + Get([]byte) ([]byte, error) + Put([]byte, []byte) +} + +type Cache struct { + store map[string][]byte + backend Backend +} + +func NewCache(backend Backend) *Cache { + return &Cache{make(map[string][]byte), backend} +} + +func (self *Cache) Get(key []byte) []byte { + data := self.store[string(key)] + if data == nil { + data, _ = self.backend.Get(key) + } + + return data +} + +func (self *Cache) Put(key []byte, data []byte) { + self.store[string(key)] = data +} + +func (self *Cache) Flush() { + for k, v := range self.store { + self.backend.Put([]byte(k), v) + } + + // This will eventually grow too large. We'd could + // do a make limit on storage and push out not-so-popular nodes. + //self.Reset() +} + +func (self *Cache) Reset() { + self.store = make(map[string][]byte) +} diff --git a/trie/fullnode.go b/trie/fullnode.go new file mode 100644 index 000000000..ebbe7f384 --- /dev/null +++ b/trie/fullnode.go @@ -0,0 +1,77 @@ +package trie + +import "fmt" + +type FullNode struct { + trie *Trie + nodes [17]Node +} + +func NewFullNode(t *Trie) *FullNode { + return &FullNode{trie: t} +} + +func (self *FullNode) Dirty() bool { return true } +func (self *FullNode) Value() Node { + self.nodes[16] = self.trie.trans(self.nodes[16]) + return self.nodes[16] +} +func (self *FullNode) Branches() []Node { + return self.nodes[:16] +} + +func (self *FullNode) Copy() Node { + nnode := NewFullNode(self.trie) + for i, node := range self.nodes { + if node != nil { + nnode.nodes[i] = node + } + } + + return nnode +} + +// Returns the length of non-nil nodes +func (self *FullNode) Len() (amount int) { + for _, node := range self.nodes { + if node != nil { + amount++ + } + } + + return +} + +func (self *FullNode) Hash() interface{} { + return self.trie.store(self) +} + +func (self *FullNode) RlpData() interface{} { + t := make([]interface{}, 17) + for i, node := range self.nodes { + if node != nil { + t[i] = node.Hash() + } else { + t[i] = "" + } + } + + return t +} + +func (self *FullNode) set(k byte, value Node) { + if _, ok := value.(*ValueNode); ok && k != 16 { + fmt.Println(value, k) + } + + self.nodes[int(k)] = value +} + +func (self *FullNode) branch(i byte) Node { + if self.nodes[int(i)] != nil { + self.nodes[int(i)] = self.trie.trans(self.nodes[int(i)]) + + return self.nodes[int(i)] + } + return nil +} diff --git a/trie/hashnode.go b/trie/hashnode.go new file mode 100644 index 000000000..40ccd54c3 --- /dev/null +++ b/trie/hashnode.go @@ -0,0 +1,22 @@ +package trie + +type HashNode struct { + key []byte +} + +func NewHash(key []byte) *HashNode { + return &HashNode{key} +} + +func (self *HashNode) RlpData() interface{} { + return self.key +} + +func (self *HashNode) Hash() interface{} { + return self.key +} + +// These methods will never be called but we have to satisfy Node interface +func (self *HashNode) Value() Node { return nil } +func (self *HashNode) Dirty() bool { return true } +func (self *HashNode) Copy() Node { return self } diff --git a/trie/iterator.go b/trie/iterator.go index 1114715a6..f0dae28bb 100644 --- a/trie/iterator.go +++ b/trie/iterator.go @@ -1,124 +1,73 @@ package trie -/* -import ( - "bytes" - - "github.com/ethereum/go-ethereum/ethutil" -) - -type NodeType byte - -const ( - EmptyNode NodeType = iota - BranchNode - LeafNode - ExtNode -) - -func getType(node *ethutil.Value) NodeType { - if node.Len() == 0 { - return EmptyNode - } - - if node.Len() == 2 { - k := CompactDecode(node.Get(0).Str()) - if HasTerm(k) { - return LeafNode - } - - return ExtNode - } - - return BranchNode -} +import "bytes" type Iterator struct { - Path [][]byte trie *Trie Key []byte - Value *ethutil.Value + Value []byte } func NewIterator(trie *Trie) *Iterator { - return &Iterator{trie: trie} + return &Iterator{trie: trie, Key: make([]byte, 32)} } -func (self *Iterator) key(node *ethutil.Value, path [][]byte) []byte { - switch getType(node) { - case LeafNode: - k := RemTerm(CompactDecode(node.Get(0).Str())) +func (self *Iterator) Next() bool { + self.trie.mu.Lock() + defer self.trie.mu.Unlock() - self.Path = append(path, k) - self.Value = node.Get(1) + key := RemTerm(CompactHexDecode(string(self.Key))) + k := self.next(self.trie.root, key) - return k - case BranchNode: - if node.Get(16).Len() > 0 { - return []byte{16} - } + self.Key = []byte(DecodeCompact(k)) - for i := byte(0); i < 16; i++ { - o := self.key(self.trie.getNode(node.Get(int(i)).Raw()), append(path, []byte{i})) - if o != nil { - return append([]byte{i}, o...) - } - } - case ExtNode: - currKey := node.Get(0).Bytes() + return len(k) > 0 - return self.key(self.trie.getNode(node.Get(1).Raw()), append(path, currKey)) - } - - return nil } -func (self *Iterator) next(node *ethutil.Value, key []byte, path [][]byte) []byte { - switch typ := getType(node); typ { - case EmptyNode: +func (self *Iterator) next(node Node, key []byte) []byte { + if node == nil { return nil - case BranchNode: - if len(key) > 0 { - subNode := self.trie.getNode(node.Get(int(key[0])).Raw()) + } - o := self.next(subNode, key[1:], append(path, key[:1])) - if o != nil { - return append([]byte{key[0]}, o...) + switch node := node.(type) { + case *FullNode: + if len(key) > 0 { + k := self.next(node.branch(key[0]), key[1:]) + if k != nil { + return append([]byte{key[0]}, k...) } } - var r byte = 0 + var r byte if len(key) > 0 { r = key[0] + 1 } for i := r; i < 16; i++ { - subNode := self.trie.getNode(node.Get(int(i)).Raw()) - o := self.key(subNode, append(path, []byte{i})) - if o != nil { - return append([]byte{i}, o...) + k := self.key(node.branch(byte(i))) + if k != nil { + return append([]byte{i}, k...) } } - case LeafNode, ExtNode: - k := RemTerm(CompactDecode(node.Get(0).Str())) - if typ == LeafNode { - if bytes.Compare([]byte(k), []byte(key)) > 0 { - self.Value = node.Get(1) - self.Path = append(path, k) + case *ShortNode: + k := RemTerm(node.Key()) + if vnode, ok := node.Value().(*ValueNode); ok { + if bytes.Compare([]byte(k), key) > 0 { + self.Value = vnode.Val() return k } } else { - subNode := self.trie.getNode(node.Get(1).Raw()) - subKey := key[len(k):] + cnode := node.Value() + var ret []byte + skey := key[len(k):] if BeginsWith(key, k) { - ret = self.next(subNode, subKey, append(path, k)) + ret = self.next(cnode, skey) } else if bytes.Compare(k, key[:len(k)]) > 0 { - ret = self.key(node, append(path, k)) - } else { - ret = nil + ret = self.key(node) } if ret != nil { @@ -130,16 +79,33 @@ func (self *Iterator) next(node *ethutil.Value, key []byte, path [][]byte) []byt return nil } -// Get the next in keys -func (self *Iterator) Next(key string) []byte { - self.trie.mut.Lock() - defer self.trie.mut.Unlock() +func (self *Iterator) key(node Node) []byte { + switch node := node.(type) { + case *ShortNode: + // Leaf node + if vnode, ok := node.Value().(*ValueNode); ok { + k := RemTerm(node.Key()) + self.Value = vnode.Val() - k := RemTerm(CompactHexDecode(key)) - n := self.next(self.trie.getNode(self.trie.Root), k, nil) + return k + } else { + k := RemTerm(node.Key()) + return append(k, self.key(node.Value())...) + } + case *FullNode: + if node.Value() != nil { + self.Value = node.Value().(*ValueNode).Val() - self.Key = []byte(DecodeCompact(n)) + return []byte{16} + } - return self.Key + for i := 0; i < 16; i++ { + k := self.key(node.branch(byte(i))) + if k != nil { + return append([]byte{byte(i)}, k...) + } + } + } + + return nil } -*/ diff --git a/trie/iterator_test.go b/trie/iterator_test.go new file mode 100644 index 000000000..74d9e903c --- /dev/null +++ b/trie/iterator_test.go @@ -0,0 +1,33 @@ +package trie + +import "testing" + +func TestIterator(t *testing.T) { + trie := NewEmpty() + vals := []struct{ k, v string }{ + {"do", "verb"}, + {"ether", "wookiedoo"}, + {"horse", "stallion"}, + {"shaman", "horse"}, + {"doge", "coin"}, + {"dog", "puppy"}, + {"somethingveryoddindeedthis is", "myothernodedata"}, + } + v := make(map[string]bool) + for _, val := range vals { + v[val.k] = false + trie.UpdateString(val.k, val.v) + } + trie.Commit() + + it := trie.Iterator() + for it.Next() { + v[string(it.Key)] = true + } + + for k, found := range v { + if !found { + t.Error("iterator didn't find", k) + } + } +} diff --git a/trie/main_test.go b/trie/main_test.go deleted file mode 100644 index f6f64c06f..000000000 --- a/trie/main_test.go +++ /dev/null @@ -1,9 +0,0 @@ -package trie - -import ( - "testing" - - checker "gopkg.in/check.v1" -) - -func Test(t *testing.T) { checker.TestingT(t) } diff --git a/trie/node.go b/trie/node.go new file mode 100644 index 000000000..a1f68480f --- /dev/null +++ b/trie/node.go @@ -0,0 +1,40 @@ +package trie + +import "fmt" + +var indices = []string{"0", "1", "2", "3", "4", "5", "6", "7", "8", "9", "a", "b", "c", "d", "e", "f", "[17]"} + +type Node interface { + Value() Node + Copy() Node // All nodes, for now, return them self + Dirty() bool + fstring(string) string + Hash() interface{} + RlpData() interface{} +} + +// Value node +func (self *ValueNode) String() string { return self.fstring("") } +func (self *FullNode) String() string { return self.fstring("") } +func (self *ShortNode) String() string { return self.fstring("") } +func (self *ValueNode) fstring(ind string) string { return fmt.Sprintf("%x ", self.data) } +func (self *HashNode) fstring(ind string) string { return fmt.Sprintf("%x ", self.key) } + +// Full node +func (self *FullNode) fstring(ind string) string { + resp := fmt.Sprintf("[\n%s ", ind) + for i, node := range self.nodes { + if node == nil { + resp += fmt.Sprintf("%s: ", indices[i]) + } else { + resp += fmt.Sprintf("%s: %v", indices[i], node.fstring(ind+" ")) + } + } + + return resp + fmt.Sprintf("\n%s] ", ind) +} + +// Short node +func (self *ShortNode) fstring(ind string) string { + return fmt.Sprintf("[ %x: %v ] ", self.key, self.value.fstring(ind+" ")) +} diff --git a/trie/shortnode.go b/trie/shortnode.go new file mode 100644 index 000000000..f132b56d9 --- /dev/null +++ b/trie/shortnode.go @@ -0,0 +1,29 @@ +package trie + +type ShortNode struct { + trie *Trie + key []byte + value Node +} + +func NewShortNode(t *Trie, key []byte, value Node) *ShortNode { + return &ShortNode{t, []byte(CompactEncode(key)), value} +} +func (self *ShortNode) Value() Node { + self.value = self.trie.trans(self.value) + + return self.value +} +func (self *ShortNode) Dirty() bool { return true } +func (self *ShortNode) Copy() Node { return NewShortNode(self.trie, self.key, self.value) } + +func (self *ShortNode) RlpData() interface{} { + return []interface{}{self.key, self.value.Hash()} +} +func (self *ShortNode) Hash() interface{} { + return self.trie.store(self) +} + +func (self *ShortNode) Key() []byte { + return CompactDecode(string(self.key)) +} diff --git a/trie/trie.go b/trie/trie.go index c9fd18e00..36f2af5d2 100644 --- a/trie/trie.go +++ b/trie/trie.go @@ -1,8 +1,8 @@ package trie -/* import ( "bytes" + "container/list" "fmt" "sync" @@ -10,618 +10,325 @@ import ( "github.com/ethereum/go-ethereum/ethutil" ) -func ParanoiaCheck(t1 *Trie) (bool, *Trie) { - t2 := New(ethutil.Config.Db, "") +func ParanoiaCheck(t1 *Trie, backend Backend) (bool, *Trie) { + t2 := New(nil, backend) - t1.NewIterator().Each(func(key string, v *ethutil.Value) { - t2.Update(key, v.Str()) - }) - - return bytes.Compare(t2.GetRoot(), t1.GetRoot()) == 0, t2 -} - -func (s *Cache) Len() int { - return len(s.nodes) -} - -// TODO -// A StateObject is an object that has a state root -// This is goig to be the object for the second level caching (the caching of object which have a state such as contracts) -type StateObject interface { - State() *Trie - Sync() - Undo() -} - -type Node struct { - Key []byte - Value *ethutil.Value - Dirty bool -} - -func NewNode(key []byte, val *ethutil.Value, dirty bool) *Node { - return &Node{Key: key, Value: val, Dirty: dirty} -} + it := t1.Iterator() + for it.Next() { + t2.Update(it.Key, it.Value) + } -func (n *Node) Copy() *Node { - return NewNode(n.Key, n.Value, n.Dirty) + return bytes.Equal(t2.Hash(), t1.Hash()), t2 } -type Cache struct { - nodes map[string]*Node - db ethutil.Database - IsDirty bool -} +type Trie struct { + mu sync.Mutex + root Node + roothash []byte + cache *Cache -func NewCache(db ethutil.Database) *Cache { - return &Cache{db: db, nodes: make(map[string]*Node)} + revisions *list.List } -func (cache *Cache) PutValue(v interface{}, force bool) interface{} { - value := ethutil.NewValue(v) - - enc := value.Encode() - if len(enc) >= 32 || force { - sha := crypto.Sha3(enc) - - cache.nodes[string(sha)] = NewNode(sha, value, true) - cache.IsDirty = true +func New(root []byte, backend Backend) *Trie { + trie := &Trie{} + trie.revisions = list.New() + trie.roothash = root + trie.cache = NewCache(backend) - return sha + if root != nil { + value := ethutil.NewValueFromBytes(trie.cache.Get(root)) + trie.root = trie.mknode(value) } - return v -} - -func (cache *Cache) Put(v interface{}) interface{} { - return cache.PutValue(v, false) + return trie } -func (cache *Cache) Get(key []byte) *ethutil.Value { - // First check if the key is the cache - if cache.nodes[string(key)] != nil { - return cache.nodes[string(key)].Value - } - - // Get the key of the database instead and cache it - data, _ := cache.db.Get(key) - // Create the cached value - value := ethutil.NewValueFromBytes(data) - - defer func() { - if r := recover(); r != nil { - fmt.Println("RECOVER GET", cache, cache.nodes) - panic("bye") - } - }() - // Create caching node - cache.nodes[string(key)] = NewNode(key, value, true) - - return value +func (self *Trie) Iterator() *Iterator { + return NewIterator(self) } -func (cache *Cache) Delete(key []byte) { - delete(cache.nodes, string(key)) - - cache.db.Delete(key) +func (self *Trie) Copy() *Trie { + return New(self.roothash, self.cache.backend) } -func (cache *Cache) Commit() { - // Don't try to commit if it isn't dirty - if !cache.IsDirty { - return - } - - for key, node := range cache.nodes { - if node.Dirty { - cache.db.Put([]byte(key), node.Value.Encode()) - node.Dirty = false +// Legacy support +func (self *Trie) Root() []byte { return self.Hash() } +func (self *Trie) Hash() []byte { + var hash []byte + if self.root != nil { + t := self.root.Hash() + if byts, ok := t.([]byte); ok && len(byts) > 0 { + hash = byts + } else { + hash = crypto.Sha3(ethutil.Encode(self.root.RlpData())) } + } else { + hash = crypto.Sha3(ethutil.Encode("")) } - cache.IsDirty = false - - // If the nodes grows beyond the 200 entries we simple empty it - // FIXME come up with something better - if len(cache.nodes) > 200 { - cache.nodes = make(map[string]*Node) - } -} -func (cache *Cache) Undo() { - for key, node := range cache.nodes { - if node.Dirty { - delete(cache.nodes, key) - } + if !bytes.Equal(hash, self.roothash) { + self.revisions.PushBack(self.roothash) + self.roothash = hash } - cache.IsDirty = false -} -// A (modified) Radix Trie implementation. The Trie implements -// a caching mechanism and will used cached values if they are -// present. If a node is not present in the cache it will try to -// fetch it from the database and store the cached value. -// Please note that the data isn't persisted unless `Sync` is -// explicitly called. -type Trie struct { - mut sync.RWMutex - prevRoot interface{} - Root interface{} - //db Database - cache *Cache + return hash } +func (self *Trie) Commit() { + self.mu.Lock() + defer self.mu.Unlock() -func copyRoot(root interface{}) interface{} { - var prevRootCopy interface{} - if b, ok := root.([]byte); ok { - prevRootCopy = ethutil.CopyBytes(b) - } else { - prevRootCopy = root - } + // Hash first + self.Hash() - return prevRootCopy + self.cache.Flush() } -func New(db ethutil.Database, Root interface{}) *Trie { - // Make absolute sure the root is copied - r := copyRoot(Root) - p := copyRoot(Root) +// Reset should only be called if the trie has been hashed +func (self *Trie) Reset() { + self.mu.Lock() + defer self.mu.Unlock() - trie := &Trie{cache: NewCache(db), Root: r, prevRoot: p} - trie.setRoot(Root) + self.cache.Reset() - return trie -} - -func (self *Trie) setRoot(root interface{}) { - switch t := root.(type) { - case string: - //if t == "" { - // root = crypto.Sha3(ethutil.Encode("")) - //} - self.Root = []byte(t) - case []byte: - self.Root = root - default: - self.Root = self.cache.PutValue(root, true) + if self.revisions.Len() > 0 { + revision := self.revisions.Remove(self.revisions.Back()).([]byte) + self.roothash = revision } + value := ethutil.NewValueFromBytes(self.cache.Get(self.roothash)) + self.root = self.mknode(value) } -func (t *Trie) Update(key, value string) { - t.mut.Lock() - defer t.mut.Unlock() +func (self *Trie) UpdateString(key, value string) Node { return self.Update([]byte(key), []byte(value)) } +func (self *Trie) Update(key, value []byte) Node { + self.mu.Lock() + defer self.mu.Unlock() - k := CompactHexDecode(key) + k := CompactHexDecode(string(key)) - var root interface{} - if value != "" { - root = t.UpdateState(t.Root, k, value) + if len(value) != 0 { + self.root = self.insert(self.root, k, &ValueNode{self, value}) } else { - root = t.deleteState(t.Root, k) + self.root = self.delete(self.root, k) } - t.setRoot(root) -} - -func (t *Trie) Get(key string) string { - t.mut.Lock() - defer t.mut.Unlock() - - k := CompactHexDecode(key) - c := ethutil.NewValue(t.getState(t.Root, k)) - return c.Str() + return self.root } -func (t *Trie) Delete(key string) { - t.mut.Lock() - defer t.mut.Unlock() - - k := CompactHexDecode(key) - - root := t.deleteState(t.Root, k) - t.setRoot(root) -} - -func (self *Trie) GetRoot() []byte { - switch t := self.Root.(type) { - case string: - if t == "" { - return crypto.Sha3(ethutil.Encode("")) - } - return []byte(t) - case []byte: - if len(t) == 0 { - return crypto.Sha3(ethutil.Encode("")) - } - - return t - default: - panic(fmt.Sprintf("invalid root type %T (%v)", self.Root, self.Root)) - } -} +func (self *Trie) GetString(key string) []byte { return self.Get([]byte(key)) } +func (self *Trie) Get(key []byte) []byte { + self.mu.Lock() + defer self.mu.Unlock() -// Simple compare function which creates a rlp value out of the evaluated objects -func (t *Trie) Cmp(trie *Trie) bool { - return ethutil.NewValue(t.Root).Cmp(ethutil.NewValue(trie.Root)) -} + k := CompactHexDecode(string(key)) -// Returns a copy of this trie -func (t *Trie) Copy() *Trie { - trie := New(t.cache.db, t.Root) - for key, node := range t.cache.nodes { - trie.cache.nodes[key] = node.Copy() + n := self.get(self.root, k) + if n != nil { + return n.(*ValueNode).Val() } - return trie + return nil } -// Save the cached value to the database. -func (t *Trie) Sync() { - t.cache.Commit() - t.prevRoot = copyRoot(t.Root) -} +func (self *Trie) DeleteString(key string) Node { return self.Delete([]byte(key)) } +func (self *Trie) Delete(key []byte) Node { + self.mu.Lock() + defer self.mu.Unlock() -func (t *Trie) Undo() { - t.cache.Undo() - t.Root = t.prevRoot -} + k := CompactHexDecode(string(key)) + self.root = self.delete(self.root, k) -func (t *Trie) Cache() *Cache { - return t.cache + return self.root } -func (t *Trie) getState(node interface{}, key []byte) interface{} { - n := ethutil.NewValue(node) - // Return the node if key is empty (= found) - if len(key) == 0 || n.IsNil() || n.Len() == 0 { - return node +func (self *Trie) insert(node Node, key []byte, value Node) Node { + if len(key) == 0 { + return value } - currentNode := t.getNode(node) - length := currentNode.Len() + if node == nil { + return NewShortNode(self, key, value) + } - if length == 0 { - return "" - } else if length == 2 { - // Decode the key - k := CompactDecode(currentNode.Get(0).Str()) - v := currentNode.Get(1).Raw() + switch node := node.(type) { + case *ShortNode: + k := node.Key() + cnode := node.Value() + if bytes.Equal(k, key) { + return NewShortNode(self, key, value) + } - if len(key) >= len(k) && bytes.Equal(k, key[:len(k)]) { //CompareIntSlice(k, key[:len(k)]) { - return t.getState(v, key[len(k):]) + var n Node + matchlength := MatchingNibbleLength(key, k) + if matchlength == len(k) { + n = self.insert(cnode, key[matchlength:], value) } else { - return "" + pnode := self.insert(nil, k[matchlength+1:], cnode) + nnode := self.insert(nil, key[matchlength+1:], value) + fulln := NewFullNode(self) + fulln.set(k[matchlength], pnode) + fulln.set(key[matchlength], nnode) + n = fulln + } + if matchlength == 0 { + return n } - } else if length == 17 { - return t.getState(currentNode.Get(int(key[0])).Raw(), key[1:]) - } - - // It shouldn't come this far - panic("unexpected return") -} - -func (t *Trie) getNode(node interface{}) *ethutil.Value { - n := ethutil.NewValue(node) - - if !n.Get(0).IsNil() { - return n - } - - str := n.Str() - if len(str) == 0 { - return n - } else if len(str) < 32 { - return ethutil.NewValueFromBytes([]byte(str)) - } - - data := t.cache.Get(n.Bytes()) - - return data -} -func (t *Trie) UpdateState(node interface{}, key []byte, value string) interface{} { - return t.InsertState(node, key, value) -} + return NewShortNode(self, key[:matchlength], n) -func (t *Trie) Put(node interface{}) interface{} { - return t.cache.Put(node) + case *FullNode: + cpy := node.Copy().(*FullNode) + cpy.set(key[0], self.insert(node.branch(key[0]), key[1:], value)) -} + return cpy -func EmptyStringSlice(l int) []interface{} { - slice := make([]interface{}, l) - for i := 0; i < l; i++ { - slice[i] = "" + default: + panic(fmt.Sprintf("%T: invalid node: %v", node, node)) } - return slice } -func (t *Trie) InsertState(node interface{}, key []byte, value interface{}) interface{} { +func (self *Trie) get(node Node, key []byte) Node { if len(key) == 0 { - return value + return node } - // New node - n := ethutil.NewValue(node) - if node == nil || n.Len() == 0 { - newNode := []interface{}{CompactEncode(key), value} - - return t.Put(newNode) + if node == nil { + return nil } - currentNode := t.getNode(node) - // Check for "special" 2 slice type node - if currentNode.Len() == 2 { - // Decode the key - - k := CompactDecode(currentNode.Get(0).Str()) - v := currentNode.Get(1).Raw() - - // Matching key pair (ie. there's already an object with this key) - if bytes.Equal(k, key) { //CompareIntSlice(k, key) { - newNode := []interface{}{CompactEncode(key), value} - return t.Put(newNode) - } - - var newHash interface{} - matchingLength := MatchingNibbleLength(key, k) - if matchingLength == len(k) { - // Insert the hash, creating a new node - newHash = t.InsertState(v, key[matchingLength:], value) - } else { - // Expand the 2 length slice to a 17 length slice - oldNode := t.InsertState("", k[matchingLength+1:], v) - newNode := t.InsertState("", key[matchingLength+1:], value) - // Create an expanded slice - scaledSlice := EmptyStringSlice(17) - // Set the copied and new node - scaledSlice[k[matchingLength]] = oldNode - scaledSlice[key[matchingLength]] = newNode - - newHash = t.Put(scaledSlice) - } - - if matchingLength == 0 { - // End of the chain, return - return newHash - } else { - newNode := []interface{}{CompactEncode(key[:matchingLength]), newHash} - return t.Put(newNode) - } - } else { - - // Copy the current node over to the new node and replace the first nibble in the key - newNode := EmptyStringSlice(17) + switch node := node.(type) { + case *ShortNode: + k := node.Key() + cnode := node.Value() - for i := 0; i < 17; i++ { - cpy := currentNode.Get(i).Raw() - if cpy != nil { - newNode[i] = cpy - } + if len(key) >= len(k) && bytes.Equal(k, key[:len(k)]) { + return self.get(cnode, key[len(k):]) } - newNode[key[0]] = t.InsertState(currentNode.Get(int(key[0])).Raw(), key[1:], value) - - return t.Put(newNode) + return nil + case *FullNode: + return self.get(node.branch(key[0]), key[1:]) + default: + panic(fmt.Sprintf("%T: invalid node: %v", node, node)) } - - panic("unexpected end") } -func (t *Trie) deleteState(node interface{}, key []byte) interface{} { - if len(key) == 0 { - return "" - } - - // New node - n := ethutil.NewValue(node) - //if node == nil || (n.Type() == reflect.String && (n.Str() == "" || n.Get(0).IsNil())) || n.Len() == 0 { - if node == nil || n.Len() == 0 { - //return nil - //fmt.Printf(" %x %d\n", n, len(n.Bytes())) - - return "" +func (self *Trie) delete(node Node, key []byte) Node { + if len(key) == 0 && node == nil { + return nil } - currentNode := t.getNode(node) - // Check for "special" 2 slice type node - if currentNode.Len() == 2 { - // Decode the key - k := CompactDecode(currentNode.Get(0).Str()) - v := currentNode.Get(1).Raw() - - // Matching key pair (ie. there's already an object with this key) - if bytes.Equal(k, key) { //CompareIntSlice(k, key) { - //fmt.Printf(" %x\n", v) - - return "" - } else if bytes.Equal(key[:len(k)], k) { //CompareIntSlice(key[:len(k)], k) { - hash := t.deleteState(v, key[len(k):]) - child := t.getNode(hash) - - var newNode []interface{} - if child.Len() == 2 { - newKey := append(k, CompactDecode(child.Get(0).Str())...) - newNode = []interface{}{CompactEncode(newKey), child.Get(1).Raw()} - } else { - newNode = []interface{}{currentNode.Get(0).Str(), hash} + switch node := node.(type) { + case *ShortNode: + k := node.Key() + cnode := node.Value() + if bytes.Equal(key, k) { + return nil + } else if bytes.Equal(key[:len(k)], k) { + child := self.delete(cnode, key[len(k):]) + + var n Node + switch child := child.(type) { + case *ShortNode: + nkey := append(k, child.Key()...) + n = NewShortNode(self, nkey, child.Value()) + case *FullNode: + sn := NewShortNode(self, node.Key(), child) + sn.key = node.key + n = sn } - //fmt.Printf("%x\n", newNode) - - return t.Put(newNode) + return n } else { return node } - } else { - // Copy the current node over to the new node and replace the first nibble in the key - n := EmptyStringSlice(17) - var newNode []interface{} - for i := 0; i < 17; i++ { - cpy := currentNode.Get(i).Raw() - if cpy != nil { - n[i] = cpy - } - } + case *FullNode: + n := node.Copy().(*FullNode) + n.set(key[0], self.delete(n.branch(key[0]), key[1:])) - n[key[0]] = t.deleteState(n[key[0]], key[1:]) - amount := -1 + pos := -1 for i := 0; i < 17; i++ { - if n[i] != "" { - if amount == -1 { - amount = i + if n.branch(byte(i)) != nil { + if pos == -1 { + pos = i } else { - amount = -2 + pos = -2 } } } - if amount == 16 { - newNode = []interface{}{CompactEncode([]byte{16}), n[amount]} - } else if amount >= 0 { - child := t.getNode(n[amount]) - if child.Len() == 17 { - newNode = []interface{}{CompactEncode([]byte{byte(amount)}), n[amount]} - } else if child.Len() == 2 { - key := append([]byte{byte(amount)}, CompactDecode(child.Get(0).Str())...) - newNode = []interface{}{CompactEncode(key), child.Get(1).Str()} - } + var nnode Node + if pos == 16 { + nnode = NewShortNode(self, []byte{16}, n.branch(byte(pos))) + } else if pos >= 0 { + cnode := n.branch(byte(pos)) + switch cnode := cnode.(type) { + case *ShortNode: + // Stitch keys + k := append([]byte{byte(pos)}, cnode.Key()...) + nnode = NewShortNode(self, k, cnode.Value()) + case *FullNode: + nnode = NewShortNode(self, []byte{byte(pos)}, n.branch(byte(pos))) + } } else { - newNode = n + nnode = n } - //fmt.Printf("%x\n", newNode) - return t.Put(newNode) + return nnode + case nil: + return nil + default: + panic(fmt.Sprintf("%T: invalid node: %v (%v)", node, node, key)) } - - panic("unexpected return") -} - -type TrieIterator struct { - trie *Trie - key string - value string - - shas [][]byte - values []string - - lastNode []byte } -func (t *Trie) NewIterator() *TrieIterator { - return &TrieIterator{trie: t} -} - -func (self *Trie) Iterator() *Iterator { - return NewIterator(self) -} - -// Some time in the near future this will need refactoring :-) -// XXX Note to self, IsSlice == inline node. Str == sha3 to node -func (it *TrieIterator) workNode(currentNode *ethutil.Value) { - if currentNode.Len() == 2 { - k := CompactDecode(currentNode.Get(0).Str()) - - if currentNode.Get(1).Str() == "" { - it.workNode(currentNode.Get(1)) - } else { - if k[len(k)-1] == 16 { - it.values = append(it.values, currentNode.Get(1).Str()) - } else { - it.shas = append(it.shas, currentNode.Get(1).Bytes()) - it.getNode(currentNode.Get(1).Bytes()) - } +// casting functions and cache storing +func (self *Trie) mknode(value *ethutil.Value) Node { + l := value.Len() + switch l { + case 0: + return nil + case 2: + // A value node may consists of 2 bytes. + if value.Get(0).Len() != 0 { + return NewShortNode(self, CompactDecode(string(value.Get(0).Bytes())), self.mknode(value.Get(1))) } - } else { - for i := 0; i < currentNode.Len(); i++ { - if i == 16 && currentNode.Get(i).Len() != 0 { - it.values = append(it.values, currentNode.Get(i).Str()) - } else { - if currentNode.Get(i).Str() == "" { - it.workNode(currentNode.Get(i)) - } else { - val := currentNode.Get(i).Str() - if val != "" { - it.shas = append(it.shas, currentNode.Get(1).Bytes()) - it.getNode([]byte(val)) - } - } - } + case 17: + fnode := NewFullNode(self) + for i := 0; i < l; i++ { + fnode.set(byte(i), self.mknode(value.Get(i))) } + return fnode + case 32: + return &HashNode{value.Bytes()} } -} -func (it *TrieIterator) getNode(node []byte) { - currentNode := it.trie.cache.Get(node) - it.workNode(currentNode) + return &ValueNode{self, value.Bytes()} } -func (it *TrieIterator) Collect() [][]byte { - if it.trie.Root == "" { - return nil - } - - it.getNode(ethutil.NewValue(it.trie.Root).Bytes()) - - return it.shas -} - -func (it *TrieIterator) Purge() int { - shas := it.Collect() - for _, sha := range shas { - it.trie.cache.Delete(sha) +func (self *Trie) trans(node Node) Node { + switch node := node.(type) { + case *HashNode: + value := ethutil.NewValueFromBytes(self.cache.Get(node.key)) + return self.mknode(value) + default: + return node } - return len(it.values) -} - -func (it *TrieIterator) Key() string { - return "" } -func (it *TrieIterator) Value() string { - return "" -} - -type EachCallback func(key string, node *ethutil.Value) +func (self *Trie) store(node Node) interface{} { + data := ethutil.Encode(node) + if len(data) >= 32 { + key := crypto.Sha3(data) + self.cache.Put(key, data) -func (it *TrieIterator) Each(cb EachCallback) { - it.fetchNode(nil, ethutil.NewValue(it.trie.Root).Bytes(), cb) -} + return key + } -func (it *TrieIterator) fetchNode(key []byte, node []byte, cb EachCallback) { - it.iterateNode(key, it.trie.cache.Get(node), cb) + return node.RlpData() } -func (it *TrieIterator) iterateNode(key []byte, currentNode *ethutil.Value, cb EachCallback) { - if currentNode.Len() == 2 { - k := CompactDecode(currentNode.Get(0).Str()) - - pk := append(key, k...) - if currentNode.Get(1).Len() != 0 && currentNode.Get(1).Str() == "" { - it.iterateNode(pk, currentNode.Get(1), cb) - } else { - if k[len(k)-1] == 16 { - cb(DecodeCompact(pk), currentNode.Get(1)) - } else { - it.fetchNode(pk, currentNode.Get(1).Bytes(), cb) - } - } - } else { - for i := 0; i < currentNode.Len(); i++ { - pk := append(key, byte(i)) - if i == 16 && currentNode.Get(i).Len() != 0 { - cb(DecodeCompact(pk), currentNode.Get(i)) - } else { - if currentNode.Get(i).Len() != 0 && currentNode.Get(i).Str() == "" { - it.iterateNode(pk, currentNode.Get(i), cb) - } else { - val := currentNode.Get(i).Str() - if val != "" { - it.fetchNode(pk, []byte(val), cb) - } - } - } - } - } +func (self *Trie) PrintRoot() { + fmt.Println(self.root) } -*/ diff --git a/trie/trie_test.go b/trie/trie_test.go index 3abe56040..ffb78d4f2 100644 --- a/trie/trie_test.go +++ b/trie/trie_test.go @@ -1,345 +1,188 @@ package trie -/* import ( "bytes" - "encoding/hex" - "encoding/json" "fmt" - "io/ioutil" - "math/rand" - "net/http" "testing" - "time" - - checker "gopkg.in/check.v1" + "github.com/ethereum/go-ethereum/crypto" "github.com/ethereum/go-ethereum/ethutil" ) -const LONG_WORD = "1234567890abcdefghijklmnopqrstuvwxxzABCEFGHIJKLMNOPQRSTUVWXYZ" - -type TrieSuite struct { - db *MemDatabase - trie *Trie -} - -type MemDatabase struct { - db map[string][]byte -} - -func NewMemDatabase() (*MemDatabase, error) { - db := &MemDatabase{db: make(map[string][]byte)} - return db, nil -} -func (db *MemDatabase) Put(key []byte, value []byte) { - db.db[string(key)] = value -} -func (db *MemDatabase) Get(key []byte) ([]byte, error) { - return db.db[string(key)], nil -} -func (db *MemDatabase) Delete(key []byte) error { - delete(db.db, string(key)) - return nil -} -func (db *MemDatabase) Print() {} -func (db *MemDatabase) Close() {} -func (db *MemDatabase) LastKnownTD() []byte { return nil } - -func NewTrie() (*MemDatabase, *Trie) { - db, _ := NewMemDatabase() - return db, New(db, "") -} - -func (s *TrieSuite) SetUpTest(c *checker.C) { - s.db, s.trie = NewTrie() -} - -func (s *TrieSuite) TestTrieSync(c *checker.C) { - s.trie.Update("dog", LONG_WORD) - c.Assert(s.db.db, checker.HasLen, 0, checker.Commentf("Expected no data in database")) - s.trie.Sync() - c.Assert(s.db.db, checker.HasLen, 3) -} - -func (s *TrieSuite) TestTrieDirtyTracking(c *checker.C) { - s.trie.Update("dog", LONG_WORD) - c.Assert(s.trie.cache.IsDirty, checker.Equals, true, checker.Commentf("Expected no data in database")) - - s.trie.Sync() - c.Assert(s.trie.cache.IsDirty, checker.Equals, false, checker.Commentf("Expected trie to be dirty")) - - s.trie.Update("test", LONG_WORD) - s.trie.cache.Undo() - c.Assert(s.trie.cache.IsDirty, checker.Equals, false) -} - -func (s *TrieSuite) TestTrieReset(c *checker.C) { - s.trie.Update("cat", LONG_WORD) - c.Assert(s.trie.cache.nodes, checker.HasLen, 1, checker.Commentf("Expected cached nodes")) +type Db map[string][]byte - s.trie.cache.Undo() - c.Assert(s.trie.cache.nodes, checker.HasLen, 0, checker.Commentf("Expected no nodes after undo")) -} +func (self Db) Get(k []byte) ([]byte, error) { return self[string(k)], nil } +func (self Db) Put(k, v []byte) { self[string(k)] = v } -func (s *TrieSuite) TestTrieGet(c *checker.C) { - s.trie.Update("cat", LONG_WORD) - x := s.trie.Get("cat") - c.Assert(x, checker.DeepEquals, LONG_WORD) +// Used for testing +func NewEmpty() *Trie { + return New(nil, make(Db)) } -func (s *TrieSuite) TestTrieUpdating(c *checker.C) { - s.trie.Update("cat", LONG_WORD) - s.trie.Update("cat", LONG_WORD+"1") - x := s.trie.Get("cat") - c.Assert(x, checker.DeepEquals, LONG_WORD+"1") +func TestEmptyTrie(t *testing.T) { + trie := NewEmpty() + res := trie.Hash() + exp := crypto.Sha3(ethutil.Encode("")) + if !bytes.Equal(res, exp) { + t.Errorf("expected %x got %x", exp, res) + } } -func (s *TrieSuite) TestTrieCmp(c *checker.C) { - _, trie1 := NewTrie() - _, trie2 := NewTrie() +func TestInsert(t *testing.T) { + trie := NewEmpty() - trie1.Update("doge", LONG_WORD) - trie2.Update("doge", LONG_WORD) - c.Assert(trie1, checker.DeepEquals, trie2) + trie.UpdateString("doe", "reindeer") + trie.UpdateString("dog", "puppy") + trie.UpdateString("dogglesworth", "cat") - trie1.Update("dog", LONG_WORD) - trie2.Update("cat", LONG_WORD) - c.Assert(trie1, checker.Not(checker.DeepEquals), trie2) -} + exp := ethutil.Hex2Bytes("8aad789dff2f538bca5d8ea56e8abe10f4c7ba3a5dea95fea4cd6e7c3a1168d3") + root := trie.Hash() + if !bytes.Equal(root, exp) { + t.Errorf("exp %x got %x", exp, root) + } -func (s *TrieSuite) TestTrieDelete(c *checker.C) { - s.trie.Update("cat", LONG_WORD) - exp := s.trie.Root - s.trie.Update("dog", LONG_WORD) - s.trie.Delete("dog") - c.Assert(s.trie.Root, checker.DeepEquals, exp) - - s.trie.Update("dog", LONG_WORD) - exp = s.trie.Root - s.trie.Update("dude", LONG_WORD) - s.trie.Delete("dude") - c.Assert(s.trie.Root, checker.DeepEquals, exp) -} + trie = NewEmpty() + trie.UpdateString("A", "aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa") -func (s *TrieSuite) TestTrieDeleteWithValue(c *checker.C) { - s.trie.Update("c", LONG_WORD) - exp := s.trie.Root - s.trie.Update("ca", LONG_WORD) - s.trie.Update("cat", LONG_WORD) - s.trie.Delete("ca") - s.trie.Delete("cat") - c.Assert(s.trie.Root, checker.DeepEquals, exp) + exp = ethutil.Hex2Bytes("d23786fb4a010da3ce639d66d5e904a11dbc02746d1ce25029e53290cabf28ab") + root = trie.Hash() + if !bytes.Equal(root, exp) { + t.Errorf("exp %x got %x", exp, root) + } } -func (s *TrieSuite) TestTriePurge(c *checker.C) { - s.trie.Update("c", LONG_WORD) - s.trie.Update("ca", LONG_WORD) - s.trie.Update("cat", LONG_WORD) +func TestGet(t *testing.T) { + trie := NewEmpty() - lenBefore := len(s.trie.cache.nodes) - it := s.trie.NewIterator() - num := it.Purge() - c.Assert(num, checker.Equals, 3) - c.Assert(len(s.trie.cache.nodes), checker.Equals, lenBefore) -} + trie.UpdateString("doe", "reindeer") + trie.UpdateString("dog", "puppy") + trie.UpdateString("dogglesworth", "cat") -func h(str string) string { - d, err := hex.DecodeString(str) - if err != nil { - panic(err) + res := trie.GetString("dog") + if !bytes.Equal(res, []byte("puppy")) { + t.Errorf("expected puppy got %x", res) } - return string(d) -} - -func get(in string) (out string) { - if len(in) > 2 && in[:2] == "0x" { - out = h(in[2:]) - } else { - out = in + unknown := trie.GetString("unknown") + if unknown != nil { + t.Errorf("expected nil got %x", unknown) } - - return } -type TrieTest struct { - Name string - In map[string]string - Root string -} +func TestDelete(t *testing.T) { + trie := NewEmpty() -func CreateTest(name string, data []byte) (TrieTest, error) { - t := TrieTest{Name: name} - err := json.Unmarshal(data, &t) - if err != nil { - return TrieTest{}, fmt.Errorf("%v", err) + vals := []struct{ k, v string }{ + {"do", "verb"}, + {"ether", "wookiedoo"}, + {"horse", "stallion"}, + {"shaman", "horse"}, + {"doge", "coin"}, + {"ether", ""}, + {"dog", "puppy"}, + {"shaman", ""}, } - - return t, nil -} - -func CreateTests(uri string, cb func(TrieTest)) map[string]TrieTest { - resp, err := http.Get(uri) - if err != nil { - panic(err) + for _, val := range vals { + if val.v != "" { + trie.UpdateString(val.k, val.v) + } else { + trie.DeleteString(val.k) + } } - defer resp.Body.Close() - data, err := ioutil.ReadAll(resp.Body) - - var objmap map[string]*json.RawMessage - err = json.Unmarshal(data, &objmap) - if err != nil { - panic(err) + hash := trie.Hash() + exp := ethutil.Hex2Bytes("5991bb8c6514148a29db676a14ac506cd2cd5775ace63c30a4fe457715e9ac84") + if !bytes.Equal(hash, exp) { + t.Errorf("expected %x got %x", exp, hash) } +} - tests := make(map[string]TrieTest) - for name, testData := range objmap { - test, err := CreateTest(name, *testData) - if err != nil { - panic(err) - } +func TestEmptyValues(t *testing.T) { + trie := NewEmpty() - if cb != nil { - cb(test) - } - tests[name] = test + vals := []struct{ k, v string }{ + {"do", "verb"}, + {"ether", "wookiedoo"}, + {"horse", "stallion"}, + {"shaman", "horse"}, + {"doge", "coin"}, + {"ether", ""}, + {"dog", "puppy"}, + {"shaman", ""}, } - - return tests -} - -func RandomData() [][]string { - data := [][]string{ - {"0x000000000000000000000000ec4f34c97e43fbb2816cfd95e388353c7181dab1", "0x4e616d6552656700000000000000000000000000000000000000000000000000"}, - {"0x0000000000000000000000000000000000000000000000000000000000000045", "0x22b224a1420a802ab51d326e29fa98e34c4f24ea"}, - {"0x0000000000000000000000000000000000000000000000000000000000000046", "0x67706c2076330000000000000000000000000000000000000000000000000000"}, - {"0x000000000000000000000000697c7b8c961b56f675d570498424ac8de1a918f6", "0x6f6f6f6820736f2067726561742c207265616c6c6c793f000000000000000000"}, - {"0x0000000000000000000000007ef9e639e2733cb34e4dfc576d4b23f72db776b2", "0x4655474156000000000000000000000000000000000000000000000000000000"}, - {"0x6f6f6f6820736f2067726561742c207265616c6c6c793f000000000000000000", "0x697c7b8c961b56f675d570498424ac8de1a918f6"}, - {"0x4655474156000000000000000000000000000000000000000000000000000000", "0x7ef9e639e2733cb34e4dfc576d4b23f72db776b2"}, - {"0x4e616d6552656700000000000000000000000000000000000000000000000000", "0xec4f34c97e43fbb2816cfd95e388353c7181dab1"}, + for _, val := range vals { + trie.UpdateString(val.k, val.v) } - var c [][]string - for len(data) != 0 { - e := rand.Intn(len(data)) - c = append(c, data[e]) - - copy(data[e:], data[e+1:]) - data[len(data)-1] = nil - data = data[:len(data)-1] + hash := trie.Hash() + exp := ethutil.Hex2Bytes("5991bb8c6514148a29db676a14ac506cd2cd5775ace63c30a4fe457715e9ac84") + if !bytes.Equal(hash, exp) { + t.Errorf("expected %x got %x", exp, hash) } - - return c } -const MaxTest = 1000 - -// This test insert data in random order and seeks to find indifferences between the different tries -func (s *TrieSuite) TestRegression(c *checker.C) { - rand.Seed(time.Now().Unix()) - - roots := make(map[string]int) - for i := 0; i < MaxTest; i++ { - _, trie := NewTrie() - data := RandomData() - - for _, test := range data { - trie.Update(test[0], test[1]) - } - trie.Delete("0x4e616d6552656700000000000000000000000000000000000000000000000000") - - roots[string(trie.Root.([]byte))] += 1 +func TestReplication(t *testing.T) { + trie := NewEmpty() + vals := []struct{ k, v string }{ + {"do", "verb"}, + {"ether", "wookiedoo"}, + {"horse", "stallion"}, + {"shaman", "horse"}, + {"doge", "coin"}, + {"ether", ""}, + {"dog", "puppy"}, + {"shaman", ""}, + {"somethingveryoddindeedthis is", "myothernodedata"}, } + for _, val := range vals { + trie.UpdateString(val.k, val.v) + } + trie.Commit() - c.Assert(len(roots) <= 1, checker.Equals, true) - // if len(roots) > 1 { - // for root, num := range roots { - // t.Errorf("%x => %d\n", root, num) - // } - // } -} - -func (s *TrieSuite) TestDelete(c *checker.C) { - s.trie.Update("a", "jeffreytestlongstring") - s.trie.Update("aa", "otherstring") - s.trie.Update("aaa", "othermorestring") - s.trie.Update("aabbbbccc", "hithere") - s.trie.Update("abbcccdd", "hstanoehutnaheoustnh") - s.trie.Update("rnthaoeuabbcccdd", "hstanoehutnaheoustnh") - s.trie.Update("rneuabbcccdd", "hstanoehutnaheoustnh") - s.trie.Update("rneuabboeusntahoeucccdd", "hstanoehutnaheoustnh") - s.trie.Update("rnxabboeusntahoeucccdd", "hstanoehutnaheoustnh") - s.trie.Delete("aaboaestnuhbccc") - s.trie.Delete("a") - s.trie.Update("a", "nthaonethaosentuh") - s.trie.Update("c", "shtaosntehua") - s.trie.Delete("a") - s.trie.Update("aaaa", "testmegood") - - _, t2 := NewTrie() - s.trie.NewIterator().Each(func(key string, v *ethutil.Value) { - if key == "aaaa" { - t2.Update(key, v.Str()) - } else { - t2.Update(key, v.Str()) - } - }) - - a := ethutil.NewValue(s.trie.Root).Bytes() - b := ethutil.NewValue(t2.Root).Bytes() + trie2 := New(trie.roothash, trie.cache.backend) + if string(trie2.GetString("horse")) != "stallion" { + t.Error("expected to have horse => stallion") + } - c.Assert(a, checker.DeepEquals, b) -} + hash := trie2.Hash() + exp := trie.Hash() + if !bytes.Equal(hash, exp) { + t.Errorf("root failure. expected %x got %x", exp, hash) + } -func (s *TrieSuite) TestTerminator(c *checker.C) { - key := CompactDecode("hello") - c.Assert(HasTerm(key), checker.Equals, true, checker.Commentf("Expected %v to have a terminator", key)) } -func (s *TrieSuite) TestIt(c *checker.C) { - s.trie.Update("cat", "cat") - s.trie.Update("doge", "doge") - s.trie.Update("wallace", "wallace") - it := s.trie.Iterator() - - inputs := []struct { - In, Out string - }{ - {"", "cat"}, - {"bobo", "cat"}, - {"c", "cat"}, - {"car", "cat"}, - {"catering", "doge"}, - {"w", "wallace"}, - {"wallace123", ""}, +func TestReset(t *testing.T) { + trie := NewEmpty() + vals := []struct{ k, v string }{ + {"do", "verb"}, + {"ether", "wookiedoo"}, + {"horse", "stallion"}, } - - for _, test := range inputs { - res := string(it.Next(test.In)) - c.Assert(res, checker.Equals, test.Out) + for _, val := range vals { + trie.UpdateString(val.k, val.v) } -} + trie.Commit() -func (s *TrieSuite) TestBeginsWith(c *checker.C) { - a := CompactDecode("hello") - b := CompactDecode("hel") - - c.Assert(BeginsWith(a, b), checker.Equals, false) - c.Assert(BeginsWith(b, a), checker.Equals, true) -} + before := ethutil.CopyBytes(trie.roothash) + trie.UpdateString("should", "revert") + trie.Hash() + // Should have no effect + trie.Hash() + trie.Hash() + // ### -func (s *TrieSuite) TestItems(c *checker.C) { - s.trie.Update("A", "aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa") - exp := "d23786fb4a010da3ce639d66d5e904a11dbc02746d1ce25029e53290cabf28ab" + trie.Reset() + after := ethutil.CopyBytes(trie.roothash) - c.Assert(s.trie.GetRoot(), checker.DeepEquals, ethutil.Hex2Bytes(exp)) + if !bytes.Equal(before, after) { + t.Errorf("expected roots to be equal. %x - %x", before, after) + } } -func TestOtherSomething(t *testing.T) { - _, trie := NewTrie() +func TestParanoia(t *testing.T) { + t.Skip() + trie := NewEmpty() vals := []struct{ k, v string }{ {"do", "verb"}, @@ -350,20 +193,40 @@ func TestOtherSomething(t *testing.T) { {"ether", ""}, {"dog", "puppy"}, {"shaman", ""}, + {"somethingveryoddindeedthis is", "myothernodedata"}, } for _, val := range vals { - trie.Update(val.k, val.v) + trie.UpdateString(val.k, val.v) } + trie.Commit() - exp := ethutil.Hex2Bytes("5991bb8c6514148a29db676a14ac506cd2cd5775ace63c30a4fe457715e9ac84") - hash := trie.Root.([]byte) - if !bytes.Equal(hash, exp) { - t.Errorf("expected %x got %x", exp, hash) + ok, t2 := ParanoiaCheck(trie, trie.cache.backend) + if !ok { + t.Errorf("trie paranoia check failed %x %x", trie.roothash, t2.roothash) } } +// Not an actual test +func TestOutput(t *testing.T) { + t.Skip() + + base := "aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa" + trie := NewEmpty() + for i := 0; i < 50; i++ { + trie.UpdateString(fmt.Sprintf("%s%d", base, i), "valueeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeee") + } + fmt.Println("############################## FULL ################################") + fmt.Println(trie.root) + + trie.Commit() + fmt.Println("############################## SMALL ################################") + trie2 := New(trie.roothash, trie.cache.backend) + trie2.GetString(base + "20") + fmt.Println(trie2.root) +} + func BenchmarkGets(b *testing.B) { - _, trie := NewTrie() + trie := NewEmpty() vals := []struct{ k, v string }{ {"do", "verb"}, {"ether", "wookiedoo"}, @@ -376,21 +239,21 @@ func BenchmarkGets(b *testing.B) { {"somethingveryoddindeedthis is", "myothernodedata"}, } for _, val := range vals { - trie.Update(val.k, val.v) + trie.UpdateString(val.k, val.v) } b.ResetTimer() for i := 0; i < b.N; i++ { - trie.Get("horse") + trie.Get([]byte("horse")) } } func BenchmarkUpdate(b *testing.B) { - _, trie := NewTrie() + trie := NewEmpty() b.ResetTimer() for i := 0; i < b.N; i++ { - trie.Update(fmt.Sprintf("aaaaaaaaaaaaaaa%d", i), "value") + trie.UpdateString(fmt.Sprintf("aaaaaaaaa%d", i), "value") } + trie.Hash() } -*/ diff --git a/trie/valuenode.go b/trie/valuenode.go new file mode 100644 index 000000000..689befb2a --- /dev/null +++ b/trie/valuenode.go @@ -0,0 +1,13 @@ +package trie + +type ValueNode struct { + trie *Trie + data []byte +} + +func (self *ValueNode) Value() Node { return self } // Best not to call :-) +func (self *ValueNode) Val() []byte { return self.data } +func (self *ValueNode) Dirty() bool { return true } +func (self *ValueNode) Copy() Node { return &ValueNode{self.trie, self.data} } +func (self *ValueNode) RlpData() interface{} { return self.data } +func (self *ValueNode) Hash() interface{} { return self.data } -- cgit