diff options
author | Felix Lange <fjl@twurst.com> | 2016-10-04 18:36:02 +0800 |
---|---|---|
committer | Felix Lange <fjl@twurst.com> | 2016-10-06 21:32:16 +0800 |
commit | 1f1ea18b5414bea22332bb4fce53cc95b5c6a07d (patch) | |
tree | d1aa3051f9c4d9f33a24519c18b70f0dd2f00644 | |
parent | ab7adb0027dbcf09cf75a533be356c1e24c46c90 (diff) | |
download | dexon-1f1ea18b5414bea22332bb4fce53cc95b5c6a07d.tar.gz dexon-1f1ea18b5414bea22332bb4fce53cc95b5c6a07d.tar.zst dexon-1f1ea18b5414bea22332bb4fce53cc95b5c6a07d.zip |
core/state: implement reverts by journaling all changes
This commit replaces the deep-copy based state revert mechanism with a
linear complexity journal. This commit also hides several internal
StateDB methods to limit the number of ways in which calling code can
use the journal incorrectly.
As usual consultation and bug fixes to the initial implementation were
provided by @karalabe, @obscuren and @Arachnid. Thank you!
-rw-r--r-- | accounts/abi/bind/backends/simulated.go | 6 | ||||
-rw-r--r-- | cmd/evm/main.go | 32 | ||||
-rw-r--r-- | core/chain_makers.go | 2 | ||||
-rw-r--r-- | core/execution.go | 8 | ||||
-rw-r--r-- | core/state/dump.go | 2 | ||||
-rw-r--r-- | core/state/journal.go | 117 | ||||
-rw-r--r-- | core/state/managed_state_test.go | 7 | ||||
-rw-r--r-- | core/state/state_object.go | 54 | ||||
-rw-r--r-- | core/state/state_test.go | 22 | ||||
-rw-r--r-- | core/state/statedb.go | 210 | ||||
-rw-r--r-- | core/state/statedb_test.go | 313 | ||||
-rw-r--r-- | core/state/sync_test.go | 2 | ||||
-rw-r--r-- | core/tx_pool.go | 2 | ||||
-rw-r--r-- | core/vm/environment.go | 4 | ||||
-rw-r--r-- | core/vm/jit_test.go | 4 | ||||
-rw-r--r-- | core/vm/runtime/env.go | 8 | ||||
-rw-r--r-- | core/vm_env.go | 8 | ||||
-rw-r--r-- | eth/api_backend.go | 6 | ||||
-rw-r--r-- | internal/ethapi/tracer_test.go | 16 | ||||
-rw-r--r-- | light/state_test.go | 14 | ||||
-rw-r--r-- | miner/worker.go | 6 | ||||
-rw-r--r-- | tests/state_test_util.go | 22 | ||||
-rw-r--r-- | tests/util.go | 34 | ||||
-rw-r--r-- | tests/vm_test_util.go | 18 |
24 files changed, 667 insertions, 250 deletions
diff --git a/accounts/abi/bind/backends/simulated.go b/accounts/abi/bind/backends/simulated.go index 7e09abb11..74203a468 100644 --- a/accounts/abi/bind/backends/simulated.go +++ b/accounts/abi/bind/backends/simulated.go @@ -172,8 +172,9 @@ func (b *SimulatedBackend) CallContract(ctx context.Context, call ethereum.CallM func (b *SimulatedBackend) PendingCallContract(ctx context.Context, call ethereum.CallMsg) ([]byte, error) { b.mu.Lock() defer b.mu.Unlock() + defer b.pendingState.RevertToSnapshot(b.pendingState.Snapshot()) - rval, _, err := b.callContract(ctx, call, b.pendingBlock, b.pendingState.Copy()) + rval, _, err := b.callContract(ctx, call, b.pendingBlock, b.pendingState) return rval, err } @@ -197,8 +198,9 @@ func (b *SimulatedBackend) SuggestGasPrice(ctx context.Context) (*big.Int, error func (b *SimulatedBackend) EstimateGas(ctx context.Context, call ethereum.CallMsg) (*big.Int, error) { b.mu.Lock() defer b.mu.Unlock() + defer b.pendingState.RevertToSnapshot(b.pendingState.Snapshot()) - _, gas, err := b.callContract(ctx, call, b.pendingBlock, b.pendingState.Copy()) + _, gas, err := b.callContract(ctx, call, b.pendingBlock, b.pendingState) return gas, err } diff --git a/cmd/evm/main.go b/cmd/evm/main.go index 09ade1577..22707c1cc 100644 --- a/cmd/evm/main.go +++ b/cmd/evm/main.go @@ -227,22 +227,22 @@ type ruleSet struct{} func (ruleSet) IsHomestead(*big.Int) bool { return true } -func (self *VMEnv) RuleSet() vm.RuleSet { return ruleSet{} } -func (self *VMEnv) Vm() vm.Vm { return self.evm } -func (self *VMEnv) Db() vm.Database { return self.state } -func (self *VMEnv) MakeSnapshot() vm.Database { return self.state.Copy() } -func (self *VMEnv) SetSnapshot(db vm.Database) { self.state.Set(db.(*state.StateDB)) } -func (self *VMEnv) Origin() common.Address { return *self.transactor } -func (self *VMEnv) BlockNumber() *big.Int { return common.Big0 } -func (self *VMEnv) Coinbase() common.Address { return *self.transactor } -func (self *VMEnv) Time() *big.Int { return self.time } -func (self *VMEnv) Difficulty() *big.Int { return common.Big1 } -func (self *VMEnv) BlockHash() []byte { return make([]byte, 32) } -func (self *VMEnv) Value() *big.Int { return self.value } -func (self *VMEnv) GasLimit() *big.Int { return big.NewInt(1000000000) } -func (self *VMEnv) VmType() vm.Type { return vm.StdVmTy } -func (self *VMEnv) Depth() int { return 0 } -func (self *VMEnv) SetDepth(i int) { self.depth = i } +func (self *VMEnv) RuleSet() vm.RuleSet { return ruleSet{} } +func (self *VMEnv) Vm() vm.Vm { return self.evm } +func (self *VMEnv) Db() vm.Database { return self.state } +func (self *VMEnv) SnapshotDatabase() int { return self.state.Snapshot() } +func (self *VMEnv) RevertToSnapshot(snap int) { self.state.RevertToSnapshot(snap) } +func (self *VMEnv) Origin() common.Address { return *self.transactor } +func (self *VMEnv) BlockNumber() *big.Int { return common.Big0 } +func (self *VMEnv) Coinbase() common.Address { return *self.transactor } +func (self *VMEnv) Time() *big.Int { return self.time } +func (self *VMEnv) Difficulty() *big.Int { return common.Big1 } +func (self *VMEnv) BlockHash() []byte { return make([]byte, 32) } +func (self *VMEnv) Value() *big.Int { return self.value } +func (self *VMEnv) GasLimit() *big.Int { return big.NewInt(1000000000) } +func (self *VMEnv) VmType() vm.Type { return vm.StdVmTy } +func (self *VMEnv) Depth() int { return 0 } +func (self *VMEnv) SetDepth(i int) { self.depth = i } func (self *VMEnv) GetHash(n uint64) common.Hash { if self.block.Number().Cmp(big.NewInt(int64(n))) == 0 { return self.block.Hash() diff --git a/core/chain_makers.go b/core/chain_makers.go index 0b9a5f75d..e3ad9cda0 100644 --- a/core/chain_makers.go +++ b/core/chain_makers.go @@ -131,7 +131,7 @@ func (b *BlockGen) AddUncheckedReceipt(receipt *types.Receipt) { // TxNonce returns the next valid transaction nonce for the // account at addr. It panics if the account does not exist. func (b *BlockGen) TxNonce(addr common.Address) uint64 { - if !b.statedb.HasAccount(addr) { + if !b.statedb.Exist(addr) { panic("account does not exist") } return b.statedb.GetNonce(addr) diff --git a/core/execution.go b/core/execution.go index 1bc02f7fb..1cb507ee7 100644 --- a/core/execution.go +++ b/core/execution.go @@ -85,7 +85,7 @@ func exec(env vm.Environment, caller vm.ContractRef, address, codeAddr *common.A createAccount = true } - snapshotPreTransfer := env.MakeSnapshot() + snapshotPreTransfer := env.SnapshotDatabase() var ( from = env.Db().GetAccount(caller.Address()) to vm.Account @@ -129,7 +129,7 @@ func exec(env vm.Environment, caller vm.ContractRef, address, codeAddr *common.A if err != nil && (env.RuleSet().IsHomestead(env.BlockNumber()) || err != vm.CodeStoreOutOfGasError) { contract.UseGas(contract.Gas) - env.SetSnapshot(snapshotPreTransfer) + env.RevertToSnapshot(snapshotPreTransfer) } return ret, addr, err @@ -144,7 +144,7 @@ func execDelegateCall(env vm.Environment, caller vm.ContractRef, originAddr, toA return nil, common.Address{}, vm.DepthError } - snapshot := env.MakeSnapshot() + snapshot := env.SnapshotDatabase() var to vm.Account if !env.Db().Exist(*toAddr) { @@ -162,7 +162,7 @@ func execDelegateCall(env vm.Environment, caller vm.ContractRef, originAddr, toA if err != nil { contract.UseGas(contract.Gas) - env.SetSnapshot(snapshot) + env.RevertToSnapshot(snapshot) } return ret, addr, err diff --git a/core/state/dump.go b/core/state/dump.go index 58ecd852b..8294d61b9 100644 --- a/core/state/dump.go +++ b/core/state/dump.go @@ -52,7 +52,7 @@ func (self *StateDB) RawDump() Dump { panic(err) } - obj := NewObject(common.BytesToAddress(addr), data, nil) + obj := newObject(nil, common.BytesToAddress(addr), data, nil) account := DumpAccount{ Balance: data.Balance.String(), Nonce: data.Nonce, diff --git a/core/state/journal.go b/core/state/journal.go new file mode 100644 index 000000000..540ade6fb --- /dev/null +++ b/core/state/journal.go @@ -0,0 +1,117 @@ +// Copyright 2016 The go-ethereum Authors +// This file is part of the go-ethereum library. +// +// The go-ethereum library is free software: you can redistribute it and/or modify +// it under the terms of the GNU Lesser General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// The go-ethereum library is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Lesser General Public License for more details. +// +// You should have received a copy of the GNU Lesser General Public License +// along with the go-ethereum library. If not, see <http://www.gnu.org/licenses/>. + +package state + +import ( + "math/big" + + "github.com/ethereum/go-ethereum/common" +) + +type journalEntry interface { + undo(*StateDB) +} + +type journal []journalEntry + +type ( + // Changes to the account trie. + createObjectChange struct { + account *common.Address + } + resetObjectChange struct { + prev *StateObject + } + deleteAccountChange struct { + account *common.Address + prev bool // whether account had already suicided + prevbalance *big.Int + } + + // Changes to individual accounts. + balanceChange struct { + account *common.Address + prev *big.Int + } + nonceChange struct { + account *common.Address + prev uint64 + } + storageChange struct { + account *common.Address + key, prevalue common.Hash + } + codeChange struct { + account *common.Address + prevcode, prevhash []byte + } + + // Changes to other state values. + refundChange struct { + prev *big.Int + } + addLogChange struct { + txhash common.Hash + } +) + +func (ch createObjectChange) undo(s *StateDB) { + s.GetStateObject(*ch.account).deleted = true + delete(s.stateObjects, *ch.account) + delete(s.stateObjectsDirty, *ch.account) +} + +func (ch resetObjectChange) undo(s *StateDB) { + s.setStateObject(ch.prev) +} + +func (ch deleteAccountChange) undo(s *StateDB) { + obj := s.GetStateObject(*ch.account) + if obj != nil { + obj.remove = ch.prev + obj.setBalance(ch.prevbalance) + } +} + +func (ch balanceChange) undo(s *StateDB) { + s.GetStateObject(*ch.account).setBalance(ch.prev) +} + +func (ch nonceChange) undo(s *StateDB) { + s.GetStateObject(*ch.account).setNonce(ch.prev) +} + +func (ch codeChange) undo(s *StateDB) { + s.GetStateObject(*ch.account).setCode(common.BytesToHash(ch.prevhash), ch.prevcode) +} + +func (ch storageChange) undo(s *StateDB) { + s.GetStateObject(*ch.account).setState(ch.key, ch.prevalue) +} + +func (ch refundChange) undo(s *StateDB) { + s.refund = ch.prev +} + +func (ch addLogChange) undo(s *StateDB) { + logs := s.logs[ch.txhash] + if len(logs) == 1 { + delete(s.logs, ch.txhash) + } else { + s.logs[ch.txhash] = logs[:len(logs)-1] + } +} diff --git a/core/state/managed_state_test.go b/core/state/managed_state_test.go index baa53428f..3f7bc2aa8 100644 --- a/core/state/managed_state_test.go +++ b/core/state/managed_state_test.go @@ -29,11 +29,8 @@ func create() (*ManagedState, *account) { db, _ := ethdb.NewMemDatabase() statedb, _ := New(common.Hash{}, db) ms := ManageState(statedb) - so := &StateObject{address: addr} - so.SetNonce(100) - ms.StateDB.stateObjects[addr] = so - ms.accounts[addr] = newAccount(so) - + ms.StateDB.SetNonce(addr, 100) + ms.accounts[addr] = newAccount(ms.StateDB.GetStateObject(addr)) return ms, ms.accounts[addr] } diff --git a/core/state/state_object.go b/core/state/state_object.go index cbd50e2a3..31ff9bcd8 100644 --- a/core/state/state_object.go +++ b/core/state/state_object.go @@ -66,6 +66,7 @@ func (self Storage) Copy() Storage { type StateObject struct { address common.Address // Ethereum address of this account data Account + db *StateDB // DB error. // State objects are used by the consensus core and VM which are @@ -99,15 +100,15 @@ type Account struct { CodeHash []byte } -// NewObject creates a state object. -func NewObject(address common.Address, data Account, onDirty func(addr common.Address)) *StateObject { +// newObject creates a state object. +func newObject(db *StateDB, address common.Address, data Account, onDirty func(addr common.Address)) *StateObject { if data.Balance == nil { data.Balance = new(big.Int) } if data.CodeHash == nil { data.CodeHash = emptyCodeHash } - return &StateObject{address: address, data: data, cachedStorage: make(Storage), dirtyStorage: make(Storage), onDirty: onDirty} + return &StateObject{db: db, address: address, data: data, cachedStorage: make(Storage), dirtyStorage: make(Storage), onDirty: onDirty} } // EncodeRLP implements rlp.Encoder. @@ -122,7 +123,7 @@ func (self *StateObject) setError(err error) { } } -func (self *StateObject) MarkForDeletion() { +func (self *StateObject) markForDeletion() { self.remove = true if self.onDirty != nil { self.onDirty(self.Address()) @@ -163,7 +164,16 @@ func (self *StateObject) GetState(db trie.Database, key common.Hash) common.Hash } // SetState updates a value in account storage. -func (self *StateObject) SetState(key, value common.Hash) { +func (self *StateObject) SetState(db trie.Database, key, value common.Hash) { + self.db.journal = append(self.db.journal, storageChange{ + account: &self.address, + key: key, + prevalue: self.GetState(db, key), + }) + self.setState(key, value) +} + +func (self *StateObject) setState(key, value common.Hash) { self.cachedStorage[key] = value self.dirtyStorage[key] = value @@ -189,7 +199,7 @@ func (self *StateObject) updateTrie(db trie.Database) { } // UpdateRoot sets the trie root to the current root hash of -func (self *StateObject) UpdateRoot(db trie.Database) { +func (self *StateObject) updateRoot(db trie.Database) { self.updateTrie(db) self.data.Root = self.trie.Hash() } @@ -232,6 +242,14 @@ func (c *StateObject) SubBalance(amount *big.Int) { } func (self *StateObject) SetBalance(amount *big.Int) { + self.db.journal = append(self.db.journal, balanceChange{ + account: &self.address, + prev: new(big.Int).Set(self.data.Balance), + }) + self.setBalance(amount) +} + +func (self *StateObject) setBalance(amount *big.Int) { self.data.Balance = amount if self.onDirty != nil { self.onDirty(self.Address()) @@ -242,8 +260,8 @@ func (self *StateObject) SetBalance(amount *big.Int) { // Return the gas back to the origin. Used by the Virtual machine or Closures func (c *StateObject) ReturnGas(gas, price *big.Int) {} -func (self *StateObject) Copy(db trie.Database, onDirty func(addr common.Address)) *StateObject { - stateObject := NewObject(self.address, self.data, onDirty) +func (self *StateObject) deepCopy(db *StateDB, onDirty func(addr common.Address)) *StateObject { + stateObject := newObject(db, self.address, self.data, onDirty) stateObject.trie = self.trie stateObject.code = self.code stateObject.dirtyStorage = self.dirtyStorage.Copy() @@ -280,6 +298,16 @@ func (self *StateObject) Code(db trie.Database) []byte { } func (self *StateObject) SetCode(codeHash common.Hash, code []byte) { + prevcode := self.Code(self.db.db) + self.db.journal = append(self.db.journal, codeChange{ + account: &self.address, + prevhash: self.CodeHash(), + prevcode: prevcode, + }) + self.setCode(codeHash, code) +} + +func (self *StateObject) setCode(codeHash common.Hash, code []byte) { self.code = code self.data.CodeHash = codeHash[:] self.dirtyCode = true @@ -290,6 +318,14 @@ func (self *StateObject) SetCode(codeHash common.Hash, code []byte) { } func (self *StateObject) SetNonce(nonce uint64) { + self.db.journal = append(self.db.journal, nonceChange{ + account: &self.address, + prev: self.data.Nonce, + }) + self.setNonce(nonce) +} + +func (self *StateObject) setNonce(nonce uint64) { self.data.Nonce = nonce if self.onDirty != nil { self.onDirty(self.Address()) @@ -322,7 +358,7 @@ func (self *StateObject) ForEachStorage(cb func(key, value common.Hash) bool) { cb(h, value) } - it := self.trie.Iterator() + it := self.getTrie(self.db.db).Iterator() for it.Next() { // ignore cached values key := common.BytesToHash(self.trie.GetKey(it.Key)) diff --git a/core/state/state_test.go b/core/state/state_test.go index 7b9b39e06..b86d8b140 100644 --- a/core/state/state_test.go +++ b/core/state/state_test.go @@ -46,8 +46,8 @@ func (s *StateSuite) TestDump(c *checker.C) { obj3.SetBalance(big.NewInt(44)) // write some of them to the trie - s.state.UpdateStateObject(obj1) - s.state.UpdateStateObject(obj2) + s.state.updateStateObject(obj1) + s.state.updateStateObject(obj2) s.state.Commit() // check that dump contains the state objects that are in trie @@ -116,12 +116,12 @@ func (s *StateSuite) TestSnapshot(c *checker.C) { // set initial state object value s.state.SetState(stateobjaddr, storageaddr, data1) // get snapshot of current state - snapshot := s.state.Copy() + snapshot := s.state.Snapshot() // set new state object value s.state.SetState(stateobjaddr, storageaddr, data2) // restore snapshot - s.state.Set(snapshot) + s.state.RevertToSnapshot(snapshot) // get state storage value res := s.state.GetState(stateobjaddr, storageaddr) @@ -129,6 +129,12 @@ func (s *StateSuite) TestSnapshot(c *checker.C) { c.Assert(data1, checker.DeepEquals, res) } +func TestSnapshotEmpty(t *testing.T) { + db, _ := ethdb.NewMemDatabase() + state, _ := New(common.Hash{}, db) + state.RevertToSnapshot(state.Snapshot()) +} + // use testing instead of checker because checker does not support // printing/logging in tests (-check.vv does not work) func TestSnapshot2(t *testing.T) { @@ -152,7 +158,7 @@ func TestSnapshot2(t *testing.T) { so0.SetCode(crypto.Keccak256Hash([]byte{'c', 'a', 'f', 'e'}), []byte{'c', 'a', 'f', 'e'}) so0.remove = false so0.deleted = false - state.SetStateObject(so0) + state.setStateObject(so0) root, _ := state.Commit() state.Reset(root) @@ -164,15 +170,15 @@ func TestSnapshot2(t *testing.T) { so1.SetCode(crypto.Keccak256Hash([]byte{'c', 'a', 'f', 'e', '2'}), []byte{'c', 'a', 'f', 'e', '2'}) so1.remove = true so1.deleted = true - state.SetStateObject(so1) + state.setStateObject(so1) so1 = state.GetStateObject(stateobjaddr1) if so1 != nil { t.Fatalf("deleted object not nil when getting") } - snapshot := state.Copy() - state.Set(snapshot) + snapshot := state.Snapshot() + state.RevertToSnapshot(snapshot) so0Restored := state.GetStateObject(stateobjaddr0) // Update lazily-loaded values before comparing. diff --git a/core/state/statedb.go b/core/state/statedb.go index 4204c456e..4f74302c3 100644 --- a/core/state/statedb.go +++ b/core/state/statedb.go @@ -20,6 +20,7 @@ package state import ( "fmt" "math/big" + "sort" "sync" "github.com/ethereum/go-ethereum/common" @@ -40,12 +41,17 @@ var StartingNonce uint64 const ( // Number of past tries to keep. The arbitrarily chosen value here // is max uncle depth + 1. - maxJournalLength = 8 + maxTrieCacheLength = 8 // Number of codehash->size associations to keep. codeSizeCacheSize = 100000 ) +type revision struct { + id int + journalIndex int +} + // StateDBs within the ethereum protocol are used to store anything // within the merkle trie. StateDBs take care of caching and storing // nested states. It's the general query interface to retrieve: @@ -69,6 +75,12 @@ type StateDB struct { logs map[common.Hash]vm.Logs logSize uint + // Journal of state modifications. This is the backbone of + // Snapshot and RevertToSnapshot. + journal journal + validRevisions []revision + nextRevisionId int + lock sync.Mutex } @@ -124,12 +136,12 @@ func (self *StateDB) Reset(root common.Hash) error { self.trie = tr self.stateObjects = make(map[common.Address]*StateObject) self.stateObjectsDirty = make(map[common.Address]struct{}) - self.refund = new(big.Int) self.thash = common.Hash{} self.bhash = common.Hash{} self.txIndex = 0 self.logs = make(map[common.Hash]vm.Logs) self.logSize = 0 + self.clearJournalAndRefund() return nil } @@ -150,7 +162,7 @@ func (self *StateDB) pushTrie(t *trie.SecureTrie) { self.lock.Lock() defer self.lock.Unlock() - if len(self.pastTries) >= maxJournalLength { + if len(self.pastTries) >= maxTrieCacheLength { copy(self.pastTries, self.pastTries[1:]) self.pastTries[len(self.pastTries)-1] = t } else { @@ -165,6 +177,8 @@ func (self *StateDB) StartRecord(thash, bhash common.Hash, ti int) { } func (self *StateDB) AddLog(log *vm.Log) { + self.journal = append(self.journal, addLogChange{txhash: self.thash}) + log.TxHash = self.thash log.BlockHash = self.bhash log.TxIndex = uint(self.txIndex) @@ -186,13 +200,12 @@ func (self *StateDB) Logs() vm.Logs { } func (self *StateDB) AddRefund(gas *big.Int) { + self.journal = append(self.journal, refundChange{prev: new(big.Int).Set(self.refund)}) self.refund.Add(self.refund, gas) } -func (self *StateDB) HasAccount(addr common.Address) bool { - return self.GetStateObject(addr) != nil -} - +// Exist reports whether the given account address exists in the state. +// Notably this also returns true for suicided accounts. func (self *StateDB) Exist(addr common.Address) bool { return self.GetStateObject(addr) != nil } @@ -207,7 +220,6 @@ func (self *StateDB) GetBalance(addr common.Address) *big.Int { if stateObject != nil { return stateObject.Balance() } - return common.Big0 } @@ -282,6 +294,13 @@ func (self *StateDB) AddBalance(addr common.Address, amount *big.Int) { } } +func (self *StateDB) SetBalance(addr common.Address, amount *big.Int) { + stateObject := self.GetOrNewStateObject(addr) + if stateObject != nil { + stateObject.SetBalance(amount) + } +} + func (self *StateDB) SetNonce(addr common.Address, nonce uint64) { stateObject := self.GetOrNewStateObject(addr) if stateObject != nil { @@ -299,27 +318,36 @@ func (self *StateDB) SetCode(addr common.Address, code []byte) { func (self *StateDB) SetState(addr common.Address, key common.Hash, value common.Hash) { stateObject := self.GetOrNewStateObject(addr) if stateObject != nil { - stateObject.SetState(key, value) + stateObject.SetState(self.db, key, value) } } +// Delete marks the given account as suicided. +// This clears the account balance. +// +// The account's state object is still available until the state is committed, +// GetStateObject will return a non-nil account after Delete. func (self *StateDB) Delete(addr common.Address) bool { stateObject := self.GetStateObject(addr) - if stateObject != nil { - stateObject.MarkForDeletion() - stateObject.data.Balance = new(big.Int) - return true + if stateObject == nil { + return false } - - return false + self.journal = append(self.journal, deleteAccountChange{ + account: &addr, + prev: stateObject.remove, + prevbalance: new(big.Int).Set(stateObject.Balance()), + }) + stateObject.markForDeletion() + stateObject.data.Balance = new(big.Int) + return true } // // Setting, updating & deleting state object methods // -// Update the given state object and apply it to state trie -func (self *StateDB) UpdateStateObject(stateObject *StateObject) { +// updateStateObject writes the given object to the trie. +func (self *StateDB) updateStateObject(stateObject *StateObject) { addr := stateObject.Address() data, err := rlp.EncodeToBytes(stateObject) if err != nil { @@ -328,10 +356,9 @@ func (self *StateDB) UpdateStateObject(stateObject *StateObject) { self.trie.Update(addr[:], data) } -// Delete the given state object and delete it from the state trie -func (self *StateDB) DeleteStateObject(stateObject *StateObject) { +// deleteStateObject removes the given object from the state trie. +func (self *StateDB) deleteStateObject(stateObject *StateObject) { stateObject.deleted = true - addr := stateObject.Address() self.trie.Delete(addr[:]) } @@ -357,12 +384,12 @@ func (self *StateDB) GetStateObject(addr common.Address) (stateObject *StateObje return nil } // Insert into the live set. - obj := NewObject(addr, data, self.MarkStateObjectDirty) - self.SetStateObject(obj) + obj := newObject(self, addr, data, self.MarkStateObjectDirty) + self.setStateObject(obj) return obj } -func (self *StateDB) SetStateObject(object *StateObject) { +func (self *StateDB) setStateObject(object *StateObject) { self.stateObjects[object.Address()] = object } @@ -370,52 +397,55 @@ func (self *StateDB) SetStateObject(object *StateObject) { func (self *StateDB) GetOrNewStateObject(addr common.Address) *StateObject { stateObject := self.GetStateObject(addr) if stateObject == nil || stateObject.deleted { - stateObject = self.CreateStateObject(addr) + stateObject, _ = self.createObject(addr) } - return stateObject } -// NewStateObject create a state object whether it exist in the trie or not -func (self *StateDB) newStateObject(addr common.Address) *StateObject { - if glog.V(logger.Core) { - glog.Infof("(+) %x\n", addr) - } - obj := NewObject(addr, Account{}, self.MarkStateObjectDirty) - obj.SetNonce(StartingNonce) // sets the object to dirty - self.stateObjects[addr] = obj - return obj -} - // MarkStateObjectDirty adds the specified object to the dirty map to avoid costly // state object cache iteration to find a handful of modified ones. func (self *StateDB) MarkStateObjectDirty(addr common.Address) { self.stateObjectsDirty[addr] = struct{}{} } -// Creates creates a new state object and takes ownership. -func (self *StateDB) CreateStateObject(addr common.Address) *StateObject { - // Get previous (if any) - so := self.GetStateObject(addr) - // Create a new one - newSo := self.newStateObject(addr) - - // If it existed set the balance to the new account - if so != nil { - newSo.data.Balance = so.data.Balance +// createObject creates a new state object. If there is an existing account with +// the given address, it is overwritten and returned as the second return value. +func (self *StateDB) createObject(addr common.Address) (newobj, prev *StateObject) { + prev = self.GetStateObject(addr) + newobj = newObject(self, addr, Account{}, self.MarkStateObjectDirty) + newobj.setNonce(StartingNonce) // sets the object to dirty + if prev == nil { + if glog.V(logger.Core) { + glog.Infof("(+) %x\n", addr) + } + self.journal = append(self.journal, createObjectChange{account: &addr}) + } else { + self.journal = append(self.journal, resetObjectChange{prev: prev}) } - - return newSo -} - -func (self *StateDB) CreateAccount(addr common.Address) vm.Account { - return self.CreateStateObject(addr) + self.setStateObject(newobj) + return newobj, prev } +// CreateAccount explicitly creates a state object. If a state object with the address +// already exists the balance is carried over to the new account. +// +// CreateAccount is called during the EVM CREATE operation. The situation might arise that +// a contract does the following: // -// Setting, copying of the state methods +// 1. sends funds to sha(account ++ (nonce + 1)) +// 2. tx_create(sha(account ++ nonce)) (note that this gets the address of 1) // +// Carrying over the balance ensures that Ether doesn't disappear. +func (self *StateDB) CreateAccount(addr common.Address) vm.Account { + new, prev := self.createObject(addr) + if prev != nil { + new.setBalance(prev.data.Balance) + } + return new +} +// Copy creates a deep, independent copy of the state. +// Snapshots of the copied state cannot be applied to the copy. func (self *StateDB) Copy() *StateDB { self.lock.Lock() defer self.lock.Unlock() @@ -434,7 +464,7 @@ func (self *StateDB) Copy() *StateDB { } // Copy the dirty states and logs for addr, _ := range self.stateObjectsDirty { - state.stateObjects[addr] = self.stateObjects[addr].Copy(self.db, state.MarkStateObjectDirty) + state.stateObjects[addr] = self.stateObjects[addr].deepCopy(state, state.MarkStateObjectDirty) state.stateObjectsDirty[addr] = struct{}{} } for hash, logs := range self.logs { @@ -444,21 +474,38 @@ func (self *StateDB) Copy() *StateDB { return state } -func (self *StateDB) Set(state *StateDB) { - self.lock.Lock() - defer self.lock.Unlock() +// Snapshot returns an identifier for the current revision of the state. +func (self *StateDB) Snapshot() int { + id := self.nextRevisionId + self.nextRevisionId++ + self.validRevisions = append(self.validRevisions, revision{id, len(self.journal)}) + return id +} + +// RevertToSnapshot reverts all state changes made since the given revision. +func (self *StateDB) RevertToSnapshot(revid int) { + // Find the snapshot in the stack of valid snapshots. + idx := sort.Search(len(self.validRevisions), func(i int) bool { + return self.validRevisions[i].id >= revid + }) + if idx == len(self.validRevisions) || self.validRevisions[idx].id != revid { + panic(fmt.Errorf("revision id %v cannot be reverted", revid)) + } + snapshot := self.validRevisions[idx].journalIndex + + // Replay the journal to undo changes. + for i := len(self.journal) - 1; i >= snapshot; i-- { + self.journal[i].undo(self) + } + self.journal = self.journal[:snapshot] - self.db = state.db - self.trie = state.trie - self.pastTries = state.pastTries - self.stateObjects = state.stateObjects - self.stateObjectsDirty = state.stateObjectsDirty - self.codeSizeCache = state.codeSizeCache - self.refund = state.refund - self.logs = state.logs - self.logSize = state.logSize + // Remove invalidated snapshots from the stack. + self.validRevisions = self.validRevisions[:idx] } +// GetRefund returns the current value of the refund counter. +// The return value must not be modified by the caller and will become +// invalid at the next call to AddRefund. func (self *StateDB) GetRefund() *big.Int { return self.refund } @@ -467,16 +514,17 @@ func (self *StateDB) GetRefund() *big.Int { // It is called in between transactions to get the root hash that // goes into transaction receipts. func (s *StateDB) IntermediateRoot() common.Hash { - s.refund = new(big.Int) for addr, _ := range s.stateObjectsDirty { stateObject := s.stateObjects[addr] if stateObject.remove { - s.DeleteStateObject(stateObject) + s.deleteStateObject(stateObject) } else { - stateObject.UpdateRoot(s.db) - s.UpdateStateObject(stateObject) + stateObject.updateRoot(s.db) + s.updateStateObject(stateObject) } } + // Invalidate journal because reverting across transactions is not allowed. + s.clearJournalAndRefund() return s.trie.Hash() } @@ -486,9 +534,9 @@ func (s *StateDB) IntermediateRoot() common.Hash { // DeleteSuicides should not be used for consensus related updates // under any circumstances. func (s *StateDB) DeleteSuicides() { - // Reset refund so that any used-gas calculations can use - // this method. - s.refund = new(big.Int) + // Reset refund so that any used-gas calculations can use this method. + s.clearJournalAndRefund() + for addr, _ := range s.stateObjectsDirty { stateObject := s.stateObjects[addr] @@ -516,15 +564,21 @@ func (s *StateDB) CommitBatch() (root common.Hash, batch ethdb.Batch) { return root, batch } -func (s *StateDB) commit(dbw trie.DatabaseWriter) (root common.Hash, err error) { +func (s *StateDB) clearJournalAndRefund() { + s.journal = nil + s.validRevisions = s.validRevisions[:0] s.refund = new(big.Int) +} + +func (s *StateDB) commit(dbw trie.DatabaseWriter) (root common.Hash, err error) { + defer s.clearJournalAndRefund() // Commit objects to the trie. for addr, stateObject := range s.stateObjects { if stateObject.remove { // If the object has been removed, don't bother syncing it // and just mark it for deletion in the trie. - s.DeleteStateObject(stateObject) + s.deleteStateObject(stateObject) } else if _, ok := s.stateObjectsDirty[addr]; ok { // Write any contract code associated with the state object if stateObject.code != nil && stateObject.dirtyCode { @@ -538,7 +592,7 @@ func (s *StateDB) commit(dbw trie.DatabaseWriter) (root common.Hash, err error) return common.Hash{}, err } // Update the object in the main account trie. - s.UpdateStateObject(stateObject) + s.updateStateObject(stateObject) } delete(s.stateObjectsDirty, addr) } @@ -549,7 +603,3 @@ func (s *StateDB) commit(dbw trie.DatabaseWriter) (root common.Hash, err error) } return root, err } - -func (self *StateDB) Refunds() *big.Int { - return self.refund -} diff --git a/core/state/statedb_test.go b/core/state/statedb_test.go index 7930b620d..e236cb8f3 100644 --- a/core/state/statedb_test.go +++ b/core/state/statedb_test.go @@ -17,11 +17,19 @@ package state import ( + "bytes" + "encoding/binary" + "fmt" + "math" "math/big" + "math/rand" + "reflect" + "strings" "testing" + "testing/quick" "github.com/ethereum/go-ethereum/common" - "github.com/ethereum/go-ethereum/crypto" + "github.com/ethereum/go-ethereum/core/vm" "github.com/ethereum/go-ethereum/ethdb" ) @@ -34,16 +42,16 @@ func TestUpdateLeaks(t *testing.T) { // Update it with some accounts for i := byte(0); i < 255; i++ { - obj := state.GetOrNewStateObject(common.BytesToAddress([]byte{i})) - obj.AddBalance(big.NewInt(int64(11 * i))) - obj.SetNonce(uint64(42 * i)) + addr := common.BytesToAddress([]byte{i}) + state.AddBalance(addr, big.NewInt(int64(11*i))) + state.SetNonce(addr, uint64(42*i)) if i%2 == 0 { - obj.SetState(common.BytesToHash([]byte{i, i, i}), common.BytesToHash([]byte{i, i, i, i})) + state.SetState(addr, common.BytesToHash([]byte{i, i, i}), common.BytesToHash([]byte{i, i, i, i})) } if i%3 == 0 { - obj.SetCode(crypto.Keccak256Hash([]byte{i, i, i, i, i}), []byte{i, i, i, i, i}) + state.SetCode(addr, []byte{i, i, i, i, i}) } - state.UpdateStateObject(obj) + state.IntermediateRoot() } // Ensure that no data was leaked into the database for _, key := range db.Keys() { @@ -61,51 +69,38 @@ func TestIntermediateLeaks(t *testing.T) { transState, _ := New(common.Hash{}, transDb) finalState, _ := New(common.Hash{}, finalDb) - // Update the states with some objects - for i := byte(0); i < 255; i++ { - // Create a new state object with some data into the transition database - obj := transState.GetOrNewStateObject(common.BytesToAddress([]byte{i})) - obj.SetBalance(big.NewInt(int64(11 * i))) - obj.SetNonce(uint64(42 * i)) + modify := func(state *StateDB, addr common.Address, i, tweak byte) { + state.SetBalance(addr, big.NewInt(int64(11*i)+int64(tweak))) + state.SetNonce(addr, uint64(42*i+tweak)) if i%2 == 0 { - obj.SetState(common.BytesToHash([]byte{i, i, i, 0}), common.BytesToHash([]byte{i, i, i, i, 0})) + state.SetState(addr, common.Hash{i, i, i, 0}, common.Hash{}) + state.SetState(addr, common.Hash{i, i, i, tweak}, common.Hash{i, i, i, i, tweak}) } if i%3 == 0 { - obj.SetCode(crypto.Keccak256Hash([]byte{i, i, i, i, i, 0}), []byte{i, i, i, i, i, 0}) + state.SetCode(addr, []byte{i, i, i, i, i, tweak}) } - transState.UpdateStateObject(obj) + } - // Overwrite all the data with new values in the transition database - obj.SetBalance(big.NewInt(int64(11*i + 1))) - obj.SetNonce(uint64(42*i + 1)) - if i%2 == 0 { - obj.SetState(common.BytesToHash([]byte{i, i, i, 0}), common.Hash{}) - obj.SetState(common.BytesToHash([]byte{i, i, i, 1}), common.BytesToHash([]byte{i, i, i, i, 1})) - } - if i%3 == 0 { - obj.SetCode(crypto.Keccak256Hash([]byte{i, i, i, i, i, 1}), []byte{i, i, i, i, i, 1}) - } - transState.UpdateStateObject(obj) + // Modify the transient state. + for i := byte(0); i < 255; i++ { + modify(transState, common.Address{byte(i)}, i, 0) + } + // Write modifications to trie. + transState.IntermediateRoot() - // Create the final state object directly in the final database - obj = finalState.GetOrNewStateObject(common.BytesToAddress([]byte{i})) - obj.SetBalance(big.NewInt(int64(11*i + 1))) - obj.SetNonce(uint64(42*i + 1)) - if i%2 == 0 { - obj.SetState(common.BytesToHash([]byte{i, i, i, 1}), common.BytesToHash([]byte{i, i, i, i, 1})) - } - if i%3 == 0 { - obj.SetCode(crypto.Keccak256Hash([]byte{i, i, i, i, i, 1}), []byte{i, i, i, i, i, 1}) - } - finalState.UpdateStateObject(obj) + // Overwrite all the data with new values in the transient database. + for i := byte(0); i < 255; i++ { + modify(transState, common.Address{byte(i)}, i, 99) + modify(finalState, common.Address{byte(i)}, i, 99) } + + // Commit and cross check the databases. if _, err := transState.Commit(); err != nil { t.Fatalf("failed to commit transition state: %v", err) } if _, err := finalState.Commit(); err != nil { t.Fatalf("failed to commit final state: %v", err) } - // Cross check the databases to ensure they are the same for _, key := range finalDb.Keys() { if _, err := transDb.Get(key); err != nil { val, _ := finalDb.Get(key) @@ -119,3 +114,243 @@ func TestIntermediateLeaks(t *testing.T) { } } } + +func TestSnapshotRandom(t *testing.T) { + config := &quick.Config{MaxCount: 1000} + err := quick.Check((*snapshotTest).run, config) + if cerr, ok := err.(*quick.CheckError); ok { + test := cerr.In[0].(*snapshotTest) + t.Errorf("%v:\n%s", test.err, test) + } else if err != nil { + t.Error(err) + } +} + +// A snapshotTest checks that reverting StateDB snapshots properly undoes all changes +// captured by the snapshot. Instances of this test with pseudorandom content are created +// by Generate. +// +// The test works as follows: +// +// A new state is created and all actions are applied to it. Several snapshots are taken +// in between actions. The test then reverts each snapshot. For each snapshot the actions +// leading up to it are replayed on a fresh, empty state. The behaviour of all public +// accessor methods on the reverted state must match the return value of the equivalent +// methods on the replayed state. +type snapshotTest struct { + addrs []common.Address // all account addresses + actions []testAction // modifications to the state + snapshots []int // actions indexes at which snapshot is taken + err error // failure details are reported through this field +} + +type testAction struct { + name string + fn func(testAction, *StateDB) + args []int64 + noAddr bool +} + +// newTestAction creates a random action that changes state. +func newTestAction(addr common.Address, r *rand.Rand) testAction { + actions := []testAction{ + { + name: "SetBalance", + fn: func(a testAction, s *StateDB) { + s.SetBalance(addr, big.NewInt(a.args[0])) + }, + args: make([]int64, 1), + }, + { + name: "AddBalance", + fn: func(a testAction, s *StateDB) { + s.AddBalance(addr, big.NewInt(a.args[0])) + }, + args: make([]int64, 1), + }, + { + name: "SetNonce", + fn: func(a testAction, s *StateDB) { + s.SetNonce(addr, uint64(a.args[0])) + }, + args: make([]int64, 1), + }, + { + name: "SetState", + fn: func(a testAction, s *StateDB) { + var key, val common.Hash + binary.BigEndian.PutUint16(key[:], uint16(a.args[0])) + binary.BigEndian.PutUint16(val[:], uint16(a.args[1])) + s.SetState(addr, key, val) + }, + args: make([]int64, 2), + }, + { + name: "SetCode", + fn: func(a testAction, s *StateDB) { + code := make([]byte, 16) + binary.BigEndian.PutUint64(code, uint64(a.args[0])) + binary.BigEndian.PutUint64(code[8:], uint64(a.args[1])) + s.SetCode(addr, code) + }, + args: make([]int64, 2), + }, + { + name: "CreateAccount", + fn: func(a testAction, s *StateDB) { + s.CreateAccount(addr) + }, + }, + { + name: "Delete", + fn: func(a testAction, s *StateDB) { + s.Delete(addr) + }, + }, + { + name: "AddRefund", + fn: func(a testAction, s *StateDB) { + s.AddRefund(big.NewInt(a.args[0])) + }, + args: make([]int64, 1), + noAddr: true, + }, + { + name: "AddLog", + fn: func(a testAction, s *StateDB) { + data := make([]byte, 2) + binary.BigEndian.PutUint16(data, uint16(a.args[0])) + s.AddLog(&vm.Log{Address: addr, Data: data}) + }, + args: make([]int64, 1), + }, + } + action := actions[r.Intn(len(actions))] + var nameargs []string + if !action.noAddr { + nameargs = append(nameargs, addr.Hex()) + } + for _, i := range action.args { + action.args[i] = rand.Int63n(100) + nameargs = append(nameargs, fmt.Sprint(action.args[i])) + } + action.name += strings.Join(nameargs, ", ") + return action +} + +// Generate returns a new snapshot test of the given size. All randomness is +// derived from r. +func (*snapshotTest) Generate(r *rand.Rand, size int) reflect.Value { + // Generate random actions. + addrs := make([]common.Address, 50) + for i := range addrs { + addrs[i][0] = byte(i) + } + actions := make([]testAction, size) + for i := range actions { + addr := addrs[r.Intn(len(addrs))] + actions[i] = newTestAction(addr, r) + } + // Generate snapshot indexes. + nsnapshots := int(math.Sqrt(float64(size))) + if size > 0 && nsnapshots == 0 { + nsnapshots = 1 + } + snapshots := make([]int, nsnapshots) + snaplen := len(actions) / nsnapshots + for i := range snapshots { + // Try to place the snapshots some number of actions apart from each other. + snapshots[i] = (i * snaplen) + r.Intn(snaplen) + } + return reflect.ValueOf(&snapshotTest{addrs, actions, snapshots, nil}) +} + +func (test *snapshotTest) String() string { + out := new(bytes.Buffer) + sindex := 0 + for i, action := range test.actions { + if len(test.snapshots) > sindex && i == test.snapshots[sindex] { + fmt.Fprintf(out, "---- snapshot %d ----\n", sindex) + sindex++ + } + fmt.Fprintf(out, "%4d: %s\n", i, action.name) + } + return out.String() +} + +func (test *snapshotTest) run() bool { + // Run all actions and create snapshots. + var ( + db, _ = ethdb.NewMemDatabase() + state, _ = New(common.Hash{}, db) + snapshotRevs = make([]int, len(test.snapshots)) + sindex = 0 + ) + for i, action := range test.actions { + if len(test.snapshots) > sindex && i == test.snapshots[sindex] { + snapshotRevs[sindex] = state.Snapshot() + sindex++ + } + action.fn(action, state) + } + + // Revert all snapshots in reverse order. Each revert must yield a state + // that is equivalent to fresh state with all actions up the snapshot applied. + for sindex--; sindex >= 0; sindex-- { + checkstate, _ := New(common.Hash{}, db) + for _, action := range test.actions[:test.snapshots[sindex]] { + action.fn(action, checkstate) + } + state.RevertToSnapshot(snapshotRevs[sindex]) + if err := test.checkEqual(state, checkstate); err != nil { + test.err = fmt.Errorf("state mismatch after revert to snapshot %d\n%v", sindex, err) + return false + } + } + return true +} + +// checkEqual checks that methods of state and checkstate return the same values. +func (test *snapshotTest) checkEqual(state, checkstate *StateDB) error { + for _, addr := range test.addrs { + var err error + checkeq := func(op string, a, b interface{}) bool { + if err == nil && !reflect.DeepEqual(a, b) { + err = fmt.Errorf("got %s(%s) == %v, want %v", op, addr.Hex(), a, b) + return false + } + return true + } + // Check basic accessor methods. + checkeq("Exist", state.Exist(addr), checkstate.Exist(addr)) + checkeq("IsDeleted", state.IsDeleted(addr), checkstate.IsDeleted(addr)) + checkeq("GetBalance", state.GetBalance(addr), checkstate.GetBalance(addr)) + checkeq("GetNonce", state.GetNonce(addr), checkstate.GetNonce(addr)) + checkeq("GetCode", state.GetCode(addr), checkstate.GetCode(addr)) + checkeq("GetCodeHash", state.GetCodeHash(addr), checkstate.GetCodeHash(addr)) + checkeq("GetCodeSize", state.GetCodeSize(addr), checkstate.GetCodeSize(addr)) + // Check storage. + if obj := state.GetStateObject(addr); obj != nil { + obj.ForEachStorage(func(key, val common.Hash) bool { + return checkeq("GetState("+key.Hex()+")", val, checkstate.GetState(addr, key)) + }) + checkobj := checkstate.GetStateObject(addr) + checkobj.ForEachStorage(func(key, checkval common.Hash) bool { + return checkeq("GetState("+key.Hex()+")", state.GetState(addr, key), checkval) + }) + } + if err != nil { + return err + } + } + + if state.GetRefund().Cmp(checkstate.GetRefund()) != 0 { + return fmt.Errorf("got GetRefund() == %d, want GetRefund() == %d", + state.GetRefund(), checkstate.GetRefund()) + } + if !reflect.DeepEqual(state.GetLogs(common.Hash{}), checkstate.GetLogs(common.Hash{})) { + return fmt.Errorf("got GetLogs(common.Hash{}) == %v, want GetLogs(common.Hash{}) == %v", + state.GetLogs(common.Hash{}), checkstate.GetLogs(common.Hash{})) + } + return nil +} diff --git a/core/state/sync_test.go b/core/state/sync_test.go index 670e1fb1b..949df7301 100644 --- a/core/state/sync_test.go +++ b/core/state/sync_test.go @@ -57,7 +57,7 @@ func makeTestState() (ethdb.Database, common.Hash, []*testAccount) { obj.SetCode(crypto.Keccak256Hash([]byte{i, i, i, i, i}), []byte{i, i, i, i, i}) acc.code = []byte{i, i, i, i, i} } - state.UpdateStateObject(obj) + state.updateStateObject(obj) accounts = append(accounts, acc) } root, _ := state.Commit() diff --git a/core/tx_pool.go b/core/tx_pool.go index f8b11a7ce..10a110e0b 100644 --- a/core/tx_pool.go +++ b/core/tx_pool.go @@ -257,7 +257,7 @@ func (pool *TxPool) validateTx(tx *types.Transaction) error { // Make sure the account exist. Non existent accounts // haven't got funds and well therefor never pass. - if !currentState.HasAccount(from) { + if !currentState.Exist(from) { return ErrNonExistentAccount } diff --git a/core/vm/environment.go b/core/vm/environment.go index daf6fb90d..1038e69d5 100644 --- a/core/vm/environment.go +++ b/core/vm/environment.go @@ -36,9 +36,9 @@ type Environment interface { // The state database Db() Database // Creates a restorable snapshot - MakeSnapshot() Database + SnapshotDatabase() int // Set database to previous snapshot - SetSnapshot(Database) + RevertToSnapshot(int) // Address of the original invoker (first occurrence of the VM invoker) Origin() common.Address // The block number this VM is invoked on diff --git a/core/vm/jit_test.go b/core/vm/jit_test.go index e6922aeb7..a6de710e1 100644 --- a/core/vm/jit_test.go +++ b/core/vm/jit_test.go @@ -179,8 +179,8 @@ func (self *Env) BlockNumber() *big.Int { return big.NewInt(0) } //func (self *Env) PrevHash() []byte { return self.parent } func (self *Env) Coinbase() common.Address { return common.Address{} } -func (self *Env) MakeSnapshot() Database { return nil } -func (self *Env) SetSnapshot(Database) {} +func (self *Env) SnapshotDatabase() int { return 0 } +func (self *Env) RevertToSnapshot(int) {} func (self *Env) Time() *big.Int { return big.NewInt(time.Now().Unix()) } func (self *Env) Difficulty() *big.Int { return big.NewInt(0) } func (self *Env) Db() Database { return nil } diff --git a/core/vm/runtime/env.go b/core/vm/runtime/env.go index a4793c98f..59fbaa792 100644 --- a/core/vm/runtime/env.go +++ b/core/vm/runtime/env.go @@ -86,11 +86,11 @@ func (self *Env) SetDepth(i int) { self.depth = i } func (self *Env) CanTransfer(from common.Address, balance *big.Int) bool { return self.state.GetBalance(from).Cmp(balance) >= 0 } -func (self *Env) MakeSnapshot() vm.Database { - return self.state.Copy() +func (self *Env) SnapshotDatabase() int { + return self.state.Snapshot() } -func (self *Env) SetSnapshot(copy vm.Database) { - self.state.Set(copy.(*state.StateDB)) +func (self *Env) RevertToSnapshot(snapshot int) { + self.state.RevertToSnapshot(snapshot) } func (self *Env) Transfer(from, to vm.Account, amount *big.Int) { diff --git a/core/vm_env.go b/core/vm_env.go index e541eaef4..d62eebbd9 100644 --- a/core/vm_env.go +++ b/core/vm_env.go @@ -89,12 +89,12 @@ func (self *VMEnv) CanTransfer(from common.Address, balance *big.Int) bool { return self.state.GetBalance(from).Cmp(balance) >= 0 } -func (self *VMEnv) MakeSnapshot() vm.Database { - return self.state.Copy() +func (self *VMEnv) SnapshotDatabase() int { + return self.state.Snapshot() } -func (self *VMEnv) SetSnapshot(copy vm.Database) { - self.state.Set(copy.(*state.StateDB)) +func (self *VMEnv) RevertToSnapshot(snapshot int) { + self.state.RevertToSnapshot(snapshot) } func (self *VMEnv) Transfer(from, to vm.Account, amount *big.Int) { diff --git a/eth/api_backend.go b/eth/api_backend.go index 4adeb0aa0..42b84bf9b 100644 --- a/eth/api_backend.go +++ b/eth/api_backend.go @@ -98,12 +98,12 @@ func (b *EthApiBackend) GetTd(blockHash common.Hash) *big.Int { } func (b *EthApiBackend) GetVMEnv(ctx context.Context, msg core.Message, state ethapi.State, header *types.Header) (vm.Environment, func() error, error) { - stateDb := state.(EthApiState).state.Copy() + statedb := state.(EthApiState).state addr, _ := msg.From() - from := stateDb.GetOrNewStateObject(addr) + from := statedb.GetOrNewStateObject(addr) from.SetBalance(common.MaxBig) vmError := func() error { return nil } - return core.NewEnv(stateDb, b.eth.chainConfig, b.eth.blockchain, msg, header, b.eth.chainConfig.VmConfig), vmError, nil + return core.NewEnv(statedb, b.eth.chainConfig, b.eth.blockchain, msg, header, b.eth.chainConfig.VmConfig), vmError, nil } func (b *EthApiBackend) SendTx(ctx context.Context, signedTx *types.Transaction) error { diff --git a/internal/ethapi/tracer_test.go b/internal/ethapi/tracer_test.go index 7c831d299..127af32a8 100644 --- a/internal/ethapi/tracer_test.go +++ b/internal/ethapi/tracer_test.go @@ -50,14 +50,14 @@ func (self *Env) Origin() common.Address { return common.Address{} } func (self *Env) BlockNumber() *big.Int { return big.NewInt(0) } //func (self *Env) PrevHash() []byte { return self.parent } -func (self *Env) Coinbase() common.Address { return common.Address{} } -func (self *Env) MakeSnapshot() vm.Database { return nil } -func (self *Env) SetSnapshot(vm.Database) {} -func (self *Env) Time() *big.Int { return big.NewInt(time.Now().Unix()) } -func (self *Env) Difficulty() *big.Int { return big.NewInt(0) } -func (self *Env) Db() vm.Database { return nil } -func (self *Env) GasLimit() *big.Int { return self.gasLimit } -func (self *Env) VmType() vm.Type { return vm.StdVmTy } +func (self *Env) Coinbase() common.Address { return common.Address{} } +func (self *Env) SnapshotDatabase() int { return 0 } +func (self *Env) RevertToSnapshot(int) {} +func (self *Env) Time() *big.Int { return big.NewInt(time.Now().Unix()) } +func (self *Env) Difficulty() *big.Int { return big.NewInt(0) } +func (self *Env) Db() vm.Database { return nil } +func (self *Env) GasLimit() *big.Int { return self.gasLimit } +func (self *Env) VmType() vm.Type { return vm.StdVmTy } func (self *Env) GetHash(n uint64) common.Hash { return common.BytesToHash(crypto.Keccak256([]byte(big.NewInt(int64(n)).String()))) } diff --git a/light/state_test.go b/light/state_test.go index d4fe95022..a6b115786 100644 --- a/light/state_test.go +++ b/light/state_test.go @@ -23,7 +23,6 @@ import ( "github.com/ethereum/go-ethereum/common" "github.com/ethereum/go-ethereum/core/state" - "github.com/ethereum/go-ethereum/crypto" "github.com/ethereum/go-ethereum/ethdb" "github.com/ethereum/go-ethereum/trie" "golang.org/x/net/context" @@ -54,16 +53,13 @@ func makeTestState() (common.Hash, ethdb.Database) { sdb, _ := ethdb.NewMemDatabase() st, _ := state.New(common.Hash{}, sdb) for i := byte(0); i < 100; i++ { - so := st.GetOrNewStateObject(common.Address{i}) + addr := common.Address{i} for j := byte(0); j < 100; j++ { - val := common.Hash{i, j} - so.SetState(common.Hash{j}, val) - so.SetNonce(100) + st.SetState(addr, common.Hash{j}, common.Hash{i, j}) } - so.AddBalance(big.NewInt(int64(i))) - so.SetCode(crypto.Keccak256Hash([]byte{i, i, i}), []byte{i, i, i}) - so.UpdateRoot(sdb) - st.UpdateStateObject(so) + st.SetNonce(addr, 100) + st.AddBalance(addr, big.NewInt(int64(i))) + st.SetCode(addr, []byte{i, i, i}) } root, _ := st.Commit() return root, sdb diff --git a/miner/worker.go b/miner/worker.go index ac1ef5ba3..e5348cef4 100644 --- a/miner/worker.go +++ b/miner/worker.go @@ -171,7 +171,7 @@ func (self *worker) pending() (*types.Block, *state.StateDB) { self.current.receipts, ), self.current.state } - return self.current.Block, self.current.state + return self.current.Block, self.current.state.Copy() } func (self *worker) start() { @@ -618,7 +618,7 @@ func (env *Work) commitTransactions(mux *event.TypeMux, txs *types.TransactionsB } func (env *Work) commitTransaction(tx *types.Transaction, bc *core.BlockChain, gp *core.GasPool) (error, vm.Logs) { - snap := env.state.Copy() + snap := env.state.Snapshot() // this is a bit of a hack to force jit for the miners config := env.config.VmConfig @@ -629,7 +629,7 @@ func (env *Work) commitTransaction(tx *types.Transaction, bc *core.BlockChain, g receipt, logs, _, err := core.ApplyTransaction(env.config, bc, gp, env.state, env.header, tx, env.header.GasUsed, config) if err != nil { - env.state.Set(snap) + env.state.RevertToSnapshot(snap) return err, nil } env.txs = append(env.txs, tx) diff --git a/tests/state_test_util.go b/tests/state_test_util.go index 67e4bf832..3c4b42a18 100644 --- a/tests/state_test_util.go +++ b/tests/state_test_util.go @@ -95,14 +95,7 @@ func BenchStateTest(ruleSet RuleSet, p string, conf bconf, b *testing.B) error { func benchStateTest(ruleSet RuleSet, test VmTest, env map[string]string, b *testing.B) { b.StopTimer() db, _ := ethdb.NewMemDatabase() - statedb, _ := state.New(common.Hash{}, db) - for addr, account := range test.Pre { - obj := StateObjectFromAccount(db, addr, account, statedb.MarkStateObjectDirty) - statedb.SetStateObject(obj) - for a, v := range account.Storage { - obj.SetState(common.HexToHash(a), common.HexToHash(v)) - } - } + statedb := makePreState(db, test.Pre) b.StartTimer() RunState(ruleSet, statedb, env, test.Exec) @@ -134,14 +127,7 @@ func runStateTests(ruleSet RuleSet, tests map[string]VmTest, skipTests []string) func runStateTest(ruleSet RuleSet, test VmTest) error { db, _ := ethdb.NewMemDatabase() - statedb, _ := state.New(common.Hash{}, db) - for addr, account := range test.Pre { - obj := StateObjectFromAccount(db, addr, account, statedb.MarkStateObjectDirty) - statedb.SetStateObject(obj) - for a, v := range account.Storage { - obj.SetState(common.HexToHash(a), common.HexToHash(v)) - } - } + statedb := makePreState(db, test.Pre) // XXX Yeah, yeah... env := make(map[string]string) @@ -227,7 +213,7 @@ func RunState(ruleSet RuleSet, statedb *state.StateDB, env, tx map[string]string } // Set pre compiled contracts vm.Precompiled = vm.PrecompiledContracts() - snapshot := statedb.Copy() + snapshot := statedb.Snapshot() gaspool := new(core.GasPool).AddGas(common.Big(env["currentGasLimit"])) key, _ := hex.DecodeString(tx["secretKey"]) @@ -237,7 +223,7 @@ func RunState(ruleSet RuleSet, statedb *state.StateDB, env, tx map[string]string vmenv.origin = addr ret, _, err := core.ApplyMessage(vmenv, message, gaspool) if core.IsNonceErr(err) || core.IsInvalidTxErr(err) || core.IsGasLimitErr(err) { - statedb.Set(snapshot) + statedb.RevertToSnapshot(snapshot) } statedb.Commit() diff --git a/tests/util.go b/tests/util.go index ffbcb9d56..8a9d09213 100644 --- a/tests/util.go +++ b/tests/util.go @@ -103,19 +103,25 @@ func (self Log) Topics() [][]byte { return t } -func StateObjectFromAccount(db ethdb.Database, addr string, account Account, onDirty func(common.Address)) *state.StateObject { +func makePreState(db ethdb.Database, accounts map[string]Account) *state.StateDB { + statedb, _ := state.New(common.Hash{}, db) + for addr, account := range accounts { + insertAccount(statedb, addr, account) + } + return statedb +} + +func insertAccount(state *state.StateDB, saddr string, account Account) { if common.IsHex(account.Code) { account.Code = account.Code[2:] } - code := common.Hex2Bytes(account.Code) - codeHash := crypto.Keccak256Hash(code) - obj := state.NewObject(common.HexToAddress(addr), state.Account{ - Balance: common.Big(account.Balance), - CodeHash: codeHash[:], - Nonce: common.Big(account.Nonce).Uint64(), - }, onDirty) - obj.SetCode(codeHash, code) - return obj + addr := common.HexToAddress(saddr) + state.SetCode(addr, common.Hex2Bytes(account.Code)) + state.SetNonce(addr, common.Big(account.Nonce).Uint64()) + state.SetBalance(addr, common.Big(account.Balance)) + for a, v := range account.Storage { + state.SetState(addr, common.HexToHash(a), common.HexToHash(v)) + } } type VmEnv struct { @@ -229,11 +235,11 @@ func (self *Env) CanTransfer(from common.Address, balance *big.Int) bool { return self.state.GetBalance(from).Cmp(balance) >= 0 } -func (self *Env) MakeSnapshot() vm.Database { - return self.state.Copy() +func (self *Env) SnapshotDatabase() int { + return self.state.Snapshot() } -func (self *Env) SetSnapshot(copy vm.Database) { - self.state.Set(copy.(*state.StateDB)) +func (self *Env) RevertToSnapshot(snapshot int) { + self.state.RevertToSnapshot(snapshot) } func (self *Env) Transfer(from, to vm.Account, amount *big.Int) { diff --git a/tests/vm_test_util.go b/tests/vm_test_util.go index 4ad72d91c..c269f21e0 100644 --- a/tests/vm_test_util.go +++ b/tests/vm_test_util.go @@ -101,14 +101,7 @@ func BenchVmTest(p string, conf bconf, b *testing.B) error { func benchVmTest(test VmTest, env map[string]string, b *testing.B) { b.StopTimer() db, _ := ethdb.NewMemDatabase() - statedb, _ := state.New(common.Hash{}, db) - for addr, account := range test.Pre { - obj := StateObjectFromAccount(db, addr, account, statedb.MarkStateObjectDirty) - statedb.SetStateObject(obj) - for a, v := range account.Storage { - obj.SetState(common.HexToHash(a), common.HexToHash(v)) - } - } + statedb := makePreState(db, test.Pre) b.StartTimer() RunVm(statedb, env, test.Exec) @@ -152,14 +145,7 @@ func runVmTests(tests map[string]VmTest, skipTests []string) error { func runVmTest(test VmTest) error { db, _ := ethdb.NewMemDatabase() - statedb, _ := state.New(common.Hash{}, db) - for addr, account := range test.Pre { - obj := StateObjectFromAccount(db, addr, account, statedb.MarkStateObjectDirty) - statedb.SetStateObject(obj) - for a, v := range account.Storage { - obj.SetState(common.HexToHash(a), common.HexToHash(v)) - } - } + statedb := makePreState(db, test.Pre) // XXX Yeah, yeah... env := make(map[string]string) |