diff options
Diffstat (limited to 'les/handler.go')
-rw-r--r-- | les/handler.go | 193 |
1 files changed, 111 insertions, 82 deletions
diff --git a/les/handler.go b/les/handler.go index 8cd37c7ab..5c93133fb 100644 --- a/les/handler.go +++ b/les/handler.go @@ -18,7 +18,6 @@ package les import ( - "bytes" "encoding/binary" "errors" "fmt" @@ -78,6 +77,7 @@ type BlockChain interface { GetHeaderByHash(hash common.Hash) *types.Header CurrentHeader() *types.Header GetTd(hash common.Hash, number uint64) *big.Int + State() (*state.StateDB, error) InsertHeaderChain(chain []*types.Header, checkFreq int) (int, error) Rollback(chain []common.Hash) GetHeaderByNumber(number uint64) *types.Header @@ -579,17 +579,19 @@ func (pm *ProtocolManager) handleMsg(p *peer) error { for _, req := range req.Reqs { // Retrieve the requested state entry, stopping if enough was found if header := core.GetHeader(pm.chainDb, req.BHash, core.GetBlockNumber(pm.chainDb, req.BHash)); header != nil { - if trie, _ := trie.New(header.Root, pm.chainDb); trie != nil { - sdata := trie.Get(req.AccKey) - var acc state.Account - if err := rlp.DecodeBytes(sdata, &acc); err == nil { - entry, _ := pm.chainDb.Get(acc.CodeHash) - if bytes+len(entry) >= softResponseLimit { - break - } - data = append(data, entry) - bytes += len(entry) - } + statedb, err := pm.blockchain.State() + if err != nil { + continue + } + account, err := pm.getAccount(statedb, header.Root, common.BytesToHash(req.AccKey)) + if err != nil { + continue + } + code, _ := statedb.Database().TrieDB().Node(common.BytesToHash(account.CodeHash)) + + data = append(data, code) + if bytes += len(code); bytes >= softResponseLimit { + break } } } @@ -701,25 +703,29 @@ func (pm *ProtocolManager) handleMsg(p *peer) error { return errResp(ErrRequestRejected, "") } for _, req := range req.Reqs { - if bytes >= softResponseLimit { - break - } // Retrieve the requested state entry, stopping if enough was found if header := core.GetHeader(pm.chainDb, req.BHash, core.GetBlockNumber(pm.chainDb, req.BHash)); header != nil { - if tr, _ := trie.New(header.Root, pm.chainDb); tr != nil { - if len(req.AccKey) > 0 { - sdata := tr.Get(req.AccKey) - tr = nil - var acc state.Account - if err := rlp.DecodeBytes(sdata, &acc); err == nil { - tr, _ = trie.New(acc.Root, pm.chainDb) - } + statedb, err := pm.blockchain.State() + if err != nil { + continue + } + var trie state.Trie + if len(req.AccKey) > 0 { + account, err := pm.getAccount(statedb, header.Root, common.BytesToHash(req.AccKey)) + if err != nil { + continue } - if tr != nil { - var proof light.NodeList - tr.Prove(req.Key, 0, &proof) - proofs = append(proofs, proof) - bytes += proof.DataSize() + trie, _ = statedb.Database().OpenStorageTrie(common.BytesToHash(req.AccKey), account.Root) + } else { + trie, _ = statedb.Database().OpenTrie(header.Root) + } + if trie != nil { + var proof light.NodeList + trie.Prove(req.Key, 0, &proof) + + proofs = append(proofs, proof) + if bytes += proof.DataSize(); bytes >= softResponseLimit { + break } } } @@ -740,9 +746,9 @@ func (pm *ProtocolManager) handleMsg(p *peer) error { } // Gather state data until the fetch or network limits is reached var ( - lastBHash common.Hash - lastAccKey []byte - tr, str *trie.Trie + lastBHash common.Hash + statedb *state.StateDB + root common.Hash ) reqCnt := len(req.Reqs) if reject(uint64(reqCnt), MaxProofsFetch) { @@ -752,35 +758,36 @@ func (pm *ProtocolManager) handleMsg(p *peer) error { nodes := light.NewNodeSet() for _, req := range req.Reqs { - if nodes.DataSize() >= softResponseLimit { - break - } - if tr == nil || req.BHash != lastBHash { + // Look up the state belonging to the request + if statedb == nil || req.BHash != lastBHash { + statedb, root, lastBHash = nil, common.Hash{}, req.BHash + if header := core.GetHeader(pm.chainDb, req.BHash, core.GetBlockNumber(pm.chainDb, req.BHash)); header != nil { - tr, _ = trie.New(header.Root, pm.chainDb) - } else { - tr = nil + statedb, _ = pm.blockchain.State() + root = header.Root } - lastBHash = req.BHash - str = nil } - if tr != nil { - if len(req.AccKey) > 0 { - if str == nil || !bytes.Equal(req.AccKey, lastAccKey) { - sdata := tr.Get(req.AccKey) - str = nil - var acc state.Account - if err := rlp.DecodeBytes(sdata, &acc); err == nil { - str, _ = trie.New(acc.Root, pm.chainDb) - } - lastAccKey = common.CopyBytes(req.AccKey) - } - if str != nil { - str.Prove(req.Key, req.FromLevel, nodes) - } - } else { - tr.Prove(req.Key, req.FromLevel, nodes) + if statedb == nil { + continue + } + // Pull the account or storage trie of the request + var trie state.Trie + if len(req.AccKey) > 0 { + account, err := pm.getAccount(statedb, root, common.BytesToHash(req.AccKey)) + if err != nil { + continue } + trie, _ = statedb.Database().OpenStorageTrie(common.BytesToHash(req.AccKey), account.Root) + } else { + trie, _ = statedb.Database().OpenTrie(root) + } + if trie == nil { + continue + } + // Prove the user's request from the account or stroage trie + trie.Prove(req.Key, req.FromLevel, nodes) + if nodes.DataSize() >= softResponseLimit { + break } } proofs := nodes.NodeList() @@ -849,23 +856,29 @@ func (pm *ProtocolManager) handleMsg(p *peer) error { if reject(uint64(reqCnt), MaxHelperTrieProofsFetch) { return errResp(ErrRequestRejected, "") } - trieDb := ethdb.NewTable(pm.chainDb, light.ChtTablePrefix) for _, req := range req.Reqs { - if bytes >= softResponseLimit { - break - } - if header := pm.blockchain.GetHeaderByNumber(req.BlockNum); header != nil { sectionHead := core.GetCanonicalHash(pm.chainDb, req.ChtNum*light.ChtV1Frequency-1) if root := light.GetChtRoot(pm.chainDb, req.ChtNum-1, sectionHead); root != (common.Hash{}) { - if tr, _ := trie.New(root, trieDb); tr != nil { - var encNumber [8]byte - binary.BigEndian.PutUint64(encNumber[:], req.BlockNum) - var proof light.NodeList - tr.Prove(encNumber[:], 0, &proof) - proofs = append(proofs, ChtResp{Header: header, Proof: proof}) - bytes += proof.DataSize() + estHeaderRlpSize + statedb, err := pm.blockchain.State() + if err != nil { + continue } + trie, err := statedb.Database().OpenTrie(root) + if err != nil { + continue + } + var encNumber [8]byte + binary.BigEndian.PutUint64(encNumber[:], req.BlockNum) + + var proof light.NodeList + trie.Prove(encNumber[:], 0, &proof) + + proofs = append(proofs, ChtResp{Header: header, Proof: proof}) + if bytes += proof.DataSize() + estHeaderRlpSize; bytes >= softResponseLimit { + break + } + } } } @@ -897,25 +910,21 @@ func (pm *ProtocolManager) handleMsg(p *peer) error { lastIdx uint64 lastType uint root common.Hash - tr *trie.Trie + statedb *state.StateDB + trie state.Trie ) nodes := light.NewNodeSet() for _, req := range req.Reqs { - if nodes.DataSize()+auxBytes >= softResponseLimit { - break - } - if tr == nil || req.HelperTrieType != lastType || req.TrieIdx != lastIdx { - var prefix string - root, prefix = pm.getHelperTrie(req.HelperTrieType, req.TrieIdx) - if root != (common.Hash{}) { - if t, err := trie.New(root, ethdb.NewTable(pm.chainDb, prefix)); err == nil { - tr = t + if trie == nil || req.HelperTrieType != lastType || req.TrieIdx != lastIdx { + statedb, trie, lastType, lastIdx = nil, nil, req.HelperTrieType, req.TrieIdx + + if root, _ = pm.getHelperTrie(req.HelperTrieType, req.TrieIdx); root != (common.Hash{}) { + if statedb, _ = pm.blockchain.State(); statedb != nil { + trie, _ = statedb.Database().OpenTrie(root) } } - lastType = req.HelperTrieType - lastIdx = req.TrieIdx } if req.AuxReq == auxRoot { var data []byte @@ -925,8 +934,8 @@ func (pm *ProtocolManager) handleMsg(p *peer) error { auxData = append(auxData, data) auxBytes += len(data) } else { - if tr != nil { - tr.Prove(req.Key, req.FromLevel, nodes) + if trie != nil { + trie.Prove(req.Key, req.FromLevel, nodes) } if req.AuxReq != 0 { data := pm.getHelperTrieAuxData(req) @@ -934,6 +943,9 @@ func (pm *ProtocolManager) handleMsg(p *peer) error { auxBytes += len(data) } } + if nodes.DataSize()+auxBytes >= softResponseLimit { + break + } } proofs := nodes.NodeList() bv, rcost := p.fcClient.RequestProcessed(costs.baseCost + uint64(reqCnt)*costs.reqCost) @@ -1090,6 +1102,23 @@ func (pm *ProtocolManager) handleMsg(p *peer) error { return nil } +// getAccount retrieves an account from the state based at root. +func (pm *ProtocolManager) getAccount(statedb *state.StateDB, root, hash common.Hash) (state.Account, error) { + trie, err := trie.New(root, statedb.Database().TrieDB()) + if err != nil { + return state.Account{}, err + } + blob, err := trie.TryGet(hash[:]) + if err != nil { + return state.Account{}, err + } + var account state.Account + if err = rlp.DecodeBytes(blob, &account); err != nil { + return state.Account{}, err + } + return account, nil +} + // getHelperTrie returns the post-processed trie root for the given trie ID and section index func (pm *ProtocolManager) getHelperTrie(id uint, idx uint64) (common.Hash, string) { switch id { |