diff options
42 files changed, 961 insertions, 983 deletions
diff --git a/.travis.yml b/.travis.yml index 556397b0a..b3b5edb27 100644 --- a/.travis.yml +++ b/.travis.yml @@ -5,9 +5,6 @@ matrix: include: - os: linux dist: trusty - go: 1.4.2 - - os: linux - dist: trusty go: 1.5.4 - os: linux dist: trusty @@ -4,8 +4,8 @@ Official golang implementation of the Ethereum protocol | Linux | OSX | ARM | Windows | Tests ----------|---------|-----|-----|---------|------ -develop | [![Build+Status](https://build.ethdev.com/buildstatusimage?builder=Linux%20Go%20develop%20branch)](https://build.ethdev.com/builders/Linux%20Go%20develop%20branch/builds/-1) | [![Build+Status](https://build.ethdev.com/buildstatusimage?builder=Linux%20Go%20develop%20branch)](https://build.ethdev.com/builders/OSX%20Go%20develop%20branch/builds/-1) | [![Build+Status](https://build.ethdev.com/buildstatusimage?builder=ARM%20Go%20develop%20branch)](https://build.ethdev.com/builders/ARM%20Go%20develop%20branch/builds/-1) | [![Build+Status](https://build.ethdev.com/buildstatusimage?builder=Windows%20Go%20develop%20branch)](https://build.ethdev.com/builders/Windows%20Go%20develop%20branch/builds/-1) | [![Buildr+Status](https://travis-ci.org/ethereum/go-ethereum.svg?branch=develop)](https://travis-ci.org/ethereum/go-ethereum) [![codecov.io](http://codecov.io/github/ethereum/go-ethereum/coverage.svg?branch=develop)](http://codecov.io/github/ethereum/go-ethereum?branch=develop) -master | [![Build+Status](https://build.ethdev.com/buildstatusimage?builder=Linux%20Go%20master%20branch)](https://build.ethdev.com/builders/Linux%20Go%20master%20branch/builds/-1) | [![Build+Status](https://build.ethdev.com/buildstatusimage?builder=OSX%20Go%20master%20branch)](https://build.ethdev.com/builders/OSX%20Go%20master%20branch/builds/-1) | [![Build+Status](https://build.ethdev.com/buildstatusimage?builder=ARM%20Go%20master%20branch)](https://build.ethdev.com/builders/ARM%20Go%20master%20branch/builds/-1) | [![Build+Status](https://build.ethdev.com/buildstatusimage?builder=Windows%20Go%20master%20branch)](https://build.ethdev.com/builders/Windows%20Go%20master%20branch/builds/-1) | [![Buildr+Status](https://travis-ci.org/ethereum/go-ethereum.svg?branch=master)](https://travis-ci.org/ethereum/go-ethereum) [![codecov.io](http://codecov.io/github/ethereum/go-ethereum/coverage.svg?branch=master)](http://codecov.io/github/ethereum/go-ethereum?branch=master) +develop | [![Build+Status](https://build.ethdev.com/buildstatusimage?builder=Linux%20Go%20develop%20branch)](https://build.ethdev.com/builders/Linux%20Go%20develop%20branch/builds/-1) | [![Build+Status](https://build.ethdev.com/buildstatusimage?builder=Linux%20Go%20develop%20branch)](https://build.ethdev.com/builders/OSX%20Go%20develop%20branch/builds/-1) | [![Build+Status](https://build.ethdev.com/buildstatusimage?builder=ARM%20Go%20develop%20branch)](https://build.ethdev.com/builders/ARM%20Go%20develop%20branch/builds/-1) | [![Build+Status](https://build.ethdev.com/buildstatusimage?builder=Windows%20Go%20develop%20branch)](https://build.ethdev.com/builders/Windows%20Go%20develop%20branch/builds/-1) | [![Buildr+Status](https://travis-ci.org/ethereum/go-ethereum.svg?branch=develop)](https://travis-ci.org/ethereum/go-ethereum) [![codecov.io](https://codecov.io/github/ethereum/go-ethereum/coverage.svg?branch=develop)](https://codecov.io/github/ethereum/go-ethereum?branch=develop) +master | [![Build+Status](https://build.ethdev.com/buildstatusimage?builder=Linux%20Go%20master%20branch)](https://build.ethdev.com/builders/Linux%20Go%20master%20branch/builds/-1) | [![Build+Status](https://build.ethdev.com/buildstatusimage?builder=OSX%20Go%20master%20branch)](https://build.ethdev.com/builders/OSX%20Go%20master%20branch/builds/-1) | [![Build+Status](https://build.ethdev.com/buildstatusimage?builder=ARM%20Go%20master%20branch)](https://build.ethdev.com/builders/ARM%20Go%20master%20branch/builds/-1) | [![Build+Status](https://build.ethdev.com/buildstatusimage?builder=Windows%20Go%20master%20branch)](https://build.ethdev.com/builders/Windows%20Go%20master%20branch/builds/-1) | [![Buildr+Status](https://travis-ci.org/ethereum/go-ethereum.svg?branch=master)](https://travis-ci.org/ethereum/go-ethereum) [![codecov.io](https://codecov.io/github/ethereum/go-ethereum/coverage.svg?branch=master)](https://codecov.io/github/ethereum/go-ethereum?branch=master) [![API Reference]( https://camo.githubusercontent.com/915b7be44ada53c290eb157634330494ebe3e30a/68747470733a2f2f676f646f632e6f72672f6769746875622e636f6d2f676f6c616e672f6764646f3f7374617475732e737667 @@ -17,7 +17,7 @@ https://camo.githubusercontent.com/915b7be44ada53c290eb157634330494ebe3e30a/6874 The following builds are built automatically by our build servers after each push to the [develop](https://github.com/ethereum/go-ethereum/tree/develop) branch. * [Docker](https://registry.hub.docker.com/u/ethereum/client-go/) -* [OS X](http://build.ethdev.com/builds/OSX%20Go%20develop%20branch/Mist-OSX-latest.dmg) +* [OS X](https://build.ethdev.com/builds/OSX%20Go%20develop%20branch/Mist-OSX-latest.dmg) * Ubuntu [trusty](https://build.ethdev.com/builds/Linux%20Go%20develop%20deb%20i386-trusty/latest/) | [utopic](https://build.ethdev.com/builds/Linux%20Go%20develop%20deb%20i386-utopic/latest/) @@ -283,9 +283,9 @@ for more details on configuring your environment, managing project dependencies ## License The go-ethereum library (i.e. all code outside of the `cmd` directory) is licensed under the -[GNU Lesser General Public License v3.0](http://www.gnu.org/licenses/lgpl-3.0.en.html), also +[GNU Lesser General Public License v3.0](https://www.gnu.org/licenses/lgpl-3.0.en.html), also included in our repository in the `COPYING.LESSER` file. The go-ethereum binaries (i.e. all code inside of the `cmd` directory) is licensed under the -[GNU General Public License v3.0](http://www.gnu.org/licenses/gpl-3.0.en.html), also included +[GNU General Public License v3.0](https://www.gnu.org/licenses/gpl-3.0.en.html), also included in our repository in the `COPYING` file. diff --git a/accounts/abi/abi.go b/accounts/abi/abi.go index 1b07b2f68..c127cd7a9 100644 --- a/accounts/abi/abi.go +++ b/accounts/abi/abi.go @@ -77,7 +77,7 @@ func (abi ABI) Pack(name string, args ...interface{}) ([]byte, error) { return append(method.Id(), arguments...), nil } -// toGoSliceType prses the input and casts it to the proper slice defined by the ABI +// toGoSliceType parses the input and casts it to the proper slice defined by the ABI // argument in T. func toGoSlice(i int, t Argument, output []byte) (interface{}, error) { index := i * 32 diff --git a/accounts/abi/bind/backends/simulated.go b/accounts/abi/bind/backends/simulated.go index 29b4e8ea3..7e09abb11 100644 --- a/accounts/abi/bind/backends/simulated.go +++ b/accounts/abi/bind/backends/simulated.go @@ -135,11 +135,8 @@ func (b *SimulatedBackend) StorageAt(ctx context.Context, contract common.Addres return nil, errBlockNumberUnsupported } statedb, _ := b.blockchain.State() - if obj := statedb.GetStateObject(contract); obj != nil { - val := obj.GetState(key) - return val[:], nil - } - return nil, nil + val := statedb.GetState(contract, key) + return val[:], nil } // TransactionReceipt returns the receipt of a transaction. diff --git a/build/ci.go b/build/ci.go index 3011a6976..87e8b6275 100644 --- a/build/ci.go +++ b/build/ci.go @@ -227,6 +227,9 @@ func doTest(cmdline []string) { // Run the actual tests. gotest := goTool("test") + // Test a single package at a time. CI builders are slow + // and some tests run into timeouts under load. + gotest.Args = append(gotest.Args, "-p", "1") if *coverage { gotest.Args = append(gotest.Args, "-covermode=atomic", "-cover") } diff --git a/build/update-license.go b/build/update-license.go index f83a2b34b..96667be15 100644 --- a/build/update-license.go +++ b/build/update-license.go @@ -49,7 +49,6 @@ var ( // don't relicense vendored sources "crypto/sha3/", "crypto/ecies/", "logger/glog/", "crypto/secp256k1/curve.go", - "trie/arc.go", // don't license generated files "contracts/chequebook/contract/", "contracts/ens/contract/", diff --git a/cmd/utils/cmd.go b/cmd/utils/cmd.go index 3b521a0e1..584afc804 100644 --- a/cmd/utils/cmd.go +++ b/cmd/utils/cmd.go @@ -23,6 +23,7 @@ import ( "os" "os/signal" "regexp" + "runtime" "github.com/ethereum/go-ethereum/common" "github.com/ethereum/go-ethereum/core" @@ -52,10 +53,16 @@ func openLogFile(Datadir string, filename string) *os.File { // is redirected to a different file. func Fatalf(format string, args ...interface{}) { w := io.MultiWriter(os.Stdout, os.Stderr) - outf, _ := os.Stdout.Stat() - errf, _ := os.Stderr.Stat() - if outf != nil && errf != nil && os.SameFile(outf, errf) { - w = os.Stderr + if runtime.GOOS == "windows" { + // The SameFile check below doesn't work on Windows. + // stdout is unlikely to get redirected though, so just print there. + w = os.Stdout + } else { + outf, _ := os.Stdout.Stat() + errf, _ := os.Stderr.Stat() + if outf != nil && errf != nil && os.SameFile(outf, errf) { + w = os.Stderr + } } fmt.Fprintf(w, "Fatal: "+format+"\n", args...) logger.Flush() diff --git a/core/blockchain.go b/core/blockchain.go index 888c98dce..1fbcdfc6f 100644 --- a/core/blockchain.go +++ b/core/blockchain.go @@ -93,10 +93,11 @@ type BlockChain struct { currentBlock *types.Block // Current head of the block chain currentFastBlock *types.Block // Current head of the fast-sync chain (may be above the block chain!) - bodyCache *lru.Cache // Cache for the most recent block bodies - bodyRLPCache *lru.Cache // Cache for the most recent block bodies in RLP encoded format - blockCache *lru.Cache // Cache for the most recent entire blocks - futureBlocks *lru.Cache // future blocks are blocks added for later processing + stateCache *state.StateDB // State database to reuse between imports (contains state cache) + bodyCache *lru.Cache // Cache for the most recent block bodies + bodyRLPCache *lru.Cache // Cache for the most recent block bodies in RLP encoded format + blockCache *lru.Cache // Cache for the most recent entire blocks + futureBlocks *lru.Cache // future blocks are blocks added for later processing quit chan struct{} // blockchain quit channel running int32 // running must be called atomically @@ -196,7 +197,15 @@ func (self *BlockChain) loadLastState() error { self.currentFastBlock = block } } - // Issue a status log and return + // Initialize a statedb cache to ensure singleton account bloom filter generation + statedb, err := state.New(self.currentBlock.Root(), self.chainDb) + if err != nil { + return err + } + self.stateCache = statedb + self.stateCache.GetAccount(common.Address{}) + + // Issue a status log for the user headerTd := self.GetTd(currentHeader.Hash(), currentHeader.Number.Uint64()) blockTd := self.GetTd(self.currentBlock.Hash(), self.currentBlock.NumberU64()) fastTd := self.GetTd(self.currentFastBlock.Hash(), self.currentFastBlock.NumberU64()) @@ -348,7 +357,12 @@ func (self *BlockChain) AuxValidator() pow.PoW { return self.pow } // State returns a new mutable state based on the current HEAD block. func (self *BlockChain) State() (*state.StateDB, error) { - return state.New(self.CurrentBlock().Root(), self.chainDb) + return self.StateAt(self.CurrentBlock().Root()) +} + +// StateAt returns a new mutable state based on a particular point in time. +func (self *BlockChain) StateAt(root common.Hash) (*state.StateDB, error) { + return self.stateCache.New(root) } // Reset purges the entire blockchain, restoring it to its genesis state. @@ -826,7 +840,6 @@ func (self *BlockChain) InsertChain(chain types.Blocks) (int, error) { tstart = time.Now() nonceChecked = make([]bool, len(chain)) - statedb *state.StateDB ) // Start the parallel nonce verifier. @@ -893,29 +906,30 @@ func (self *BlockChain) InsertChain(chain types.Blocks) (int, error) { // Create a new statedb using the parent block and report an // error if it fails. - if statedb == nil { - statedb, err = state.New(self.GetBlock(block.ParentHash(), block.NumberU64()-1).Root(), self.chainDb) - } else { - err = statedb.Reset(chain[i-1].Root()) + switch { + case i == 0: + err = self.stateCache.Reset(self.GetBlock(block.ParentHash(), block.NumberU64()-1).Root()) + default: + err = self.stateCache.Reset(chain[i-1].Root()) } if err != nil { reportBlock(block, err) return i, err } // Process block using the parent state as reference point. - receipts, logs, usedGas, err := self.processor.Process(block, statedb, self.config.VmConfig) + receipts, logs, usedGas, err := self.processor.Process(block, self.stateCache, self.config.VmConfig) if err != nil { reportBlock(block, err) return i, err } // Validate the state using the default validator - err = self.Validator().ValidateState(block, self.GetBlock(block.ParentHash(), block.NumberU64()-1), statedb, receipts, usedGas) + err = self.Validator().ValidateState(block, self.GetBlock(block.ParentHash(), block.NumberU64()-1), self.stateCache, receipts, usedGas) if err != nil { reportBlock(block, err) return i, err } // Write state changes to database - _, err = statedb.Commit() + _, err = self.stateCache.Commit() if err != nil { return i, err } diff --git a/core/chain_makers_test.go b/core/chain_makers_test.go index f52b09ad9..5fc255c71 100644 --- a/core/chain_makers_test.go +++ b/core/chain_makers_test.go @@ -79,7 +79,7 @@ func ExampleGenerateChain() { evmux := &event.TypeMux{} blockchain, _ := NewBlockChain(db, MakeChainConfig(), FakePow{}, evmux) if i, err := blockchain.InsertChain(chain); err != nil { - fmt.Printf("insert error (block %d): %v\n", i, err) + fmt.Printf("insert error (block %d): %v\n", chain[i].NumberU64(), err) return } diff --git a/core/state/dump.go b/core/state/dump.go index a328b0537..58ecd852b 100644 --- a/core/state/dump.go +++ b/core/state/dump.go @@ -21,9 +21,10 @@ import ( "fmt" "github.com/ethereum/go-ethereum/common" + "github.com/ethereum/go-ethereum/rlp" ) -type Account struct { +type DumpAccount struct { Balance string `json:"balance"` Nonce uint64 `json:"nonce"` Root string `json:"root"` @@ -32,40 +33,41 @@ type Account struct { Storage map[string]string `json:"storage"` } -type World struct { - Root string `json:"root"` - Accounts map[string]Account `json:"accounts"` +type Dump struct { + Root string `json:"root"` + Accounts map[string]DumpAccount `json:"accounts"` } -func (self *StateDB) RawDump() World { - world := World{ +func (self *StateDB) RawDump() Dump { + dump := Dump{ Root: common.Bytes2Hex(self.trie.Root()), - Accounts: make(map[string]Account), + Accounts: make(map[string]DumpAccount), } it := self.trie.Iterator() for it.Next() { addr := self.trie.GetKey(it.Key) - stateObject, err := DecodeObject(common.BytesToAddress(addr), self.db, it.Value) - if err != nil { + var data Account + if err := rlp.DecodeBytes(it.Value, &data); err != nil { panic(err) } - account := Account{ - Balance: stateObject.balance.String(), - Nonce: stateObject.nonce, - Root: common.Bytes2Hex(stateObject.Root()), - CodeHash: common.Bytes2Hex(stateObject.codeHash), - Code: common.Bytes2Hex(stateObject.Code()), + obj := NewObject(common.BytesToAddress(addr), data, nil) + account := DumpAccount{ + Balance: data.Balance.String(), + Nonce: data.Nonce, + Root: common.Bytes2Hex(data.Root[:]), + CodeHash: common.Bytes2Hex(data.CodeHash), + Code: common.Bytes2Hex(obj.Code(self.db)), Storage: make(map[string]string), } - storageIt := stateObject.trie.Iterator() + storageIt := obj.getTrie(self.db).Iterator() for storageIt.Next() { account.Storage[common.Bytes2Hex(self.trie.GetKey(storageIt.Key))] = common.Bytes2Hex(storageIt.Value) } - world.Accounts[common.Bytes2Hex(addr)] = account + dump.Accounts[common.Bytes2Hex(addr)] = account } - return world + return dump } func (self *StateDB) Dump() []byte { @@ -76,12 +78,3 @@ func (self *StateDB) Dump() []byte { return json } - -// Debug stuff -func (self *StateObject) CreateOutputForDiff() { - fmt.Printf("%x %x %x %x\n", self.Address(), self.Root(), self.balance.Bytes(), self.nonce) - it := self.trie.Iterator() - for it.Next() { - fmt.Printf("%x %x\n", it.Key, it.Value) - } -} diff --git a/core/state/iterator.go b/core/state/iterator.go index 9d8a69b7c..14265b277 100644 --- a/core/state/iterator.go +++ b/core/state/iterator.go @@ -76,7 +76,7 @@ func (it *NodeIterator) step() error { } // Initialize the iterator if we've just started if it.stateIt == nil { - it.stateIt = trie.NewNodeIterator(it.state.trie.Trie) + it.stateIt = it.state.trie.NodeIterator() } // If we had data nodes previously, we surely have at least state nodes if it.dataIt != nil { diff --git a/core/state/managed_state.go b/core/state/managed_state.go index f8e2f2b87..ad73dc0dc 100644 --- a/core/state/managed_state.go +++ b/core/state/managed_state.go @@ -33,14 +33,14 @@ type ManagedState struct { mu sync.RWMutex - accounts map[string]*account + accounts map[common.Address]*account } // ManagedState returns a new managed state with the statedb as it's backing layer func ManageState(statedb *StateDB) *ManagedState { return &ManagedState{ StateDB: statedb.Copy(), - accounts: make(map[string]*account), + accounts: make(map[common.Address]*account), } } @@ -103,7 +103,7 @@ func (ms *ManagedState) SetNonce(addr common.Address, nonce uint64) { so := ms.GetOrNewStateObject(addr) so.SetNonce(nonce) - ms.accounts[addr.Str()] = newAccount(so) + ms.accounts[addr] = newAccount(so) } // HasAccount returns whether the given address is managed or not @@ -114,29 +114,28 @@ func (ms *ManagedState) HasAccount(addr common.Address) bool { } func (ms *ManagedState) hasAccount(addr common.Address) bool { - _, ok := ms.accounts[addr.Str()] + _, ok := ms.accounts[addr] return ok } // populate the managed state func (ms *ManagedState) getAccount(addr common.Address) *account { - straddr := addr.Str() - if account, ok := ms.accounts[straddr]; !ok { + if account, ok := ms.accounts[addr]; !ok { so := ms.GetOrNewStateObject(addr) - ms.accounts[straddr] = newAccount(so) + ms.accounts[addr] = newAccount(so) } else { // Always make sure the state account nonce isn't actually higher // than the tracked one. so := ms.StateDB.GetStateObject(addr) - if so != nil && uint64(len(account.nonces))+account.nstart < so.nonce { - ms.accounts[straddr] = newAccount(so) + if so != nil && uint64(len(account.nonces))+account.nstart < so.Nonce() { + ms.accounts[addr] = newAccount(so) } } - return ms.accounts[straddr] + return ms.accounts[addr] } func newAccount(so *StateObject) *account { - return &account{so, so.nonce, nil} + return &account{so, so.Nonce(), nil} } diff --git a/core/state/managed_state_test.go b/core/state/managed_state_test.go index 0b53a42c5..baa53428f 100644 --- a/core/state/managed_state_test.go +++ b/core/state/managed_state_test.go @@ -29,11 +29,12 @@ func create() (*ManagedState, *account) { db, _ := ethdb.NewMemDatabase() statedb, _ := New(common.Hash{}, db) ms := ManageState(statedb) - so := &StateObject{address: addr, nonce: 100} - ms.StateDB.stateObjects[addr.Str()] = so - ms.accounts[addr.Str()] = newAccount(so) + so := &StateObject{address: addr} + so.SetNonce(100) + ms.StateDB.stateObjects[addr] = so + ms.accounts[addr] = newAccount(so) - return ms, ms.accounts[addr.Str()] + return ms, ms.accounts[addr] } func TestNewNonce(t *testing.T) { @@ -92,7 +93,7 @@ func TestRemoteNonceChange(t *testing.T) { account.nonces = append(account.nonces, nn...) nonce := ms.NewNonce(addr) - ms.StateDB.stateObjects[addr.Str()].nonce = 200 + ms.StateDB.stateObjects[addr].data.Nonce = 200 nonce = ms.NewNonce(addr) if nonce != 200 { t.Error("expected nonce after remote update to be", 201, "got", nonce) @@ -100,7 +101,7 @@ func TestRemoteNonceChange(t *testing.T) { ms.NewNonce(addr) ms.NewNonce(addr) ms.NewNonce(addr) - ms.StateDB.stateObjects[addr.Str()].nonce = 200 + ms.StateDB.stateObjects[addr].data.Nonce = 200 nonce = ms.NewNonce(addr) if nonce != 204 { t.Error("expected nonce after remote update to be", 201, "got", nonce) diff --git a/core/state/state_object.go b/core/state/state_object.go index 769c63d42..a54620d55 100644 --- a/core/state/state_object.go +++ b/core/state/state_object.go @@ -57,143 +57,194 @@ func (self Storage) Copy() Storage { return cpy } +// StateObject represents an Ethereum account which is being modified. +// +// The usage pattern is as follows: +// First you need to obtain a state object. +// Account values can be accessed and modified through the object. +// Finally, call CommitTrie to write the modified storage trie into a database. type StateObject struct { - db trie.Database // State database for storing state changes - trie *trie.SecureTrie - - // Address belonging to this account - address common.Address - // The balance of the account - balance *big.Int - // The nonce of the account - nonce uint64 - // The code hash if code is present (i.e. a contract) - codeHash []byte - // The code for this account - code Code - // Temporarily initialisation code - initCode Code - // Cached storage (flushed when updated) - storage Storage - - // Mark for deletion + address common.Address // Ethereum address of this account + data Account + + // DB error. + // State objects are used by the consensus core and VM which are + // unable to deal with database-level errors. Any error that occurs + // during a database read is memoized here and will eventually be returned + // by StateDB.Commit. + dbErr error + + // Write caches. + trie *trie.SecureTrie // storage trie, which becomes non-nil on first access + code Code // contract bytecode, which gets set when code is loaded + storage Storage // Cached storage (flushed when updated) + + // Cache flags. // When an object is marked for deletion it will be delete from the trie // during the "update" phase of the state transition - remove bool - deleted bool - dirty bool + dirtyCode bool // true if the code was updated + remove bool + deleted bool + onDirty func(addr common.Address) // Callback method to mark a state object newly dirty } -func NewStateObject(address common.Address, db trie.Database) *StateObject { - object := &StateObject{ - db: db, - address: address, - balance: new(big.Int), - dirty: true, - codeHash: emptyCodeHash, - storage: make(Storage), - } - object.trie, _ = trie.NewSecure(common.Hash{}, db) - return object +// Account is the Ethereum consensus representation of accounts. +// These objects are stored in the main account trie. +type Account struct { + Nonce uint64 + Balance *big.Int + Root common.Hash // merkle root of the storage trie + CodeHash []byte } -func (self *StateObject) MarkForDeletion() { - self.remove = true - self.dirty = true - - if glog.V(logger.Core) { - glog.Infof("%x: #%d %v X\n", self.Address(), self.nonce, self.balance) +// NewObject creates a state object. +func NewObject(address common.Address, data Account, onDirty func(addr common.Address)) *StateObject { + if data.Balance == nil { + data.Balance = new(big.Int) } + if data.CodeHash == nil { + data.CodeHash = emptyCodeHash + } + return &StateObject{address: address, data: data, storage: make(Storage), onDirty: onDirty} } -func (c *StateObject) getAddr(addr common.Hash) common.Hash { - var ret []byte - rlp.DecodeBytes(c.trie.Get(addr[:]), &ret) - return common.BytesToHash(ret) +// EncodeRLP implements rlp.Encoder. +func (c *StateObject) EncodeRLP(w io.Writer) error { + return rlp.Encode(w, c.data) } -func (c *StateObject) setAddr(addr, value common.Hash) { - v, err := rlp.EncodeToBytes(bytes.TrimLeft(value[:], "\x00")) - if err != nil { - // if RLPing failed we better panic and not fail silently. This would be considered a consensus issue - panic(err) +// setError remembers the first non-nil error it is called with. +func (self *StateObject) setError(err error) { + if self.dbErr == nil { + self.dbErr = err } - c.trie.Update(addr[:], v) } -func (self *StateObject) Storage() Storage { - return self.storage +func (self *StateObject) MarkForDeletion() { + self.remove = true + if self.onDirty != nil { + self.onDirty(self.Address()) + self.onDirty = nil + } + if glog.V(logger.Core) { + glog.Infof("%x: #%d %v X\n", self.Address(), self.Nonce(), self.Balance()) + } } -func (self *StateObject) GetState(key common.Hash) common.Hash { - value, exists := self.storage[key] - if !exists { - value = self.getAddr(key) - if (value != common.Hash{}) { - self.storage[key] = value +func (c *StateObject) getTrie(db trie.Database) *trie.SecureTrie { + if c.trie == nil { + var err error + c.trie, err = trie.NewSecure(c.data.Root, db) + if err != nil { + c.trie, _ = trie.NewSecure(common.Hash{}, db) + c.setError(fmt.Errorf("can't create storage trie: %v", err)) } } + return c.trie +} +// GetState returns a value in account storage. +func (self *StateObject) GetState(db trie.Database, key common.Hash) common.Hash { + value, exists := self.storage[key] + if exists { + return value + } + // Load from DB in case it is missing. + tr := self.getTrie(db) + var ret []byte + rlp.DecodeBytes(tr.Get(key[:]), &ret) + value = common.BytesToHash(ret) + if (value != common.Hash{}) { + self.storage[key] = value + } return value } +// SetState updates a value in account storage. func (self *StateObject) SetState(key, value common.Hash) { self.storage[key] = value - self.dirty = true + if self.onDirty != nil { + self.onDirty(self.Address()) + self.onDirty = nil + } } -// Update updates the current cached storage to the trie -func (self *StateObject) Update() { +// updateTrie writes cached storage modifications into the object's storage trie. +func (self *StateObject) updateTrie(db trie.Database) { + tr := self.getTrie(db) for key, value := range self.storage { if (value == common.Hash{}) { - self.trie.Delete(key[:]) + tr.Delete(key[:]) continue } - self.setAddr(key, value) + // Encoding []byte cannot fail, ok to ignore the error. + v, _ := rlp.EncodeToBytes(bytes.TrimLeft(value[:], "\x00")) + tr.Update(key[:], v) } } +// UpdateRoot sets the trie root to the current root hash of +func (self *StateObject) UpdateRoot(db trie.Database) { + self.updateTrie(db) + self.data.Root = self.trie.Hash() +} + +// CommitTrie the storage trie of the object to dwb. +// This updates the trie root. +func (self *StateObject) CommitTrie(db trie.Database, dbw trie.DatabaseWriter) error { + self.updateTrie(db) + if self.dbErr != nil { + fmt.Println("dbErr:", self.dbErr) + return self.dbErr + } + root, err := self.trie.CommitTo(dbw) + if err == nil { + self.data.Root = root + } + return err +} + func (c *StateObject) AddBalance(amount *big.Int) { - c.SetBalance(new(big.Int).Add(c.balance, amount)) + if amount.Cmp(common.Big0) == 0 { + return + } + c.SetBalance(new(big.Int).Add(c.Balance(), amount)) if glog.V(logger.Core) { - glog.Infof("%x: #%d %v (+ %v)\n", c.Address(), c.nonce, c.balance, amount) + glog.Infof("%x: #%d %v (+ %v)\n", c.Address(), c.Nonce(), c.Balance(), amount) } } func (c *StateObject) SubBalance(amount *big.Int) { - c.SetBalance(new(big.Int).Sub(c.balance, amount)) + if amount.Cmp(common.Big0) == 0 { + return + } + c.SetBalance(new(big.Int).Sub(c.Balance(), amount)) if glog.V(logger.Core) { - glog.Infof("%x: #%d %v (- %v)\n", c.Address(), c.nonce, c.balance, amount) + glog.Infof("%x: #%d %v (- %v)\n", c.Address(), c.Nonce(), c.Balance(), amount) } } -func (c *StateObject) SetBalance(amount *big.Int) { - c.balance = amount - c.dirty = true -} - -func (c *StateObject) St() Storage { - return c.storage +func (self *StateObject) SetBalance(amount *big.Int) { + self.data.Balance = amount + if self.onDirty != nil { + self.onDirty(self.Address()) + self.onDirty = nil + } } // Return the gas back to the origin. Used by the Virtual machine or Closures func (c *StateObject) ReturnGas(gas, price *big.Int) {} -func (self *StateObject) Copy() *StateObject { - stateObject := NewStateObject(self.Address(), self.db) - stateObject.balance.Set(self.balance) - stateObject.codeHash = common.CopyBytes(self.codeHash) - stateObject.nonce = self.nonce +func (self *StateObject) Copy(db trie.Database, onDirty func(addr common.Address)) *StateObject { + stateObject := NewObject(self.address, self.data, onDirty) stateObject.trie = self.trie - stateObject.code = common.CopyBytes(self.code) - stateObject.initCode = common.CopyBytes(self.initCode) + stateObject.code = self.code stateObject.storage = self.storage.Copy() stateObject.remove = self.remove - stateObject.dirty = self.dirty + stateObject.dirtyCode = self.dirtyCode stateObject.deleted = self.deleted - return stateObject } @@ -201,40 +252,55 @@ func (self *StateObject) Copy() *StateObject { // Attribute accessors // -func (self *StateObject) Balance() *big.Int { - return self.balance -} - // Returns the address of the contract/account func (c *StateObject) Address() common.Address { return c.address } -func (self *StateObject) Trie() *trie.SecureTrie { - return self.trie +// Code returns the contract code associated with this object, if any. +func (self *StateObject) Code(db trie.Database) []byte { + if self.code != nil { + return self.code + } + if bytes.Equal(self.CodeHash(), emptyCodeHash) { + return nil + } + code, err := db.Get(self.CodeHash()) + if err != nil { + self.setError(fmt.Errorf("can't load code hash %x: %v", self.CodeHash(), err)) + } + self.code = code + return code } -func (self *StateObject) Root() []byte { - return self.trie.Root() +func (self *StateObject) SetCode(code []byte) { + self.code = code + self.data.CodeHash = crypto.Keccak256(code) + self.dirtyCode = true + if self.onDirty != nil { + self.onDirty(self.Address()) + self.onDirty = nil + } } -func (self *StateObject) Code() []byte { - return self.code +func (self *StateObject) SetNonce(nonce uint64) { + self.data.Nonce = nonce + if self.onDirty != nil { + self.onDirty(self.Address()) + self.onDirty = nil + } } -func (self *StateObject) SetCode(code []byte) { - self.code = code - self.codeHash = crypto.Keccak256(code) - self.dirty = true +func (self *StateObject) CodeHash() []byte { + return self.data.CodeHash } -func (self *StateObject) SetNonce(nonce uint64) { - self.nonce = nonce - self.dirty = true +func (self *StateObject) Balance() *big.Int { + return self.data.Balance } func (self *StateObject) Nonce() uint64 { - return self.nonce + return self.data.Nonce } // Never called, but must be present to allow StateObject to be used @@ -259,39 +325,3 @@ func (self *StateObject) ForEachStorage(cb func(key, value common.Hash) bool) { } } } - -type extStateObject struct { - Nonce uint64 - Balance *big.Int - Root common.Hash - CodeHash []byte -} - -// EncodeRLP implements rlp.Encoder. -func (c *StateObject) EncodeRLP(w io.Writer) error { - return rlp.Encode(w, []interface{}{c.nonce, c.balance, c.Root(), c.codeHash}) -} - -// DecodeObject decodes an RLP-encoded state object. -func DecodeObject(address common.Address, db trie.Database, data []byte) (*StateObject, error) { - var ( - obj = &StateObject{address: address, db: db, storage: make(Storage)} - ext extStateObject - err error - ) - if err = rlp.DecodeBytes(data, &ext); err != nil { - return nil, err - } - if obj.trie, err = trie.NewSecure(ext.Root, db); err != nil { - return nil, err - } - if !bytes.Equal(ext.CodeHash, emptyCodeHash) { - if obj.code, err = db.Get(ext.CodeHash); err != nil { - return nil, fmt.Errorf("can't get code for hash %x: %v", ext.CodeHash, err) - } - } - obj.nonce = ext.Nonce - obj.balance = ext.Balance - obj.codeHash = ext.CodeHash - return obj, nil -} diff --git a/core/state/state_test.go b/core/state/state_test.go index ce86a5b76..fcdc38588 100644 --- a/core/state/state_test.go +++ b/core/state/state_test.go @@ -146,22 +146,23 @@ func TestSnapshot2(t *testing.T) { // db, trie are already non-empty values so0 := state.GetStateObject(stateobjaddr0) - so0.balance = big.NewInt(42) - so0.nonce = 43 + so0.SetBalance(big.NewInt(42)) + so0.SetNonce(43) so0.SetCode([]byte{'c', 'a', 'f', 'e'}) - so0.remove = true + so0.remove = false so0.deleted = false - so0.dirty = false state.SetStateObject(so0) + root, _ := state.Commit() + state.Reset(root) + // and one with deleted == true so1 := state.GetStateObject(stateobjaddr1) - so1.balance = big.NewInt(52) - so1.nonce = 53 + so1.SetBalance(big.NewInt(52)) + so1.SetNonce(53) so1.SetCode([]byte{'c', 'a', 'f', 'e', '2'}) so1.remove = true so1.deleted = true - so1.dirty = true state.SetStateObject(so1) so1 = state.GetStateObject(stateobjaddr1) @@ -173,43 +174,50 @@ func TestSnapshot2(t *testing.T) { state.Set(snapshot) so0Restored := state.GetStateObject(stateobjaddr0) - so1Restored := state.GetStateObject(stateobjaddr1) + // Update lazily-loaded values before comparing. + so0Restored.GetState(db, storageaddr) + so0Restored.Code(db) // non-deleted is equal (restored) compareStateObjects(so0Restored, so0, t) + // deleted should be nil, both before and after restore of state copy + so1Restored := state.GetStateObject(stateobjaddr1) if so1Restored != nil { - t.Fatalf("deleted object not nil after restoring snapshot") + t.Fatalf("deleted object not nil after restoring snapshot: %+v", so1Restored) } } func compareStateObjects(so0, so1 *StateObject, t *testing.T) { - if so0.address != so1.address { + if so0.Address() != so1.Address() { t.Fatalf("Address mismatch: have %v, want %v", so0.address, so1.address) } - if so0.balance.Cmp(so1.balance) != 0 { - t.Fatalf("Balance mismatch: have %v, want %v", so0.balance, so1.balance) + if so0.Balance().Cmp(so1.Balance()) != 0 { + t.Fatalf("Balance mismatch: have %v, want %v", so0.Balance(), so1.Balance()) } - if so0.nonce != so1.nonce { - t.Fatalf("Nonce mismatch: have %v, want %v", so0.nonce, so1.nonce) + if so0.Nonce() != so1.Nonce() { + t.Fatalf("Nonce mismatch: have %v, want %v", so0.Nonce(), so1.Nonce()) } - if !bytes.Equal(so0.codeHash, so1.codeHash) { - t.Fatalf("CodeHash mismatch: have %v, want %v", so0.codeHash, so1.codeHash) + if so0.data.Root != so1.data.Root { + t.Errorf("Root mismatch: have %x, want %x", so0.data.Root[:], so1.data.Root[:]) + } + if !bytes.Equal(so0.CodeHash(), so1.CodeHash()) { + t.Fatalf("CodeHash mismatch: have %v, want %v", so0.CodeHash(), so1.CodeHash()) } if !bytes.Equal(so0.code, so1.code) { t.Fatalf("Code mismatch: have %v, want %v", so0.code, so1.code) } - if !bytes.Equal(so0.initCode, so1.initCode) { - t.Fatalf("InitCode mismatch: have %v, want %v", so0.initCode, so1.initCode) - } + if len(so1.storage) != len(so0.storage) { + t.Errorf("Storage size mismatch: have %d, want %d", len(so1.storage), len(so0.storage)) + } for k, v := range so1.storage { if so0.storage[k] != v { - t.Fatalf("Storage key %s mismatch: have %v, want %v", k, so0.storage[k], v) + t.Errorf("Storage key %x mismatch: have %v, want %v", k, so0.storage[k], v) } } for k, v := range so0.storage { if so1.storage[k] != v { - t.Fatalf("Storage key %s mismatch: have %v, want none.", k, v) + t.Errorf("Storage key %x mismatch: have %v, want none.", k, v) } } @@ -219,7 +227,4 @@ func compareStateObjects(so0, so1 *StateObject, t *testing.T) { if so0.deleted != so1.deleted { t.Fatalf("Deleted mismatch: have %v, want %v", so0.deleted, so1.deleted) } - if so0.dirty != so1.dirty { - t.Fatalf("Dirty mismatch: have %v, want %v", so0.dirty, so1.dirty) - } } diff --git a/core/state/statedb.go b/core/state/statedb.go index 3e25e0c16..5c51e3b59 100644 --- a/core/state/statedb.go +++ b/core/state/statedb.go @@ -20,6 +20,7 @@ package state import ( "fmt" "math/big" + "sync" "github.com/ethereum/go-ethereum/common" "github.com/ethereum/go-ethereum/core/vm" @@ -28,29 +29,46 @@ import ( "github.com/ethereum/go-ethereum/logger/glog" "github.com/ethereum/go-ethereum/rlp" "github.com/ethereum/go-ethereum/trie" + lru "github.com/hashicorp/golang-lru" ) // The starting nonce determines the default nonce when new accounts are being // created. var StartingNonce uint64 +const ( + // Number of past tries to keep. The arbitrarily chosen value here + // is max uncle depth + 1. + maxJournalLength = 8 + + // Number of codehash->size associations to keep. + codeSizeCacheSize = 100000 +) + // StateDBs within the ethereum protocol are used to store anything // within the merkle trie. StateDBs take care of caching and storing // nested states. It's the general query interface to retrieve: // * Contracts // * Accounts type StateDB struct { - db ethdb.Database - trie *trie.SecureTrie + db ethdb.Database + trie *trie.SecureTrie + pastTries []*trie.SecureTrie + codeSizeCache *lru.Cache - stateObjects map[string]*StateObject + // This map holds 'live' objects, which will get modified while processing a state transition. + stateObjects map[common.Address]*StateObject + stateObjectsDirty map[common.Address]struct{} + // The refund counter, also used by state transitioning. refund *big.Int thash, bhash common.Hash txIndex int logs map[common.Hash]vm.Logs logSize uint + + lock sync.Mutex } // Create a new state from a given trie @@ -59,35 +77,84 @@ func New(root common.Hash, db ethdb.Database) (*StateDB, error) { if err != nil { return nil, err } + csc, _ := lru.New(codeSizeCacheSize) return &StateDB{ - db: db, - trie: tr, - stateObjects: make(map[string]*StateObject), - refund: new(big.Int), - logs: make(map[common.Hash]vm.Logs), + db: db, + trie: tr, + codeSizeCache: csc, + stateObjects: make(map[common.Address]*StateObject), + stateObjectsDirty: make(map[common.Address]struct{}), + refund: new(big.Int), + logs: make(map[common.Hash]vm.Logs), + }, nil +} + +// New creates a new statedb by reusing any journalled tries to avoid costly +// disk io. +func (self *StateDB) New(root common.Hash) (*StateDB, error) { + self.lock.Lock() + defer self.lock.Unlock() + + tr, err := self.openTrie(root) + if err != nil { + return nil, err + } + return &StateDB{ + db: self.db, + trie: tr, + codeSizeCache: self.codeSizeCache, + stateObjects: make(map[common.Address]*StateObject), + stateObjectsDirty: make(map[common.Address]struct{}), + refund: new(big.Int), + logs: make(map[common.Hash]vm.Logs), }, nil } // Reset clears out all emphemeral state objects from the state db, but keeps // the underlying state trie to avoid reloading data for the next operations. func (self *StateDB) Reset(root common.Hash) error { - var ( - err error - tr = self.trie - ) - if self.trie.Hash() != root { - if tr, err = trie.NewSecure(root, self.db); err != nil { - return err + self.lock.Lock() + defer self.lock.Unlock() + + tr, err := self.openTrie(root) + if err != nil { + return err + } + self.trie = tr + self.stateObjects = make(map[common.Address]*StateObject) + self.stateObjectsDirty = make(map[common.Address]struct{}) + self.refund = new(big.Int) + self.thash = common.Hash{} + self.bhash = common.Hash{} + self.txIndex = 0 + self.logs = make(map[common.Hash]vm.Logs) + self.logSize = 0 + + return nil +} + +// openTrie creates a trie. It uses an existing trie if one is available +// from the journal if available. +func (self *StateDB) openTrie(root common.Hash) (*trie.SecureTrie, error) { + for i := len(self.pastTries) - 1; i >= 0; i-- { + if self.pastTries[i].Hash() == root { + tr := *self.pastTries[i] + return &tr, nil } } - *self = StateDB{ - db: self.db, - trie: tr, - stateObjects: make(map[string]*StateObject), - refund: new(big.Int), - logs: make(map[common.Hash]vm.Logs), + return trie.NewSecure(root, self.db) +} + +func (self *StateDB) pushTrie(t *trie.SecureTrie) { + self.lock.Lock() + defer self.lock.Unlock() + + if len(self.pastTries) >= maxJournalLength { + copy(self.pastTries, self.pastTries[1:]) + self.pastTries[len(self.pastTries)-1] = t + } else { + self.pastTries = append(self.pastTries, t) } - return nil } func (self *StateDB) StartRecord(thash, bhash common.Hash, ti int) { @@ -137,7 +204,7 @@ func (self *StateDB) GetAccount(addr common.Address) vm.Account { func (self *StateDB) GetBalance(addr common.Address) *big.Int { stateObject := self.GetStateObject(addr) if stateObject != nil { - return stateObject.balance + return stateObject.Balance() } return common.Big0 @@ -146,7 +213,7 @@ func (self *StateDB) GetBalance(addr common.Address) *big.Int { func (self *StateDB) GetNonce(addr common.Address) uint64 { stateObject := self.GetStateObject(addr) if stateObject != nil { - return stateObject.nonce + return stateObject.Nonce() } return StartingNonce @@ -155,18 +222,35 @@ func (self *StateDB) GetNonce(addr common.Address) uint64 { func (self *StateDB) GetCode(addr common.Address) []byte { stateObject := self.GetStateObject(addr) if stateObject != nil { - return stateObject.code + code := stateObject.Code(self.db) + key := common.BytesToHash(stateObject.CodeHash()) + self.codeSizeCache.Add(key, len(code)) + return code } - return nil } +func (self *StateDB) GetCodeSize(addr common.Address) int { + stateObject := self.GetStateObject(addr) + if stateObject == nil { + return 0 + } + key := common.BytesToHash(stateObject.CodeHash()) + if cached, ok := self.codeSizeCache.Get(key); ok { + return cached.(int) + } + size := len(stateObject.Code(self.db)) + if stateObject.dbErr == nil { + self.codeSizeCache.Add(key, size) + } + return size +} + func (self *StateDB) GetState(a common.Address, b common.Hash) common.Hash { stateObject := self.GetStateObject(a) if stateObject != nil { - return stateObject.GetState(b) + return stateObject.GetState(self.db, b) } - return common.Hash{} } @@ -214,8 +298,7 @@ func (self *StateDB) Delete(addr common.Address) bool { stateObject := self.GetStateObject(addr) if stateObject != nil { stateObject.MarkForDeletion() - stateObject.balance = new(big.Int) - + stateObject.data.Balance = new(big.Int) return true } @@ -242,35 +325,36 @@ func (self *StateDB) DeleteStateObject(stateObject *StateObject) { addr := stateObject.Address() self.trie.Delete(addr[:]) - //delete(self.stateObjects, addr.Str()) } -// Retrieve a state object given my the address. Nil if not found +// Retrieve a state object given my the address. Returns nil if not found. func (self *StateDB) GetStateObject(addr common.Address) (stateObject *StateObject) { - stateObject = self.stateObjects[addr.Str()] - if stateObject != nil { - if stateObject.deleted { - stateObject = nil + // Prefer 'live' objects. + if obj := self.stateObjects[addr]; obj != nil { + if obj.deleted { + return nil } - - return stateObject + return obj } - data := self.trie.Get(addr[:]) - if len(data) == 0 { + // Load the object from the database. + enc := self.trie.Get(addr[:]) + if len(enc) == 0 { return nil } - stateObject, err := DecodeObject(addr, self.db, data) - if err != nil { + var data Account + if err := rlp.DecodeBytes(enc, &data); err != nil { glog.Errorf("can't decode object at %x: %v", addr[:], err) return nil } - self.SetStateObject(stateObject) - return stateObject + // Insert into the live set. + obj := NewObject(addr, data, self.MarkStateObjectDirty) + self.SetStateObject(obj) + return obj } func (self *StateDB) SetStateObject(object *StateObject) { - self.stateObjects[object.Address().Str()] = object + self.stateObjects[object.Address()] = object } // Retrieve a state object or create a new state object if nil @@ -288,15 +372,19 @@ func (self *StateDB) newStateObject(addr common.Address) *StateObject { if glog.V(logger.Core) { glog.Infof("(+) %x\n", addr) } + obj := NewObject(addr, Account{}, self.MarkStateObjectDirty) + obj.SetNonce(StartingNonce) // sets the object to dirty + self.stateObjects[addr] = obj + return obj +} - stateObject := NewStateObject(addr, self.db) - stateObject.SetNonce(StartingNonce) - self.stateObjects[addr.Str()] = stateObject - - return stateObject +// MarkStateObjectDirty adds the specified object to the dirty map to avoid costly +// state object cache iteration to find a handful of modified ones. +func (self *StateDB) MarkStateObjectDirty(addr common.Address) { + self.stateObjectsDirty[addr] = struct{}{} } -// Creates creates a new state object and takes ownership. This is different from "NewStateObject" +// Creates creates a new state object and takes ownership. func (self *StateDB) CreateStateObject(addr common.Address) *StateObject { // Get previous (if any) so := self.GetStateObject(addr) @@ -305,7 +393,7 @@ func (self *StateDB) CreateStateObject(addr common.Address) *StateObject { // If it existed set the balance to the new account if so != nil { - newSo.balance = so.balance + newSo.data.Balance = so.data.Balance } return newSo @@ -320,28 +408,43 @@ func (self *StateDB) CreateAccount(addr common.Address) vm.Account { // func (self *StateDB) Copy() *StateDB { - // ignore error - we assume state-to-be-copied always exists - state, _ := New(common.Hash{}, self.db) - state.trie = self.trie - for k, stateObject := range self.stateObjects { - state.stateObjects[k] = stateObject.Copy() + self.lock.Lock() + defer self.lock.Unlock() + + // Copy all the basic fields, initialize the memory ones + state := &StateDB{ + db: self.db, + trie: self.trie, + pastTries: self.pastTries, + codeSizeCache: self.codeSizeCache, + stateObjects: make(map[common.Address]*StateObject, len(self.stateObjectsDirty)), + stateObjectsDirty: make(map[common.Address]struct{}, len(self.stateObjectsDirty)), + refund: new(big.Int).Set(self.refund), + logs: make(map[common.Hash]vm.Logs, len(self.logs)), + logSize: self.logSize, + } + // Copy the dirty states and logs + for addr, _ := range self.stateObjectsDirty { + state.stateObjects[addr] = self.stateObjects[addr].Copy(self.db, state.MarkStateObjectDirty) + state.stateObjectsDirty[addr] = struct{}{} } - - state.refund.Set(self.refund) - for hash, logs := range self.logs { state.logs[hash] = make(vm.Logs, len(logs)) copy(state.logs[hash], logs) } - state.logSize = self.logSize - return state } func (self *StateDB) Set(state *StateDB) { + self.lock.Lock() + defer self.lock.Unlock() + + self.db = state.db self.trie = state.trie + self.pastTries = state.pastTries self.stateObjects = state.stateObjects - + self.stateObjectsDirty = state.stateObjectsDirty + self.codeSizeCache = state.codeSizeCache self.refund = state.refund self.logs = state.logs self.logSize = state.logSize @@ -356,15 +459,13 @@ func (self *StateDB) GetRefund() *big.Int { // goes into transaction receipts. func (s *StateDB) IntermediateRoot() common.Hash { s.refund = new(big.Int) - for _, stateObject := range s.stateObjects { - if stateObject.dirty { - if stateObject.remove { - s.DeleteStateObject(stateObject) - } else { - stateObject.Update() - s.UpdateStateObject(stateObject) - } - stateObject.dirty = false + for addr, _ := range s.stateObjectsDirty { + stateObject := s.stateObjects[addr] + if stateObject.remove { + s.DeleteStateObject(stateObject) + } else { + stateObject.UpdateRoot(s.db) + s.UpdateStateObject(stateObject) } } return s.trie.Hash() @@ -379,15 +480,15 @@ func (s *StateDB) DeleteSuicides() { // Reset refund so that any used-gas calculations can use // this method. s.refund = new(big.Int) - for _, stateObject := range s.stateObjects { - if stateObject.dirty { - // If the object has been removed by a suicide - // flag the object as deleted. - if stateObject.remove { - stateObject.deleted = true - } - stateObject.dirty = false + for addr, _ := range s.stateObjectsDirty { + stateObject := s.stateObjects[addr] + + // If the object has been removed by a suicide + // flag the object as deleted. + if stateObject.remove { + stateObject.deleted = true } + delete(s.stateObjectsDirty, addr) } } @@ -406,46 +507,40 @@ func (s *StateDB) CommitBatch() (root common.Hash, batch ethdb.Batch) { return root, batch } -func (s *StateDB) commit(db trie.DatabaseWriter) (common.Hash, error) { +func (s *StateDB) commit(dbw trie.DatabaseWriter) (root common.Hash, err error) { s.refund = new(big.Int) - for _, stateObject := range s.stateObjects { + // Commit objects to the trie. + for addr, stateObject := range s.stateObjects { if stateObject.remove { // If the object has been removed, don't bother syncing it // and just mark it for deletion in the trie. s.DeleteStateObject(stateObject) - } else { + } else if _, ok := s.stateObjectsDirty[addr]; ok { // Write any contract code associated with the state object - if len(stateObject.code) > 0 { - if err := db.Put(stateObject.codeHash, stateObject.code); err != nil { + if stateObject.code != nil && stateObject.dirtyCode { + if err := dbw.Put(stateObject.CodeHash(), stateObject.code); err != nil { return common.Hash{}, err } + stateObject.dirtyCode = false } - // Write any storage changes in the state object to its trie. - stateObject.Update() - - // Commit the trie of the object to the batch. - // This updates the trie root internally, so - // getting the root hash of the storage trie - // through UpdateStateObject is fast. - if _, err := stateObject.trie.CommitTo(db); err != nil { + // Write any storage changes in the state object to its storage trie. + if err := stateObject.CommitTrie(s.db, dbw); err != nil { return common.Hash{}, err } - // Update the object in the account trie. + // Update the object in the main account trie. s.UpdateStateObject(stateObject) } - stateObject.dirty = false + delete(s.stateObjectsDirty, addr) + } + // Write trie changes. + root, err = s.trie.CommitTo(dbw) + if err == nil { + s.pushTrie(s.trie) } - return s.trie.CommitTo(db) + return root, err } func (self *StateDB) Refunds() *big.Int { return self.refund } - -// Debug stuff -func (self *StateDB) CreateOutputForDiff() { - for _, stateObject := range self.stateObjects { - stateObject.CreateOutputForDiff() - } -} diff --git a/core/state/sync_test.go b/core/state/sync_test.go index 715645c6c..c768781a4 100644 --- a/core/state/sync_test.go +++ b/core/state/sync_test.go @@ -62,9 +62,6 @@ func makeTestState() (ethdb.Database, common.Hash, []*testAccount) { } root, _ := state.Commit() - // Remove any potentially cached data from the test state creation - trie.ClearGlobalCache() - // Return the generated state return db, root, accounts } @@ -72,9 +69,6 @@ func makeTestState() (ethdb.Database, common.Hash, []*testAccount) { // checkStateAccounts cross references a reconstructed state with an expected // account array. func checkStateAccounts(t *testing.T, db ethdb.Database, root common.Hash, accounts []*testAccount) { - // Remove any potentially cached data from the state synchronisation - trie.ClearGlobalCache() - // Check root availability and state contents state, err := New(root, db) if err != nil { @@ -98,9 +92,6 @@ func checkStateAccounts(t *testing.T, db ethdb.Database, root common.Hash, accou // checkStateConsistency checks that all nodes in a state trie are indeed present. func checkStateConsistency(db ethdb.Database, root common.Hash) error { - // Remove any potentially cached data from the test state creation or previous checks - trie.ClearGlobalCache() - // Create and iterate a state trie rooted in a sub-node if _, err := db.Get(root.Bytes()); err != nil { return nil // Consider a non existent state consistent diff --git a/core/vm/environment.go b/core/vm/environment.go index 747627565..4bd03de7e 100644 --- a/core/vm/environment.go +++ b/core/vm/environment.go @@ -94,6 +94,7 @@ type Database interface { GetNonce(common.Address) uint64 SetNonce(common.Address, uint64) + GetCodeSize(common.Address) int GetCode(common.Address) []byte SetCode(common.Address, []byte) diff --git a/core/vm/instructions.go b/core/vm/instructions.go index a95ba26c5..849a8463c 100644 --- a/core/vm/instructions.go +++ b/core/vm/instructions.go @@ -363,7 +363,7 @@ func opCalldataCopy(instr instruction, pc *uint64, env Environment, contract *Co func opExtCodeSize(instr instruction, pc *uint64, env Environment, contract *Contract, memory *Memory, stack *Stack) { addr := common.BigToAddress(stack.pop()) - l := big.NewInt(int64(len(env.Db().GetCode(addr)))) + l := big.NewInt(int64(env.Db().GetCodeSize(addr))) stack.push(l) } diff --git a/eth/api.go b/eth/api.go index f4bce47b8..c2fdbe99c 100644 --- a/eth/api.go +++ b/eth/api.go @@ -288,14 +288,14 @@ func NewPublicDebugAPI(eth *Ethereum) *PublicDebugAPI { } // DumpBlock retrieves the entire state of the database at a given block. -func (api *PublicDebugAPI) DumpBlock(number uint64) (state.World, error) { +func (api *PublicDebugAPI) DumpBlock(number uint64) (state.Dump, error) { block := api.eth.BlockChain().GetBlockByNumber(number) if block == nil { - return state.World{}, fmt.Errorf("block #%d not found", number) + return state.Dump{}, fmt.Errorf("block #%d not found", number) } - stateDb, err := state.New(block.Root(), api.eth.ChainDb()) + stateDb, err := api.eth.BlockChain().StateAt(block.Root()) if err != nil { - return state.World{}, err + return state.Dump{}, err } return stateDb.RawDump(), nil } @@ -406,7 +406,7 @@ func (api *PrivateDebugAPI) traceBlock(block *types.Block, logConfig *vm.LogConf if err := core.ValidateHeader(api.config, blockchain.AuxValidator(), block.Header(), blockchain.GetHeader(block.ParentHash(), block.NumberU64()-1), true, false); err != nil { return false, structLogger.StructLogs(), err } - statedb, err := state.New(blockchain.GetBlock(block.ParentHash(), block.NumberU64()-1).Root(), api.eth.ChainDb()) + statedb, err := blockchain.StateAt(blockchain.GetBlock(block.ParentHash(), block.NumberU64()-1).Root()) if err != nil { return false, structLogger.StructLogs(), err } @@ -501,7 +501,7 @@ func (api *PrivateDebugAPI) TraceTransaction(ctx context.Context, txHash common. if parent == nil { return nil, fmt.Errorf("block parent %x not found", block.ParentHash()) } - stateDb, err := state.New(parent.Root(), api.eth.ChainDb()) + stateDb, err := api.eth.BlockChain().StateAt(parent.Root()) if err != nil { return nil, err } diff --git a/eth/api_backend.go b/eth/api_backend.go index 4f8f06529..4adeb0aa0 100644 --- a/eth/api_backend.go +++ b/eth/api_backend.go @@ -81,7 +81,7 @@ func (b *EthApiBackend) StateAndHeaderByNumber(blockNr rpc.BlockNumber) (ethapi. if header == nil { return nil, nil, nil } - stateDb, err := state.New(header.Root, b.eth.chainDb) + stateDb, err := b.eth.BlockChain().StateAt(header.Root) return EthApiState{stateDb}, header, err } diff --git a/ethclient/ethclient.go b/ethclient/ethclient.go index ffa8228cc..aa7796f32 100644 --- a/ethclient/ethclient.go +++ b/ethclient/ethclient.go @@ -269,7 +269,7 @@ func (ec *Client) NonceAt(ctx context.Context, account common.Address, blockNumb // FilterLogs executes a filter query. func (ec *Client) FilterLogs(ctx context.Context, q ethereum.FilterQuery) ([]vm.Log, error) { var result []vm.Log - err := ec.c.CallContext(ctx, &result, "eth_getFilterLogs", toFilterArg(q)) + err := ec.c.CallContext(ctx, &result, "eth_getLogs", toFilterArg(q)) return result, err } @@ -281,7 +281,7 @@ func (ec *Client) SubscribeFilterLogs(ctx context.Context, q ethereum.FilterQuer func toFilterArg(q ethereum.FilterQuery) interface{} { arg := map[string]interface{}{ "fromBlock": toBlockNumArg(q.FromBlock), - "endBlock": toBlockNumArg(q.ToBlock), + "toBlock": toBlockNumArg(q.ToBlock), "addresses": q.Addresses, "topics": q.Topics, } diff --git a/ethdb/database.go b/ethdb/database.go index 2e951927c..479c54b60 100644 --- a/ethdb/database.go +++ b/ethdb/database.go @@ -28,6 +28,7 @@ import ( "github.com/ethereum/go-ethereum/metrics" "github.com/syndtr/goleveldb/leveldb" "github.com/syndtr/goleveldb/leveldb/errors" + "github.com/syndtr/goleveldb/leveldb/filter" "github.com/syndtr/goleveldb/leveldb/iterator" "github.com/syndtr/goleveldb/leveldb/opt" @@ -84,6 +85,7 @@ func NewLDBDatabase(file string, cache int, handles int) (*LDBDatabase, error) { OpenFilesCacheCapacity: handles, BlockCacheCapacity: cache / 2 * opt.MiB, WriteBuffer: cache / 4 * opt.MiB, // Two of these are used internally + Filter: filter.NewBloomFilter(10), }) if _, corrupted := err.(*errors.ErrCorrupted); corrupted { db, err = leveldb.RecoverFile(file, nil) diff --git a/internal/ethapi/api.go b/internal/ethapi/api.go index 0b1384f58..9a97be25f 100644 --- a/internal/ethapi/api.go +++ b/internal/ethapi/api.go @@ -454,6 +454,8 @@ type CallArgs struct { } func (s *PublicBlockChainAPI) doCall(ctx context.Context, args CallArgs, blockNr rpc.BlockNumber) (string, *big.Int, error) { + defer func(start time.Time) { glog.V(logger.Debug).Infof("call took %v", time.Since(start)) }(time.Now()) + state, header, err := s.b.StateAndHeaderByNumber(blockNr) if state == nil || err != nil { return "0x", common.Big0, err @@ -1280,8 +1282,8 @@ func (api *PrivateDebugAPI) ChaindbProperty(property string) (string, error) { } // SetHead rewinds the head of the blockchain to a previous block. -func (api *PrivateDebugAPI) SetHead(number uint64) { - api.b.SetHead(number) +func (api *PrivateDebugAPI) SetHead(number rpc.HexNumber) { + api.b.SetHead(uint64(number.Int64())) } // PublicNetAPI offers network related RPC methods diff --git a/light/state.go b/light/state.go index e18f9cdc5..4f2177238 100644 --- a/light/state.go +++ b/light/state.go @@ -261,7 +261,9 @@ func (self *LightState) Copy() *LightState { state := NewLightState(common.Hash{}, self.odr) state.trie = self.trie for k, stateObject := range self.stateObjects { - state.stateObjects[k] = stateObject.Copy() + if stateObject.dirty { + state.stateObjects[k] = stateObject.Copy() + } } return state diff --git a/light/state_object.go b/light/state_object.go index 030653c77..1e9c7f4b1 100644 --- a/light/state_object.go +++ b/light/state_object.go @@ -79,8 +79,6 @@ type StateObject struct { codeHash []byte // The code for this account code Code - // Temporarily initialisation code - initCode Code // Cached storage (flushed when updated) storage Storage @@ -188,8 +186,7 @@ func (self *StateObject) Copy() *StateObject { stateObject.codeHash = common.CopyBytes(self.codeHash) stateObject.nonce = self.nonce stateObject.trie = self.trie - stateObject.code = common.CopyBytes(self.code) - stateObject.initCode = common.CopyBytes(self.initCode) + stateObject.code = self.code stateObject.storage = self.storage.Copy() stateObject.remove = self.remove stateObject.dirty = self.dirty diff --git a/light/state_test.go b/light/state_test.go index 2c2e6daea..d7014a2dc 100644 --- a/light/state_test.go +++ b/light/state_test.go @@ -42,7 +42,6 @@ func (odr *testOdr) Retrieve(ctx context.Context, req OdrRequest) error { case *TrieRequest: t, _ := trie.New(req.root, odr.sdb) req.proof = t.Prove(req.key) - trie.ClearGlobalCache() case *NodeDataRequest: req.data, _ = odr.sdb.Get(req.hash[:]) } @@ -62,7 +61,7 @@ func makeTestState() (common.Hash, ethdb.Database) { } so.AddBalance(big.NewInt(int64(i))) so.SetCode([]byte{i, i, i}) - so.Update() + so.UpdateRoot(sdb) st.UpdateStateObject(so) } root, _ := st.Commit() @@ -75,7 +74,6 @@ func TestLightStateOdr(t *testing.T) { odr := &testOdr{sdb: sdb, ldb: ldb} ls := NewLightState(root, odr) ctx := context.Background() - trie.ClearGlobalCache() for i := byte(0); i < 100; i++ { addr := common.Address{i} @@ -160,7 +158,6 @@ func TestLightStateSetCopy(t *testing.T) { odr := &testOdr{sdb: sdb, ldb: ldb} ls := NewLightState(root, odr) ctx := context.Background() - trie.ClearGlobalCache() for i := byte(0); i < 100; i++ { addr := common.Address{i} @@ -237,7 +234,6 @@ func TestLightStateDelete(t *testing.T) { odr := &testOdr{sdb: sdb, ldb: ldb} ls := NewLightState(root, odr) ctx := context.Background() - trie.ClearGlobalCache() addr := common.Address{42} diff --git a/miner/worker.go b/miner/worker.go index 1676036d8..ac1ef5ba3 100644 --- a/miner/worker.go +++ b/miner/worker.go @@ -361,7 +361,7 @@ func (self *worker) push(work *Work) { // makeCurrent creates a new environment for the current cycle. func (self *worker) makeCurrent(parent *types.Block, header *types.Header) error { - state, err := state.New(parent.Root(), self.eth.ChainDb()) + state, err := self.chain.StateAt(parent.Root()) if err != nil { return err } diff --git a/p2p/nat/nat.go b/p2p/nat/nat.go index 42f615f36..f9ba93613 100644 --- a/p2p/nat/nat.go +++ b/p2p/nat/nat.go @@ -105,7 +105,7 @@ func Map(m Interface, c chan struct{}, protocol string, extport, intport int, na glog.V(logger.Debug).Infof("deleting port mapping: %s %d -> %d (%s) using %s\n", protocol, extport, intport, name, m) m.DeleteMapping(protocol, extport, intport) }() - if err := m.AddMapping(protocol, intport, extport, name, mapTimeout); err != nil { + if err := m.AddMapping(protocol, extport, intport, name, mapTimeout); err != nil { glog.V(logger.Debug).Infof("network port %s:%d could not be mapped: %v\n", protocol, intport, err) } else { glog.V(logger.Info).Infof("mapped network port %s:%d -> %d (%s) using %s\n", protocol, extport, intport, name, m) @@ -118,7 +118,7 @@ func Map(m Interface, c chan struct{}, protocol string, extport, intport int, na } case <-refresh.C: glog.V(logger.Detail).Infof("refresh port mapping %s:%d -> %d (%s) using %s\n", protocol, extport, intport, name, m) - if err := m.AddMapping(protocol, intport, extport, name, mapTimeout); err != nil { + if err := m.AddMapping(protocol, extport, intport, name, mapTimeout); err != nil { glog.V(logger.Debug).Infof("network port %s:%d could not be mapped: %v\n", protocol, intport, err) } refresh.Reset(mapUpdateInterval) diff --git a/tests/state_test_util.go b/tests/state_test_util.go index 36fa30881..67e4bf832 100644 --- a/tests/state_test_util.go +++ b/tests/state_test_util.go @@ -97,7 +97,7 @@ func benchStateTest(ruleSet RuleSet, test VmTest, env map[string]string, b *test db, _ := ethdb.NewMemDatabase() statedb, _ := state.New(common.Hash{}, db) for addr, account := range test.Pre { - obj := StateObjectFromAccount(db, addr, account) + obj := StateObjectFromAccount(db, addr, account, statedb.MarkStateObjectDirty) statedb.SetStateObject(obj) for a, v := range account.Storage { obj.SetState(common.HexToHash(a), common.HexToHash(v)) @@ -136,7 +136,7 @@ func runStateTest(ruleSet RuleSet, test VmTest) error { db, _ := ethdb.NewMemDatabase() statedb, _ := state.New(common.Hash{}, db) for addr, account := range test.Pre { - obj := StateObjectFromAccount(db, addr, account) + obj := StateObjectFromAccount(db, addr, account, statedb.MarkStateObjectDirty) statedb.SetStateObject(obj) for a, v := range account.Storage { obj.SetState(common.HexToHash(a), common.HexToHash(v)) @@ -187,7 +187,7 @@ func runStateTest(ruleSet RuleSet, test VmTest) error { } for addr, value := range account.Storage { - v := obj.GetState(common.HexToHash(addr)) + v := statedb.GetState(obj.Address(), common.HexToHash(addr)) vexp := common.HexToHash(value) if v != vexp { diff --git a/tests/util.go b/tests/util.go index 79c3bfad1..08fac2dd1 100644 --- a/tests/util.go +++ b/tests/util.go @@ -103,16 +103,17 @@ func (self Log) Topics() [][]byte { return t } -func StateObjectFromAccount(db ethdb.Database, addr string, account Account) *state.StateObject { - obj := state.NewStateObject(common.HexToAddress(addr), db) - obj.SetBalance(common.Big(account.Balance)) - +func StateObjectFromAccount(db ethdb.Database, addr string, account Account, onDirty func(common.Address)) *state.StateObject { if common.IsHex(account.Code) { account.Code = account.Code[2:] } - obj.SetCode(common.Hex2Bytes(account.Code)) - obj.SetNonce(common.Big(account.Nonce).Uint64()) - + code := common.Hex2Bytes(account.Code) + obj := state.NewObject(common.HexToAddress(addr), state.Account{ + Balance: common.Big(account.Balance), + CodeHash: crypto.Keccak256(code), + Nonce: common.Big(account.Nonce).Uint64(), + }, onDirty) + obj.SetCode(code) return obj } diff --git a/tests/vm_test_util.go b/tests/vm_test_util.go index 37f0af33c..4ad72d91c 100644 --- a/tests/vm_test_util.go +++ b/tests/vm_test_util.go @@ -103,7 +103,7 @@ func benchVmTest(test VmTest, env map[string]string, b *testing.B) { db, _ := ethdb.NewMemDatabase() statedb, _ := state.New(common.Hash{}, db) for addr, account := range test.Pre { - obj := StateObjectFromAccount(db, addr, account) + obj := StateObjectFromAccount(db, addr, account, statedb.MarkStateObjectDirty) statedb.SetStateObject(obj) for a, v := range account.Storage { obj.SetState(common.HexToHash(a), common.HexToHash(v)) @@ -154,7 +154,7 @@ func runVmTest(test VmTest) error { db, _ := ethdb.NewMemDatabase() statedb, _ := state.New(common.Hash{}, db) for addr, account := range test.Pre { - obj := StateObjectFromAccount(db, addr, account) + obj := StateObjectFromAccount(db, addr, account, statedb.MarkStateObjectDirty) statedb.SetStateObject(obj) for a, v := range account.Storage { obj.SetState(common.HexToHash(a), common.HexToHash(v)) @@ -205,11 +205,9 @@ func runVmTest(test VmTest) error { if obj == nil { continue } - for addr, value := range account.Storage { - v := obj.GetState(common.HexToHash(addr)) + v := statedb.GetState(obj.Address(), common.HexToHash(addr)) vexp := common.HexToHash(value) - if v != vexp { return fmt.Errorf("(%x: %s) storage failed. Expected %x, got %x (%v %v)\n", obj.Address().Bytes()[0:4], addr, vexp, v, vexp.Big(), v.Big()) } diff --git a/trie/arc.go b/trie/arc.go deleted file mode 100644 index fc7a3259f..000000000 --- a/trie/arc.go +++ /dev/null @@ -1,206 +0,0 @@ -// Copyright (c) 2015 Hans Alexander Gugel <alexander.gugel@gmail.com> -// -// Permission is hereby granted, free of charge, to any person obtaining a copy -// of this software and associated documentation files (the "Software"), to deal -// in the Software without restriction, including without limitation the rights -// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell -// copies of the Software, and to permit persons to whom the Software is -// furnished to do so, subject to the following conditions: -// -// The above copyright notice and this permission notice shall be included in all -// copies or substantial portions of the Software. -// -// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR -// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, -// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE -// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER -// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, -// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE -// SOFTWARE. - -// This file contains a modified version of package arc from -// https://github.com/alexanderGugel/arc -// -// It implements the ARC (Adaptive Replacement Cache) algorithm as detailed in -// https://www.usenix.org/legacy/event/fast03/tech/full_papers/megiddo/megiddo.pdf - -package trie - -import ( - "container/list" - "sync" -) - -type arc struct { - p int - c int - t1 *list.List - b1 *list.List - t2 *list.List - b2 *list.List - cache map[string]*entry - mutex sync.Mutex -} - -type entry struct { - key hashNode - value node - ll *list.List - el *list.Element -} - -// newARC returns a new Adaptive Replacement Cache with the -// given capacity. -func newARC(c int) *arc { - return &arc{ - c: c, - t1: list.New(), - b1: list.New(), - t2: list.New(), - b2: list.New(), - cache: make(map[string]*entry, c), - } -} - -// Clear clears the cache -func (a *arc) Clear() { - a.mutex.Lock() - defer a.mutex.Unlock() - a.p = 0 - a.t1 = list.New() - a.b1 = list.New() - a.t2 = list.New() - a.b2 = list.New() - a.cache = make(map[string]*entry, a.c) -} - -// Put inserts a new key-value pair into the cache. -// This optimizes future access to this entry (side effect). -func (a *arc) Put(key hashNode, value node) bool { - a.mutex.Lock() - defer a.mutex.Unlock() - ent, ok := a.cache[string(key)] - if ok != true { - ent = &entry{key: key, value: value} - a.req(ent) - a.cache[string(key)] = ent - } else { - ent.value = value - a.req(ent) - } - return ok -} - -// Get retrieves a previously via Set inserted entry. -// This optimizes future access to this entry (side effect). -func (a *arc) Get(key hashNode) (value node, ok bool) { - a.mutex.Lock() - defer a.mutex.Unlock() - ent, ok := a.cache[string(key)] - if ok { - a.req(ent) - return ent.value, ent.value != nil - } - return nil, false -} - -func (a *arc) req(ent *entry) { - if ent.ll == a.t1 || ent.ll == a.t2 { - // Case I - ent.setMRU(a.t2) - } else if ent.ll == a.b1 { - // Case II - // Cache Miss in t1 and t2 - - // Adaptation - var d int - if a.b1.Len() >= a.b2.Len() { - d = 1 - } else { - d = a.b2.Len() / a.b1.Len() - } - a.p = a.p + d - if a.p > a.c { - a.p = a.c - } - - a.replace(ent) - ent.setMRU(a.t2) - } else if ent.ll == a.b2 { - // Case III - // Cache Miss in t1 and t2 - - // Adaptation - var d int - if a.b2.Len() >= a.b1.Len() { - d = 1 - } else { - d = a.b1.Len() / a.b2.Len() - } - a.p = a.p - d - if a.p < 0 { - a.p = 0 - } - - a.replace(ent) - ent.setMRU(a.t2) - } else if ent.ll == nil { - // Case IV - - if a.t1.Len()+a.b1.Len() == a.c { - // Case A - if a.t1.Len() < a.c { - a.delLRU(a.b1) - a.replace(ent) - } else { - a.delLRU(a.t1) - } - } else if a.t1.Len()+a.b1.Len() < a.c { - // Case B - if a.t1.Len()+a.t2.Len()+a.b1.Len()+a.b2.Len() >= a.c { - if a.t1.Len()+a.t2.Len()+a.b1.Len()+a.b2.Len() == 2*a.c { - a.delLRU(a.b2) - } - a.replace(ent) - } - } - - ent.setMRU(a.t1) - } -} - -func (a *arc) delLRU(list *list.List) { - lru := list.Back() - list.Remove(lru) - delete(a.cache, string(lru.Value.(*entry).key)) -} - -func (a *arc) replace(ent *entry) { - if a.t1.Len() > 0 && ((a.t1.Len() > a.p) || (ent.ll == a.b2 && a.t1.Len() == a.p)) { - lru := a.t1.Back().Value.(*entry) - lru.value = nil - lru.setMRU(a.b1) - } else { - lru := a.t2.Back().Value.(*entry) - lru.value = nil - lru.setMRU(a.b2) - } -} - -func (e *entry) setLRU(list *list.List) { - e.detach() - e.ll = list - e.el = e.ll.PushBack(e) -} - -func (e *entry) setMRU(list *list.List) { - e.detach() - e.ll = list - e.el = e.ll.PushFront(e) -} - -func (e *entry) detach() { - if e.ll != nil { - e.ll.Remove(e.el) - } -} diff --git a/trie/hasher.go b/trie/hasher.go new file mode 100644 index 000000000..87e02fb85 --- /dev/null +++ b/trie/hasher.go @@ -0,0 +1,157 @@ +// Copyright 2016 The go-ethereum Authors +// This file is part of the go-ethereum library. +// +// The go-ethereum library is free software: you can redistribute it and/or modify +// it under the terms of the GNU Lesser General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// The go-ethereum library is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Lesser General Public License for more details. +// +// You should have received a copy of the GNU Lesser General Public License +// along with the go-ethereum library. If not, see <http://www.gnu.org/licenses/>. + +package trie + +import ( + "bytes" + "hash" + "sync" + + "github.com/ethereum/go-ethereum/common" + "github.com/ethereum/go-ethereum/crypto/sha3" + "github.com/ethereum/go-ethereum/rlp" +) + +type hasher struct { + tmp *bytes.Buffer + sha hash.Hash +} + +// hashers live in a global pool. +var hasherPool = sync.Pool{ + New: func() interface{} { + return &hasher{tmp: new(bytes.Buffer), sha: sha3.NewKeccak256()} + }, +} + +func newHasher() *hasher { + return hasherPool.Get().(*hasher) +} + +func returnHasherToPool(h *hasher) { + hasherPool.Put(h) +} + +// hash collapses a node down into a hash node, also returning a copy of the +// original node initialzied with the computed hash to replace the original one. +func (h *hasher) hash(n node, db DatabaseWriter, force bool) (node, node, error) { + // If we're not storing the node, just hashing, use avaialble cached data + if hash, dirty := n.cache(); hash != nil && (db == nil || !dirty) { + return hash, n, nil + } + // Trie not processed yet or needs storage, walk the children + collapsed, cached, err := h.hashChildren(n, db) + if err != nil { + return hashNode{}, n, err + } + hashed, err := h.store(collapsed, db, force) + if err != nil { + return hashNode{}, n, err + } + // Cache the hash and RLP blob of the ndoe for later reuse + if hash, ok := hashed.(hashNode); ok && !force { + switch cached := cached.(type) { + case shortNode: + cached.hash = hash + if db != nil { + cached.dirty = false + } + return hashed, cached, nil + case fullNode: + cached.hash = hash + if db != nil { + cached.dirty = false + } + return hashed, cached, nil + } + } + return hashed, cached, nil +} + +// hashChildren replaces the children of a node with their hashes if the encoded +// size of the child is larger than a hash, returning the collapsed node as well +// as a replacement for the original node with the child hashes cached in. +func (h *hasher) hashChildren(original node, db DatabaseWriter) (node, node, error) { + var err error + + switch n := original.(type) { + case shortNode: + // Hash the short node's child, caching the newly hashed subtree + cached := n + cached.Key = common.CopyBytes(cached.Key) + + n.Key = compactEncode(n.Key) + if _, ok := n.Val.(valueNode); !ok { + if n.Val, cached.Val, err = h.hash(n.Val, db, false); err != nil { + return n, original, err + } + } + if n.Val == nil { + n.Val = valueNode(nil) // Ensure that nil children are encoded as empty strings. + } + return n, cached, nil + + case fullNode: + // Hash the full node's children, caching the newly hashed subtrees + cached := fullNode{dirty: n.dirty} + + for i := 0; i < 16; i++ { + if n.Children[i] != nil { + if n.Children[i], cached.Children[i], err = h.hash(n.Children[i], db, false); err != nil { + return n, original, err + } + } else { + n.Children[i] = valueNode(nil) // Ensure that nil children are encoded as empty strings. + } + } + cached.Children[16] = n.Children[16] + if n.Children[16] == nil { + n.Children[16] = valueNode(nil) + } + return n, cached, nil + + default: + // Value and hash nodes don't have children so they're left as were + return n, original, nil + } +} + +func (h *hasher) store(n node, db DatabaseWriter, force bool) (node, error) { + // Don't store hashes or empty nodes. + if _, isHash := n.(hashNode); n == nil || isHash { + return n, nil + } + // Generate the RLP encoding of the node + h.tmp.Reset() + if err := rlp.Encode(h.tmp, n); err != nil { + panic("encode error: " + err.Error()) + } + if h.tmp.Len() < 32 && !force { + return n, nil // Nodes smaller than 32 bytes are stored inside their parent + } + // Larger nodes are replaced by their hash and stored in the database. + hash, _ := n.cache() + if hash == nil { + h.sha.Reset() + h.sha.Write(h.tmp.Bytes()) + hash = hashNode(h.sha.Sum(nil)) + } + if db != nil { + return hash, db.Put(hash, h.tmp.Bytes()) + } + return hash, nil +} diff --git a/trie/iterator.go b/trie/iterator.go index 88c4cee7f..8cad51aff 100644 --- a/trie/iterator.go +++ b/trie/iterator.go @@ -16,18 +16,13 @@ package trie -import ( - "bytes" - "fmt" +import "github.com/ethereum/go-ethereum/common" - "github.com/ethereum/go-ethereum/common" - "github.com/ethereum/go-ethereum/logger" - "github.com/ethereum/go-ethereum/logger/glog" -) - -// Iterator is a key-value trie iterator to traverse the data contents. +// Iterator is a key-value trie iterator that traverses a Trie. type Iterator struct { - trie *Trie + trie *Trie + nodeIt *NodeIterator + keyBuf []byte Key []byte // Current data key on which the iterator is positioned on Value []byte // Current data value on which the iterator is positioned on @@ -35,119 +30,45 @@ type Iterator struct { // NewIterator creates a new key-value iterator. func NewIterator(trie *Trie) *Iterator { - return &Iterator{trie: trie, Key: nil} -} - -// Next moves the iterator forward with one key-value entry. -func (self *Iterator) Next() bool { - isIterStart := false - if self.Key == nil { - isIterStart = true - self.Key = make([]byte, 32) + return &Iterator{ + trie: trie, + nodeIt: NewNodeIterator(trie), + keyBuf: make([]byte, 0, 64), + Key: nil, } - - key := remTerm(compactHexDecode(self.Key)) - k := self.next(self.trie.root, key, isIterStart) - - self.Key = []byte(decodeCompact(k)) - - return len(k) > 0 } -func (self *Iterator) next(node interface{}, key []byte, isIterStart bool) []byte { - if node == nil { - return nil - } - - switch node := node.(type) { - case fullNode: - if len(key) > 0 { - k := self.next(node.Children[key[0]], key[1:], isIterStart) - if k != nil { - return append([]byte{key[0]}, k...) - } - } - - var r byte - if len(key) > 0 { - r = key[0] + 1 - } - - for i := r; i < 16; i++ { - k := self.key(node.Children[i]) - if k != nil { - return append([]byte{i}, k...) - } +// Next moves the iterator forward one key-value entry. +func (it *Iterator) Next() bool { + for it.nodeIt.Next() { + if it.nodeIt.Leaf { + it.Key = it.makeKey() + it.Value = it.nodeIt.LeafBlob + return true } - - case shortNode: - k := remTerm(node.Key) - if vnode, ok := node.Val.(valueNode); ok { - switch bytes.Compare([]byte(k), key) { - case 0: - if isIterStart { - self.Value = vnode - return k - } - case 1: - self.Value = vnode - return k - } - } else { - cnode := node.Val - - var ret []byte - skey := key[len(k):] - if bytes.HasPrefix(key, k) { - ret = self.next(cnode, skey, isIterStart) - } else if bytes.Compare(k, key[:len(k)]) > 0 { - return self.key(node) - } - - if ret != nil { - return append(k, ret...) - } - } - - case hashNode: - rn, err := self.trie.resolveHash(node, nil, nil) - if err != nil && glog.V(logger.Error) { - glog.Errorf("Unhandled trie error: %v", err) - } - return self.next(rn, key, isIterStart) } - return nil + it.Key = nil + it.Value = nil + return false } -func (self *Iterator) key(node interface{}) []byte { - switch node := node.(type) { - case shortNode: - // Leaf node - k := remTerm(node.Key) - if vnode, ok := node.Val.(valueNode); ok { - self.Value = vnode - return k - } - return append(k, self.key(node.Val)...) - case fullNode: - if node.Children[16] != nil { - self.Value = node.Children[16].(valueNode) - return []byte{16} - } - for i := 0; i < 16; i++ { - k := self.key(node.Children[i]) - if k != nil { - return append([]byte{byte(i)}, k...) +func (it *Iterator) makeKey() []byte { + key := it.keyBuf[:0] + for _, se := range it.nodeIt.stack { + switch node := se.node.(type) { + case fullNode: + if se.child <= 16 { + key = append(key, byte(se.child)) + } + case shortNode: + if hasTerm(node.Key) { + key = append(key, node.Key[:len(node.Key)-1]...) + } else { + key = append(key, node.Key...) } } - case hashNode: - rn, err := self.trie.resolveHash(node, nil, nil) - if err != nil && glog.V(logger.Error) { - glog.Errorf("Unhandled trie error: %v", err) - } - return self.key(rn) } - return nil + return decodeCompact(key) } // nodeIteratorState represents the iteration state at one particular node of the @@ -199,25 +120,27 @@ func (it *NodeIterator) Next() bool { // step moves the iterator to the next node of the trie. func (it *NodeIterator) step() error { - // Abort if we reached the end of the iteration if it.trie == nil { + // Abort if we reached the end of the iteration return nil } - // Initialize the iterator if we've just started, or pop off the old node otherwise if len(it.stack) == 0 { - // Always start with a collapsed root + // Initialize the iterator if we've just started. root := it.trie.Hash() - it.stack = append(it.stack, &nodeIteratorState{node: hashNode(root[:]), child: -1}) - if it.stack[0].node == nil { - return fmt.Errorf("root node missing: %x", it.trie.Hash()) + state := &nodeIteratorState{node: it.trie.root, child: -1} + if root != emptyRoot { + state.hash = root } + it.stack = append(it.stack, state) } else { + // Continue iterating at the previous node otherwise. it.stack = it.stack[:len(it.stack)-1] if len(it.stack) == 0 { it.trie = nil return nil } } + // Continue iteration to the next child for { parent := it.stack[len(it.stack)-1] @@ -232,7 +155,12 @@ func (it *NodeIterator) step() error { } for parent.child++; parent.child < len(node.Children); parent.child++ { if current := node.Children[parent.child]; current != nil { - it.stack = append(it.stack, &nodeIteratorState{node: current, parent: ancestor, child: -1}) + it.stack = append(it.stack, &nodeIteratorState{ + hash: common.BytesToHash(node.hash), + node: current, + parent: ancestor, + child: -1, + }) break } } @@ -242,7 +170,12 @@ func (it *NodeIterator) step() error { break } parent.child++ - it.stack = append(it.stack, &nodeIteratorState{node: node.Val, parent: ancestor, child: -1}) + it.stack = append(it.stack, &nodeIteratorState{ + hash: common.BytesToHash(node.hash), + node: node.Val, + parent: ancestor, + child: -1, + }) } else if hash, ok := parent.node.(hashNode); ok { // Hash node, resolve the hash child from the database, then the node itself if parent.child >= 0 { @@ -254,7 +187,12 @@ func (it *NodeIterator) step() error { if err != nil { return err } - it.stack = append(it.stack, &nodeIteratorState{hash: common.BytesToHash(hash), node: node, parent: ancestor, child: -1}) + it.stack = append(it.stack, &nodeIteratorState{ + hash: common.BytesToHash(hash), + node: node, + parent: ancestor, + child: -1, + }) } else { break } diff --git a/trie/iterator_test.go b/trie/iterator_test.go index dc8276116..2bcc3700e 100644 --- a/trie/iterator_test.go +++ b/trie/iterator_test.go @@ -34,21 +34,60 @@ func TestIterator(t *testing.T) { {"dog", "puppy"}, {"somethingveryoddindeedthis is", "myothernodedata"}, } - v := make(map[string]bool) + all := make(map[string]string) for _, val := range vals { - v[val.k] = false + all[val.k] = val.v trie.Update([]byte(val.k), []byte(val.v)) } trie.Commit() + found := make(map[string]string) it := NewIterator(trie) for it.Next() { - v[string(it.Key)] = true + found[string(it.Key)] = string(it.Value) } - for k, found := range v { - if !found { - t.Error("iterator didn't find", k) + for k, v := range all { + if found[k] != v { + t.Errorf("iterator value mismatch for %s: got %q want %q", k, found[k], v) + } + } +} + +type kv struct { + k, v []byte + t bool +} + +func TestIteratorLargeData(t *testing.T) { + trie := newEmpty() + vals := make(map[string]*kv) + + for i := byte(0); i < 255; i++ { + value := &kv{common.LeftPadBytes([]byte{i}, 32), []byte{i}, false} + value2 := &kv{common.LeftPadBytes([]byte{10, i}, 32), []byte{i}, false} + trie.Update(value.k, value.v) + trie.Update(value2.k, value2.v) + vals[string(value.k)] = value + vals[string(value2.k)] = value2 + } + + it := NewIterator(trie) + for it.Next() { + vals[string(it.Key)].t = true + } + + var untouched []*kv + for _, value := range vals { + if !value.t { + untouched = append(untouched, value) + } + } + + if len(untouched) > 0 { + t.Errorf("Missed %d nodes", len(untouched)) + for _, value := range untouched { + t.Error(value) } } } diff --git a/trie/proof.go b/trie/proof.go index 5135de047..116c13a1b 100644 --- a/trie/proof.go +++ b/trie/proof.go @@ -70,15 +70,13 @@ func (t *Trie) Prove(key []byte) []rlp.RawValue { panic(fmt.Sprintf("%T: invalid node: %v", tn, tn)) } } - if t.hasher == nil { - t.hasher = newHasher() - } + hasher := newHasher() proof := make([]rlp.RawValue, 0, len(nodes)) for i, n := range nodes { // Don't bother checking for errors here since hasher panics // if encoding doesn't work and we're not writing to any database. - n, _, _ = t.hasher.hashChildren(n, nil) - hn, _ := t.hasher.store(n, nil, false) + n, _, _ = hasher.hashChildren(n, nil) + hn, _ := hasher.store(n, nil, false) if _, ok := hn.(hashNode); ok || i == 0 { // If the node's database encoding is a hash (or is the // root node), it becomes a proof element. diff --git a/trie/secure_trie.go b/trie/secure_trie.go index 1d027c102..2a8b57214 100644 --- a/trie/secure_trie.go +++ b/trie/secure_trie.go @@ -17,16 +17,15 @@ package trie import ( - "hash" - "github.com/ethereum/go-ethereum/common" - "github.com/ethereum/go-ethereum/crypto/sha3" "github.com/ethereum/go-ethereum/logger" "github.com/ethereum/go-ethereum/logger/glog" ) var secureKeyPrefix = []byte("secure-key-") +const secureKeyLength = 11 + 32 // Length of the above prefix + 32byte hash + // SecureTrie wraps a trie with key hashing. In a secure trie, all // access operations hash the key using keccak256. This prevents // calling code from creating long chains of nodes that @@ -38,12 +37,11 @@ var secureKeyPrefix = []byte("secure-key-") // // SecureTrie is not safe for concurrent use. type SecureTrie struct { - *Trie - - hash hash.Hash - hashKeyBuf []byte - secKeyBuf []byte - secKeyCache map[string][]byte + trie Trie + hashKeyBuf [secureKeyLength]byte + secKeyBuf [200]byte + secKeyCache map[string][]byte + secKeyCacheOwner *SecureTrie // Pointer to self, replace the key cache on mismatch } // NewSecure creates a trie with an existing root node from db. @@ -61,8 +59,7 @@ func NewSecure(root common.Hash, db Database) (*SecureTrie, error) { return nil, err } return &SecureTrie{ - Trie: trie, - secKeyCache: make(map[string][]byte), + trie: *trie, }, nil } @@ -80,7 +77,7 @@ func (t *SecureTrie) Get(key []byte) []byte { // The value bytes must not be modified by the caller. // If a node was not found in the database, a MissingNodeError is returned. func (t *SecureTrie) TryGet(key []byte) ([]byte, error) { - return t.Trie.TryGet(t.hashKey(key)) + return t.trie.TryGet(t.hashKey(key)) } // Update associates key with value in the trie. Subsequent calls to @@ -105,11 +102,11 @@ func (t *SecureTrie) Update(key, value []byte) { // If a node was not found in the database, a MissingNodeError is returned. func (t *SecureTrie) TryUpdate(key, value []byte) error { hk := t.hashKey(key) - err := t.Trie.TryUpdate(hk, value) + err := t.trie.TryUpdate(hk, value) if err != nil { return err } - t.secKeyCache[string(hk)] = common.CopyBytes(key) + t.getSecKeyCache()[string(hk)] = common.CopyBytes(key) return nil } @@ -124,17 +121,17 @@ func (t *SecureTrie) Delete(key []byte) { // If a node was not found in the database, a MissingNodeError is returned. func (t *SecureTrie) TryDelete(key []byte) error { hk := t.hashKey(key) - delete(t.secKeyCache, string(hk)) - return t.Trie.TryDelete(hk) + delete(t.getSecKeyCache(), string(hk)) + return t.trie.TryDelete(hk) } // GetKey returns the sha3 preimage of a hashed key that was // previously used to store a value. func (t *SecureTrie) GetKey(shaKey []byte) []byte { - if key, ok := t.secKeyCache[string(shaKey)]; ok { + if key, ok := t.getSecKeyCache()[string(shaKey)]; ok { return key } - key, _ := t.Trie.db.Get(t.secKey(shaKey)) + key, _ := t.trie.db.Get(t.secKey(shaKey)) return key } @@ -144,7 +141,23 @@ func (t *SecureTrie) GetKey(shaKey []byte) []byte { // Committing flushes nodes from memory. Subsequent Get calls will load nodes // from the database. func (t *SecureTrie) Commit() (root common.Hash, err error) { - return t.CommitTo(t.db) + return t.CommitTo(t.trie.db) +} + +func (t *SecureTrie) Hash() common.Hash { + return t.trie.Hash() +} + +func (t *SecureTrie) Root() []byte { + return t.trie.Root() +} + +func (t *SecureTrie) Iterator() *Iterator { + return t.trie.Iterator() +} + +func (t *SecureTrie) NodeIterator() *NodeIterator { + return NewNodeIterator(&t.trie) } // CommitTo writes all nodes and the secure hash pre-images to the given database. @@ -154,7 +167,7 @@ func (t *SecureTrie) Commit() (root common.Hash, err error) { // the trie's database. Calling code must ensure that the changes made to db are // written back to the trie's attached database before using the trie. func (t *SecureTrie) CommitTo(db DatabaseWriter) (root common.Hash, err error) { - if len(t.secKeyCache) > 0 { + if len(t.getSecKeyCache()) > 0 { for hk, key := range t.secKeyCache { if err := db.Put(t.secKey([]byte(hk)), key); err != nil { return common.Hash{}, err @@ -162,27 +175,37 @@ func (t *SecureTrie) CommitTo(db DatabaseWriter) (root common.Hash, err error) { } t.secKeyCache = make(map[string][]byte) } - n, clean, err := t.hashRoot(db) - if err != nil { - return (common.Hash{}), err - } - t.root = clean - return common.BytesToHash(n.(hashNode)), nil + return t.trie.CommitTo(db) } +// secKey returns the database key for the preimage of key, as an ephemeral buffer. +// The caller must not hold onto the return value because it will become +// invalid on the next call to hashKey or secKey. func (t *SecureTrie) secKey(key []byte) []byte { - t.secKeyBuf = append(t.secKeyBuf[:0], secureKeyPrefix...) - t.secKeyBuf = append(t.secKeyBuf, key...) - return t.secKeyBuf + buf := append(t.secKeyBuf[:0], secureKeyPrefix...) + buf = append(buf, key...) + return buf } +// hashKey returns the hash of key as an ephemeral buffer. +// The caller must not hold onto the return value because it will become +// invalid on the next call to hashKey or secKey. func (t *SecureTrie) hashKey(key []byte) []byte { - if t.hash == nil { - t.hash = sha3.NewKeccak256() - t.hashKeyBuf = make([]byte, 32) + h := newHasher() + h.sha.Reset() + h.sha.Write(key) + buf := h.sha.Sum(t.hashKeyBuf[:0]) + returnHasherToPool(h) + return buf +} + +// getSecKeyCache returns the current secure key cache, creating a new one if +// ownership changed (i.e. the current secure trie is a copy of another owning +// the actual cache). +func (t *SecureTrie) getSecKeyCache() map[string][]byte { + if t != t.secKeyCacheOwner { + t.secKeyCacheOwner = t + t.secKeyCache = make(map[string][]byte) } - t.hash.Reset() - t.hash.Write(key) - t.hashKeyBuf = t.hash.Sum(t.hashKeyBuf[:0]) - return t.hashKeyBuf + return t.secKeyCache } diff --git a/trie/secure_trie_test.go b/trie/secure_trie_test.go index 0be5b3d15..3171b8c31 100644 --- a/trie/secure_trie_test.go +++ b/trie/secure_trie_test.go @@ -18,6 +18,8 @@ package trie import ( "bytes" + "runtime" + "sync" "testing" "github.com/ethereum/go-ethereum/common" @@ -31,6 +33,37 @@ func newEmptySecure() *SecureTrie { return trie } +// makeTestSecureTrie creates a large enough secure trie for testing. +func makeTestSecureTrie() (ethdb.Database, *SecureTrie, map[string][]byte) { + // Create an empty trie + db, _ := ethdb.NewMemDatabase() + trie, _ := NewSecure(common.Hash{}, db) + + // Fill it with some arbitrary data + content := make(map[string][]byte) + for i := byte(0); i < 255; i++ { + // Map the same data under multiple keys + key, val := common.LeftPadBytes([]byte{1, i}, 32), []byte{i} + content[string(key)] = val + trie.Update(key, val) + + key, val = common.LeftPadBytes([]byte{2, i}, 32), []byte{i} + content[string(key)] = val + trie.Update(key, val) + + // Add some other data to inflate th trie + for j := byte(3); j < 13; j++ { + key, val = common.LeftPadBytes([]byte{j, i}, 32), []byte{j, i} + content[string(key)] = val + trie.Update(key, val) + } + } + trie.Commit() + + // Return the generated trie + return db, trie, content +} + func TestSecureDelete(t *testing.T) { trie := newEmptySecure() vals := []struct{ k, v string }{ @@ -72,3 +105,41 @@ func TestSecureGetKey(t *testing.T) { t.Errorf("GetKey returned %q, want %q", k, key) } } + +func TestSecureTrieConcurrency(t *testing.T) { + // Create an initial trie and copy if for concurrent access + _, trie, _ := makeTestSecureTrie() + + threads := runtime.NumCPU() + tries := make([]*SecureTrie, threads) + for i := 0; i < threads; i++ { + cpy := *trie + tries[i] = &cpy + } + // Start a batch of goroutines interactng with the trie + pend := new(sync.WaitGroup) + pend.Add(threads) + for i := 0; i < threads; i++ { + go func(index int) { + defer pend.Done() + + for j := byte(0); j < 255; j++ { + // Map the same data under multiple keys + key, val := common.LeftPadBytes([]byte{byte(index), 1, j}, 32), []byte{j} + tries[index].Update(key, val) + + key, val = common.LeftPadBytes([]byte{byte(index), 2, j}, 32), []byte{j} + tries[index].Update(key, val) + + // Add some other data to inflate the trie + for k := byte(3); k < 13; k++ { + key, val = common.LeftPadBytes([]byte{byte(index), k, j}, 32), []byte{k, j} + tries[index].Update(key, val) + } + } + tries[index].Commit() + }(i) + } + // Wait for all threads to finish + pend.Wait() +} diff --git a/trie/sync_test.go b/trie/sync_test.go index a81f7650e..a763dc564 100644 --- a/trie/sync_test.go +++ b/trie/sync_test.go @@ -51,9 +51,6 @@ func makeTestTrie() (ethdb.Database, *Trie, map[string][]byte) { } trie.Commit() - // Remove any potentially cached data from the test trie creation - globalCache.Clear() - // Return the generated trie return db, trie, content } @@ -61,9 +58,6 @@ func makeTestTrie() (ethdb.Database, *Trie, map[string][]byte) { // checkTrieContents cross references a reconstructed trie with an expected data // content map. func checkTrieContents(t *testing.T, db Database, root []byte, content map[string][]byte) { - // Remove any potentially cached data from the trie synchronisation - globalCache.Clear() - // Check root availability and trie contents trie, err := New(common.BytesToHash(root), db) if err != nil { @@ -81,9 +75,6 @@ func checkTrieContents(t *testing.T, db Database, root []byte, content map[strin // checkTrieConsistency checks that all nodes in a trie are indeed present. func checkTrieConsistency(db Database, root common.Hash) error { - // Remove any potentially cached data from the test trie creation or previous checks - globalCache.Clear() - // Create and iterate a trie rooted in a subnode trie, err := New(root, db) if err != nil { diff --git a/trie/trie.go b/trie/trie.go index a530e7b2a..93e189e2e 100644 --- a/trie/trie.go +++ b/trie/trie.go @@ -20,22 +20,14 @@ package trie import ( "bytes" "fmt" - "hash" "github.com/ethereum/go-ethereum/common" "github.com/ethereum/go-ethereum/crypto" - "github.com/ethereum/go-ethereum/crypto/sha3" "github.com/ethereum/go-ethereum/logger" "github.com/ethereum/go-ethereum/logger/glog" - "github.com/ethereum/go-ethereum/rlp" ) -const defaultCacheCapacity = 800 - var ( - // The global cache stores decoded trie nodes by hash as they get loaded. - globalCache = newARC(defaultCacheCapacity) - // This is the known root hash of an empty trie. emptyRoot = common.HexToHash("56e81f171bcc55a6ff8345e692c0f86e5b48e01b996cadc001622fb5e363b421") @@ -43,11 +35,6 @@ var ( emptyState = crypto.Keccak256Hash(nil) ) -// ClearGlobalCache clears the global trie cache -func ClearGlobalCache() { - globalCache.Clear() -} - // Database must be implemented by backing stores for the trie. type Database interface { DatabaseWriter @@ -72,7 +59,6 @@ type Trie struct { root node db Database originalRoot common.Hash - *hasher } // New creates a trie with an existing root node from db. @@ -118,32 +104,50 @@ func (t *Trie) Get(key []byte) []byte { // If a node was not found in the database, a MissingNodeError is returned. func (t *Trie) TryGet(key []byte) ([]byte, error) { key = compactHexDecode(key) - pos := 0 - tn := t.root - for pos < len(key) { - switch n := tn.(type) { - case shortNode: - if len(key)-pos < len(n.Key) || !bytes.Equal(n.Key, key[pos:pos+len(n.Key)]) { - return nil, nil - } - tn = n.Val - pos += len(n.Key) - case fullNode: - tn = n.Children[key[pos]] - pos++ - case nil: - return nil, nil - case hashNode: - var err error - tn, err = t.resolveHash(n, key[:pos], key[pos:]) - if err != nil { - return nil, err - } - default: - panic(fmt.Sprintf("%T: invalid node: %v", tn, tn)) + value, newroot, didResolve, err := t.tryGet(t.root, key, 0) + if err == nil && didResolve { + t.root = newroot + } + return value, err +} + +func (t *Trie) tryGet(origNode node, key []byte, pos int) (value []byte, newnode node, didResolve bool, err error) { + switch n := (origNode).(type) { + case nil: + return nil, nil, false, nil + case valueNode: + return n, n, false, nil + case shortNode: + if len(key)-pos < len(n.Key) || !bytes.Equal(n.Key, key[pos:pos+len(n.Key)]) { + // key not found in trie + return nil, n, false, nil + } + value, newnode, didResolve, err = t.tryGet(n.Val, key, pos+len(n.Key)) + if err == nil && didResolve { + n.Val = newnode + return value, n, didResolve, err + } else { + return value, origNode, didResolve, err + } + case fullNode: + child := n.Children[key[pos]] + value, newnode, didResolve, err = t.tryGet(child, key, pos+1) + if err == nil && didResolve { + n.Children[key[pos]] = newnode + return value, n, didResolve, err + } else { + return value, origNode, didResolve, err + } + case hashNode: + child, err := t.resolveHash(n, key[:pos], key[pos:]) + if err != nil { + return nil, n, true, err } + value, newnode, _, err := t.tryGet(child, key, pos) + return value, newnode, true, err + default: + panic(fmt.Sprintf("%T: invalid node: %v", origNode, origNode)) } - return tn.(valueNode), nil } // Update associates key with value in the trie. Subsequent calls to @@ -410,9 +414,6 @@ func (t *Trie) resolve(n node, prefix, suffix []byte) (node, error) { } func (t *Trie) resolveHash(n hashNode, prefix, suffix []byte) (node, error) { - if v, ok := globalCache.Get(n); ok { - return v, nil - } enc, err := t.db.Get(n) if err != nil || enc == nil { return nil, &MissingNodeError{ @@ -424,9 +425,6 @@ func (t *Trie) resolveHash(n hashNode, prefix, suffix []byte) (node, error) { } } dec := mustDecodeNode(n, enc) - if dec != nil { - globalCache.Put(n, dec) - } return dec, nil } @@ -474,127 +472,7 @@ func (t *Trie) hashRoot(db DatabaseWriter) (node, node, error) { if t.root == nil { return hashNode(emptyRoot.Bytes()), nil, nil } - if t.hasher == nil { - t.hasher = newHasher() - } - return t.hasher.hash(t.root, db, true) -} - -type hasher struct { - tmp *bytes.Buffer - sha hash.Hash -} - -func newHasher() *hasher { - return &hasher{tmp: new(bytes.Buffer), sha: sha3.NewKeccak256()} -} - -// hash collapses a node down into a hash node, also returning a copy of the -// original node initialzied with the computed hash to replace the original one. -func (h *hasher) hash(n node, db DatabaseWriter, force bool) (node, node, error) { - // If we're not storing the node, just hashing, use avaialble cached data - if hash, dirty := n.cache(); hash != nil && (db == nil || !dirty) { - return hash, n, nil - } - // Trie not processed yet or needs storage, walk the children - collapsed, cached, err := h.hashChildren(n, db) - if err != nil { - return hashNode{}, n, err - } - hashed, err := h.store(collapsed, db, force) - if err != nil { - return hashNode{}, n, err - } - // Cache the hash and RLP blob of the ndoe for later reuse - if hash, ok := hashed.(hashNode); ok && !force { - switch cached := cached.(type) { - case shortNode: - cached.hash = hash - if db != nil { - cached.dirty = false - } - return hashed, cached, nil - case fullNode: - cached.hash = hash - if db != nil { - cached.dirty = false - } - return hashed, cached, nil - } - } - return hashed, cached, nil -} - -// hashChildren replaces the children of a node with their hashes if the encoded -// size of the child is larger than a hash, returning the collapsed node as well -// as a replacement for the original node with the child hashes cached in. -func (h *hasher) hashChildren(original node, db DatabaseWriter) (node, node, error) { - var err error - - switch n := original.(type) { - case shortNode: - // Hash the short node's child, caching the newly hashed subtree - cached := n - cached.Key = common.CopyBytes(cached.Key) - - n.Key = compactEncode(n.Key) - if _, ok := n.Val.(valueNode); !ok { - if n.Val, cached.Val, err = h.hash(n.Val, db, false); err != nil { - return n, original, err - } - } - if n.Val == nil { - n.Val = valueNode(nil) // Ensure that nil children are encoded as empty strings. - } - return n, cached, nil - - case fullNode: - // Hash the full node's children, caching the newly hashed subtrees - cached := fullNode{dirty: n.dirty} - - for i := 0; i < 16; i++ { - if n.Children[i] != nil { - if n.Children[i], cached.Children[i], err = h.hash(n.Children[i], db, false); err != nil { - return n, original, err - } - } else { - n.Children[i] = valueNode(nil) // Ensure that nil children are encoded as empty strings. - } - } - cached.Children[16] = n.Children[16] - if n.Children[16] == nil { - n.Children[16] = valueNode(nil) - } - return n, cached, nil - - default: - // Value and hash nodes don't have children so they're left as were - return n, original, nil - } -} - -func (h *hasher) store(n node, db DatabaseWriter, force bool) (node, error) { - // Don't store hashes or empty nodes. - if _, isHash := n.(hashNode); n == nil || isHash { - return n, nil - } - // Generate the RLP encoding of the node - h.tmp.Reset() - if err := rlp.Encode(h.tmp, n); err != nil { - panic("encode error: " + err.Error()) - } - if h.tmp.Len() < 32 && !force { - return n, nil // Nodes smaller than 32 bytes are stored inside their parent - } - // Larger nodes are replaced by their hash and stored in the database. - hash, _ := n.cache() - if hash == nil { - h.sha.Reset() - h.sha.Write(h.tmp.Bytes()) - hash = hashNode(h.sha.Sum(nil)) - } - if db != nil { - return hash, db.Put(hash, h.tmp.Bytes()) - } - return hash, nil + h := newHasher() + defer returnHasherToPool(h) + return h.hash(t.root, db, true) } diff --git a/trie/trie_test.go b/trie/trie_test.go index 121ba24c1..5a3ea1be9 100644 --- a/trie/trie_test.go +++ b/trie/trie_test.go @@ -76,8 +76,6 @@ func TestMissingNode(t *testing.T) { updateString(trie, "123456", "asdfasdfasdfasdfasdfasdfasdfasdf") root, _ := trie.Commit() - ClearGlobalCache() - trie, _ = New(root, db) _, err := trie.TryGet([]byte("120000")) if err != nil { @@ -109,7 +107,6 @@ func TestMissingNode(t *testing.T) { } db.Delete(common.FromHex("e1d943cc8f061a0c0b98162830b970395ac9315654824bf21b73b891365262f9")) - ClearGlobalCache() trie, _ = New(root, db) _, err = trie.TryGet([]byte("120000")) @@ -362,44 +359,6 @@ func TestLargeValue(t *testing.T) { } -type kv struct { - k, v []byte - t bool -} - -func TestLargeData(t *testing.T) { - trie := newEmpty() - vals := make(map[string]*kv) - - for i := byte(0); i < 255; i++ { - value := &kv{common.LeftPadBytes([]byte{i}, 32), []byte{i}, false} - value2 := &kv{common.LeftPadBytes([]byte{10, i}, 32), []byte{i}, false} - trie.Update(value.k, value.v) - trie.Update(value2.k, value2.v) - vals[string(value.k)] = value - vals[string(value2.k)] = value2 - } - - it := NewIterator(trie) - for it.Next() { - vals[string(it.Key)].t = true - } - - var untouched []*kv - for _, value := range vals { - if !value.t { - untouched = append(untouched, value) - } - } - - if len(untouched) > 0 { - t.Errorf("Missed %d nodes", len(untouched)) - for _, value := range untouched { - t.Error(value) - } - } -} - func BenchmarkGet(b *testing.B) { benchGet(b, false) } func BenchmarkGetDB(b *testing.B) { benchGet(b, true) } func BenchmarkUpdateBE(b *testing.B) { benchUpdate(b, binary.BigEndian) } |