diff options
Diffstat (limited to 'core/state/statedb.go')
-rw-r--r-- | core/state/statedb.go | 210 |
1 files changed, 130 insertions, 80 deletions
diff --git a/core/state/statedb.go b/core/state/statedb.go index 4204c456e..4f74302c3 100644 --- a/core/state/statedb.go +++ b/core/state/statedb.go @@ -20,6 +20,7 @@ package state import ( "fmt" "math/big" + "sort" "sync" "github.com/ethereum/go-ethereum/common" @@ -40,12 +41,17 @@ var StartingNonce uint64 const ( // Number of past tries to keep. The arbitrarily chosen value here // is max uncle depth + 1. - maxJournalLength = 8 + maxTrieCacheLength = 8 // Number of codehash->size associations to keep. codeSizeCacheSize = 100000 ) +type revision struct { + id int + journalIndex int +} + // StateDBs within the ethereum protocol are used to store anything // within the merkle trie. StateDBs take care of caching and storing // nested states. It's the general query interface to retrieve: @@ -69,6 +75,12 @@ type StateDB struct { logs map[common.Hash]vm.Logs logSize uint + // Journal of state modifications. This is the backbone of + // Snapshot and RevertToSnapshot. + journal journal + validRevisions []revision + nextRevisionId int + lock sync.Mutex } @@ -124,12 +136,12 @@ func (self *StateDB) Reset(root common.Hash) error { self.trie = tr self.stateObjects = make(map[common.Address]*StateObject) self.stateObjectsDirty = make(map[common.Address]struct{}) - self.refund = new(big.Int) self.thash = common.Hash{} self.bhash = common.Hash{} self.txIndex = 0 self.logs = make(map[common.Hash]vm.Logs) self.logSize = 0 + self.clearJournalAndRefund() return nil } @@ -150,7 +162,7 @@ func (self *StateDB) pushTrie(t *trie.SecureTrie) { self.lock.Lock() defer self.lock.Unlock() - if len(self.pastTries) >= maxJournalLength { + if len(self.pastTries) >= maxTrieCacheLength { copy(self.pastTries, self.pastTries[1:]) self.pastTries[len(self.pastTries)-1] = t } else { @@ -165,6 +177,8 @@ func (self *StateDB) StartRecord(thash, bhash common.Hash, ti int) { } func (self *StateDB) AddLog(log *vm.Log) { + self.journal = append(self.journal, addLogChange{txhash: self.thash}) + log.TxHash = self.thash log.BlockHash = self.bhash log.TxIndex = uint(self.txIndex) @@ -186,13 +200,12 @@ func (self *StateDB) Logs() vm.Logs { } func (self *StateDB) AddRefund(gas *big.Int) { + self.journal = append(self.journal, refundChange{prev: new(big.Int).Set(self.refund)}) self.refund.Add(self.refund, gas) } -func (self *StateDB) HasAccount(addr common.Address) bool { - return self.GetStateObject(addr) != nil -} - +// Exist reports whether the given account address exists in the state. +// Notably this also returns true for suicided accounts. func (self *StateDB) Exist(addr common.Address) bool { return self.GetStateObject(addr) != nil } @@ -207,7 +220,6 @@ func (self *StateDB) GetBalance(addr common.Address) *big.Int { if stateObject != nil { return stateObject.Balance() } - return common.Big0 } @@ -282,6 +294,13 @@ func (self *StateDB) AddBalance(addr common.Address, amount *big.Int) { } } +func (self *StateDB) SetBalance(addr common.Address, amount *big.Int) { + stateObject := self.GetOrNewStateObject(addr) + if stateObject != nil { + stateObject.SetBalance(amount) + } +} + func (self *StateDB) SetNonce(addr common.Address, nonce uint64) { stateObject := self.GetOrNewStateObject(addr) if stateObject != nil { @@ -299,27 +318,36 @@ func (self *StateDB) SetCode(addr common.Address, code []byte) { func (self *StateDB) SetState(addr common.Address, key common.Hash, value common.Hash) { stateObject := self.GetOrNewStateObject(addr) if stateObject != nil { - stateObject.SetState(key, value) + stateObject.SetState(self.db, key, value) } } +// Delete marks the given account as suicided. +// This clears the account balance. +// +// The account's state object is still available until the state is committed, +// GetStateObject will return a non-nil account after Delete. func (self *StateDB) Delete(addr common.Address) bool { stateObject := self.GetStateObject(addr) - if stateObject != nil { - stateObject.MarkForDeletion() - stateObject.data.Balance = new(big.Int) - return true + if stateObject == nil { + return false } - - return false + self.journal = append(self.journal, deleteAccountChange{ + account: &addr, + prev: stateObject.remove, + prevbalance: new(big.Int).Set(stateObject.Balance()), + }) + stateObject.markForDeletion() + stateObject.data.Balance = new(big.Int) + return true } // // Setting, updating & deleting state object methods // -// Update the given state object and apply it to state trie -func (self *StateDB) UpdateStateObject(stateObject *StateObject) { +// updateStateObject writes the given object to the trie. +func (self *StateDB) updateStateObject(stateObject *StateObject) { addr := stateObject.Address() data, err := rlp.EncodeToBytes(stateObject) if err != nil { @@ -328,10 +356,9 @@ func (self *StateDB) UpdateStateObject(stateObject *StateObject) { self.trie.Update(addr[:], data) } -// Delete the given state object and delete it from the state trie -func (self *StateDB) DeleteStateObject(stateObject *StateObject) { +// deleteStateObject removes the given object from the state trie. +func (self *StateDB) deleteStateObject(stateObject *StateObject) { stateObject.deleted = true - addr := stateObject.Address() self.trie.Delete(addr[:]) } @@ -357,12 +384,12 @@ func (self *StateDB) GetStateObject(addr common.Address) (stateObject *StateObje return nil } // Insert into the live set. - obj := NewObject(addr, data, self.MarkStateObjectDirty) - self.SetStateObject(obj) + obj := newObject(self, addr, data, self.MarkStateObjectDirty) + self.setStateObject(obj) return obj } -func (self *StateDB) SetStateObject(object *StateObject) { +func (self *StateDB) setStateObject(object *StateObject) { self.stateObjects[object.Address()] = object } @@ -370,52 +397,55 @@ func (self *StateDB) SetStateObject(object *StateObject) { func (self *StateDB) GetOrNewStateObject(addr common.Address) *StateObject { stateObject := self.GetStateObject(addr) if stateObject == nil || stateObject.deleted { - stateObject = self.CreateStateObject(addr) + stateObject, _ = self.createObject(addr) } - return stateObject } -// NewStateObject create a state object whether it exist in the trie or not -func (self *StateDB) newStateObject(addr common.Address) *StateObject { - if glog.V(logger.Core) { - glog.Infof("(+) %x\n", addr) - } - obj := NewObject(addr, Account{}, self.MarkStateObjectDirty) - obj.SetNonce(StartingNonce) // sets the object to dirty - self.stateObjects[addr] = obj - return obj -} - // MarkStateObjectDirty adds the specified object to the dirty map to avoid costly // state object cache iteration to find a handful of modified ones. func (self *StateDB) MarkStateObjectDirty(addr common.Address) { self.stateObjectsDirty[addr] = struct{}{} } -// Creates creates a new state object and takes ownership. -func (self *StateDB) CreateStateObject(addr common.Address) *StateObject { - // Get previous (if any) - so := self.GetStateObject(addr) - // Create a new one - newSo := self.newStateObject(addr) - - // If it existed set the balance to the new account - if so != nil { - newSo.data.Balance = so.data.Balance +// createObject creates a new state object. If there is an existing account with +// the given address, it is overwritten and returned as the second return value. +func (self *StateDB) createObject(addr common.Address) (newobj, prev *StateObject) { + prev = self.GetStateObject(addr) + newobj = newObject(self, addr, Account{}, self.MarkStateObjectDirty) + newobj.setNonce(StartingNonce) // sets the object to dirty + if prev == nil { + if glog.V(logger.Core) { + glog.Infof("(+) %x\n", addr) + } + self.journal = append(self.journal, createObjectChange{account: &addr}) + } else { + self.journal = append(self.journal, resetObjectChange{prev: prev}) } - - return newSo -} - -func (self *StateDB) CreateAccount(addr common.Address) vm.Account { - return self.CreateStateObject(addr) + self.setStateObject(newobj) + return newobj, prev } +// CreateAccount explicitly creates a state object. If a state object with the address +// already exists the balance is carried over to the new account. +// +// CreateAccount is called during the EVM CREATE operation. The situation might arise that +// a contract does the following: // -// Setting, copying of the state methods +// 1. sends funds to sha(account ++ (nonce + 1)) +// 2. tx_create(sha(account ++ nonce)) (note that this gets the address of 1) // +// Carrying over the balance ensures that Ether doesn't disappear. +func (self *StateDB) CreateAccount(addr common.Address) vm.Account { + new, prev := self.createObject(addr) + if prev != nil { + new.setBalance(prev.data.Balance) + } + return new +} +// Copy creates a deep, independent copy of the state. +// Snapshots of the copied state cannot be applied to the copy. func (self *StateDB) Copy() *StateDB { self.lock.Lock() defer self.lock.Unlock() @@ -434,7 +464,7 @@ func (self *StateDB) Copy() *StateDB { } // Copy the dirty states and logs for addr, _ := range self.stateObjectsDirty { - state.stateObjects[addr] = self.stateObjects[addr].Copy(self.db, state.MarkStateObjectDirty) + state.stateObjects[addr] = self.stateObjects[addr].deepCopy(state, state.MarkStateObjectDirty) state.stateObjectsDirty[addr] = struct{}{} } for hash, logs := range self.logs { @@ -444,21 +474,38 @@ func (self *StateDB) Copy() *StateDB { return state } -func (self *StateDB) Set(state *StateDB) { - self.lock.Lock() - defer self.lock.Unlock() +// Snapshot returns an identifier for the current revision of the state. +func (self *StateDB) Snapshot() int { + id := self.nextRevisionId + self.nextRevisionId++ + self.validRevisions = append(self.validRevisions, revision{id, len(self.journal)}) + return id +} + +// RevertToSnapshot reverts all state changes made since the given revision. +func (self *StateDB) RevertToSnapshot(revid int) { + // Find the snapshot in the stack of valid snapshots. + idx := sort.Search(len(self.validRevisions), func(i int) bool { + return self.validRevisions[i].id >= revid + }) + if idx == len(self.validRevisions) || self.validRevisions[idx].id != revid { + panic(fmt.Errorf("revision id %v cannot be reverted", revid)) + } + snapshot := self.validRevisions[idx].journalIndex + + // Replay the journal to undo changes. + for i := len(self.journal) - 1; i >= snapshot; i-- { + self.journal[i].undo(self) + } + self.journal = self.journal[:snapshot] - self.db = state.db - self.trie = state.trie - self.pastTries = state.pastTries - self.stateObjects = state.stateObjects - self.stateObjectsDirty = state.stateObjectsDirty - self.codeSizeCache = state.codeSizeCache - self.refund = state.refund - self.logs = state.logs - self.logSize = state.logSize + // Remove invalidated snapshots from the stack. + self.validRevisions = self.validRevisions[:idx] } +// GetRefund returns the current value of the refund counter. +// The return value must not be modified by the caller and will become +// invalid at the next call to AddRefund. func (self *StateDB) GetRefund() *big.Int { return self.refund } @@ -467,16 +514,17 @@ func (self *StateDB) GetRefund() *big.Int { // It is called in between transactions to get the root hash that // goes into transaction receipts. func (s *StateDB) IntermediateRoot() common.Hash { - s.refund = new(big.Int) for addr, _ := range s.stateObjectsDirty { stateObject := s.stateObjects[addr] if stateObject.remove { - s.DeleteStateObject(stateObject) + s.deleteStateObject(stateObject) } else { - stateObject.UpdateRoot(s.db) - s.UpdateStateObject(stateObject) + stateObject.updateRoot(s.db) + s.updateStateObject(stateObject) } } + // Invalidate journal because reverting across transactions is not allowed. + s.clearJournalAndRefund() return s.trie.Hash() } @@ -486,9 +534,9 @@ func (s *StateDB) IntermediateRoot() common.Hash { // DeleteSuicides should not be used for consensus related updates // under any circumstances. func (s *StateDB) DeleteSuicides() { - // Reset refund so that any used-gas calculations can use - // this method. - s.refund = new(big.Int) + // Reset refund so that any used-gas calculations can use this method. + s.clearJournalAndRefund() + for addr, _ := range s.stateObjectsDirty { stateObject := s.stateObjects[addr] @@ -516,15 +564,21 @@ func (s *StateDB) CommitBatch() (root common.Hash, batch ethdb.Batch) { return root, batch } -func (s *StateDB) commit(dbw trie.DatabaseWriter) (root common.Hash, err error) { +func (s *StateDB) clearJournalAndRefund() { + s.journal = nil + s.validRevisions = s.validRevisions[:0] s.refund = new(big.Int) +} + +func (s *StateDB) commit(dbw trie.DatabaseWriter) (root common.Hash, err error) { + defer s.clearJournalAndRefund() // Commit objects to the trie. for addr, stateObject := range s.stateObjects { if stateObject.remove { // If the object has been removed, don't bother syncing it // and just mark it for deletion in the trie. - s.DeleteStateObject(stateObject) + s.deleteStateObject(stateObject) } else if _, ok := s.stateObjectsDirty[addr]; ok { // Write any contract code associated with the state object if stateObject.code != nil && stateObject.dirtyCode { @@ -538,7 +592,7 @@ func (s *StateDB) commit(dbw trie.DatabaseWriter) (root common.Hash, err error) return common.Hash{}, err } // Update the object in the main account trie. - s.UpdateStateObject(stateObject) + s.updateStateObject(stateObject) } delete(s.stateObjectsDirty, addr) } @@ -549,7 +603,3 @@ func (s *StateDB) commit(dbw trie.DatabaseWriter) (root common.Hash, err error) } return root, err } - -func (self *StateDB) Refunds() *big.Int { - return self.refund -} |