diff options
-rw-r--r-- | common/types.go | 10 | ||||
-rw-r--r-- | trie/iterator.go | 14 | ||||
-rw-r--r-- | trie/iterator_test.go | 2 | ||||
-rw-r--r-- | trie/secure_trie.go | 26 | ||||
-rw-r--r-- | trie/trie.go | 58 | ||||
-rw-r--r-- | trie/trie_test.go | 51 |
6 files changed, 90 insertions, 71 deletions
diff --git a/common/types.go b/common/types.go index e0963a7c5..1fe31657b 100644 --- a/common/types.go +++ b/common/types.go @@ -27,27 +27,27 @@ func StringToAddress(s string) Address { return BytesToAddress([]byte(s)) } // Don't use the default 'String' method in case we want to overwrite // Get the string representation of the underlying hash -func (h Hash) Str() string { +func (h *Hash) Str() string { return string(h[:]) } // Sets the hash to the value of b. If b is larger than len(h) it will panic -func (h Hash) SetBytes(b []byte) { +func (h *Hash) SetBytes(b []byte) { if len(b) > len(h) { panic("unable to set bytes. too big") } // reverse loop - for i := len(b); i >= 0; i-- { + for i := len(b) - 1; i >= 0; i-- { h[i] = b[i] } } // Set string `s` to h. If s is larger than len(h) it will panic -func (h Hash) SetString(s string) { h.SetBytes([]byte(s)) } +func (h *Hash) SetString(s string) { h.SetBytes([]byte(s)) } // Sets h to other -func (h Hash) Set(other Hash) { +func (h *Hash) Set(other Hash) { for i, v := range other { h[i] = v } diff --git a/trie/iterator.go b/trie/iterator.go index fda7c6cbe..aff614f95 100644 --- a/trie/iterator.go +++ b/trie/iterator.go @@ -2,17 +2,19 @@ package trie import ( "bytes" + + "github.com/ethereum/go-ethereum/common" ) type Iterator struct { trie *Trie - Key []byte + Key common.Hash Value []byte } func NewIterator(trie *Trie) *Iterator { - return &Iterator{trie: trie, Key: nil} + return &Iterator{trie: trie} } func (self *Iterator) Next() bool { @@ -20,15 +22,15 @@ func (self *Iterator) Next() bool { defer self.trie.mu.Unlock() isIterStart := false - if self.Key == nil { + if (self.Key == common.Hash{}) { isIterStart = true - self.Key = make([]byte, 32) + //self.Key = make([]byte, 32) } - key := RemTerm(CompactHexDecode(string(self.Key))) + key := RemTerm(CompactHexDecode(self.Key.Str())) k := self.next(self.trie.root, key, isIterStart) - self.Key = []byte(DecodeCompact(k)) + self.Key = common.StringToHash(DecodeCompact(k)) return len(k) > 0 } diff --git a/trie/iterator_test.go b/trie/iterator_test.go index 74d9e903c..5f95caa68 100644 --- a/trie/iterator_test.go +++ b/trie/iterator_test.go @@ -22,7 +22,7 @@ func TestIterator(t *testing.T) { it := trie.Iterator() for it.Next() { - v[string(it.Key)] = true + v[it.Key.Str()] = true } for k, found := range v { diff --git a/trie/secure_trie.go b/trie/secure_trie.go index b9fa376b8..b31791cad 100644 --- a/trie/secure_trie.go +++ b/trie/secure_trie.go @@ -1,34 +1,38 @@ package trie -import "github.com/ethereum/go-ethereum/crypto" +import ( + "github.com/ethereum/go-ethereum/common" + "github.com/ethereum/go-ethereum/crypto" +) type SecureTrie struct { *Trie } -func NewSecure(root []byte, backend Backend) *SecureTrie { +func NewSecure(root common.Hash, backend Backend) *SecureTrie { return &SecureTrie{New(root, backend)} } -func (self *SecureTrie) Update(key, value []byte) Node { - return self.Trie.Update(crypto.Sha3(key), value) +func (self *SecureTrie) Update(key common.Hash, value []byte) Node { + return self.Trie.Update(common.BytesToHash(crypto.Sha3(key[:])), value) } + func (self *SecureTrie) UpdateString(key, value string) Node { - return self.Update([]byte(key), []byte(value)) + return self.Update(common.StringToHash(key), []byte(value)) } -func (self *SecureTrie) Get(key []byte) []byte { - return self.Trie.Get(crypto.Sha3(key)) +func (self *SecureTrie) Get(key common.Hash) []byte { + return self.Trie.Get(common.BytesToHash(crypto.Sha3(key[:]))) } func (self *SecureTrie) GetString(key string) []byte { - return self.Get([]byte(key)) + return self.Get(common.StringToHash(key)) } -func (self *SecureTrie) Delete(key []byte) Node { - return self.Trie.Delete(crypto.Sha3(key)) +func (self *SecureTrie) Delete(key common.Hash) Node { + return self.Trie.Delete(common.BytesToHash(crypto.Sha3(key[:]))) } func (self *SecureTrie) DeleteString(key string) Node { - return self.Delete([]byte(key)) + return self.Delete(common.StringToHash(key)) } func (self *SecureTrie) Copy() *SecureTrie { diff --git a/trie/trie.go b/trie/trie.go index cb1e5618f..759718400 100644 --- a/trie/trie.go +++ b/trie/trie.go @@ -11,14 +11,15 @@ import ( ) func ParanoiaCheck(t1 *Trie, backend Backend) (bool, *Trie) { - t2 := New(nil, backend) + t2 := New(common.Hash{}, backend) it := t1.Iterator() for it.Next() { t2.Update(it.Key, it.Value) } - return bytes.Equal(t2.Hash(), t1.Hash()), t2 + a, b := t2.Hash(), t1.Hash() + return bytes.Equal(a[:], b[:]), t2 } type Trie struct { @@ -38,8 +39,8 @@ func New(root common.Hash, backend Backend) *Trie { trie.cache = NewCache(backend) } - if root != nil { - value := common.NewValueFromBytes(trie.cache.Get(root)) + if (root != common.Hash{}) { + value := common.NewValueFromBytes(trie.cache.Get(root[:])) trie.root = trie.mknode(value) } @@ -51,12 +52,13 @@ func (self *Trie) Iterator() *Iterator { } func (self *Trie) Copy() *Trie { + //cpy := make([]byte, 32) + //copy(cpy, self.roothash) + // cheap copying method var cpy common.Hash - cpy.Set(self.roothash[:]) - cpy := make([]byte, 32) - copy(cpy, self.roothash) - trie := New(nil, nil) + cpy.Set(self.roothash) + trie := New(common.Hash{}, nil) trie.cache = self.cache.Copy() if self.root != nil { trie.root = self.root.Copy(trie) @@ -66,21 +68,21 @@ func (self *Trie) Copy() *Trie { } // Legacy support -func (self *Trie) Root() []byte { return self.Hash() } -func (self *Trie) Hash() []byte { - var hash []byte +func (self *Trie) Root() common.Hash { return self.Hash() } +func (self *Trie) Hash() common.Hash { + var hash common.Hash if self.root != nil { t := self.root.Hash() - if byts, ok := t.([]byte); ok && len(byts) > 0 { - hash = byts + if h, ok := t.(common.Hash); ok && (h != common.Hash{}) { + hash = h } else { - hash = crypto.Sha3(common.Encode(self.root.RlpData())) + hash = common.BytesToHash(crypto.Sha3(common.Encode(self.root.RlpData()))) } } else { - hash = crypto.Sha3(common.Encode("")) + hash = common.BytesToHash(crypto.Sha3(common.Encode(""))) } - if !bytes.Equal(hash, self.roothash) { + if hash != self.roothash { self.revisions.PushBack(self.roothash) self.roothash = hash } @@ -105,19 +107,21 @@ func (self *Trie) Reset() { self.cache.Reset() if self.revisions.Len() > 0 { - revision := self.revisions.Remove(self.revisions.Back()).([]byte) + revision := self.revisions.Remove(self.revisions.Back()).(common.Hash) self.roothash = revision } - value := common.NewValueFromBytes(self.cache.Get(self.roothash)) + value := common.NewValueFromBytes(self.cache.Get(self.roothash[:])) self.root = self.mknode(value) } -func (self *Trie) UpdateString(key, value string) Node { return self.Update([]byte(key), []byte(value)) } -func (self *Trie) Update(key, value []byte) Node { +func (self *Trie) UpdateString(key, value string) Node { + return self.Update(common.StringToHash(key), []byte(value)) +} +func (self *Trie) Update(key common.Hash, value []byte) Node { self.mu.Lock() defer self.mu.Unlock() - k := CompactHexDecode(string(key)) + k := CompactHexDecode(key.Str()) if len(value) != 0 { self.root = self.insert(self.root, k, &ValueNode{self, value}) @@ -128,12 +132,12 @@ func (self *Trie) Update(key, value []byte) Node { return self.root } -func (self *Trie) GetString(key string) []byte { return self.Get([]byte(key)) } -func (self *Trie) Get(key []byte) []byte { +func (self *Trie) GetString(key string) []byte { return self.Get(common.StringToHash(key)) } +func (self *Trie) Get(key common.Hash) []byte { self.mu.Lock() defer self.mu.Unlock() - k := CompactHexDecode(string(key)) + k := CompactHexDecode(key.Str()) n := self.get(self.root, k) if n != nil { @@ -143,12 +147,12 @@ func (self *Trie) Get(key []byte) []byte { return nil } -func (self *Trie) DeleteString(key string) Node { return self.Delete([]byte(key)) } -func (self *Trie) Delete(key []byte) Node { +func (self *Trie) DeleteString(key string) Node { return self.Delete(common.StringToHash(key)) } +func (self *Trie) Delete(key common.Hash) Node { self.mu.Lock() defer self.mu.Unlock() - k := CompactHexDecode(string(key)) + k := CompactHexDecode(key.Str()) self.root = self.delete(self.root, k) return self.root diff --git a/trie/trie_test.go b/trie/trie_test.go index 1393e0c97..f5d17c3da 100644 --- a/trie/trie_test.go +++ b/trie/trie_test.go @@ -5,8 +5,8 @@ import ( "fmt" "testing" - "github.com/ethereum/go-ethereum/crypto" "github.com/ethereum/go-ethereum/common" + "github.com/ethereum/go-ethereum/crypto" ) type Db map[string][]byte @@ -16,18 +16,18 @@ func (self Db) Put(k, v []byte) { self[string(k)] = v } // Used for testing func NewEmpty() *Trie { - return New(nil, make(Db)) + return New(common.Hash{}, make(Db)) } func NewEmptySecure() *SecureTrie { - return NewSecure(nil, make(Db)) + return NewSecure(common.Hash{}, make(Db)) } func TestEmptyTrie(t *testing.T) { trie := NewEmpty() res := trie.Hash() exp := crypto.Sha3(common.Encode("")) - if !bytes.Equal(res, exp) { + if !bytes.Equal(res[:], exp[:]) { t.Errorf("expected %x got %x", exp, res) } } @@ -41,7 +41,7 @@ func TestInsert(t *testing.T) { exp := common.Hex2Bytes("8aad789dff2f538bca5d8ea56e8abe10f4c7ba3a5dea95fea4cd6e7c3a1168d3") root := trie.Hash() - if !bytes.Equal(root, exp) { + if !bytes.Equal(root[:], exp[:]) { t.Errorf("exp %x got %x", exp, root) } @@ -50,7 +50,7 @@ func TestInsert(t *testing.T) { exp = common.Hex2Bytes("d23786fb4a010da3ce639d66d5e904a11dbc02746d1ce25029e53290cabf28ab") root = trie.Hash() - if !bytes.Equal(root, exp) { + if !bytes.Equal(root[:], exp) { t.Errorf("exp %x got %x", exp, root) } } @@ -96,7 +96,7 @@ func TestDelete(t *testing.T) { hash := trie.Hash() exp := common.Hex2Bytes("5991bb8c6514148a29db676a14ac506cd2cd5775ace63c30a4fe457715e9ac84") - if !bytes.Equal(hash, exp) { + if !bytes.Equal(hash[:], exp) { t.Errorf("expected %x got %x", exp, hash) } } @@ -120,7 +120,7 @@ func TestEmptyValues(t *testing.T) { hash := trie.Hash() exp := common.Hex2Bytes("5991bb8c6514148a29db676a14ac506cd2cd5775ace63c30a4fe457715e9ac84") - if !bytes.Equal(hash, exp) { + if !bytes.Equal(hash[:], exp) { t.Errorf("expected %x got %x", exp, hash) } } @@ -150,7 +150,7 @@ func TestReplication(t *testing.T) { hash := trie2.Hash() exp := trie.Hash() - if !bytes.Equal(hash, exp) { + if !bytes.Equal(hash[:], exp[:]) { t.Errorf("root failure. expected %x got %x", exp, hash) } @@ -168,7 +168,9 @@ func TestReset(t *testing.T) { } trie.Commit() - before := common.CopyBytes(trie.roothash) + var before common.Hash + before.Set(trie.roothash) + trie.UpdateString("should", "revert") trie.Hash() // Should have no effect @@ -177,9 +179,11 @@ func TestReset(t *testing.T) { // ### trie.Reset() - after := common.CopyBytes(trie.roothash) - if !bytes.Equal(before, after) { + var after common.Hash + after.Set(trie.roothash) + + if before != after { t.Errorf("expected roots to be equal. %x - %x", before, after) } } @@ -248,7 +252,7 @@ func BenchmarkGets(b *testing.B) { b.ResetTimer() for i := 0; i < b.N; i++ { - trie.Get([]byte("horse")) + trie.GetString("horse") } } @@ -263,8 +267,9 @@ func BenchmarkUpdate(b *testing.B) { } type kv struct { - k, v []byte - t bool + k common.Hash + v []byte + t bool } func TestLargeData(t *testing.T) { @@ -272,17 +277,21 @@ func TestLargeData(t *testing.T) { vals := make(map[string]*kv) for i := byte(0); i < 255; i++ { - value := &kv{common.LeftPadBytes([]byte{i}, 32), []byte{i}, false} - value2 := &kv{common.LeftPadBytes([]byte{10, i}, 32), []byte{i}, false} + var k1 common.Hash + k1.SetBytes([]byte{i}) + var k2 common.Hash + k2.SetBytes([]byte{10, i}) + value := &kv{k1, []byte{i}, false} + value2 := &kv{k2, []byte{i}, false} trie.Update(value.k, value.v) trie.Update(value2.k, value2.v) - vals[string(value.k)] = value - vals[string(value2.k)] = value2 + vals[value.k.Str()] = value + vals[value2.k.Str()] = value2 } it := trie.Iterator() for it.Next() { - vals[string(it.Key)].t = true + vals[it.Key.Str()].t = true } var untouched []*kv @@ -323,7 +332,7 @@ func TestSecureDelete(t *testing.T) { hash := trie.Hash() exp := common.Hex2Bytes("29b235a58c3c25ab83010c327d5932bcf05324b7d6b1185e650798034783ca9d") - if !bytes.Equal(hash, exp) { + if !bytes.Equal(hash[:], exp) { t.Errorf("expected %x got %x", exp, hash) } } |