aboutsummaryrefslogtreecommitdiffstats
path: root/trie
diff options
context:
space:
mode:
Diffstat (limited to 'trie')
-rw-r--r--trie/cache.go50
-rw-r--r--trie/encoding.go76
-rw-r--r--trie/encoding_test.go59
-rw-r--r--trie/fullnode.go77
-rw-r--r--trie/hashnode.go25
-rw-r--r--trie/iterator.go124
-rw-r--r--trie/iterator_test.go33
-rw-r--r--trie/node.go44
-rw-r--r--trie/secure_trie.go36
-rw-r--r--trie/shortnode.go35
-rw-r--r--trie/slice.go53
-rw-r--r--trie/trie.go345
-rw-r--r--trie/trie_test.go329
-rw-r--r--trie/valuenode.go15
14 files changed, 1301 insertions, 0 deletions
diff --git a/trie/cache.go b/trie/cache.go
new file mode 100644
index 000000000..2143785fa
--- /dev/null
+++ b/trie/cache.go
@@ -0,0 +1,50 @@
+package trie
+
+type Backend interface {
+ Get([]byte) ([]byte, error)
+ Put([]byte, []byte)
+}
+
+type Cache struct {
+ store map[string][]byte
+ backend Backend
+}
+
+func NewCache(backend Backend) *Cache {
+ return &Cache{make(map[string][]byte), backend}
+}
+
+func (self *Cache) Get(key []byte) []byte {
+ data := self.store[string(key)]
+ if data == nil {
+ data, _ = self.backend.Get(key)
+ }
+
+ return data
+}
+
+func (self *Cache) Put(key []byte, data []byte) {
+ self.store[string(key)] = data
+}
+
+func (self *Cache) Flush() {
+ for k, v := range self.store {
+ self.backend.Put([]byte(k), v)
+ }
+
+ // This will eventually grow too large. We'd could
+ // do a make limit on storage and push out not-so-popular nodes.
+ //self.Reset()
+}
+
+func (self *Cache) Copy() *Cache {
+ cache := NewCache(self.backend)
+ for k, v := range self.store {
+ cache.store[k] = v
+ }
+ return cache
+}
+
+func (self *Cache) Reset() {
+ //self.store = make(map[string][]byte)
+}
diff --git a/trie/encoding.go b/trie/encoding.go
new file mode 100644
index 000000000..5c42c556f
--- /dev/null
+++ b/trie/encoding.go
@@ -0,0 +1,76 @@
+package trie
+
+import (
+ "bytes"
+ "encoding/hex"
+ "strings"
+)
+
+func CompactEncode(hexSlice []byte) string {
+ terminator := 0
+ if hexSlice[len(hexSlice)-1] == 16 {
+ terminator = 1
+ }
+
+ if terminator == 1 {
+ hexSlice = hexSlice[:len(hexSlice)-1]
+ }
+
+ oddlen := len(hexSlice) % 2
+ flags := byte(2*terminator + oddlen)
+ if oddlen != 0 {
+ hexSlice = append([]byte{flags}, hexSlice...)
+ } else {
+ hexSlice = append([]byte{flags, 0}, hexSlice...)
+ }
+
+ var buff bytes.Buffer
+ for i := 0; i < len(hexSlice); i += 2 {
+ buff.WriteByte(byte(16*hexSlice[i] + hexSlice[i+1]))
+ }
+
+ return buff.String()
+}
+
+func CompactDecode(str string) []byte {
+ base := CompactHexDecode(str)
+ base = base[:len(base)-1]
+ if base[0] >= 2 {
+ base = append(base, 16)
+ }
+ if base[0]%2 == 1 {
+ base = base[1:]
+ } else {
+ base = base[2:]
+ }
+
+ return base
+}
+
+func CompactHexDecode(str string) []byte {
+ base := "0123456789abcdef"
+ var hexSlice []byte
+
+ enc := hex.EncodeToString([]byte(str))
+ for _, v := range enc {
+ hexSlice = append(hexSlice, byte(strings.IndexByte(base, byte(v))))
+ }
+ hexSlice = append(hexSlice, 16)
+
+ return hexSlice
+}
+
+func DecodeCompact(key []byte) string {
+ const base = "0123456789abcdef"
+ var str string
+
+ for _, v := range key {
+ if v < 16 {
+ str += string(base[v])
+ }
+ }
+
+ res, _ := hex.DecodeString(str)
+
+ return string(res)
+}
diff --git a/trie/encoding_test.go b/trie/encoding_test.go
new file mode 100644
index 000000000..193c898f3
--- /dev/null
+++ b/trie/encoding_test.go
@@ -0,0 +1,59 @@
+package trie
+
+import (
+ checker "gopkg.in/check.v1"
+)
+
+type TrieEncodingSuite struct{}
+
+var _ = checker.Suite(&TrieEncodingSuite{})
+
+func (s *TrieEncodingSuite) TestCompactEncode(c *checker.C) {
+ // even compact encode
+ test1 := []byte{1, 2, 3, 4, 5}
+ res1 := CompactEncode(test1)
+ c.Assert(res1, checker.Equals, "\x11\x23\x45")
+
+ // odd compact encode
+ test2 := []byte{0, 1, 2, 3, 4, 5}
+ res2 := CompactEncode(test2)
+ c.Assert(res2, checker.Equals, "\x00\x01\x23\x45")
+
+ //odd terminated compact encode
+ test3 := []byte{0, 15, 1, 12, 11, 8 /*term*/, 16}
+ res3 := CompactEncode(test3)
+ c.Assert(res3, checker.Equals, "\x20\x0f\x1c\xb8")
+
+ // even terminated compact encode
+ test4 := []byte{15, 1, 12, 11, 8 /*term*/, 16}
+ res4 := CompactEncode(test4)
+ c.Assert(res4, checker.Equals, "\x3f\x1c\xb8")
+}
+
+func (s *TrieEncodingSuite) TestCompactHexDecode(c *checker.C) {
+ exp := []byte{7, 6, 6, 5, 7, 2, 6, 2, 16}
+ res := CompactHexDecode("verb")
+ c.Assert(res, checker.DeepEquals, exp)
+}
+
+func (s *TrieEncodingSuite) TestCompactDecode(c *checker.C) {
+ // odd compact decode
+ exp := []byte{1, 2, 3, 4, 5}
+ res := CompactDecode("\x11\x23\x45")
+ c.Assert(res, checker.DeepEquals, exp)
+
+ // even compact decode
+ exp = []byte{0, 1, 2, 3, 4, 5}
+ res = CompactDecode("\x00\x01\x23\x45")
+ c.Assert(res, checker.DeepEquals, exp)
+
+ // even terminated compact decode
+ exp = []byte{0, 15, 1, 12, 11, 8 /*term*/, 16}
+ res = CompactDecode("\x20\x0f\x1c\xb8")
+ c.Assert(res, checker.DeepEquals, exp)
+
+ // even terminated compact decode
+ exp = []byte{15, 1, 12, 11, 8 /*term*/, 16}
+ res = CompactDecode("\x3f\x1c\xb8")
+ c.Assert(res, checker.DeepEquals, exp)
+}
diff --git a/trie/fullnode.go b/trie/fullnode.go
new file mode 100644
index 000000000..522fdb373
--- /dev/null
+++ b/trie/fullnode.go
@@ -0,0 +1,77 @@
+package trie
+
+import "fmt"
+
+type FullNode struct {
+ trie *Trie
+ nodes [17]Node
+}
+
+func NewFullNode(t *Trie) *FullNode {
+ return &FullNode{trie: t}
+}
+
+func (self *FullNode) Dirty() bool { return true }
+func (self *FullNode) Value() Node {
+ self.nodes[16] = self.trie.trans(self.nodes[16])
+ return self.nodes[16]
+}
+func (self *FullNode) Branches() []Node {
+ return self.nodes[:16]
+}
+
+func (self *FullNode) Copy(t *Trie) Node {
+ nnode := NewFullNode(t)
+ for i, node := range self.nodes {
+ if node != nil {
+ nnode.nodes[i] = node.Copy(t)
+ }
+ }
+
+ return nnode
+}
+
+// Returns the length of non-nil nodes
+func (self *FullNode) Len() (amount int) {
+ for _, node := range self.nodes {
+ if node != nil {
+ amount++
+ }
+ }
+
+ return
+}
+
+func (self *FullNode) Hash() interface{} {
+ return self.trie.store(self)
+}
+
+func (self *FullNode) RlpData() interface{} {
+ t := make([]interface{}, 17)
+ for i, node := range self.nodes {
+ if node != nil {
+ t[i] = node.Hash()
+ } else {
+ t[i] = ""
+ }
+ }
+
+ return t
+}
+
+func (self *FullNode) set(k byte, value Node) {
+ if _, ok := value.(*ValueNode); ok && k != 16 {
+ fmt.Println(value, k)
+ }
+
+ self.nodes[int(k)] = value
+}
+
+func (self *FullNode) branch(i byte) Node {
+ if self.nodes[int(i)] != nil {
+ self.nodes[int(i)] = self.trie.trans(self.nodes[int(i)])
+
+ return self.nodes[int(i)]
+ }
+ return nil
+}
diff --git a/trie/hashnode.go b/trie/hashnode.go
new file mode 100644
index 000000000..8125cc3c9
--- /dev/null
+++ b/trie/hashnode.go
@@ -0,0 +1,25 @@
+package trie
+
+import "github.com/ethereum/go-ethereum/common"
+
+type HashNode struct {
+ key []byte
+ trie *Trie
+}
+
+func NewHash(key []byte, trie *Trie) *HashNode {
+ return &HashNode{key, trie}
+}
+
+func (self *HashNode) RlpData() interface{} {
+ return self.key
+}
+
+func (self *HashNode) Hash() interface{} {
+ return self.key
+}
+
+// These methods will never be called but we have to satisfy Node interface
+func (self *HashNode) Value() Node { return nil }
+func (self *HashNode) Dirty() bool { return true }
+func (self *HashNode) Copy(t *Trie) Node { return NewHash(common.CopyBytes(self.key), t) }
diff --git a/trie/iterator.go b/trie/iterator.go
new file mode 100644
index 000000000..fda7c6cbe
--- /dev/null
+++ b/trie/iterator.go
@@ -0,0 +1,124 @@
+package trie
+
+import (
+ "bytes"
+)
+
+type Iterator struct {
+ trie *Trie
+
+ Key []byte
+ Value []byte
+}
+
+func NewIterator(trie *Trie) *Iterator {
+ return &Iterator{trie: trie, Key: nil}
+}
+
+func (self *Iterator) Next() bool {
+ self.trie.mu.Lock()
+ defer self.trie.mu.Unlock()
+
+ isIterStart := false
+ if self.Key == nil {
+ isIterStart = true
+ self.Key = make([]byte, 32)
+ }
+
+ key := RemTerm(CompactHexDecode(string(self.Key)))
+ k := self.next(self.trie.root, key, isIterStart)
+
+ self.Key = []byte(DecodeCompact(k))
+
+ return len(k) > 0
+}
+
+func (self *Iterator) next(node Node, key []byte, isIterStart bool) []byte {
+ if node == nil {
+ return nil
+ }
+
+ switch node := node.(type) {
+ case *FullNode:
+ if len(key) > 0 {
+ k := self.next(node.branch(key[0]), key[1:], isIterStart)
+ if k != nil {
+ return append([]byte{key[0]}, k...)
+ }
+ }
+
+ var r byte
+ if len(key) > 0 {
+ r = key[0] + 1
+ }
+
+ for i := r; i < 16; i++ {
+ k := self.key(node.branch(byte(i)))
+ if k != nil {
+ return append([]byte{i}, k...)
+ }
+ }
+
+ case *ShortNode:
+ k := RemTerm(node.Key())
+ if vnode, ok := node.Value().(*ValueNode); ok {
+ switch bytes.Compare([]byte(k), key) {
+ case 0:
+ if isIterStart {
+ self.Value = vnode.Val()
+ return k
+ }
+ case 1:
+ self.Value = vnode.Val()
+ return k
+ }
+ } else {
+ cnode := node.Value()
+
+ var ret []byte
+ skey := key[len(k):]
+ if BeginsWith(key, k) {
+ ret = self.next(cnode, skey, isIterStart)
+ } else if bytes.Compare(k, key[:len(k)]) > 0 {
+ return self.key(node)
+ }
+
+ if ret != nil {
+ return append(k, ret...)
+ }
+ }
+ }
+
+ return nil
+}
+
+func (self *Iterator) key(node Node) []byte {
+ switch node := node.(type) {
+ case *ShortNode:
+ // Leaf node
+ if vnode, ok := node.Value().(*ValueNode); ok {
+ k := RemTerm(node.Key())
+ self.Value = vnode.Val()
+
+ return k
+ } else {
+ k := RemTerm(node.Key())
+ return append(k, self.key(node.Value())...)
+ }
+ case *FullNode:
+ if node.Value() != nil {
+ self.Value = node.Value().(*ValueNode).Val()
+
+ return []byte{16}
+ }
+
+ for i := 0; i < 16; i++ {
+ k := self.key(node.branch(byte(i)))
+ if k != nil {
+ return append([]byte{byte(i)}, k...)
+ }
+ }
+ }
+
+ return nil
+}
diff --git a/trie/iterator_test.go b/trie/iterator_test.go
new file mode 100644
index 000000000..74d9e903c
--- /dev/null
+++ b/trie/iterator_test.go
@@ -0,0 +1,33 @@
+package trie
+
+import "testing"
+
+func TestIterator(t *testing.T) {
+ trie := NewEmpty()
+ vals := []struct{ k, v string }{
+ {"do", "verb"},
+ {"ether", "wookiedoo"},
+ {"horse", "stallion"},
+ {"shaman", "horse"},
+ {"doge", "coin"},
+ {"dog", "puppy"},
+ {"somethingveryoddindeedthis is", "myothernodedata"},
+ }
+ v := make(map[string]bool)
+ for _, val := range vals {
+ v[val.k] = false
+ trie.UpdateString(val.k, val.v)
+ }
+ trie.Commit()
+
+ it := trie.Iterator()
+ for it.Next() {
+ v[string(it.Key)] = true
+ }
+
+ for k, found := range v {
+ if !found {
+ t.Error("iterator didn't find", k)
+ }
+ }
+}
diff --git a/trie/node.go b/trie/node.go
new file mode 100644
index 000000000..0d8a7cff9
--- /dev/null
+++ b/trie/node.go
@@ -0,0 +1,44 @@
+package trie
+
+import "fmt"
+
+var indices = []string{"0", "1", "2", "3", "4", "5", "6", "7", "8", "9", "a", "b", "c", "d", "e", "f", "[17]"}
+
+type Node interface {
+ Value() Node
+ Copy(*Trie) Node // All nodes, for now, return them self
+ Dirty() bool
+ fstring(string) string
+ Hash() interface{}
+ RlpData() interface{}
+}
+
+// Value node
+func (self *ValueNode) String() string { return self.fstring("") }
+func (self *FullNode) String() string { return self.fstring("") }
+func (self *ShortNode) String() string { return self.fstring("") }
+func (self *ValueNode) fstring(ind string) string { return fmt.Sprintf("%x ", self.data) }
+
+//func (self *HashNode) fstring(ind string) string { return fmt.Sprintf("< %x > ", self.key) }
+func (self *HashNode) fstring(ind string) string {
+ return fmt.Sprintf("%v", self.trie.trans(self))
+}
+
+// Full node
+func (self *FullNode) fstring(ind string) string {
+ resp := fmt.Sprintf("[\n%s ", ind)
+ for i, node := range self.nodes {
+ if node == nil {
+ resp += fmt.Sprintf("%s: <nil> ", indices[i])
+ } else {
+ resp += fmt.Sprintf("%s: %v", indices[i], node.fstring(ind+" "))
+ }
+ }
+
+ return resp + fmt.Sprintf("\n%s] ", ind)
+}
+
+// Short node
+func (self *ShortNode) fstring(ind string) string {
+ return fmt.Sprintf("[ %x: %v ] ", self.key, self.value.fstring(ind+" "))
+}
diff --git a/trie/secure_trie.go b/trie/secure_trie.go
new file mode 100644
index 000000000..b9fa376b8
--- /dev/null
+++ b/trie/secure_trie.go
@@ -0,0 +1,36 @@
+package trie
+
+import "github.com/ethereum/go-ethereum/crypto"
+
+type SecureTrie struct {
+ *Trie
+}
+
+func NewSecure(root []byte, backend Backend) *SecureTrie {
+ return &SecureTrie{New(root, backend)}
+}
+
+func (self *SecureTrie) Update(key, value []byte) Node {
+ return self.Trie.Update(crypto.Sha3(key), value)
+}
+func (self *SecureTrie) UpdateString(key, value string) Node {
+ return self.Update([]byte(key), []byte(value))
+}
+
+func (self *SecureTrie) Get(key []byte) []byte {
+ return self.Trie.Get(crypto.Sha3(key))
+}
+func (self *SecureTrie) GetString(key string) []byte {
+ return self.Get([]byte(key))
+}
+
+func (self *SecureTrie) Delete(key []byte) Node {
+ return self.Trie.Delete(crypto.Sha3(key))
+}
+func (self *SecureTrie) DeleteString(key string) Node {
+ return self.Delete([]byte(key))
+}
+
+func (self *SecureTrie) Copy() *SecureTrie {
+ return &SecureTrie{self.Trie.Copy()}
+}
diff --git a/trie/shortnode.go b/trie/shortnode.go
new file mode 100644
index 000000000..edd490b4d
--- /dev/null
+++ b/trie/shortnode.go
@@ -0,0 +1,35 @@
+package trie
+
+import "github.com/ethereum/go-ethereum/common"
+
+type ShortNode struct {
+ trie *Trie
+ key []byte
+ value Node
+}
+
+func NewShortNode(t *Trie, key []byte, value Node) *ShortNode {
+ return &ShortNode{t, []byte(CompactEncode(key)), value}
+}
+func (self *ShortNode) Value() Node {
+ self.value = self.trie.trans(self.value)
+
+ return self.value
+}
+func (self *ShortNode) Dirty() bool { return true }
+func (self *ShortNode) Copy(t *Trie) Node {
+ node := &ShortNode{t, nil, self.value.Copy(t)}
+ node.key = common.CopyBytes(self.key)
+ return node
+}
+
+func (self *ShortNode) RlpData() interface{} {
+ return []interface{}{self.key, self.value.Hash()}
+}
+func (self *ShortNode) Hash() interface{} {
+ return self.trie.store(self)
+}
+
+func (self *ShortNode) Key() []byte {
+ return CompactDecode(string(self.key))
+}
diff --git a/trie/slice.go b/trie/slice.go
new file mode 100644
index 000000000..f53b6c749
--- /dev/null
+++ b/trie/slice.go
@@ -0,0 +1,53 @@
+package trie
+
+import (
+ "bytes"
+ "math"
+)
+
+// Helper function for comparing slices
+func CompareIntSlice(a, b []int) bool {
+ if len(a) != len(b) {
+ return false
+ }
+ for i, v := range a {
+ if v != b[i] {
+ return false
+ }
+ }
+ return true
+}
+
+// Returns the amount of nibbles that match each other from 0 ...
+func MatchingNibbleLength(a, b []byte) int {
+ var i, length = 0, int(math.Min(float64(len(a)), float64(len(b))))
+
+ for i < length {
+ if a[i] != b[i] {
+ break
+ }
+ i++
+ }
+
+ return i
+}
+
+func HasTerm(s []byte) bool {
+ return s[len(s)-1] == 16
+}
+
+func RemTerm(s []byte) []byte {
+ if HasTerm(s) {
+ return s[:len(s)-1]
+ }
+
+ return s
+}
+
+func BeginsWith(a, b []byte) bool {
+ if len(b) > len(a) {
+ return false
+ }
+
+ return bytes.Equal(a[:len(b)], b)
+}
diff --git a/trie/trie.go b/trie/trie.go
new file mode 100644
index 000000000..1c1112a7f
--- /dev/null
+++ b/trie/trie.go
@@ -0,0 +1,345 @@
+package trie
+
+import (
+ "bytes"
+ "container/list"
+ "fmt"
+ "sync"
+
+ "github.com/ethereum/go-ethereum/crypto"
+ "github.com/ethereum/go-ethereum/common"
+)
+
+func ParanoiaCheck(t1 *Trie, backend Backend) (bool, *Trie) {
+ t2 := New(nil, backend)
+
+ it := t1.Iterator()
+ for it.Next() {
+ t2.Update(it.Key, it.Value)
+ }
+
+ return bytes.Equal(t2.Hash(), t1.Hash()), t2
+}
+
+type Trie struct {
+ mu sync.Mutex
+ root Node
+ roothash []byte
+ cache *Cache
+
+ revisions *list.List
+}
+
+func New(root []byte, backend Backend) *Trie {
+ trie := &Trie{}
+ trie.revisions = list.New()
+ trie.roothash = root
+ if backend != nil {
+ trie.cache = NewCache(backend)
+ }
+
+ if root != nil {
+ value := common.NewValueFromBytes(trie.cache.Get(root))
+ trie.root = trie.mknode(value)
+ }
+
+ return trie
+}
+
+func (self *Trie) Iterator() *Iterator {
+ return NewIterator(self)
+}
+
+func (self *Trie) Copy() *Trie {
+ cpy := make([]byte, 32)
+ copy(cpy, self.roothash)
+ trie := New(nil, nil)
+ trie.cache = self.cache.Copy()
+ if self.root != nil {
+ trie.root = self.root.Copy(trie)
+ }
+
+ return trie
+}
+
+// Legacy support
+func (self *Trie) Root() []byte { return self.Hash() }
+func (self *Trie) Hash() []byte {
+ var hash []byte
+ if self.root != nil {
+ t := self.root.Hash()
+ if byts, ok := t.([]byte); ok && len(byts) > 0 {
+ hash = byts
+ } else {
+ hash = crypto.Sha3(common.Encode(self.root.RlpData()))
+ }
+ } else {
+ hash = crypto.Sha3(common.Encode(""))
+ }
+
+ if !bytes.Equal(hash, self.roothash) {
+ self.revisions.PushBack(self.roothash)
+ self.roothash = hash
+ }
+
+ return hash
+}
+func (self *Trie) Commit() {
+ self.mu.Lock()
+ defer self.mu.Unlock()
+
+ // Hash first
+ self.Hash()
+
+ self.cache.Flush()
+}
+
+// Reset should only be called if the trie has been hashed
+func (self *Trie) Reset() {
+ self.mu.Lock()
+ defer self.mu.Unlock()
+
+ self.cache.Reset()
+
+ if self.revisions.Len() > 0 {
+ revision := self.revisions.Remove(self.revisions.Back()).([]byte)
+ self.roothash = revision
+ }
+ value := common.NewValueFromBytes(self.cache.Get(self.roothash))
+ self.root = self.mknode(value)
+}
+
+func (self *Trie) UpdateString(key, value string) Node { return self.Update([]byte(key), []byte(value)) }
+func (self *Trie) Update(key, value []byte) Node {
+ self.mu.Lock()
+ defer self.mu.Unlock()
+
+ k := CompactHexDecode(string(key))
+
+ if len(value) != 0 {
+ self.root = self.insert(self.root, k, &ValueNode{self, value})
+ } else {
+ self.root = self.delete(self.root, k)
+ }
+
+ return self.root
+}
+
+func (self *Trie) GetString(key string) []byte { return self.Get([]byte(key)) }
+func (self *Trie) Get(key []byte) []byte {
+ self.mu.Lock()
+ defer self.mu.Unlock()
+
+ k := CompactHexDecode(string(key))
+
+ n := self.get(self.root, k)
+ if n != nil {
+ return n.(*ValueNode).Val()
+ }
+
+ return nil
+}
+
+func (self *Trie) DeleteString(key string) Node { return self.Delete([]byte(key)) }
+func (self *Trie) Delete(key []byte) Node {
+ self.mu.Lock()
+ defer self.mu.Unlock()
+
+ k := CompactHexDecode(string(key))
+ self.root = self.delete(self.root, k)
+
+ return self.root
+}
+
+func (self *Trie) insert(node Node, key []byte, value Node) Node {
+ if len(key) == 0 {
+ return value
+ }
+
+ if node == nil {
+ return NewShortNode(self, key, value)
+ }
+
+ switch node := node.(type) {
+ case *ShortNode:
+ k := node.Key()
+ cnode := node.Value()
+ if bytes.Equal(k, key) {
+ return NewShortNode(self, key, value)
+ }
+
+ var n Node
+ matchlength := MatchingNibbleLength(key, k)
+ if matchlength == len(k) {
+ n = self.insert(cnode, key[matchlength:], value)
+ } else {
+ pnode := self.insert(nil, k[matchlength+1:], cnode)
+ nnode := self.insert(nil, key[matchlength+1:], value)
+ fulln := NewFullNode(self)
+ fulln.set(k[matchlength], pnode)
+ fulln.set(key[matchlength], nnode)
+ n = fulln
+ }
+ if matchlength == 0 {
+ return n
+ }
+
+ return NewShortNode(self, key[:matchlength], n)
+
+ case *FullNode:
+ cpy := node.Copy(self).(*FullNode)
+ cpy.set(key[0], self.insert(node.branch(key[0]), key[1:], value))
+
+ return cpy
+
+ default:
+ panic(fmt.Sprintf("%T: invalid node: %v", node, node))
+ }
+}
+
+func (self *Trie) get(node Node, key []byte) Node {
+ if len(key) == 0 {
+ return node
+ }
+
+ if node == nil {
+ return nil
+ }
+
+ switch node := node.(type) {
+ case *ShortNode:
+ k := node.Key()
+ cnode := node.Value()
+
+ if len(key) >= len(k) && bytes.Equal(k, key[:len(k)]) {
+ return self.get(cnode, key[len(k):])
+ }
+
+ return nil
+ case *FullNode:
+ return self.get(node.branch(key[0]), key[1:])
+ default:
+ panic(fmt.Sprintf("%T: invalid node: %v", node, node))
+ }
+}
+
+func (self *Trie) delete(node Node, key []byte) Node {
+ if len(key) == 0 && node == nil {
+ return nil
+ }
+
+ switch node := node.(type) {
+ case *ShortNode:
+ k := node.Key()
+ cnode := node.Value()
+ if bytes.Equal(key, k) {
+ return nil
+ } else if bytes.Equal(key[:len(k)], k) {
+ child := self.delete(cnode, key[len(k):])
+
+ var n Node
+ switch child := child.(type) {
+ case *ShortNode:
+ nkey := append(k, child.Key()...)
+ n = NewShortNode(self, nkey, child.Value())
+ case *FullNode:
+ sn := NewShortNode(self, node.Key(), child)
+ sn.key = node.key
+ n = sn
+ }
+
+ return n
+ } else {
+ return node
+ }
+
+ case *FullNode:
+ n := node.Copy(self).(*FullNode)
+ n.set(key[0], self.delete(n.branch(key[0]), key[1:]))
+
+ pos := -1
+ for i := 0; i < 17; i++ {
+ if n.branch(byte(i)) != nil {
+ if pos == -1 {
+ pos = i
+ } else {
+ pos = -2
+ }
+ }
+ }
+
+ var nnode Node
+ if pos == 16 {
+ nnode = NewShortNode(self, []byte{16}, n.branch(byte(pos)))
+ } else if pos >= 0 {
+ cnode := n.branch(byte(pos))
+ switch cnode := cnode.(type) {
+ case *ShortNode:
+ // Stitch keys
+ k := append([]byte{byte(pos)}, cnode.Key()...)
+ nnode = NewShortNode(self, k, cnode.Value())
+ case *FullNode:
+ nnode = NewShortNode(self, []byte{byte(pos)}, n.branch(byte(pos)))
+ }
+ } else {
+ nnode = n
+ }
+
+ return nnode
+ case nil:
+ return nil
+ default:
+ panic(fmt.Sprintf("%T: invalid node: %v (%v)", node, node, key))
+ }
+}
+
+// casting functions and cache storing
+func (self *Trie) mknode(value *common.Value) Node {
+ l := value.Len()
+ switch l {
+ case 0:
+ return nil
+ case 2:
+ // A value node may consists of 2 bytes.
+ if value.Get(0).Len() != 0 {
+ return NewShortNode(self, CompactDecode(string(value.Get(0).Bytes())), self.mknode(value.Get(1)))
+ }
+ case 17:
+ fnode := NewFullNode(self)
+ for i := 0; i < l; i++ {
+ fnode.set(byte(i), self.mknode(value.Get(i)))
+ }
+ return fnode
+ case 32:
+ return &HashNode{value.Bytes(), self}
+ }
+
+ return &ValueNode{self, value.Bytes()}
+}
+
+func (self *Trie) trans(node Node) Node {
+ switch node := node.(type) {
+ case *HashNode:
+ value := common.NewValueFromBytes(self.cache.Get(node.key))
+ return self.mknode(value)
+ default:
+ return node
+ }
+}
+
+func (self *Trie) store(node Node) interface{} {
+ data := common.Encode(node)
+ if len(data) >= 32 {
+ key := crypto.Sha3(data)
+ self.cache.Put(key, data)
+
+ return key
+ }
+
+ return node.RlpData()
+}
+
+func (self *Trie) PrintRoot() {
+ fmt.Println(self.root)
+ fmt.Printf("root=%x\n", self.Root())
+}
diff --git a/trie/trie_test.go b/trie/trie_test.go
new file mode 100644
index 000000000..1393e0c97
--- /dev/null
+++ b/trie/trie_test.go
@@ -0,0 +1,329 @@
+package trie
+
+import (
+ "bytes"
+ "fmt"
+ "testing"
+
+ "github.com/ethereum/go-ethereum/crypto"
+ "github.com/ethereum/go-ethereum/common"
+)
+
+type Db map[string][]byte
+
+func (self Db) Get(k []byte) ([]byte, error) { return self[string(k)], nil }
+func (self Db) Put(k, v []byte) { self[string(k)] = v }
+
+// Used for testing
+func NewEmpty() *Trie {
+ return New(nil, make(Db))
+}
+
+func NewEmptySecure() *SecureTrie {
+ return NewSecure(nil, make(Db))
+}
+
+func TestEmptyTrie(t *testing.T) {
+ trie := NewEmpty()
+ res := trie.Hash()
+ exp := crypto.Sha3(common.Encode(""))
+ if !bytes.Equal(res, exp) {
+ t.Errorf("expected %x got %x", exp, res)
+ }
+}
+
+func TestInsert(t *testing.T) {
+ trie := NewEmpty()
+
+ trie.UpdateString("doe", "reindeer")
+ trie.UpdateString("dog", "puppy")
+ trie.UpdateString("dogglesworth", "cat")
+
+ exp := common.Hex2Bytes("8aad789dff2f538bca5d8ea56e8abe10f4c7ba3a5dea95fea4cd6e7c3a1168d3")
+ root := trie.Hash()
+ if !bytes.Equal(root, exp) {
+ t.Errorf("exp %x got %x", exp, root)
+ }
+
+ trie = NewEmpty()
+ trie.UpdateString("A", "aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa")
+
+ exp = common.Hex2Bytes("d23786fb4a010da3ce639d66d5e904a11dbc02746d1ce25029e53290cabf28ab")
+ root = trie.Hash()
+ if !bytes.Equal(root, exp) {
+ t.Errorf("exp %x got %x", exp, root)
+ }
+}
+
+func TestGet(t *testing.T) {
+ trie := NewEmpty()
+
+ trie.UpdateString("doe", "reindeer")
+ trie.UpdateString("dog", "puppy")
+ trie.UpdateString("dogglesworth", "cat")
+
+ res := trie.GetString("dog")
+ if !bytes.Equal(res, []byte("puppy")) {
+ t.Errorf("expected puppy got %x", res)
+ }
+
+ unknown := trie.GetString("unknown")
+ if unknown != nil {
+ t.Errorf("expected nil got %x", unknown)
+ }
+}
+
+func TestDelete(t *testing.T) {
+ trie := NewEmpty()
+
+ vals := []struct{ k, v string }{
+ {"do", "verb"},
+ {"ether", "wookiedoo"},
+ {"horse", "stallion"},
+ {"shaman", "horse"},
+ {"doge", "coin"},
+ {"ether", ""},
+ {"dog", "puppy"},
+ {"shaman", ""},
+ }
+ for _, val := range vals {
+ if val.v != "" {
+ trie.UpdateString(val.k, val.v)
+ } else {
+ trie.DeleteString(val.k)
+ }
+ }
+
+ hash := trie.Hash()
+ exp := common.Hex2Bytes("5991bb8c6514148a29db676a14ac506cd2cd5775ace63c30a4fe457715e9ac84")
+ if !bytes.Equal(hash, exp) {
+ t.Errorf("expected %x got %x", exp, hash)
+ }
+}
+
+func TestEmptyValues(t *testing.T) {
+ trie := NewEmpty()
+
+ vals := []struct{ k, v string }{
+ {"do", "verb"},
+ {"ether", "wookiedoo"},
+ {"horse", "stallion"},
+ {"shaman", "horse"},
+ {"doge", "coin"},
+ {"ether", ""},
+ {"dog", "puppy"},
+ {"shaman", ""},
+ }
+ for _, val := range vals {
+ trie.UpdateString(val.k, val.v)
+ }
+
+ hash := trie.Hash()
+ exp := common.Hex2Bytes("5991bb8c6514148a29db676a14ac506cd2cd5775ace63c30a4fe457715e9ac84")
+ if !bytes.Equal(hash, exp) {
+ t.Errorf("expected %x got %x", exp, hash)
+ }
+}
+
+func TestReplication(t *testing.T) {
+ trie := NewEmpty()
+ vals := []struct{ k, v string }{
+ {"do", "verb"},
+ {"ether", "wookiedoo"},
+ {"horse", "stallion"},
+ {"shaman", "horse"},
+ {"doge", "coin"},
+ {"ether", ""},
+ {"dog", "puppy"},
+ {"shaman", ""},
+ {"somethingveryoddindeedthis is", "myothernodedata"},
+ }
+ for _, val := range vals {
+ trie.UpdateString(val.k, val.v)
+ }
+ trie.Commit()
+
+ trie2 := New(trie.roothash, trie.cache.backend)
+ if string(trie2.GetString("horse")) != "stallion" {
+ t.Error("expected to have horse => stallion")
+ }
+
+ hash := trie2.Hash()
+ exp := trie.Hash()
+ if !bytes.Equal(hash, exp) {
+ t.Errorf("root failure. expected %x got %x", exp, hash)
+ }
+
+}
+
+func TestReset(t *testing.T) {
+ trie := NewEmpty()
+ vals := []struct{ k, v string }{
+ {"do", "verb"},
+ {"ether", "wookiedoo"},
+ {"horse", "stallion"},
+ }
+ for _, val := range vals {
+ trie.UpdateString(val.k, val.v)
+ }
+ trie.Commit()
+
+ before := common.CopyBytes(trie.roothash)
+ trie.UpdateString("should", "revert")
+ trie.Hash()
+ // Should have no effect
+ trie.Hash()
+ trie.Hash()
+ // ###
+
+ trie.Reset()
+ after := common.CopyBytes(trie.roothash)
+
+ if !bytes.Equal(before, after) {
+ t.Errorf("expected roots to be equal. %x - %x", before, after)
+ }
+}
+
+func TestParanoia(t *testing.T) {
+ t.Skip()
+ trie := NewEmpty()
+
+ vals := []struct{ k, v string }{
+ {"do", "verb"},
+ {"ether", "wookiedoo"},
+ {"horse", "stallion"},
+ {"shaman", "horse"},
+ {"doge", "coin"},
+ {"ether", ""},
+ {"dog", "puppy"},
+ {"shaman", ""},
+ {"somethingveryoddindeedthis is", "myothernodedata"},
+ }
+ for _, val := range vals {
+ trie.UpdateString(val.k, val.v)
+ }
+ trie.Commit()
+
+ ok, t2 := ParanoiaCheck(trie, trie.cache.backend)
+ if !ok {
+ t.Errorf("trie paranoia check failed %x %x", trie.roothash, t2.roothash)
+ }
+}
+
+// Not an actual test
+func TestOutput(t *testing.T) {
+ t.Skip()
+
+ base := "aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa"
+ trie := NewEmpty()
+ for i := 0; i < 50; i++ {
+ trie.UpdateString(fmt.Sprintf("%s%d", base, i), "valueeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeee")
+ }
+ fmt.Println("############################## FULL ################################")
+ fmt.Println(trie.root)
+
+ trie.Commit()
+ fmt.Println("############################## SMALL ################################")
+ trie2 := New(trie.roothash, trie.cache.backend)
+ trie2.GetString(base + "20")
+ fmt.Println(trie2.root)
+}
+
+func BenchmarkGets(b *testing.B) {
+ trie := NewEmpty()
+ vals := []struct{ k, v string }{
+ {"do", "verb"},
+ {"ether", "wookiedoo"},
+ {"horse", "stallion"},
+ {"shaman", "horse"},
+ {"doge", "coin"},
+ {"ether", ""},
+ {"dog", "puppy"},
+ {"shaman", ""},
+ {"somethingveryoddindeedthis is", "myothernodedata"},
+ }
+ for _, val := range vals {
+ trie.UpdateString(val.k, val.v)
+ }
+
+ b.ResetTimer()
+ for i := 0; i < b.N; i++ {
+ trie.Get([]byte("horse"))
+ }
+}
+
+func BenchmarkUpdate(b *testing.B) {
+ trie := NewEmpty()
+
+ b.ResetTimer()
+ for i := 0; i < b.N; i++ {
+ trie.UpdateString(fmt.Sprintf("aaaaaaaaa%d", i), "value")
+ }
+ trie.Hash()
+}
+
+type kv struct {
+ k, v []byte
+ t bool
+}
+
+func TestLargeData(t *testing.T) {
+ trie := NewEmpty()
+ vals := make(map[string]*kv)
+
+ for i := byte(0); i < 255; i++ {
+ value := &kv{common.LeftPadBytes([]byte{i}, 32), []byte{i}, false}
+ value2 := &kv{common.LeftPadBytes([]byte{10, i}, 32), []byte{i}, false}
+ trie.Update(value.k, value.v)
+ trie.Update(value2.k, value2.v)
+ vals[string(value.k)] = value
+ vals[string(value2.k)] = value2
+ }
+
+ it := trie.Iterator()
+ for it.Next() {
+ vals[string(it.Key)].t = true
+ }
+
+ var untouched []*kv
+ for _, value := range vals {
+ if !value.t {
+ untouched = append(untouched, value)
+ }
+ }
+
+ if len(untouched) > 0 {
+ t.Errorf("Missed %d nodes", len(untouched))
+ for _, value := range untouched {
+ t.Error(value)
+ }
+ }
+}
+
+func TestSecureDelete(t *testing.T) {
+ trie := NewEmptySecure()
+
+ vals := []struct{ k, v string }{
+ {"do", "verb"},
+ {"ether", "wookiedoo"},
+ {"horse", "stallion"},
+ {"shaman", "horse"},
+ {"doge", "coin"},
+ {"ether", ""},
+ {"dog", "puppy"},
+ {"shaman", ""},
+ }
+ for _, val := range vals {
+ if val.v != "" {
+ trie.UpdateString(val.k, val.v)
+ } else {
+ trie.DeleteString(val.k)
+ }
+ }
+
+ hash := trie.Hash()
+ exp := common.Hex2Bytes("29b235a58c3c25ab83010c327d5932bcf05324b7d6b1185e650798034783ca9d")
+ if !bytes.Equal(hash, exp) {
+ t.Errorf("expected %x got %x", exp, hash)
+ }
+}
diff --git a/trie/valuenode.go b/trie/valuenode.go
new file mode 100644
index 000000000..7bf8ff06e
--- /dev/null
+++ b/trie/valuenode.go
@@ -0,0 +1,15 @@
+package trie
+
+import "github.com/ethereum/go-ethereum/common"
+
+type ValueNode struct {
+ trie *Trie
+ data []byte
+}
+
+func (self *ValueNode) Value() Node { return self } // Best not to call :-)
+func (self *ValueNode) Val() []byte { return self.data }
+func (self *ValueNode) Dirty() bool { return true }
+func (self *ValueNode) Copy(t *Trie) Node { return &ValueNode{t, common.CopyBytes(self.data)} }
+func (self *ValueNode) RlpData() interface{} { return self.data }
+func (self *ValueNode) Hash() interface{} { return self.data }