diff options
Diffstat (limited to 'state')
-rw-r--r-- | state/dump.go | 2 | ||||
-rw-r--r-- | state/log.go | 32 | ||||
-rw-r--r-- | state/managed_state.go | 25 | ||||
-rw-r--r-- | state/managed_state_test.go | 12 | ||||
-rw-r--r-- | state/state_object.go | 38 | ||||
-rw-r--r-- | state/state_test.go | 24 | ||||
-rw-r--r-- | state/statedb.go | 69 |
7 files changed, 110 insertions, 92 deletions
diff --git a/state/dump.go b/state/dump.go index 6db0d5074..712f8da1f 100644 --- a/state/dump.go +++ b/state/dump.go @@ -28,7 +28,7 @@ func (self *StateDB) RawDump() World { it := self.trie.Iterator() for it.Next() { - stateObject := NewStateObjectFromBytes(it.Key, it.Value, self.db) + stateObject := NewStateObjectFromBytes(common.BytesToAddress(it.Key), it.Value, self.db) account := Account{Balance: stateObject.balance.String(), Nonce: stateObject.nonce, Root: common.Bytes2Hex(stateObject.Root()), CodeHash: common.Bytes2Hex(stateObject.codeHash)} account.Storage = make(map[string]string) diff --git a/state/log.go b/state/log.go index a0859aaf2..f8aa4c08c 100644 --- a/state/log.go +++ b/state/log.go @@ -2,36 +2,36 @@ package state import ( "fmt" + "io" "github.com/ethereum/go-ethereum/common" + "github.com/ethereum/go-ethereum/rlp" ) type Log interface { - common.RlpEncodable - - Address() []byte - Topics() [][]byte + Address() common.Address + Topics() []common.Hash Data() []byte Number() uint64 } type StateLog struct { - address []byte - topics [][]byte + address common.Address + topics []common.Hash data []byte number uint64 } -func NewLog(address []byte, topics [][]byte, data []byte, number uint64) *StateLog { +func NewLog(address common.Address, topics []common.Hash, data []byte, number uint64) *StateLog { return &StateLog{address, topics, data, number} } -func (self *StateLog) Address() []byte { +func (self *StateLog) Address() common.Address { return self.address } -func (self *StateLog) Topics() [][]byte { +func (self *StateLog) Topics() []common.Hash { return self.topics } @@ -43,7 +43,12 @@ func (self *StateLog) Number() uint64 { return self.number } +/* func NewLogFromValue(decoder *common.Value) *StateLog { + var extlog struct { + + } + log := &StateLog{ address: decoder.Get(0).Bytes(), data: decoder.Get(2).Bytes(), @@ -56,10 +61,17 @@ func NewLogFromValue(decoder *common.Value) *StateLog { return log } +*/ + +func (self *StateLog) EncodeRLP(w io.Writer) error { + return rlp.Encode(w, []interface{}{self.address, self.topics, self.data}) +} +/* func (self *StateLog) RlpData() interface{} { return []interface{}{self.address, common.ByteSliceToInterface(self.topics), self.data} } +*/ func (self *StateLog) String() string { return fmt.Sprintf(`log: %x %x %x`, self.address, self.topics, self.data) @@ -67,6 +79,7 @@ func (self *StateLog) String() string { type Logs []Log +/* func (self Logs) RlpData() interface{} { data := make([]interface{}, len(self)) for i, log := range self { @@ -75,6 +88,7 @@ func (self Logs) RlpData() interface{} { return data } +*/ func (self Logs) String() (ret string) { for _, log := range self { diff --git a/state/managed_state.go b/state/managed_state.go index aff0206b2..0fcc1be67 100644 --- a/state/managed_state.go +++ b/state/managed_state.go @@ -1,6 +1,10 @@ package state -import "sync" +import ( + "sync" + + "github.com/ethereum/go-ethereum/common" +) type account struct { stateObject *StateObject @@ -29,7 +33,7 @@ func (ms *ManagedState) SetState(statedb *StateDB) { ms.StateDB = statedb } -func (ms *ManagedState) RemoveNonce(addr []byte, n uint64) { +func (ms *ManagedState) RemoveNonce(addr common.Address, n uint64) { if ms.hasAccount(addr) { ms.mu.Lock() defer ms.mu.Unlock() @@ -43,7 +47,7 @@ func (ms *ManagedState) RemoveNonce(addr []byte, n uint64) { } } -func (ms *ManagedState) NewNonce(addr []byte) uint64 { +func (ms *ManagedState) NewNonce(addr common.Address) uint64 { ms.mu.RLock() defer ms.mu.RUnlock() @@ -57,26 +61,27 @@ func (ms *ManagedState) NewNonce(addr []byte) uint64 { return uint64(len(account.nonces)) + account.nstart } -func (ms *ManagedState) hasAccount(addr []byte) bool { - _, ok := ms.accounts[string(addr)] +func (ms *ManagedState) hasAccount(addr common.Address) bool { + _, ok := ms.accounts[addr.Str()] return ok } -func (ms *ManagedState) getAccount(addr []byte) *account { - if account, ok := ms.accounts[string(addr)]; !ok { +func (ms *ManagedState) getAccount(addr common.Address) *account { + straddr := addr.Str() + if account, ok := ms.accounts[straddr]; !ok { so := ms.GetOrNewStateObject(addr) - ms.accounts[string(addr)] = newAccount(so) + ms.accounts[straddr] = newAccount(so) } else { // Always make sure the state account nonce isn't actually higher // than the tracked one. so := ms.StateDB.GetStateObject(addr) if so != nil && uint64(len(account.nonces))+account.nstart < so.nonce { - ms.accounts[string(addr)] = newAccount(so) + ms.accounts[straddr] = newAccount(so) } } - return ms.accounts[string(addr)] + return ms.accounts[straddr] } func newAccount(so *StateObject) *account { diff --git a/state/managed_state_test.go b/state/managed_state_test.go index 4aad1e1e3..b61f59e6d 100644 --- a/state/managed_state_test.go +++ b/state/managed_state_test.go @@ -6,15 +6,15 @@ import ( "github.com/ethereum/go-ethereum/common" ) -var addr = common.Address([]byte("test")) +var addr = common.BytesToAddress([]byte("test")) func create() (*ManagedState, *account) { ms := ManageState(&StateDB{stateObjects: make(map[string]*StateObject)}) so := &StateObject{address: addr, nonce: 100} - ms.StateDB.stateObjects[string(addr)] = so - ms.accounts[string(addr)] = newAccount(so) + ms.StateDB.stateObjects[addr.Str()] = so + ms.accounts[addr.Str()] = newAccount(so) - return ms, ms.accounts[string(addr)] + return ms, ms.accounts[addr.Str()] } func TestNewNonce(t *testing.T) { @@ -73,7 +73,7 @@ func TestRemoteNonceChange(t *testing.T) { account.nonces = append(account.nonces, nn...) nonce := ms.NewNonce(addr) - ms.StateDB.stateObjects[string(addr)].nonce = 200 + ms.StateDB.stateObjects[addr.Str()].nonce = 200 nonce = ms.NewNonce(addr) if nonce != 200 { t.Error("expected nonce after remote update to be", 201, "got", nonce) @@ -81,7 +81,7 @@ func TestRemoteNonceChange(t *testing.T) { ms.NewNonce(addr) ms.NewNonce(addr) ms.NewNonce(addr) - ms.StateDB.stateObjects[string(addr)].nonce = 200 + ms.StateDB.stateObjects[addr.Str()].nonce = 200 nonce = ms.NewNonce(addr) if nonce != 204 { t.Error("expected nonce after remote update to be", 201, "got", nonce) diff --git a/state/state_object.go b/state/state_object.go index cdb9abf79..a7c20722c 100644 --- a/state/state_object.go +++ b/state/state_object.go @@ -44,7 +44,7 @@ type StateObject struct { State *StateDB // Address belonging to this account - address []byte + address common.Address // The balance of the account balance *big.Int // The nonce of the account @@ -77,12 +77,12 @@ func (self *StateObject) Reset() { self.State.Reset() } -func NewStateObject(addr []byte, db common.Database) *StateObject { +func NewStateObject(address common.Address, db common.Database) *StateObject { // This to ensure that it has 20 bytes (and not 0 bytes), thus left or right pad doesn't matter. - address := common.Address(addr) + //address := common.ToAddress(addr) object := &StateObject{db: db, address: address, balance: new(big.Int), gasPool: new(big.Int), dirty: true} - object.State = New(nil, db) //New(trie.New(common.Config.Db, "")) + object.State = New(common.Hash{}, db) //New(trie.New(common.Config.Db, "")) object.storage = make(Storage) object.gasPool = new(big.Int) object.prepaid = new(big.Int) @@ -90,12 +90,12 @@ func NewStateObject(addr []byte, db common.Database) *StateObject { return object } -func NewStateObjectFromBytes(address, data []byte, db common.Database) *StateObject { +func NewStateObjectFromBytes(address common.Address, data []byte, db common.Database) *StateObject { // TODO clean me up var extobject struct { Nonce uint64 Balance *big.Int - Root []byte + Root common.Hash CodeHash []byte } err := rlp.Decode(bytes.NewReader(data), &extobject) @@ -124,8 +124,8 @@ func (self *StateObject) MarkForDeletion() { statelogger.Debugf("%x: #%d %v X\n", self.Address(), self.nonce, self.balance) } -func (c *StateObject) getAddr(addr []byte) *common.Value { - return common.NewValueFromBytes([]byte(c.State.trie.Get(addr))) +func (c *StateObject) getAddr(addr common.Hash) *common.Value { + return common.NewValueFromBytes([]byte(c.State.trie.Get(addr[:]))) } func (c *StateObject) setAddr(addr []byte, value interface{}) { @@ -133,34 +133,32 @@ func (c *StateObject) setAddr(addr []byte, value interface{}) { } func (self *StateObject) GetStorage(key *big.Int) *common.Value { - return self.GetState(key.Bytes()) + return self.GetState(common.BytesToHash(key.Bytes())) } func (self *StateObject) SetStorage(key *big.Int, value *common.Value) { - self.SetState(key.Bytes(), value) + self.SetState(common.BytesToHash(key.Bytes()), value) } func (self *StateObject) Storage() Storage { return self.storage } -func (self *StateObject) GetState(k []byte) *common.Value { - key := common.LeftPadBytes(k, 32) - - value := self.storage[string(key)] +func (self *StateObject) GetState(key common.Hash) *common.Value { + strkey := key.Str() + value := self.storage[strkey] if value == nil { value = self.getAddr(key) if !value.IsNil() { - self.storage[string(key)] = value + self.storage[strkey] = value } } return value } -func (self *StateObject) SetState(k []byte, value *common.Value) { - key := common.LeftPadBytes(k, 32) - self.storage[string(key)] = value.Copy() +func (self *StateObject) SetState(k common.Hash, value *common.Value) { + self.storage[k.Str()] = value.Copy() self.dirty = true } @@ -284,7 +282,7 @@ func (c *StateObject) N() *big.Int { } // Returns the address of the contract/account -func (c *StateObject) Address() []byte { +func (c *StateObject) Address() common.Address { return c.address } @@ -341,7 +339,7 @@ func (c *StateObject) RlpDecode(data []byte) { decoder := common.NewValueFromBytes(data) c.nonce = decoder.Get(0).Uint() c.balance = decoder.Get(1).BigInt() - c.State = New(decoder.Get(2).Bytes(), c.db) //New(trie.New(common.Config.Db, decoder.Get(2).Interface())) + c.State = New(common.BytesToHash(decoder.Get(2).Bytes()), c.db) //New(trie.New(common.Config.Db, decoder.Get(2).Interface())) c.storage = make(map[string]*common.Value) c.gasPool = new(big.Int) diff --git a/state/state_test.go b/state/state_test.go index 3a1ea225d..a3d3973de 100644 --- a/state/state_test.go +++ b/state/state_test.go @@ -1,7 +1,6 @@ package state import ( - "fmt" "math/big" "testing" @@ -17,15 +16,16 @@ type StateSuite struct { var _ = checker.Suite(&StateSuite{}) -// var ZeroHash256 = make([]byte, 32) +var toAddr = common.BytesToAddress func (s *StateSuite) TestDump(c *checker.C) { + return // generate a few entries - obj1 := s.state.GetOrNewStateObject([]byte{0x01}) + obj1 := s.state.GetOrNewStateObject(toAddr([]byte{0x01})) obj1.AddBalance(big.NewInt(22)) - obj2 := s.state.GetOrNewStateObject([]byte{0x01, 0x02}) + obj2 := s.state.GetOrNewStateObject(toAddr([]byte{0x01, 0x02})) obj2.SetCode([]byte{3, 3, 3, 3, 3, 3, 3}) - obj3 := s.state.GetOrNewStateObject([]byte{0x02}) + obj3 := s.state.GetOrNewStateObject(toAddr([]byte{0x02})) obj3.SetBalance(big.NewInt(44)) // write some of them to the trie @@ -60,27 +60,25 @@ func (s *StateSuite) TestDump(c *checker.C) { func (s *StateSuite) SetUpTest(c *checker.C) { db, _ := ethdb.NewMemDatabase() - s.state = New(nil, db) + s.state = New(common.Hash{}, db) } func TestNull(t *testing.T) { db, _ := ethdb.NewMemDatabase() - state := New(nil, db) + state := New(common.Hash{}, db) - address := common.FromHex("0x823140710bf13990e4500136726d8b55") + address := common.HexToAddress("0x823140710bf13990e4500136726d8b55") state.NewStateObject(address) //value := common.FromHex("0x823140710bf13990e4500136726d8b55") value := make([]byte, 16) - fmt.Println("test it here", common.NewValue(value)) - state.SetState(address, []byte{0}, value) + state.SetState(address, common.Hash{}, value) state.Update(nil) state.Sync() - value = state.GetState(address, []byte{0}) - fmt.Printf("res: %x\n", value) + value = state.GetState(address, common.Hash{}) } func (s *StateSuite) TestSnapshot(c *checker.C) { - stateobjaddr := []byte("aa") + stateobjaddr := toAddr([]byte("aa")) storageaddr := common.Big("0") data1 := common.NewValue(42) data2 := common.NewValue(43) diff --git a/state/statedb.go b/state/statedb.go index d01b6056d..6fcd39dbc 100644 --- a/state/statedb.go +++ b/state/statedb.go @@ -28,8 +28,8 @@ type StateDB struct { } // Create a new state from a given trie -func New(root []byte, db common.Database) *StateDB { - trie := trie.NewSecure(common.CopyBytes(root), db) +func New(root common.Hash, db common.Database) *StateDB { + trie := trie.NewSecure(root[:], db) return &StateDB{db: db, trie: trie, stateObjects: make(map[string]*StateObject), refund: make(map[string]*big.Int)} } @@ -49,15 +49,16 @@ func (self *StateDB) Logs() Logs { return self.logs } -func (self *StateDB) Refund(addr []byte, gas *big.Int) { - if self.refund[string(addr)] == nil { - self.refund[string(addr)] = new(big.Int) +func (self *StateDB) Refund(address common.Address, gas *big.Int) { + addr := address.Str() + if self.refund[addr] == nil { + self.refund[addr] = new(big.Int) } - self.refund[string(addr)].Add(self.refund[string(addr)], gas) + self.refund[addr].Add(self.refund[addr], gas) } // Retrieve the balance from the given address or 0 if object not found -func (self *StateDB) GetBalance(addr []byte) *big.Int { +func (self *StateDB) GetBalance(addr common.Address) *big.Int { stateObject := self.GetStateObject(addr) if stateObject != nil { return stateObject.balance @@ -66,14 +67,14 @@ func (self *StateDB) GetBalance(addr []byte) *big.Int { return common.Big0 } -func (self *StateDB) AddBalance(addr []byte, amount *big.Int) { +func (self *StateDB) AddBalance(addr common.Address, amount *big.Int) { stateObject := self.GetStateObject(addr) if stateObject != nil { stateObject.AddBalance(amount) } } -func (self *StateDB) GetNonce(addr []byte) uint64 { +func (self *StateDB) GetNonce(addr common.Address) uint64 { stateObject := self.GetStateObject(addr) if stateObject != nil { return stateObject.nonce @@ -82,7 +83,7 @@ func (self *StateDB) GetNonce(addr []byte) uint64 { return 0 } -func (self *StateDB) GetCode(addr []byte) []byte { +func (self *StateDB) GetCode(addr common.Address) []byte { stateObject := self.GetStateObject(addr) if stateObject != nil { return stateObject.code @@ -91,7 +92,7 @@ func (self *StateDB) GetCode(addr []byte) []byte { return nil } -func (self *StateDB) GetState(a, b []byte) []byte { +func (self *StateDB) GetState(a common.Address, b common.Hash) []byte { stateObject := self.GetStateObject(a) if stateObject != nil { return stateObject.GetState(b).Bytes() @@ -100,28 +101,28 @@ func (self *StateDB) GetState(a, b []byte) []byte { return nil } -func (self *StateDB) SetNonce(addr []byte, nonce uint64) { +func (self *StateDB) SetNonce(addr common.Address, nonce uint64) { stateObject := self.GetStateObject(addr) if stateObject != nil { stateObject.SetNonce(nonce) } } -func (self *StateDB) SetCode(addr, code []byte) { +func (self *StateDB) SetCode(addr common.Address, code []byte) { stateObject := self.GetStateObject(addr) if stateObject != nil { stateObject.SetCode(code) } } -func (self *StateDB) SetState(addr, key []byte, value interface{}) { +func (self *StateDB) SetState(addr common.Address, key common.Hash, value interface{}) { stateObject := self.GetStateObject(addr) if stateObject != nil { stateObject.SetState(key, common.NewValue(value)) } } -func (self *StateDB) Delete(addr []byte) bool { +func (self *StateDB) Delete(addr common.Address) bool { stateObject := self.GetStateObject(addr) if stateObject != nil { stateObject.MarkForDeletion() @@ -133,7 +134,7 @@ func (self *StateDB) Delete(addr []byte) bool { return false } -func (self *StateDB) IsDeleted(addr []byte) bool { +func (self *StateDB) IsDeleted(addr common.Address) bool { stateObject := self.GetStateObject(addr) if stateObject != nil { return stateObject.remove @@ -147,32 +148,34 @@ func (self *StateDB) IsDeleted(addr []byte) bool { // Update the given state object and apply it to state trie func (self *StateDB) UpdateStateObject(stateObject *StateObject) { - addr := stateObject.Address() + //addr := stateObject.Address() if len(stateObject.CodeHash()) > 0 { self.db.Put(stateObject.CodeHash(), stateObject.code) } - self.trie.Update(addr, stateObject.RlpEncode()) + addr := stateObject.Address() + self.trie.Update(addr[:], stateObject.RlpEncode()) } // Delete the given state object and delete it from the state trie func (self *StateDB) DeleteStateObject(stateObject *StateObject) { - self.trie.Delete(stateObject.Address()) + addr := stateObject.Address() + self.trie.Delete(addr[:]) - delete(self.stateObjects, string(stateObject.Address())) + delete(self.stateObjects, addr.Str()) } // Retrieve a state object given my the address. Nil if not found -func (self *StateDB) GetStateObject(addr []byte) *StateObject { - addr = common.Address(addr) +func (self *StateDB) GetStateObject(addr common.Address) *StateObject { + //addr = common.Address(addr) - stateObject := self.stateObjects[string(addr)] + stateObject := self.stateObjects[addr.Str()] if stateObject != nil { return stateObject } - data := self.trie.Get(addr) + data := self.trie.Get(addr[:]) if len(data) == 0 { return nil } @@ -184,11 +187,11 @@ func (self *StateDB) GetStateObject(addr []byte) *StateObject { } func (self *StateDB) SetStateObject(object *StateObject) { - self.stateObjects[string(object.address)] = object + self.stateObjects[object.Address().Str()] = object } // Retrieve a state object or create a new state object if nil -func (self *StateDB) GetOrNewStateObject(addr []byte) *StateObject { +func (self *StateDB) GetOrNewStateObject(addr common.Address) *StateObject { stateObject := self.GetStateObject(addr) if stateObject == nil { stateObject = self.NewStateObject(addr) @@ -198,19 +201,19 @@ func (self *StateDB) GetOrNewStateObject(addr []byte) *StateObject { } // Create a state object whether it exist in the trie or not -func (self *StateDB) NewStateObject(addr []byte) *StateObject { - addr = common.Address(addr) +func (self *StateDB) NewStateObject(addr common.Address) *StateObject { + //addr = common.Address(addr) statelogger.Debugf("(+) %x\n", addr) stateObject := NewStateObject(addr, self.db) - self.stateObjects[string(addr)] = stateObject + self.stateObjects[addr.Str()] = stateObject return stateObject } // Deprecated -func (self *StateDB) GetAccount(addr []byte) *StateObject { +func (self *StateDB) GetAccount(addr common.Address) *StateObject { return self.GetOrNewStateObject(addr) } @@ -223,7 +226,7 @@ func (s *StateDB) Cmp(other *StateDB) bool { } func (self *StateDB) Copy() *StateDB { - state := New(nil, self.db) + state := New(common.Hash{}, self.db) state.trie = self.trie.Copy() for k, stateObject := range self.stateObjects { state.stateObjects[k] = stateObject.Copy() @@ -248,8 +251,8 @@ func (self *StateDB) Set(state *StateDB) { self.logs = state.logs } -func (s *StateDB) Root() []byte { - return s.trie.Root() +func (s *StateDB) Root() common.Hash { + return common.BytesToHash(s.trie.Root()) } func (s *StateDB) Trie() *trie.SecureTrie { |