aboutsummaryrefslogtreecommitdiffstats
path: root/ptrie
diff options
context:
space:
mode:
Diffstat (limited to 'ptrie')
-rw-r--r--ptrie/fullnode.go59
-rw-r--r--ptrie/hashnode.go22
-rw-r--r--ptrie/node.go40
-rw-r--r--ptrie/shortnode.go31
-rw-r--r--ptrie/trie.go286
-rw-r--r--ptrie/trie_test.go138
-rw-r--r--ptrie/valuenode.go13
7 files changed, 589 insertions, 0 deletions
diff --git a/ptrie/fullnode.go b/ptrie/fullnode.go
new file mode 100644
index 000000000..2b1a62789
--- /dev/null
+++ b/ptrie/fullnode.go
@@ -0,0 +1,59 @@
+package ptrie
+
+type FullNode struct {
+ trie *Trie
+ nodes [17]Node
+}
+
+func NewFullNode(t *Trie) *FullNode {
+ return &FullNode{trie: t}
+}
+
+func (self *FullNode) Dirty() bool { return true }
+func (self *FullNode) Value() Node {
+ self.nodes[16] = self.trie.trans(self.nodes[16])
+ return self.nodes[16]
+}
+
+func (self *FullNode) Copy() Node { return self }
+
+// Returns the length of non-nil nodes
+func (self *FullNode) Len() (amount int) {
+ for _, node := range self.nodes {
+ if node != nil {
+ amount++
+ }
+ }
+
+ return
+}
+
+func (self *FullNode) Hash() interface{} {
+ return self.trie.store(self)
+}
+
+func (self *FullNode) RlpData() interface{} {
+ t := make([]interface{}, 17)
+ for i, node := range self.nodes {
+ if node != nil {
+ t[i] = node.Hash()
+ } else {
+ t[i] = ""
+ }
+ }
+
+ return t
+}
+
+func (self *FullNode) set(k byte, value Node) {
+ self.nodes[int(k)] = value
+}
+
+func (self *FullNode) get(i byte) Node {
+ if self.nodes[int(i)] != nil {
+ self.nodes[int(i)] = self.trie.trans(self.nodes[int(i)])
+
+ return self.nodes[int(i)]
+ }
+ return nil
+}
diff --git a/ptrie/hashnode.go b/ptrie/hashnode.go
new file mode 100644
index 000000000..4c17569d7
--- /dev/null
+++ b/ptrie/hashnode.go
@@ -0,0 +1,22 @@
+package ptrie
+
+type HashNode struct {
+ key []byte
+}
+
+func NewHash(key []byte) *HashNode {
+ return &HashNode{key}
+}
+
+func (self *HashNode) RlpData() interface{} {
+ return self.key
+}
+
+func (self *HashNode) Hash() interface{} {
+ return self.key
+}
+
+// These methods will never be called but we have to satisfy Node interface
+func (self *HashNode) Value() Node { return nil }
+func (self *HashNode) Dirty() bool { return true }
+func (self *HashNode) Copy() Node { return self }
diff --git a/ptrie/node.go b/ptrie/node.go
new file mode 100644
index 000000000..2c85dbce7
--- /dev/null
+++ b/ptrie/node.go
@@ -0,0 +1,40 @@
+package ptrie
+
+import "fmt"
+
+var indices = []string{"0", "1", "2", "3", "4", "5", "6", "7", "8", "9", "a", "b", "c", "d", "e", "f", "[17]"}
+
+type Node interface {
+ Value() Node
+ Copy() Node // All nodes, for now, return them self
+ Dirty() bool
+ fstring(string) string
+ Hash() interface{}
+ RlpData() interface{}
+}
+
+// Value node
+func (self *ValueNode) String() string { return self.fstring("") }
+func (self *FullNode) String() string { return self.fstring("") }
+func (self *ShortNode) String() string { return self.fstring("") }
+func (self *ValueNode) fstring(ind string) string { return fmt.Sprintf("%s ", self.data) }
+func (self *HashNode) fstring(ind string) string { return fmt.Sprintf("%x ", self.key) }
+
+// Full node
+func (self *FullNode) fstring(ind string) string {
+ resp := fmt.Sprintf("[\n%s ", ind)
+ for i, node := range self.nodes {
+ if node == nil {
+ resp += fmt.Sprintf("%s: <nil> ", indices[i])
+ } else {
+ resp += fmt.Sprintf("%s: %v", indices[i], node.fstring(ind+" "))
+ }
+ }
+
+ return resp + fmt.Sprintf("\n%s] ", ind)
+}
+
+// Short node
+func (self *ShortNode) fstring(ind string) string {
+ return fmt.Sprintf("[ %s: %v ] ", self.key, self.value.fstring(ind+" "))
+}
diff --git a/ptrie/shortnode.go b/ptrie/shortnode.go
new file mode 100644
index 000000000..49319c555
--- /dev/null
+++ b/ptrie/shortnode.go
@@ -0,0 +1,31 @@
+package ptrie
+
+import "github.com/ethereum/go-ethereum/trie"
+
+type ShortNode struct {
+ trie *Trie
+ key []byte
+ value Node
+}
+
+func NewShortNode(t *Trie, key []byte, value Node) *ShortNode {
+ return &ShortNode{t, []byte(trie.CompactEncode(key)), value}
+}
+func (self *ShortNode) Value() Node {
+ self.value = self.trie.trans(self.value)
+
+ return self.value
+}
+func (self *ShortNode) Dirty() bool { return true }
+func (self *ShortNode) Copy() Node { return self }
+
+func (self *ShortNode) RlpData() interface{} {
+ return []interface{}{self.key, self.value.Hash()}
+}
+func (self *ShortNode) Hash() interface{} {
+ return self.trie.store(self)
+}
+
+func (self *ShortNode) Key() []byte {
+ return trie.CompactDecode(string(self.key))
+}
diff --git a/ptrie/trie.go b/ptrie/trie.go
new file mode 100644
index 000000000..3e642b334
--- /dev/null
+++ b/ptrie/trie.go
@@ -0,0 +1,286 @@
+package ptrie
+
+import (
+ "bytes"
+ "sync"
+
+ "github.com/ethereum/go-ethereum/crypto"
+ "github.com/ethereum/go-ethereum/ethutil"
+ "github.com/ethereum/go-ethereum/trie"
+)
+
+type Backend interface {
+ Get([]byte) []byte
+ Set([]byte, []byte)
+}
+
+type Cache map[string][]byte
+
+func (self Cache) Get(key []byte) []byte {
+ return self[string(key)]
+}
+func (self Cache) Set(key []byte, data []byte) {
+ self[string(key)] = data
+}
+
+type Trie struct {
+ mu sync.Mutex
+ root Node
+ roothash []byte
+ backend Backend
+}
+
+func NewEmpty() *Trie {
+ return &Trie{sync.Mutex{}, nil, nil, make(Cache)}
+}
+
+func New(root []byte, backend Backend) *Trie {
+ trie := &Trie{}
+ trie.roothash = root
+ trie.backend = backend
+
+ value := ethutil.NewValueFromBytes(trie.backend.Get(root))
+ trie.root = trie.mknode(value)
+
+ return trie
+}
+
+func (self *Trie) Hash() []byte {
+ var hash []byte
+ if self.root != nil {
+ t := self.root.Hash()
+ if byts, ok := t.([]byte); ok {
+ hash = byts
+ } else {
+ hash = crypto.Sha3(ethutil.Encode(self.root.RlpData()))
+ }
+ } else {
+ hash = crypto.Sha3(ethutil.Encode(self.root))
+ }
+
+ self.roothash = hash
+
+ return hash
+}
+
+func (self *Trie) UpdateString(key, value string) Node { return self.Update([]byte(key), []byte(value)) }
+func (self *Trie) Update(key, value []byte) Node {
+ self.mu.Lock()
+ defer self.mu.Unlock()
+
+ k := trie.CompactHexDecode(string(key))
+
+ if len(value) != 0 {
+ self.root = self.insert(self.root, k, &ValueNode{self, value})
+ } else {
+ self.root = self.delete(self.root, k)
+ }
+
+ return self.root
+}
+
+func (self *Trie) GetString(key string) []byte { return self.Get([]byte(key)) }
+func (self *Trie) Get(key []byte) []byte {
+ self.mu.Lock()
+ defer self.mu.Unlock()
+
+ k := trie.CompactHexDecode(string(key))
+
+ n := self.get(self.root, k)
+ if n != nil {
+ return n.(*ValueNode).Val()
+ }
+
+ return nil
+}
+
+func (self *Trie) DeleteString(key string) Node { return self.Delete([]byte(key)) }
+func (self *Trie) Delete(key []byte) Node {
+ self.mu.Lock()
+ defer self.mu.Unlock()
+
+ k := trie.CompactHexDecode(string(key))
+ self.root = self.delete(self.root, k)
+
+ return self.root
+}
+
+func (self *Trie) insert(node Node, key []byte, value Node) Node {
+ if len(key) == 0 {
+ return value
+ }
+
+ if node == nil {
+ return NewShortNode(self, key, value)
+ }
+
+ switch node := node.(type) {
+ case *ShortNode:
+ k := node.Key()
+ cnode := node.Value()
+ if bytes.Equal(k, key) {
+ return NewShortNode(self, key, value)
+ }
+
+ var n Node
+ matchlength := trie.MatchingNibbleLength(key, k)
+ if matchlength == len(k) {
+ n = self.insert(cnode, key[matchlength:], value)
+ } else {
+ pnode := self.insert(nil, k[matchlength+1:], cnode)
+ nnode := self.insert(nil, key[matchlength+1:], value)
+ fulln := NewFullNode(self)
+ fulln.set(k[matchlength], pnode)
+ fulln.set(key[matchlength], nnode)
+ n = fulln
+ }
+ if matchlength == 0 {
+ return n
+ }
+
+ return NewShortNode(self, key[:matchlength], n)
+
+ case *FullNode:
+ cpy := node.Copy().(*FullNode)
+ cpy.set(key[0], self.insert(node.get(key[0]), key[1:], value))
+
+ return cpy
+
+ default:
+ panic("Invalid node")
+ }
+}
+
+func (self *Trie) get(node Node, key []byte) Node {
+ if len(key) == 0 {
+ return node
+ }
+
+ if node == nil {
+ return nil
+ }
+
+ switch node := node.(type) {
+ case *ShortNode:
+ k := node.Key()
+ cnode := node.Value()
+
+ if len(key) >= len(k) && bytes.Equal(k, key[:len(k)]) {
+ return self.get(cnode, key[len(k):])
+ }
+
+ return nil
+ case *FullNode:
+ return self.get(node.get(key[0]), key[1:])
+ default:
+ panic("Invalid node")
+ }
+}
+
+func (self *Trie) delete(node Node, key []byte) Node {
+ if len(key) == 0 {
+ return nil
+ }
+
+ switch node := node.(type) {
+ case *ShortNode:
+ k := node.Key()
+ cnode := node.Value()
+ if bytes.Equal(key, k) {
+ return nil
+ } else if bytes.Equal(key[:len(k)], k) {
+ child := self.delete(cnode, key[len(k):])
+
+ var n Node
+ switch child := child.(type) {
+ case *ShortNode:
+ nkey := append(k, child.Key()...)
+ n = NewShortNode(self, nkey, child.Value())
+ case *FullNode:
+ n = NewShortNode(self, node.key, child)
+ }
+
+ return n
+ } else {
+ return node
+ }
+
+ case *FullNode:
+ n := node.Copy().(*FullNode)
+ n.set(key[0], self.delete(n.get(key[0]), key[1:]))
+
+ pos := -1
+ for i := 0; i < 17; i++ {
+ if n.get(byte(i)) != nil {
+ if pos == -1 {
+ pos = i
+ } else {
+ pos = -2
+ }
+ }
+ }
+
+ var nnode Node
+ if pos == 16 {
+ nnode = NewShortNode(self, []byte{16}, n.get(byte(pos)))
+ } else if pos >= 0 {
+ cnode := n.get(byte(pos))
+ switch cnode := cnode.(type) {
+ case *ShortNode:
+ // Stitch keys
+ k := append([]byte{byte(pos)}, cnode.Key()...)
+ nnode = NewShortNode(self, k, cnode.Value())
+ case *FullNode:
+ nnode = NewShortNode(self, []byte{byte(pos)}, n.get(byte(pos)))
+ }
+ } else {
+ nnode = n
+ }
+
+ return nnode
+
+ default:
+ panic("Invalid node")
+ }
+}
+
+// casting functions and cache storing
+func (self *Trie) mknode(value *ethutil.Value) Node {
+ l := value.Len()
+ switch l {
+ case 2:
+ return NewShortNode(self, trie.CompactDecode(string(value.Get(0).Bytes())), self.mknode(value.Get(1)))
+ case 17:
+ fnode := NewFullNode(self)
+ for i := 0; i < l; i++ {
+ fnode.set(byte(i), self.mknode(value.Get(i)))
+ }
+ return fnode
+ case 32:
+ return &HashNode{value.Bytes()}
+ default:
+ return &ValueNode{self, value.Bytes()}
+ }
+}
+
+func (self *Trie) trans(node Node) Node {
+ switch node := node.(type) {
+ case *HashNode:
+ value := ethutil.NewValueFromBytes(self.backend.Get(node.key))
+ return self.mknode(value)
+ default:
+ return node
+ }
+}
+
+func (self *Trie) store(node Node) interface{} {
+ data := ethutil.Encode(node)
+ if len(data) >= 32 {
+ key := crypto.Sha3(data)
+ self.backend.Set(key, data)
+
+ return key
+ }
+
+ return node.RlpData()
+}
diff --git a/ptrie/trie_test.go b/ptrie/trie_test.go
new file mode 100644
index 000000000..29380d3f0
--- /dev/null
+++ b/ptrie/trie_test.go
@@ -0,0 +1,138 @@
+package ptrie
+
+import (
+ "bytes"
+ "fmt"
+ "testing"
+
+ "github.com/ethereum/go-ethereum/ethutil"
+)
+
+func TestInsert(t *testing.T) {
+ trie := NewEmpty()
+
+ trie.UpdateString("doe", "reindeer")
+ trie.UpdateString("dog", "puppy")
+ trie.UpdateString("dogglesworth", "cat")
+
+ exp := ethutil.Hex2Bytes("8aad789dff2f538bca5d8ea56e8abe10f4c7ba3a5dea95fea4cd6e7c3a1168d3")
+ root := trie.Hash()
+ if !bytes.Equal(root, exp) {
+ t.Errorf("exp %x got %x", exp, root)
+ }
+
+ trie = NewEmpty()
+ trie.UpdateString("A", "aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa")
+
+ exp = ethutil.Hex2Bytes("d23786fb4a010da3ce639d66d5e904a11dbc02746d1ce25029e53290cabf28ab")
+ root = trie.Hash()
+ if !bytes.Equal(root, exp) {
+ t.Errorf("exp %x got %x", exp, root)
+ }
+}
+
+func TestGet(t *testing.T) {
+ trie := NewEmpty()
+
+ trie.UpdateString("doe", "reindeer")
+ trie.UpdateString("dog", "puppy")
+ trie.UpdateString("dogglesworth", "cat")
+
+ res := trie.GetString("dog")
+ if !bytes.Equal(res, []byte("puppy")) {
+ t.Errorf("expected puppy got %x", res)
+ }
+
+ unknown := trie.GetString("unknown")
+ if unknown != nil {
+ t.Errorf("expected nil got %x", unknown)
+ }
+}
+
+func TestDelete(t *testing.T) {
+ trie := NewEmpty()
+
+ vals := []struct{ k, v string }{
+ {"do", "verb"},
+ {"ether", "wookiedoo"},
+ {"horse", "stallion"},
+ {"shaman", "horse"},
+ {"doge", "coin"},
+ {"ether", ""},
+ {"dog", "puppy"},
+ {"shaman", ""},
+ }
+ for _, val := range vals {
+ trie.UpdateString(val.k, val.v)
+ }
+
+ hash := trie.Hash()
+ exp := ethutil.Hex2Bytes("5991bb8c6514148a29db676a14ac506cd2cd5775ace63c30a4fe457715e9ac84")
+ if !bytes.Equal(hash, exp) {
+ t.Errorf("expected %x got %x", exp, hash)
+ }
+}
+
+func TestReplication(t *testing.T) {
+ trie := NewEmpty()
+ vals := []struct{ k, v string }{
+ {"do", "verb"},
+ {"ether", "wookiedoo"},
+ {"horse", "stallion"},
+ {"shaman", "horse"},
+ {"doge", "coin"},
+ {"ether", ""},
+ {"dog", "puppy"},
+ {"shaman", ""},
+ {"somethingveryoddindeedthis is", "myothernodedata"},
+ }
+ for _, val := range vals {
+ trie.UpdateString(val.k, val.v)
+ }
+ trie.Hash()
+
+ trie2 := New(trie.roothash, trie.backend)
+ if string(trie2.GetString("horse")) != "stallion" {
+ t.Error("expected to have harse => stallion")
+ }
+
+ hash := trie2.Hash()
+ exp := trie.Hash()
+ if !bytes.Equal(hash, exp) {
+ t.Errorf("root failure. expected %x got %x", exp, hash)
+ }
+
+}
+
+func BenchmarkGets(b *testing.B) {
+ trie := NewEmpty()
+ vals := []struct{ k, v string }{
+ {"do", "verb"},
+ {"ether", "wookiedoo"},
+ {"horse", "stallion"},
+ {"shaman", "horse"},
+ {"doge", "coin"},
+ {"ether", ""},
+ {"dog", "puppy"},
+ {"shaman", ""},
+ {"somethingveryoddindeedthis is", "myothernodedata"},
+ }
+ for _, val := range vals {
+ trie.UpdateString(val.k, val.v)
+ }
+
+ b.ResetTimer()
+ for i := 0; i < b.N; i++ {
+ trie.Get([]byte("horse"))
+ }
+}
+
+func BenchmarkUpdate(b *testing.B) {
+ trie := NewEmpty()
+
+ b.ResetTimer()
+ for i := 0; i < b.N; i++ {
+ trie.UpdateString(fmt.Sprintf("aaaaaaaaaaaaaaa%d", j), "value")
+ }
+ trie.Hash()
+}
diff --git a/ptrie/valuenode.go b/ptrie/valuenode.go
new file mode 100644
index 000000000..c226621a7
--- /dev/null
+++ b/ptrie/valuenode.go
@@ -0,0 +1,13 @@
+package ptrie
+
+type ValueNode struct {
+ trie *Trie
+ data []byte
+}
+
+func (self *ValueNode) Value() Node { return self } // Best not to call :-)
+func (self *ValueNode) Val() []byte { return self.data }
+func (self *ValueNode) Dirty() bool { return true }
+func (self *ValueNode) Copy() Node { return self }
+func (self *ValueNode) RlpData() interface{} { return self.data }
+func (self *ValueNode) Hash() interface{} { return self.data }