diff options
Diffstat (limited to 'trie')
-rw-r--r-- | trie/trie.go | 32 | ||||
-rw-r--r-- | trie/trie_test.go | 7 |
2 files changed, 26 insertions, 13 deletions
diff --git a/trie/trie.go b/trie/trie.go index d5ab2035a..139e3d286 100644 --- a/trie/trie.go +++ b/trie/trie.go @@ -16,10 +16,7 @@ func ParanoiaCheck(t1 *Trie) (bool, *Trie) { t2.Update(key, v.Str()) }) - a := ethutil.NewValue(t2.Root).Bytes() - b := ethutil.NewValue(t1.Root).Bytes() - - return bytes.Compare(a, b) == 0, t2 + return bytes.Compare(t2.GetRoot(), t1.GetRoot()) == 0, t2 } func (s *Cache) Len() int { @@ -97,7 +94,7 @@ func (cache *Cache) Get(key []byte) *ethutil.Value { } }() // Create caching node - cache.nodes[string(key)] = NewNode(key, value, false) + cache.nodes[string(key)] = NewNode(key, value, true) return value } @@ -177,10 +174,12 @@ func New(db ethutil.Database, Root interface{}) *Trie { func (self *Trie) setRoot(root interface{}) { switch t := root.(type) { case string: - if t == "" { - root = crypto.Sha3(ethutil.Encode("")) - } - self.Root = root + /* + if t == "" { + root = crypto.Sha3(ethutil.Encode("")) + } + */ + self.Root = []byte(t) case []byte: self.Root = root default: @@ -223,13 +222,20 @@ func (t *Trie) Delete(key string) { } func (self *Trie) GetRoot() []byte { - switch self.Root.(type) { + switch t := self.Root.(type) { case string: - return []byte(self.Root.(string)) + if t == "" { + return crypto.Sha3(ethutil.Encode("")) + } + return []byte(t) case []byte: - return self.Root.([]byte) + if len(t) == 0 { + return crypto.Sha3(ethutil.Encode("")) + } + + return t default: - panic(fmt.Sprintf("invalid root type %T", self.Root)) + panic(fmt.Sprintf("invalid root type %T (%v)", self.Root, self.Root)) } } diff --git a/trie/trie_test.go b/trie/trie_test.go index 5559f807d..43cd6c145 100644 --- a/trie/trie_test.go +++ b/trie/trie_test.go @@ -327,6 +327,13 @@ func (s *TrieSuite) TestBeginsWith(c *checker.C) { c.Assert(BeginsWith(b, a), checker.Equals, true) } +func (s *TrieSuite) TestItems(c *checker.C) { + s.trie.Update("A", "aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa") + exp := "d23786fb4a010da3ce639d66d5e904a11dbc02746d1ce25029e53290cabf28ab" + + c.Assert(s.trie.GetRoot(), checker.DeepEquals, ethutil.Hex2Bytes(exp)) +} + /* func TestRndCase(t *testing.T) { _, trie := NewTrie() |