From ca376ead88a5a26626a90abdb62f4de7f6313822 Mon Sep 17 00:00:00 2001 From: Felföldi Zsolt Date: Tue, 24 Oct 2017 15:19:09 +0200 Subject: les, light: LES/2 protocol version (#14970) This PR implements the new LES protocol version extensions: * new and more efficient Merkle proofs reply format (when replying to a multiple Merkle proofs request, we just send a single set of trie nodes containing all necessary nodes) * BBT (BloomBitsTrie) works similarly to the existing CHT and contains the bloombits search data to speed up log searches * GetTxStatusMsg returns the inclusion position or the pending/queued/unknown state of a transaction referenced by hash * an optional signature of new block data (number/hash/td) can be included in AnnounceMsg to provide an option for "very light clients" (mobile/embedded devices) to skip expensive Ethash check and accept multiple signatures of somewhat trusted servers (still a lot better than trusting a single server completely and retrieving everything through RPC). The new client mode is not implemented in this PR, just the protocol extension. --- les/api_backend.go | 9 +- les/backend.go | 79 ++++++++---- les/bloombits.go | 84 +++++++++++++ les/handler.go | 354 +++++++++++++++++++++++++++++++++++++++++++++++----- les/handler_test.go | 162 ++++++++++++++++++++++-- les/helper_test.go | 10 +- les/odr.go | 46 +++++-- les/odr_requests.go | 310 +++++++++++++++++++++++++++++++++++++-------- les/odr_test.go | 11 +- les/peer.go | 133 ++++++++++++++++---- les/protocol.go | 64 ++++++++-- les/request_test.go | 11 +- les/retrieve.go | 5 +- les/server.go | 198 ++++++++++++----------------- 14 files changed, 1202 insertions(+), 274 deletions(-) create mode 100644 les/bloombits.go (limited to 'les') diff --git a/les/api_backend.go b/les/api_backend.go index 0d2d31b67..56f617a7d 100644 --- a/les/api_backend.go +++ b/les/api_backend.go @@ -174,8 +174,15 @@ func (b *LesApiBackend) AccountManager() *accounts.Manager { } func (b *LesApiBackend) BloomStatus() (uint64, uint64) { - return params.BloomBitsBlocks, 0 + if b.eth.bloomIndexer == nil { + return 0, 0 + } + sections, _, _ := b.eth.bloomIndexer.Sections() + return light.BloomTrieFrequency, sections } func (b *LesApiBackend) ServiceFilter(ctx context.Context, session *bloombits.MatcherSession) { + for i := 0; i < bloomFilterThreads; i++ { + go session.Multiplex(bloomRetrievalBatch, bloomRetrievalWait, b.eth.bloomRequests) + } } diff --git a/les/backend.go b/les/backend.go index 4c33417c0..3a68d13eb 100644 --- a/les/backend.go +++ b/les/backend.go @@ -27,6 +27,7 @@ import ( "github.com/ethereum/go-ethereum/common/hexutil" "github.com/ethereum/go-ethereum/consensus" "github.com/ethereum/go-ethereum/core" + "github.com/ethereum/go-ethereum/core/bloombits" "github.com/ethereum/go-ethereum/core/types" "github.com/ethereum/go-ethereum/eth" "github.com/ethereum/go-ethereum/eth/downloader" @@ -61,6 +62,9 @@ type LightEthereum struct { // DB interfaces chainDb ethdb.Database // Block chain database + bloomRequests chan chan *bloombits.Retrieval // Channel receiving bloom data retrieval requests + bloomIndexer, chtIndexer, bloomTrieIndexer *core.ChainIndexer + ApiBackend *LesApiBackend eventMux *event.TypeMux @@ -87,47 +91,61 @@ func New(ctx *node.ServiceContext, config *eth.Config) (*LightEthereum, error) { peers := newPeerSet() quitSync := make(chan struct{}) - eth := &LightEthereum{ - chainConfig: chainConfig, - chainDb: chainDb, - eventMux: ctx.EventMux, - peers: peers, - reqDist: newRequestDistributor(peers, quitSync), - accountManager: ctx.AccountManager, - engine: eth.CreateConsensusEngine(ctx, config, chainConfig, chainDb), - shutdownChan: make(chan bool), - networkId: config.NetworkId, + leth := &LightEthereum{ + chainConfig: chainConfig, + chainDb: chainDb, + eventMux: ctx.EventMux, + peers: peers, + reqDist: newRequestDistributor(peers, quitSync), + accountManager: ctx.AccountManager, + engine: eth.CreateConsensusEngine(ctx, config, chainConfig, chainDb), + shutdownChan: make(chan bool), + networkId: config.NetworkId, + bloomRequests: make(chan chan *bloombits.Retrieval), + bloomIndexer: eth.NewBloomIndexer(chainDb, light.BloomTrieFrequency), + chtIndexer: light.NewChtIndexer(chainDb, true), + bloomTrieIndexer: light.NewBloomTrieIndexer(chainDb, true), } - eth.relay = NewLesTxRelay(peers, eth.reqDist) - eth.serverPool = newServerPool(chainDb, quitSync, ð.wg) - eth.retriever = newRetrieveManager(peers, eth.reqDist, eth.serverPool) - eth.odr = NewLesOdr(chainDb, eth.retriever) - if eth.blockchain, err = light.NewLightChain(eth.odr, eth.chainConfig, eth.engine); err != nil { + leth.relay = NewLesTxRelay(peers, leth.reqDist) + leth.serverPool = newServerPool(chainDb, quitSync, &leth.wg) + leth.retriever = newRetrieveManager(peers, leth.reqDist, leth.serverPool) + leth.odr = NewLesOdr(chainDb, leth.chtIndexer, leth.bloomTrieIndexer, leth.bloomIndexer, leth.retriever) + if leth.blockchain, err = light.NewLightChain(leth.odr, leth.chainConfig, leth.engine); err != nil { return nil, err } + leth.bloomIndexer.Start(leth.blockchain) // Rewind the chain in case of an incompatible config upgrade. if compat, ok := genesisErr.(*params.ConfigCompatError); ok { log.Warn("Rewinding chain to upgrade configuration", "err", compat) - eth.blockchain.SetHead(compat.RewindTo) + leth.blockchain.SetHead(compat.RewindTo) core.WriteChainConfig(chainDb, genesisHash, chainConfig) } - eth.txPool = light.NewTxPool(eth.chainConfig, eth.blockchain, eth.relay) - if eth.protocolManager, err = NewProtocolManager(eth.chainConfig, true, config.NetworkId, eth.eventMux, eth.engine, eth.peers, eth.blockchain, nil, chainDb, eth.odr, eth.relay, quitSync, ð.wg); err != nil { + leth.txPool = light.NewTxPool(leth.chainConfig, leth.blockchain, leth.relay) + if leth.protocolManager, err = NewProtocolManager(leth.chainConfig, true, ClientProtocolVersions, config.NetworkId, leth.eventMux, leth.engine, leth.peers, leth.blockchain, nil, chainDb, leth.odr, leth.relay, quitSync, &leth.wg); err != nil { return nil, err } - eth.ApiBackend = &LesApiBackend{eth, nil} + leth.ApiBackend = &LesApiBackend{leth, nil} gpoParams := config.GPO if gpoParams.Default == nil { gpoParams.Default = config.GasPrice } - eth.ApiBackend.gpo = gasprice.NewOracle(eth.ApiBackend, gpoParams) - return eth, nil + leth.ApiBackend.gpo = gasprice.NewOracle(leth.ApiBackend, gpoParams) + return leth, nil } -func lesTopic(genesisHash common.Hash) discv5.Topic { - return discv5.Topic("LES@" + common.Bytes2Hex(genesisHash.Bytes()[0:8])) +func lesTopic(genesisHash common.Hash, protocolVersion uint) discv5.Topic { + var name string + switch protocolVersion { + case lpv1: + name = "LES" + case lpv2: + name = "LES2" + default: + panic(nil) + } + return discv5.Topic(name + common.Bytes2Hex(genesisHash.Bytes()[0:8])) } type LightDummyAPI struct{} @@ -200,9 +218,13 @@ func (s *LightEthereum) Protocols() []p2p.Protocol { // Start implements node.Service, starting all internal goroutines needed by the // Ethereum protocol implementation. func (s *LightEthereum) Start(srvr *p2p.Server) error { + s.startBloomHandlers() log.Warn("Light client mode is an experimental feature") s.netRPCService = ethapi.NewPublicNetAPI(srvr, s.networkId) - s.serverPool.start(srvr, lesTopic(s.blockchain.Genesis().Hash())) + // search the topic belonging to the oldest supported protocol because + // servers always advertise all supported protocols + protocolVersion := ClientProtocolVersions[len(ClientProtocolVersions)-1] + s.serverPool.start(srvr, lesTopic(s.blockchain.Genesis().Hash(), protocolVersion)) s.protocolManager.Start() return nil } @@ -211,6 +233,15 @@ func (s *LightEthereum) Start(srvr *p2p.Server) error { // Ethereum protocol. func (s *LightEthereum) Stop() error { s.odr.Stop() + if s.bloomIndexer != nil { + s.bloomIndexer.Close() + } + if s.chtIndexer != nil { + s.chtIndexer.Close() + } + if s.bloomTrieIndexer != nil { + s.bloomTrieIndexer.Close() + } s.blockchain.Stop() s.protocolManager.Stop() s.txPool.Stop() diff --git a/les/bloombits.go b/les/bloombits.go new file mode 100644 index 000000000..dff83d349 --- /dev/null +++ b/les/bloombits.go @@ -0,0 +1,84 @@ +// Copyright 2017 The go-ethereum Authors +// This file is part of the go-ethereum library. +// +// The go-ethereum library is free software: you can redistribute it and/or modify +// it under the terms of the GNU Lesser General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// The go-ethereum library is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Lesser General Public License for more details. +// +// You should have received a copy of the GNU Lesser General Public License +// along with the go-ethereum library. If not, see . + +package les + +import ( + "time" + + "github.com/ethereum/go-ethereum/common/bitutil" + "github.com/ethereum/go-ethereum/light" +) + +const ( + // bloomServiceThreads is the number of goroutines used globally by an Ethereum + // instance to service bloombits lookups for all running filters. + bloomServiceThreads = 16 + + // bloomFilterThreads is the number of goroutines used locally per filter to + // multiplex requests onto the global servicing goroutines. + bloomFilterThreads = 3 + + // bloomRetrievalBatch is the maximum number of bloom bit retrievals to service + // in a single batch. + bloomRetrievalBatch = 16 + + // bloomRetrievalWait is the maximum time to wait for enough bloom bit requests + // to accumulate request an entire batch (avoiding hysteresis). + bloomRetrievalWait = time.Microsecond * 100 +) + +// startBloomHandlers starts a batch of goroutines to accept bloom bit database +// retrievals from possibly a range of filters and serving the data to satisfy. +func (eth *LightEthereum) startBloomHandlers() { + for i := 0; i < bloomServiceThreads; i++ { + go func() { + for { + select { + case <-eth.shutdownChan: + return + + case request := <-eth.bloomRequests: + task := <-request + task.Bitsets = make([][]byte, len(task.Sections)) + compVectors, err := light.GetBloomBits(task.Context, eth.odr, task.Bit, task.Sections) + if err == nil { + for i, _ := range task.Sections { + if blob, err := bitutil.DecompressBytes(compVectors[i], int(light.BloomTrieFrequency/8)); err == nil { + task.Bitsets[i] = blob + } else { + task.Error = err + } + } + } else { + task.Error = err + } + request <- task + } + } + }() + } +} + +const ( + // bloomConfirms is the number of confirmation blocks before a bloom section is + // considered probably final and its rotated bits are calculated. + bloomConfirms = 256 + + // bloomThrottling is the time to wait between processing two consecutive index + // sections. It's useful during chain upgrades to prevent disk overload. + bloomThrottling = 100 * time.Millisecond +) diff --git a/les/handler.go b/les/handler.go index df7eb6af5..de07b7244 100644 --- a/les/handler.go +++ b/les/handler.go @@ -18,6 +18,7 @@ package les import ( + "bytes" "encoding/binary" "errors" "fmt" @@ -35,6 +36,7 @@ import ( "github.com/ethereum/go-ethereum/eth/downloader" "github.com/ethereum/go-ethereum/ethdb" "github.com/ethereum/go-ethereum/event" + "github.com/ethereum/go-ethereum/light" "github.com/ethereum/go-ethereum/log" "github.com/ethereum/go-ethereum/p2p" "github.com/ethereum/go-ethereum/p2p/discover" @@ -50,13 +52,14 @@ const ( ethVersion = 63 // equivalent eth version for the downloader - MaxHeaderFetch = 192 // Amount of block headers to be fetched per retrieval request - MaxBodyFetch = 32 // Amount of block bodies to be fetched per retrieval request - MaxReceiptFetch = 128 // Amount of transaction receipts to allow fetching per request - MaxCodeFetch = 64 // Amount of contract codes to allow fetching per request - MaxProofsFetch = 64 // Amount of merkle proofs to be fetched per retrieval request - MaxHeaderProofsFetch = 64 // Amount of merkle proofs to be fetched per retrieval request - MaxTxSend = 64 // Amount of transactions to be send per request + MaxHeaderFetch = 192 // Amount of block headers to be fetched per retrieval request + MaxBodyFetch = 32 // Amount of block bodies to be fetched per retrieval request + MaxReceiptFetch = 128 // Amount of transaction receipts to allow fetching per request + MaxCodeFetch = 64 // Amount of contract codes to allow fetching per request + MaxProofsFetch = 64 // Amount of merkle proofs to be fetched per retrieval request + MaxHelperTrieProofsFetch = 64 // Amount of merkle proofs to be fetched per retrieval request + MaxTxSend = 64 // Amount of transactions to be send per request + MaxTxStatus = 256 // Amount of transactions to queried per request disableClientRemovePeer = false ) @@ -86,8 +89,7 @@ type BlockChain interface { } type txPool interface { - // AddRemotes should add the given transactions to the pool. - AddRemotes([]*types.Transaction) error + AddOrGetTxStatus(txs []*types.Transaction, txHashes []common.Hash) []core.TxStatusData } type ProtocolManager struct { @@ -125,7 +127,7 @@ type ProtocolManager struct { // NewProtocolManager returns a new ethereum sub protocol manager. The Ethereum sub protocol manages peers capable // with the ethereum network. -func NewProtocolManager(chainConfig *params.ChainConfig, lightSync bool, networkId uint64, mux *event.TypeMux, engine consensus.Engine, peers *peerSet, blockchain BlockChain, txpool txPool, chainDb ethdb.Database, odr *LesOdr, txrelay *LesTxRelay, quitSync chan struct{}, wg *sync.WaitGroup) (*ProtocolManager, error) { +func NewProtocolManager(chainConfig *params.ChainConfig, lightSync bool, protocolVersions []uint, networkId uint64, mux *event.TypeMux, engine consensus.Engine, peers *peerSet, blockchain BlockChain, txpool txPool, chainDb ethdb.Database, odr *LesOdr, txrelay *LesTxRelay, quitSync chan struct{}, wg *sync.WaitGroup) (*ProtocolManager, error) { // Create the protocol manager with the base fields manager := &ProtocolManager{ lightSync: lightSync, @@ -147,15 +149,16 @@ func NewProtocolManager(chainConfig *params.ChainConfig, lightSync bool, network manager.retriever = odr.retriever manager.reqDist = odr.retriever.dist } + // Initiate a sub-protocol for every implemented version we can handle - manager.SubProtocols = make([]p2p.Protocol, 0, len(ProtocolVersions)) - for i, version := range ProtocolVersions { + manager.SubProtocols = make([]p2p.Protocol, 0, len(protocolVersions)) + for _, version := range protocolVersions { // Compatible, initialize the sub-protocol version := version // Closure for the run manager.SubProtocols = append(manager.SubProtocols, p2p.Protocol{ Name: "les", Version: version, - Length: ProtocolLengths[i], + Length: ProtocolLengths[version], Run: func(p *p2p.Peer, rw p2p.MsgReadWriter) error { var entry *poolEntry peer := manager.newPeer(int(version), networkId, p, rw) @@ -315,7 +318,7 @@ func (pm *ProtocolManager) handle(p *peer) error { } } -var reqList = []uint64{GetBlockHeadersMsg, GetBlockBodiesMsg, GetCodeMsg, GetReceiptsMsg, GetProofsMsg, SendTxMsg, GetHeaderProofsMsg} +var reqList = []uint64{GetBlockHeadersMsg, GetBlockBodiesMsg, GetCodeMsg, GetReceiptsMsg, GetProofsV1Msg, SendTxMsg, SendTxV2Msg, GetTxStatusMsg, GetHeaderProofsMsg, GetProofsV2Msg, GetHelperTrieProofsMsg} // handleMsg is invoked whenever an inbound message is received from a remote // peer. The remote connection is torn down upon returning any error. @@ -362,11 +365,23 @@ func (pm *ProtocolManager) handleMsg(p *peer) error { // Block header query, collect the requested headers and reply case AnnounceMsg: p.Log().Trace("Received announce message") + if p.requestAnnounceType == announceTypeNone { + return errResp(ErrUnexpectedResponse, "") + } var req announceData if err := msg.Decode(&req); err != nil { return errResp(ErrDecode, "%v: %v", msg, err) } + + if p.requestAnnounceType == announceTypeSigned { + if err := req.checkSignature(p.pubKey); err != nil { + p.Log().Trace("Invalid announcement signature", "err", err) + return err + } + p.Log().Trace("Valid announcement signature") + } + p.Log().Trace("Announce message content", "number", req.Number, "hash", req.Hash, "td", req.Td, "reorg", req.ReorgDepth) if pm.fetcher != nil { pm.fetcher.announce(p, &req) @@ -655,7 +670,7 @@ func (pm *ProtocolManager) handleMsg(p *peer) error { Obj: resp.Receipts, } - case GetProofsMsg: + case GetProofsV1Msg: p.Log().Trace("Received proofs request") // Decode the retrieval message var req struct { @@ -690,9 +705,10 @@ func (pm *ProtocolManager) handleMsg(p *peer) error { } } if tr != nil { - proof := tr.Prove(req.Key) + var proof light.NodeList + tr.Prove(req.Key, 0, &proof) proofs = append(proofs, proof) - bytes += len(proof) + bytes += proof.DataSize() } } } @@ -701,7 +717,67 @@ func (pm *ProtocolManager) handleMsg(p *peer) error { pm.server.fcCostStats.update(msg.Code, uint64(reqCnt), rcost) return p.SendProofs(req.ReqID, bv, proofs) - case ProofsMsg: + case GetProofsV2Msg: + p.Log().Trace("Received les/2 proofs request") + // Decode the retrieval message + var req struct { + ReqID uint64 + Reqs []ProofReq + } + if err := msg.Decode(&req); err != nil { + return errResp(ErrDecode, "msg %v: %v", msg, err) + } + // Gather state data until the fetch or network limits is reached + var ( + lastBHash common.Hash + lastAccKey []byte + tr, str *trie.Trie + ) + reqCnt := len(req.Reqs) + if reject(uint64(reqCnt), MaxProofsFetch) { + return errResp(ErrRequestRejected, "") + } + + nodes := light.NewNodeSet() + + for _, req := range req.Reqs { + if nodes.DataSize() >= softResponseLimit { + break + } + if tr == nil || req.BHash != lastBHash { + if header := core.GetHeader(pm.chainDb, req.BHash, core.GetBlockNumber(pm.chainDb, req.BHash)); header != nil { + tr, _ = trie.New(header.Root, pm.chainDb) + } else { + tr = nil + } + lastBHash = req.BHash + str = nil + } + if tr != nil { + if len(req.AccKey) > 0 { + if str == nil || !bytes.Equal(req.AccKey, lastAccKey) { + sdata := tr.Get(req.AccKey) + str = nil + var acc state.Account + if err := rlp.DecodeBytes(sdata, &acc); err == nil { + str, _ = trie.New(acc.Root, pm.chainDb) + } + lastAccKey = common.CopyBytes(req.AccKey) + } + if str != nil { + str.Prove(req.Key, req.FromLevel, nodes) + } + } else { + tr.Prove(req.Key, req.FromLevel, nodes) + } + } + } + proofs := nodes.NodeList() + bv, rcost := p.fcClient.RequestProcessed(costs.baseCost + uint64(reqCnt)*costs.reqCost) + pm.server.fcCostStats.update(msg.Code, uint64(reqCnt), rcost) + return p.SendProofsV2(req.ReqID, bv, proofs) + + case ProofsV1Msg: if pm.odr == nil { return errResp(ErrUnexpectedResponse, "") } @@ -710,14 +786,35 @@ func (pm *ProtocolManager) handleMsg(p *peer) error { // A batch of merkle proofs arrived to one of our previous requests var resp struct { ReqID, BV uint64 - Data [][]rlp.RawValue + Data []light.NodeList } if err := msg.Decode(&resp); err != nil { return errResp(ErrDecode, "msg %v: %v", msg, err) } p.fcServer.GotReply(resp.ReqID, resp.BV) deliverMsg = &Msg{ - MsgType: MsgProofs, + MsgType: MsgProofsV1, + ReqID: resp.ReqID, + Obj: resp.Data, + } + + case ProofsV2Msg: + if pm.odr == nil { + return errResp(ErrUnexpectedResponse, "") + } + + p.Log().Trace("Received les/2 proofs response") + // A batch of merkle proofs arrived to one of our previous requests + var resp struct { + ReqID, BV uint64 + Data light.NodeList + } + if err := msg.Decode(&resp); err != nil { + return errResp(ErrDecode, "msg %v: %v", msg, err) + } + p.fcServer.GotReply(resp.ReqID, resp.BV) + deliverMsg = &Msg{ + MsgType: MsgProofsV2, ReqID: resp.ReqID, Obj: resp.Data, } @@ -738,22 +835,25 @@ func (pm *ProtocolManager) handleMsg(p *peer) error { proofs []ChtResp ) reqCnt := len(req.Reqs) - if reject(uint64(reqCnt), MaxHeaderProofsFetch) { + if reject(uint64(reqCnt), MaxHelperTrieProofsFetch) { return errResp(ErrRequestRejected, "") } + trieDb := ethdb.NewTable(pm.chainDb, light.ChtTablePrefix) for _, req := range req.Reqs { if bytes >= softResponseLimit { break } if header := pm.blockchain.GetHeaderByNumber(req.BlockNum); header != nil { - if root := getChtRoot(pm.chainDb, req.ChtNum); root != (common.Hash{}) { - if tr, _ := trie.New(root, pm.chainDb); tr != nil { + sectionHead := core.GetCanonicalHash(pm.chainDb, (req.ChtNum+1)*light.ChtV1Frequency-1) + if root := light.GetChtRoot(pm.chainDb, req.ChtNum, sectionHead); root != (common.Hash{}) { + if tr, _ := trie.New(root, trieDb); tr != nil { var encNumber [8]byte binary.BigEndian.PutUint64(encNumber[:], req.BlockNum) - proof := tr.Prove(encNumber[:]) + var proof light.NodeList + tr.Prove(encNumber[:], 0, &proof) proofs = append(proofs, ChtResp{Header: header, Proof: proof}) - bytes += len(proof) + estHeaderRlpSize + bytes += proof.DataSize() + estHeaderRlpSize } } } @@ -762,6 +862,73 @@ func (pm *ProtocolManager) handleMsg(p *peer) error { pm.server.fcCostStats.update(msg.Code, uint64(reqCnt), rcost) return p.SendHeaderProofs(req.ReqID, bv, proofs) + case GetHelperTrieProofsMsg: + p.Log().Trace("Received helper trie proof request") + // Decode the retrieval message + var req struct { + ReqID uint64 + Reqs []HelperTrieReq + } + if err := msg.Decode(&req); err != nil { + return errResp(ErrDecode, "msg %v: %v", msg, err) + } + // Gather state data until the fetch or network limits is reached + var ( + auxBytes int + auxData [][]byte + ) + reqCnt := len(req.Reqs) + if reject(uint64(reqCnt), MaxHelperTrieProofsFetch) { + return errResp(ErrRequestRejected, "") + } + + var ( + lastIdx uint64 + lastType uint + root common.Hash + tr *trie.Trie + ) + + nodes := light.NewNodeSet() + + for _, req := range req.Reqs { + if nodes.DataSize()+auxBytes >= softResponseLimit { + break + } + if tr == nil || req.HelperTrieType != lastType || req.TrieIdx != lastIdx { + var prefix string + root, prefix = pm.getHelperTrie(req.HelperTrieType, req.TrieIdx) + if root != (common.Hash{}) { + if t, err := trie.New(root, ethdb.NewTable(pm.chainDb, prefix)); err == nil { + tr = t + } + } + lastType = req.HelperTrieType + lastIdx = req.TrieIdx + } + if req.AuxReq == auxRoot { + var data []byte + if root != (common.Hash{}) { + data = root[:] + } + auxData = append(auxData, data) + auxBytes += len(data) + } else { + if tr != nil { + tr.Prove(req.Key, req.FromLevel, nodes) + } + if req.AuxReq != 0 { + data := pm.getHelperTrieAuxData(req) + auxData = append(auxData, data) + auxBytes += len(data) + } + } + } + proofs := nodes.NodeList() + bv, rcost := p.fcClient.RequestProcessed(costs.baseCost + uint64(reqCnt)*costs.reqCost) + pm.server.fcCostStats.update(msg.Code, uint64(reqCnt), rcost) + return p.SendHelperTrieProofs(req.ReqID, bv, HelperTrieResps{Proofs: proofs, AuxData: auxData}) + case HeaderProofsMsg: if pm.odr == nil { return errResp(ErrUnexpectedResponse, "") @@ -782,9 +949,30 @@ func (pm *ProtocolManager) handleMsg(p *peer) error { Obj: resp.Data, } + case HelperTrieProofsMsg: + if pm.odr == nil { + return errResp(ErrUnexpectedResponse, "") + } + + p.Log().Trace("Received helper trie proof response") + var resp struct { + ReqID, BV uint64 + Data HelperTrieResps + } + if err := msg.Decode(&resp); err != nil { + return errResp(ErrDecode, "msg %v: %v", msg, err) + } + + p.fcServer.GotReply(resp.ReqID, resp.BV) + deliverMsg = &Msg{ + MsgType: MsgHelperTrieProofs, + ReqID: resp.ReqID, + Obj: resp.Data, + } + case SendTxMsg: if pm.txpool == nil { - return errResp(ErrUnexpectedResponse, "") + return errResp(ErrRequestRejected, "") } // Transactions arrived, parse all of them and deliver to the pool var txs []*types.Transaction @@ -796,13 +984,82 @@ func (pm *ProtocolManager) handleMsg(p *peer) error { return errResp(ErrRequestRejected, "") } - if err := pm.txpool.AddRemotes(txs); err != nil { - return errResp(ErrUnexpectedResponse, "msg: %v", err) + txHashes := make([]common.Hash, len(txs)) + for i, tx := range txs { + txHashes[i] = tx.Hash() } + pm.addOrGetTxStatus(txs, txHashes) _, rcost := p.fcClient.RequestProcessed(costs.baseCost + uint64(reqCnt)*costs.reqCost) pm.server.fcCostStats.update(msg.Code, uint64(reqCnt), rcost) + case SendTxV2Msg: + if pm.txpool == nil { + return errResp(ErrRequestRejected, "") + } + // Transactions arrived, parse all of them and deliver to the pool + var req struct { + ReqID uint64 + Txs []*types.Transaction + } + if err := msg.Decode(&req); err != nil { + return errResp(ErrDecode, "msg %v: %v", msg, err) + } + reqCnt := len(req.Txs) + if reject(uint64(reqCnt), MaxTxSend) { + return errResp(ErrRequestRejected, "") + } + + txHashes := make([]common.Hash, len(req.Txs)) + for i, tx := range req.Txs { + txHashes[i] = tx.Hash() + } + + res := pm.addOrGetTxStatus(req.Txs, txHashes) + + bv, rcost := p.fcClient.RequestProcessed(costs.baseCost + uint64(reqCnt)*costs.reqCost) + pm.server.fcCostStats.update(msg.Code, uint64(reqCnt), rcost) + return p.SendTxStatus(req.ReqID, bv, res) + + case GetTxStatusMsg: + if pm.txpool == nil { + return errResp(ErrUnexpectedResponse, "") + } + // Transactions arrived, parse all of them and deliver to the pool + var req struct { + ReqID uint64 + TxHashes []common.Hash + } + if err := msg.Decode(&req); err != nil { + return errResp(ErrDecode, "msg %v: %v", msg, err) + } + reqCnt := len(req.TxHashes) + if reject(uint64(reqCnt), MaxTxStatus) { + return errResp(ErrRequestRejected, "") + } + + res := pm.addOrGetTxStatus(nil, req.TxHashes) + + bv, rcost := p.fcClient.RequestProcessed(costs.baseCost + uint64(reqCnt)*costs.reqCost) + pm.server.fcCostStats.update(msg.Code, uint64(reqCnt), rcost) + return p.SendTxStatus(req.ReqID, bv, res) + + case TxStatusMsg: + if pm.odr == nil { + return errResp(ErrUnexpectedResponse, "") + } + + p.Log().Trace("Received tx status response") + var resp struct { + ReqID, BV uint64 + Status []core.TxStatusData + } + if err := msg.Decode(&resp); err != nil { + return errResp(ErrDecode, "msg %v: %v", msg, err) + } + + p.fcServer.GotReply(resp.ReqID, resp.BV) + default: p.Log().Trace("Received unknown message", "code", msg.Code) return errResp(ErrInvalidMsgCode, "%v", msg.Code) @@ -820,6 +1077,47 @@ func (pm *ProtocolManager) handleMsg(p *peer) error { return nil } +// getHelperTrie returns the post-processed trie root for the given trie ID and section index +func (pm *ProtocolManager) getHelperTrie(id uint, idx uint64) (common.Hash, string) { + switch id { + case htCanonical: + sectionHead := core.GetCanonicalHash(pm.chainDb, (idx+1)*light.ChtFrequency-1) + return light.GetChtV2Root(pm.chainDb, idx, sectionHead), light.ChtTablePrefix + case htBloomBits: + sectionHead := core.GetCanonicalHash(pm.chainDb, (idx+1)*light.BloomTrieFrequency-1) + return light.GetBloomTrieRoot(pm.chainDb, idx, sectionHead), light.BloomTrieTablePrefix + } + return common.Hash{}, "" +} + +// getHelperTrieAuxData returns requested auxiliary data for the given HelperTrie request +func (pm *ProtocolManager) getHelperTrieAuxData(req HelperTrieReq) []byte { + if req.HelperTrieType == htCanonical && req.AuxReq == auxHeader { + if len(req.Key) != 8 { + return nil + } + blockNum := binary.BigEndian.Uint64(req.Key) + hash := core.GetCanonicalHash(pm.chainDb, blockNum) + return core.GetHeaderRLP(pm.chainDb, hash, blockNum) + } + return nil +} + +func (pm *ProtocolManager) addOrGetTxStatus(txs []*types.Transaction, txHashes []common.Hash) []core.TxStatusData { + status := pm.txpool.AddOrGetTxStatus(txs, txHashes) + for i, _ := range status { + blockHash, blockNum, txIndex := core.GetTxLookupEntry(pm.chainDb, txHashes[i]) + if blockHash != (common.Hash{}) { + enc, err := rlp.EncodeToBytes(core.TxLookupEntry{BlockHash: blockHash, BlockIndex: blockNum, Index: txIndex}) + if err != nil { + panic(err) + } + status[i] = core.TxStatusData{Status: core.TxStatusIncluded, Data: enc} + } + } + return status +} + // NodeInfo retrieves some protocol metadata about the running host node. func (self *ProtocolManager) NodeInfo() *eth.EthNodeInfo { return ð.EthNodeInfo{ diff --git a/les/handler_test.go b/les/handler_test.go index b1f1aa095..a094cdc84 100644 --- a/les/handler_test.go +++ b/les/handler_test.go @@ -17,7 +17,10 @@ package les import ( + "bytes" + "math/big" "math/rand" + "runtime" "testing" "github.com/ethereum/go-ethereum/common" @@ -26,7 +29,9 @@ import ( "github.com/ethereum/go-ethereum/crypto" "github.com/ethereum/go-ethereum/eth/downloader" "github.com/ethereum/go-ethereum/ethdb" + "github.com/ethereum/go-ethereum/light" "github.com/ethereum/go-ethereum/p2p" + "github.com/ethereum/go-ethereum/params" "github.com/ethereum/go-ethereum/rlp" "github.com/ethereum/go-ethereum/trie" ) @@ -39,9 +44,29 @@ func expectResponse(r p2p.MsgReader, msgcode, reqID, bv uint64, data interface{} return p2p.ExpectMsg(r, msgcode, resp{reqID, bv, data}) } +func testCheckProof(t *testing.T, exp *light.NodeSet, got light.NodeList) { + if exp.KeyCount() > len(got) { + t.Errorf("proof has fewer nodes than expected") + return + } + if exp.KeyCount() < len(got) { + t.Errorf("proof has more nodes than expected") + return + } + for _, node := range got { + n, _ := exp.Get(crypto.Keccak256(node)) + if !bytes.Equal(n, node) { + t.Errorf("proof contents mismatch") + return + } + } +} + // Tests that block headers can be retrieved from a remote chain based on user queries. func TestGetBlockHeadersLes1(t *testing.T) { testGetBlockHeaders(t, 1) } +func TestGetBlockHeadersLes2(t *testing.T) { testGetBlockHeaders(t, 2) } + func testGetBlockHeaders(t *testing.T, protocol int) { db, _ := ethdb.NewMemDatabase() pm := newTestProtocolManagerMust(t, false, downloader.MaxHashFetch+15, nil, nil, nil, db) @@ -171,6 +196,8 @@ func testGetBlockHeaders(t *testing.T, protocol int) { // Tests that block contents can be retrieved from a remote chain based on their hashes. func TestGetBlockBodiesLes1(t *testing.T) { testGetBlockBodies(t, 1) } +func TestGetBlockBodiesLes2(t *testing.T) { testGetBlockBodies(t, 2) } + func testGetBlockBodies(t *testing.T, protocol int) { db, _ := ethdb.NewMemDatabase() pm := newTestProtocolManagerMust(t, false, downloader.MaxBlockFetch+15, nil, nil, nil, db) @@ -247,6 +274,8 @@ func testGetBlockBodies(t *testing.T, protocol int) { // Tests that the contract codes can be retrieved based on account addresses. func TestGetCodeLes1(t *testing.T) { testGetCode(t, 1) } +func TestGetCodeLes2(t *testing.T) { testGetCode(t, 2) } + func testGetCode(t *testing.T, protocol int) { // Assemble the test environment db, _ := ethdb.NewMemDatabase() @@ -280,6 +309,8 @@ func testGetCode(t *testing.T, protocol int) { // Tests that the transaction receipts can be retrieved based on hashes. func TestGetReceiptLes1(t *testing.T) { testGetReceipt(t, 1) } +func TestGetReceiptLes2(t *testing.T) { testGetReceipt(t, 2) } + func testGetReceipt(t *testing.T, protocol int) { // Assemble the test environment db, _ := ethdb.NewMemDatabase() @@ -307,6 +338,8 @@ func testGetReceipt(t *testing.T, protocol int) { // Tests that trie merkle proofs can be retrieved func TestGetProofsLes1(t *testing.T) { testGetProofs(t, 1) } +func TestGetProofsLes2(t *testing.T) { testGetProofs(t, 2) } + func testGetProofs(t *testing.T, protocol int) { // Assemble the test environment db, _ := ethdb.NewMemDatabase() @@ -315,8 +348,11 @@ func testGetProofs(t *testing.T, protocol int) { peer, _ := newTestPeer(t, "peer", protocol, pm, true) defer peer.close() - var proofreqs []ProofReq - var proofs [][]rlp.RawValue + var ( + proofreqs []ProofReq + proofsV1 [][]rlp.RawValue + ) + proofsV2 := light.NewNodeSet() accounts := []common.Address{testBankAddress, acc1Addr, acc2Addr, {}} for i := uint64(0); i <= bc.CurrentBlock().NumberU64(); i++ { @@ -331,14 +367,124 @@ func testGetProofs(t *testing.T, protocol int) { } proofreqs = append(proofreqs, req) - proof := trie.Prove(crypto.Keccak256(acc[:])) - proofs = append(proofs, proof) + switch protocol { + case 1: + var proof light.NodeList + trie.Prove(crypto.Keccak256(acc[:]), 0, &proof) + proofsV1 = append(proofsV1, proof) + case 2: + trie.Prove(crypto.Keccak256(acc[:]), 0, proofsV2) + } } } // Send the proof request and verify the response - cost := peer.GetRequestCost(GetProofsMsg, len(proofreqs)) - sendRequest(peer.app, GetProofsMsg, 42, cost, proofreqs) - if err := expectResponse(peer.app, ProofsMsg, 42, testBufLimit, proofs); err != nil { - t.Errorf("proofs mismatch: %v", err) + switch protocol { + case 1: + cost := peer.GetRequestCost(GetProofsV1Msg, len(proofreqs)) + sendRequest(peer.app, GetProofsV1Msg, 42, cost, proofreqs) + if err := expectResponse(peer.app, ProofsV1Msg, 42, testBufLimit, proofsV1); err != nil { + t.Errorf("proofs mismatch: %v", err) + } + case 2: + cost := peer.GetRequestCost(GetProofsV2Msg, len(proofreqs)) + sendRequest(peer.app, GetProofsV2Msg, 42, cost, proofreqs) + msg, err := peer.app.ReadMsg() + if err != nil { + t.Errorf("Message read error: %v", err) + } + var resp struct { + ReqID, BV uint64 + Data light.NodeList + } + if err := msg.Decode(&resp); err != nil { + t.Errorf("reply decode error: %v", err) + } + if msg.Code != ProofsV2Msg { + t.Errorf("Message code mismatch") + } + if resp.ReqID != 42 { + t.Errorf("ReqID mismatch") + } + if resp.BV != testBufLimit { + t.Errorf("BV mismatch") + } + testCheckProof(t, proofsV2, resp.Data) + } +} + +func TestTransactionStatusLes2(t *testing.T) { + db, _ := ethdb.NewMemDatabase() + pm := newTestProtocolManagerMust(t, false, 0, nil, nil, nil, db) + chain := pm.blockchain.(*core.BlockChain) + txpool := core.NewTxPool(core.DefaultTxPoolConfig, params.TestChainConfig, chain) + pm.txpool = txpool + peer, _ := newTestPeer(t, "peer", 2, pm, true) + defer peer.close() + + var reqID uint64 + + test := func(tx *types.Transaction, send bool, expStatus core.TxStatusData) { + reqID++ + if send { + cost := peer.GetRequestCost(SendTxV2Msg, 1) + sendRequest(peer.app, SendTxV2Msg, reqID, cost, types.Transactions{tx}) + } else { + cost := peer.GetRequestCost(GetTxStatusMsg, 1) + sendRequest(peer.app, GetTxStatusMsg, reqID, cost, []common.Hash{tx.Hash()}) + } + if err := expectResponse(peer.app, TxStatusMsg, reqID, testBufLimit, []core.TxStatusData{expStatus}); err != nil { + t.Errorf("transaction status mismatch") + } + } + + signer := types.HomesteadSigner{} + + // test error status by sending an underpriced transaction + tx0, _ := types.SignTx(types.NewTransaction(0, acc1Addr, big.NewInt(10000), bigTxGas, nil, nil), signer, testBankKey) + test(tx0, true, core.TxStatusData{Status: core.TxStatusError, Data: []byte("transaction underpriced")}) + + tx1, _ := types.SignTx(types.NewTransaction(0, acc1Addr, big.NewInt(10000), bigTxGas, big.NewInt(100000000000), nil), signer, testBankKey) + test(tx1, false, core.TxStatusData{Status: core.TxStatusUnknown}) // query before sending, should be unknown + test(tx1, true, core.TxStatusData{Status: core.TxStatusPending}) // send valid processable tx, should return pending + test(tx1, true, core.TxStatusData{Status: core.TxStatusPending}) // adding it again should not return an error + + tx2, _ := types.SignTx(types.NewTransaction(1, acc1Addr, big.NewInt(10000), bigTxGas, big.NewInt(100000000000), nil), signer, testBankKey) + tx3, _ := types.SignTx(types.NewTransaction(2, acc1Addr, big.NewInt(10000), bigTxGas, big.NewInt(100000000000), nil), signer, testBankKey) + // send transactions in the wrong order, tx3 should be queued + test(tx3, true, core.TxStatusData{Status: core.TxStatusQueued}) + test(tx2, true, core.TxStatusData{Status: core.TxStatusPending}) + // query again, now tx3 should be pending too + test(tx3, false, core.TxStatusData{Status: core.TxStatusPending}) + + // generate and add a block with tx1 and tx2 included + gchain, _ := core.GenerateChain(params.TestChainConfig, chain.GetBlockByNumber(0), db, 1, func(i int, block *core.BlockGen) { + block.AddTx(tx1) + block.AddTx(tx2) + }) + if _, err := chain.InsertChain(gchain); err != nil { + panic(err) + } + + // check if their status is included now + block1hash := core.GetCanonicalHash(db, 1) + tx1pos, _ := rlp.EncodeToBytes(core.TxLookupEntry{BlockHash: block1hash, BlockIndex: 1, Index: 0}) + tx2pos, _ := rlp.EncodeToBytes(core.TxLookupEntry{BlockHash: block1hash, BlockIndex: 1, Index: 1}) + test(tx1, false, core.TxStatusData{Status: core.TxStatusIncluded, Data: tx1pos}) + test(tx2, false, core.TxStatusData{Status: core.TxStatusIncluded, Data: tx2pos}) + + // create a reorg that rolls them back + gchain, _ = core.GenerateChain(params.TestChainConfig, chain.GetBlockByNumber(0), db, 2, func(i int, block *core.BlockGen) {}) + if _, err := chain.InsertChain(gchain); err != nil { + panic(err) + } + // wait until TxPool processes the reorg + for { + if pending, _ := txpool.Stats(); pending == 3 { + break + } + runtime.Gosched() } + // check if their status is pending again + test(tx1, false, core.TxStatusData{Status: core.TxStatusPending}) + test(tx2, false, core.TxStatusData{Status: core.TxStatusPending}) } diff --git a/les/helper_test.go b/les/helper_test.go index b33454e1d..a06f84cca 100644 --- a/les/helper_test.go +++ b/les/helper_test.go @@ -43,7 +43,7 @@ import ( var ( testBankKey, _ = crypto.HexToECDSA("b71c71a67e1177ad4e901695e1b4b9ee17ae16c6668d313eac2f96dbcda3f291") testBankAddress = crypto.PubkeyToAddress(testBankKey.PublicKey) - testBankFunds = big.NewInt(1000000) + testBankFunds = big.NewInt(1000000000000000000) acc1Key, _ = crypto.HexToECDSA("8a1f9a8f95be41cd7ccb6168179afb4504aefe388d1e14474d32c45c72ce7b7a") acc2Key, _ = crypto.HexToECDSA("49a7b37aa6f6645917e7b807e9d1c00d4fa71f18343b0d4122a4d2df64dd6fee") @@ -156,7 +156,13 @@ func newTestProtocolManager(lightSync bool, blocks int, generator func(int, *cor chain = blockchain } - pm, err := NewProtocolManager(gspec.Config, lightSync, NetworkId, evmux, engine, peers, chain, nil, db, odr, nil, make(chan struct{}), new(sync.WaitGroup)) + var protocolVersions []uint + if lightSync { + protocolVersions = ClientProtocolVersions + } else { + protocolVersions = ServerProtocolVersions + } + pm, err := NewProtocolManager(gspec.Config, lightSync, protocolVersions, NetworkId, evmux, engine, peers, chain, nil, db, odr, nil, make(chan struct{}), new(sync.WaitGroup)) if err != nil { return nil, err } diff --git a/les/odr.go b/les/odr.go index 3f7584b48..986630dbf 100644 --- a/les/odr.go +++ b/les/odr.go @@ -19,6 +19,7 @@ package les import ( "context" + "github.com/ethereum/go-ethereum/core" "github.com/ethereum/go-ethereum/ethdb" "github.com/ethereum/go-ethereum/light" "github.com/ethereum/go-ethereum/log" @@ -26,33 +27,56 @@ import ( // LesOdr implements light.OdrBackend type LesOdr struct { - db ethdb.Database - stop chan struct{} - retriever *retrieveManager + db ethdb.Database + chtIndexer, bloomTrieIndexer, bloomIndexer *core.ChainIndexer + retriever *retrieveManager + stop chan struct{} } -func NewLesOdr(db ethdb.Database, retriever *retrieveManager) *LesOdr { +func NewLesOdr(db ethdb.Database, chtIndexer, bloomTrieIndexer, bloomIndexer *core.ChainIndexer, retriever *retrieveManager) *LesOdr { return &LesOdr{ - db: db, - retriever: retriever, - stop: make(chan struct{}), + db: db, + chtIndexer: chtIndexer, + bloomTrieIndexer: bloomTrieIndexer, + bloomIndexer: bloomIndexer, + retriever: retriever, + stop: make(chan struct{}), } } +// Stop cancels all pending retrievals func (odr *LesOdr) Stop() { close(odr.stop) } +// Database returns the backing database func (odr *LesOdr) Database() ethdb.Database { return odr.db } +// ChtIndexer returns the CHT chain indexer +func (odr *LesOdr) ChtIndexer() *core.ChainIndexer { + return odr.chtIndexer +} + +// BloomTrieIndexer returns the bloom trie chain indexer +func (odr *LesOdr) BloomTrieIndexer() *core.ChainIndexer { + return odr.bloomTrieIndexer +} + +// BloomIndexer returns the bloombits chain indexer +func (odr *LesOdr) BloomIndexer() *core.ChainIndexer { + return odr.bloomIndexer +} + const ( MsgBlockBodies = iota MsgCode MsgReceipts - MsgProofs + MsgProofsV1 + MsgProofsV2 MsgHeaderProofs + MsgHelperTrieProofs ) // Msg encodes a LES message that delivers reply data for a request @@ -64,7 +88,7 @@ type Msg struct { // Retrieve tries to fetch an object from the LES network. // If the network retrieval was successful, it stores the object in local db. -func (self *LesOdr) Retrieve(ctx context.Context, req light.OdrRequest) (err error) { +func (odr *LesOdr) Retrieve(ctx context.Context, req light.OdrRequest) (err error) { lreq := LesRequest(req) reqID := genReqID() @@ -84,9 +108,9 @@ func (self *LesOdr) Retrieve(ctx context.Context, req light.OdrRequest) (err err }, } - if err = self.retriever.retrieve(ctx, reqID, rq, func(p distPeer, msg *Msg) error { return lreq.Validate(self.db, msg) }); err == nil { + if err = odr.retriever.retrieve(ctx, reqID, rq, func(p distPeer, msg *Msg) error { return lreq.Validate(odr.db, msg) }, odr.stop); err == nil { // retrieved from network, store in db - req.StoreResult(self.db) + req.StoreResult(odr.db) } else { log.Debug("Failed to retrieve data from network", "err", err) } diff --git a/les/odr_requests.go b/les/odr_requests.go index 1f853b341..937a4f1d9 100644 --- a/les/odr_requests.go +++ b/les/odr_requests.go @@ -36,13 +36,15 @@ import ( var ( errInvalidMessageType = errors.New("invalid message type") - errMultipleEntries = errors.New("multiple response entries") + errInvalidEntryCount = errors.New("invalid number of response entries") errHeaderUnavailable = errors.New("header unavailable") errTxHashMismatch = errors.New("transaction hash mismatch") errUncleHashMismatch = errors.New("uncle hash mismatch") errReceiptHashMismatch = errors.New("receipt hash mismatch") errDataHashMismatch = errors.New("data hash mismatch") errCHTHashMismatch = errors.New("cht hash mismatch") + errCHTNumberMismatch = errors.New("cht number mismatch") + errUselessNodes = errors.New("useless nodes in merkle proof nodeset") ) type LesOdrRequest interface { @@ -64,6 +66,8 @@ func LesRequest(req light.OdrRequest) LesOdrRequest { return (*CodeRequest)(r) case *light.ChtRequest: return (*ChtRequest)(r) + case *light.BloomRequest: + return (*BloomRequest)(r) default: return nil } @@ -101,7 +105,7 @@ func (r *BlockRequest) Validate(db ethdb.Database, msg *Msg) error { } bodies := msg.Obj.([]*types.Body) if len(bodies) != 1 { - return errMultipleEntries + return errInvalidEntryCount } body := bodies[0] @@ -157,7 +161,7 @@ func (r *ReceiptsRequest) Validate(db ethdb.Database, msg *Msg) error { } receipts := msg.Obj.([]types.Receipts) if len(receipts) != 1 { - return errMultipleEntries + return errInvalidEntryCount } receipt := receipts[0] @@ -186,7 +190,14 @@ type TrieRequest light.TrieRequest // GetCost returns the cost of the given ODR request according to the serving // peer's cost table (implementation of LesOdrRequest) func (r *TrieRequest) GetCost(peer *peer) uint64 { - return peer.GetRequestCost(GetProofsMsg, 1) + switch peer.version { + case lpv1: + return peer.GetRequestCost(GetProofsV1Msg, 1) + case lpv2: + return peer.GetRequestCost(GetProofsV2Msg, 1) + default: + panic(nil) + } } // CanSend tells if a certain peer is suitable for serving the given request @@ -197,12 +208,12 @@ func (r *TrieRequest) CanSend(peer *peer) bool { // Request sends an ODR request to the LES network (implementation of LesOdrRequest) func (r *TrieRequest) Request(reqID uint64, peer *peer) error { peer.Log().Debug("Requesting trie proof", "root", r.Id.Root, "key", r.Key) - req := &ProofReq{ + req := ProofReq{ BHash: r.Id.BlockHash, AccKey: r.Id.AccKey, Key: r.Key, } - return peer.RequestProofs(reqID, r.GetCost(peer), []*ProofReq{req}) + return peer.RequestProofs(reqID, r.GetCost(peer), []ProofReq{req}) } // Valid processes an ODR request reply message from the LES network @@ -211,20 +222,38 @@ func (r *TrieRequest) Request(reqID uint64, peer *peer) error { func (r *TrieRequest) Validate(db ethdb.Database, msg *Msg) error { log.Debug("Validating trie proof", "root", r.Id.Root, "key", r.Key) - // Ensure we have a correct message with a single proof - if msg.MsgType != MsgProofs { + switch msg.MsgType { + case MsgProofsV1: + proofs := msg.Obj.([]light.NodeList) + if len(proofs) != 1 { + return errInvalidEntryCount + } + nodeSet := proofs[0].NodeSet() + // Verify the proof and store if checks out + if _, err, _ := trie.VerifyProof(r.Id.Root, r.Key, nodeSet); err != nil { + return fmt.Errorf("merkle proof verification failed: %v", err) + } + r.Proof = nodeSet + return nil + + case MsgProofsV2: + proofs := msg.Obj.(light.NodeList) + // Verify the proof and store if checks out + nodeSet := proofs.NodeSet() + reads := &readTraceDB{db: nodeSet} + if _, err, _ := trie.VerifyProof(r.Id.Root, r.Key, reads); err != nil { + return fmt.Errorf("merkle proof verification failed: %v", err) + } + // check if all nodes have been read by VerifyProof + if len(reads.reads) != nodeSet.KeyCount() { + return errUselessNodes + } + r.Proof = nodeSet + return nil + + default: return errInvalidMessageType } - proofs := msg.Obj.([][]rlp.RawValue) - if len(proofs) != 1 { - return errMultipleEntries - } - // Verify the proof and store if checks out - if _, err := trie.VerifyProof(r.Id.Root, r.Key, proofs[0]); err != nil { - return fmt.Errorf("merkle proof verification failed: %v", err) - } - r.Proof = proofs[0] - return nil } type CodeReq struct { @@ -249,11 +278,11 @@ func (r *CodeRequest) CanSend(peer *peer) bool { // Request sends an ODR request to the LES network (implementation of LesOdrRequest) func (r *CodeRequest) Request(reqID uint64, peer *peer) error { peer.Log().Debug("Requesting code data", "hash", r.Hash) - req := &CodeReq{ + req := CodeReq{ BHash: r.Id.BlockHash, AccKey: r.Id.AccKey, } - return peer.RequestCode(reqID, r.GetCost(peer), []*CodeReq{req}) + return peer.RequestCode(reqID, r.GetCost(peer), []CodeReq{req}) } // Valid processes an ODR request reply message from the LES network @@ -268,7 +297,7 @@ func (r *CodeRequest) Validate(db ethdb.Database, msg *Msg) error { } reply := msg.Obj.([][]byte) if len(reply) != 1 { - return errMultipleEntries + return errInvalidEntryCount } data := reply[0] @@ -280,10 +309,36 @@ func (r *CodeRequest) Validate(db ethdb.Database, msg *Msg) error { return nil } +const ( + // helper trie type constants + htCanonical = iota // Canonical hash trie + htBloomBits // BloomBits trie + + // applicable for all helper trie requests + auxRoot = 1 + // applicable for htCanonical + auxHeader = 2 +) + +type HelperTrieReq struct { + HelperTrieType uint + TrieIdx uint64 + Key []byte + FromLevel, AuxReq uint +} + +type HelperTrieResps struct { // describes all responses, not just a single one + Proofs light.NodeList + AuxData [][]byte +} + +// legacy LES/1 type ChtReq struct { - ChtNum, BlockNum, FromLevel uint64 + ChtNum, BlockNum uint64 + FromLevel uint } +// legacy LES/1 type ChtResp struct { Header *types.Header Proof []rlp.RawValue @@ -295,7 +350,14 @@ type ChtRequest light.ChtRequest // GetCost returns the cost of the given ODR request according to the serving // peer's cost table (implementation of LesOdrRequest) func (r *ChtRequest) GetCost(peer *peer) uint64 { - return peer.GetRequestCost(GetHeaderProofsMsg, 1) + switch peer.version { + case lpv1: + return peer.GetRequestCost(GetHeaderProofsMsg, 1) + case lpv2: + return peer.GetRequestCost(GetHelperTrieProofsMsg, 1) + default: + panic(nil) + } } // CanSend tells if a certain peer is suitable for serving the given request @@ -303,17 +365,21 @@ func (r *ChtRequest) CanSend(peer *peer) bool { peer.lock.RLock() defer peer.lock.RUnlock() - return r.ChtNum <= (peer.headInfo.Number-light.ChtConfirmations)/light.ChtFrequency + return peer.headInfo.Number >= light.HelperTrieConfirmations && r.ChtNum <= (peer.headInfo.Number-light.HelperTrieConfirmations)/light.ChtFrequency } // Request sends an ODR request to the LES network (implementation of LesOdrRequest) func (r *ChtRequest) Request(reqID uint64, peer *peer) error { peer.Log().Debug("Requesting CHT", "cht", r.ChtNum, "block", r.BlockNum) - req := &ChtReq{ - ChtNum: r.ChtNum, - BlockNum: r.BlockNum, + var encNum [8]byte + binary.BigEndian.PutUint64(encNum[:], r.BlockNum) + req := HelperTrieReq{ + HelperTrieType: htCanonical, + TrieIdx: r.ChtNum, + Key: encNum[:], + AuxReq: auxHeader, } - return peer.RequestHeaderProofs(reqID, r.GetCost(peer), []*ChtReq{req}) + return peer.RequestHelperTrieProofs(reqID, r.GetCost(peer), []HelperTrieReq{req}) } // Valid processes an ODR request reply message from the LES network @@ -322,35 +388,179 @@ func (r *ChtRequest) Request(reqID uint64, peer *peer) error { func (r *ChtRequest) Validate(db ethdb.Database, msg *Msg) error { log.Debug("Validating CHT", "cht", r.ChtNum, "block", r.BlockNum) - // Ensure we have a correct message with a single proof element - if msg.MsgType != MsgHeaderProofs { + switch msg.MsgType { + case MsgHeaderProofs: // LES/1 backwards compatibility + proofs := msg.Obj.([]ChtResp) + if len(proofs) != 1 { + return errInvalidEntryCount + } + proof := proofs[0] + + // Verify the CHT + var encNumber [8]byte + binary.BigEndian.PutUint64(encNumber[:], r.BlockNum) + + value, err, _ := trie.VerifyProof(r.ChtRoot, encNumber[:], light.NodeList(proof.Proof).NodeSet()) + if err != nil { + return err + } + var node light.ChtNode + if err := rlp.DecodeBytes(value, &node); err != nil { + return err + } + if node.Hash != proof.Header.Hash() { + return errCHTHashMismatch + } + // Verifications passed, store and return + r.Header = proof.Header + r.Proof = light.NodeList(proof.Proof).NodeSet() + r.Td = node.Td + case MsgHelperTrieProofs: + resp := msg.Obj.(HelperTrieResps) + if len(resp.AuxData) != 1 { + return errInvalidEntryCount + } + nodeSet := resp.Proofs.NodeSet() + headerEnc := resp.AuxData[0] + if len(headerEnc) == 0 { + return errHeaderUnavailable + } + header := new(types.Header) + if err := rlp.DecodeBytes(headerEnc, header); err != nil { + return errHeaderUnavailable + } + + // Verify the CHT + var encNumber [8]byte + binary.BigEndian.PutUint64(encNumber[:], r.BlockNum) + + reads := &readTraceDB{db: nodeSet} + value, err, _ := trie.VerifyProof(r.ChtRoot, encNumber[:], reads) + if err != nil { + return fmt.Errorf("merkle proof verification failed: %v", err) + } + if len(reads.reads) != nodeSet.KeyCount() { + return errUselessNodes + } + + var node light.ChtNode + if err := rlp.DecodeBytes(value, &node); err != nil { + return err + } + if node.Hash != header.Hash() { + return errCHTHashMismatch + } + if r.BlockNum != header.Number.Uint64() { + return errCHTNumberMismatch + } + // Verifications passed, store and return + r.Header = header + r.Proof = nodeSet + r.Td = node.Td + default: return errInvalidMessageType } - proofs := msg.Obj.([]ChtResp) - if len(proofs) != 1 { - return errMultipleEntries - } - proof := proofs[0] + return nil +} - // Verify the CHT - var encNumber [8]byte - binary.BigEndian.PutUint64(encNumber[:], r.BlockNum) +type BloomReq struct { + BloomTrieNum, BitIdx, SectionIdx, FromLevel uint64 +} - value, err := trie.VerifyProof(r.ChtRoot, encNumber[:], proof.Proof) - if err != nil { - return err +// ODR request type for requesting headers by Canonical Hash Trie, see LesOdrRequest interface +type BloomRequest light.BloomRequest + +// GetCost returns the cost of the given ODR request according to the serving +// peer's cost table (implementation of LesOdrRequest) +func (r *BloomRequest) GetCost(peer *peer) uint64 { + return peer.GetRequestCost(GetHelperTrieProofsMsg, len(r.SectionIdxList)) +} + +// CanSend tells if a certain peer is suitable for serving the given request +func (r *BloomRequest) CanSend(peer *peer) bool { + peer.lock.RLock() + defer peer.lock.RUnlock() + + if peer.version < lpv2 { + return false } - var node light.ChtNode - if err := rlp.DecodeBytes(value, &node); err != nil { - return err + return peer.headInfo.Number >= light.HelperTrieConfirmations && r.BloomTrieNum <= (peer.headInfo.Number-light.HelperTrieConfirmations)/light.BloomTrieFrequency +} + +// Request sends an ODR request to the LES network (implementation of LesOdrRequest) +func (r *BloomRequest) Request(reqID uint64, peer *peer) error { + peer.Log().Debug("Requesting BloomBits", "bloomTrie", r.BloomTrieNum, "bitIdx", r.BitIdx, "sections", r.SectionIdxList) + reqs := make([]HelperTrieReq, len(r.SectionIdxList)) + + var encNumber [10]byte + binary.BigEndian.PutUint16(encNumber[0:2], uint16(r.BitIdx)) + + for i, sectionIdx := range r.SectionIdxList { + binary.BigEndian.PutUint64(encNumber[2:10], sectionIdx) + reqs[i] = HelperTrieReq{ + HelperTrieType: htBloomBits, + TrieIdx: r.BloomTrieNum, + Key: common.CopyBytes(encNumber[:]), + } } - if node.Hash != proof.Header.Hash() { - return errCHTHashMismatch + return peer.RequestHelperTrieProofs(reqID, r.GetCost(peer), reqs) +} + +// Valid processes an ODR request reply message from the LES network +// returns true and stores results in memory if the message was a valid reply +// to the request (implementation of LesOdrRequest) +func (r *BloomRequest) Validate(db ethdb.Database, msg *Msg) error { + log.Debug("Validating BloomBits", "bloomTrie", r.BloomTrieNum, "bitIdx", r.BitIdx, "sections", r.SectionIdxList) + + // Ensure we have a correct message with a single proof element + if msg.MsgType != MsgHelperTrieProofs { + return errInvalidMessageType + } + resps := msg.Obj.(HelperTrieResps) + proofs := resps.Proofs + nodeSet := proofs.NodeSet() + reads := &readTraceDB{db: nodeSet} + + r.BloomBits = make([][]byte, len(r.SectionIdxList)) + + // Verify the proofs + var encNumber [10]byte + binary.BigEndian.PutUint16(encNumber[0:2], uint16(r.BitIdx)) + + for i, idx := range r.SectionIdxList { + binary.BigEndian.PutUint64(encNumber[2:10], idx) + value, err, _ := trie.VerifyProof(r.BloomTrieRoot, encNumber[:], reads) + if err != nil { + return err + } + r.BloomBits[i] = value } - // Verifications passed, store and return - r.Header = proof.Header - r.Proof = proof.Proof - r.Td = node.Td + if len(reads.reads) != nodeSet.KeyCount() { + return errUselessNodes + } + r.Proofs = nodeSet return nil } + +// readTraceDB stores the keys of database reads. We use this to check that received node +// sets contain only the trie nodes necessary to make proofs pass. +type readTraceDB struct { + db trie.DatabaseReader + reads map[string]struct{} +} + +// Get returns a stored node +func (db *readTraceDB) Get(k []byte) ([]byte, error) { + if db.reads == nil { + db.reads = make(map[string]struct{}) + } + db.reads[string(k)] = struct{}{} + return db.db.Get(k) +} + +// Has returns true if the node set contains the given key +func (db *readTraceDB) Has(key []byte) (bool, error) { + _, err := db.Get(key) + return err == nil, nil +} diff --git a/les/odr_test.go b/les/odr_test.go index f56c4036d..865f5d83e 100644 --- a/les/odr_test.go +++ b/les/odr_test.go @@ -29,6 +29,7 @@ import ( "github.com/ethereum/go-ethereum/core/state" "github.com/ethereum/go-ethereum/core/types" "github.com/ethereum/go-ethereum/core/vm" + "github.com/ethereum/go-ethereum/eth" "github.com/ethereum/go-ethereum/ethdb" "github.com/ethereum/go-ethereum/light" "github.com/ethereum/go-ethereum/params" @@ -39,6 +40,8 @@ type odrTestFn func(ctx context.Context, db ethdb.Database, config *params.Chain func TestOdrGetBlockLes1(t *testing.T) { testOdr(t, 1, 1, odrGetBlock) } +func TestOdrGetBlockLes2(t *testing.T) { testOdr(t, 2, 1, odrGetBlock) } + func odrGetBlock(ctx context.Context, db ethdb.Database, config *params.ChainConfig, bc *core.BlockChain, lc *light.LightChain, bhash common.Hash) []byte { var block *types.Block if bc != nil { @@ -55,6 +58,8 @@ func odrGetBlock(ctx context.Context, db ethdb.Database, config *params.ChainCon func TestOdrGetReceiptsLes1(t *testing.T) { testOdr(t, 1, 1, odrGetReceipts) } +func TestOdrGetReceiptsLes2(t *testing.T) { testOdr(t, 2, 1, odrGetReceipts) } + func odrGetReceipts(ctx context.Context, db ethdb.Database, config *params.ChainConfig, bc *core.BlockChain, lc *light.LightChain, bhash common.Hash) []byte { var receipts types.Receipts if bc != nil { @@ -71,6 +76,8 @@ func odrGetReceipts(ctx context.Context, db ethdb.Database, config *params.Chain func TestOdrAccountsLes1(t *testing.T) { testOdr(t, 1, 1, odrAccounts) } +func TestOdrAccountsLes2(t *testing.T) { testOdr(t, 2, 1, odrAccounts) } + func odrAccounts(ctx context.Context, db ethdb.Database, config *params.ChainConfig, bc *core.BlockChain, lc *light.LightChain, bhash common.Hash) []byte { dummyAddr := common.HexToAddress("1234567812345678123456781234567812345678") acc := []common.Address{testBankAddress, acc1Addr, acc2Addr, dummyAddr} @@ -100,6 +107,8 @@ func odrAccounts(ctx context.Context, db ethdb.Database, config *params.ChainCon func TestOdrContractCallLes1(t *testing.T) { testOdr(t, 1, 2, odrContractCall) } +func TestOdrContractCallLes2(t *testing.T) { testOdr(t, 2, 2, odrContractCall) } + type callmsg struct { types.Message } @@ -154,7 +163,7 @@ func testOdr(t *testing.T, protocol int, expFail uint64, fn odrTestFn) { rm := newRetrieveManager(peers, dist, nil) db, _ := ethdb.NewMemDatabase() ldb, _ := ethdb.NewMemDatabase() - odr := NewLesOdr(ldb, rm) + odr := NewLesOdr(ldb, light.NewChtIndexer(db, true), light.NewBloomTrieIndexer(db, true), eth.NewBloomIndexer(db, light.BloomTrieFrequency), rm) pm := newTestProtocolManagerMust(t, false, 4, testChainGen, nil, nil, db) lpm := newTestProtocolManagerMust(t, true, 0, nil, peers, odr, ldb) _, err1, lpeer, err2 := newTestPeerPair("peer", protocol, pm, lpm) diff --git a/les/peer.go b/les/peer.go index 3ba2df3fe..104afb6dc 100644 --- a/les/peer.go +++ b/les/peer.go @@ -18,6 +18,8 @@ package les import ( + "crypto/ecdsa" + "encoding/binary" "errors" "fmt" "math/big" @@ -25,9 +27,11 @@ import ( "time" "github.com/ethereum/go-ethereum/common" + "github.com/ethereum/go-ethereum/core" "github.com/ethereum/go-ethereum/core/types" "github.com/ethereum/go-ethereum/eth" "github.com/ethereum/go-ethereum/les/flowcontrol" + "github.com/ethereum/go-ethereum/light" "github.com/ethereum/go-ethereum/p2p" "github.com/ethereum/go-ethereum/rlp" ) @@ -40,14 +44,23 @@ var ( const maxResponseErrors = 50 // number of invalid responses tolerated (makes the protocol less brittle but still avoids spam) +const ( + announceTypeNone = iota + announceTypeSimple + announceTypeSigned +) + type peer struct { *p2p.Peer + pubKey *ecdsa.PublicKey rw p2p.MsgReadWriter version int // Protocol version negotiated network uint64 // Network ID being on + announceType, requestAnnounceType uint64 + id string headInfo *announceData @@ -68,9 +81,11 @@ type peer struct { func newPeer(version int, network uint64, p *p2p.Peer, rw p2p.MsgReadWriter) *peer { id := p.ID() + pubKey, _ := id.Pubkey() return &peer{ Peer: p, + pubKey: pubKey, rw: rw, version: version, network: network, @@ -197,16 +212,31 @@ func (p *peer) SendReceiptsRLP(reqID, bv uint64, receipts []rlp.RawValue) error return sendResponse(p.rw, ReceiptsMsg, reqID, bv, receipts) } -// SendProofs sends a batch of merkle proofs, corresponding to the ones requested. +// SendProofs sends a batch of legacy LES/1 merkle proofs, corresponding to the ones requested. func (p *peer) SendProofs(reqID, bv uint64, proofs proofsData) error { - return sendResponse(p.rw, ProofsMsg, reqID, bv, proofs) + return sendResponse(p.rw, ProofsV1Msg, reqID, bv, proofs) } -// SendHeaderProofs sends a batch of header proofs, corresponding to the ones requested. +// SendProofsV2 sends a batch of merkle proofs, corresponding to the ones requested. +func (p *peer) SendProofsV2(reqID, bv uint64, proofs light.NodeList) error { + return sendResponse(p.rw, ProofsV2Msg, reqID, bv, proofs) +} + +// SendHeaderProofs sends a batch of legacy LES/1 header proofs, corresponding to the ones requested. func (p *peer) SendHeaderProofs(reqID, bv uint64, proofs []ChtResp) error { return sendResponse(p.rw, HeaderProofsMsg, reqID, bv, proofs) } +// SendHelperTrieProofs sends a batch of HelperTrie proofs, corresponding to the ones requested. +func (p *peer) SendHelperTrieProofs(reqID, bv uint64, resp HelperTrieResps) error { + return sendResponse(p.rw, HelperTrieProofsMsg, reqID, bv, resp) +} + +// SendTxStatus sends a batch of transaction status records, corresponding to the ones requested. +func (p *peer) SendTxStatus(reqID, bv uint64, status []core.TxStatusData) error { + return sendResponse(p.rw, TxStatusMsg, reqID, bv, status) +} + // RequestHeadersByHash fetches a batch of blocks' headers corresponding to the // specified header query, based on the hash of an origin block. func (p *peer) RequestHeadersByHash(reqID, cost uint64, origin common.Hash, amount int, skip int, reverse bool) error { @@ -230,7 +260,7 @@ func (p *peer) RequestBodies(reqID, cost uint64, hashes []common.Hash) error { // RequestCode fetches a batch of arbitrary data from a node's known state // data, corresponding to the specified hashes. -func (p *peer) RequestCode(reqID, cost uint64, reqs []*CodeReq) error { +func (p *peer) RequestCode(reqID, cost uint64, reqs []CodeReq) error { p.Log().Debug("Fetching batch of codes", "count", len(reqs)) return sendRequest(p.rw, GetCodeMsg, reqID, cost, reqs) } @@ -242,20 +272,58 @@ func (p *peer) RequestReceipts(reqID, cost uint64, hashes []common.Hash) error { } // RequestProofs fetches a batch of merkle proofs from a remote node. -func (p *peer) RequestProofs(reqID, cost uint64, reqs []*ProofReq) error { +func (p *peer) RequestProofs(reqID, cost uint64, reqs []ProofReq) error { p.Log().Debug("Fetching batch of proofs", "count", len(reqs)) - return sendRequest(p.rw, GetProofsMsg, reqID, cost, reqs) + switch p.version { + case lpv1: + return sendRequest(p.rw, GetProofsV1Msg, reqID, cost, reqs) + case lpv2: + return sendRequest(p.rw, GetProofsV2Msg, reqID, cost, reqs) + default: + panic(nil) + } + +} + +// RequestHelperTrieProofs fetches a batch of HelperTrie merkle proofs from a remote node. +func (p *peer) RequestHelperTrieProofs(reqID, cost uint64, reqs []HelperTrieReq) error { + p.Log().Debug("Fetching batch of HelperTrie proofs", "count", len(reqs)) + switch p.version { + case lpv1: + reqsV1 := make([]ChtReq, len(reqs)) + for i, req := range reqs { + if req.HelperTrieType != htCanonical || req.AuxReq != auxHeader || len(req.Key) != 8 { + return fmt.Errorf("Request invalid in LES/1 mode") + } + blockNum := binary.BigEndian.Uint64(req.Key) + // convert HelperTrie request to old CHT request + reqsV1[i] = ChtReq{ChtNum: (req.TrieIdx+1)*(light.ChtFrequency/light.ChtV1Frequency) - 1, BlockNum: blockNum, FromLevel: req.FromLevel} + } + return sendRequest(p.rw, GetHeaderProofsMsg, reqID, cost, reqsV1) + case lpv2: + return sendRequest(p.rw, GetHelperTrieProofsMsg, reqID, cost, reqs) + default: + panic(nil) + } } -// RequestHeaderProofs fetches a batch of header merkle proofs from a remote node. -func (p *peer) RequestHeaderProofs(reqID, cost uint64, reqs []*ChtReq) error { - p.Log().Debug("Fetching batch of header proofs", "count", len(reqs)) - return sendRequest(p.rw, GetHeaderProofsMsg, reqID, cost, reqs) +// RequestTxStatus fetches a batch of transaction status records from a remote node. +func (p *peer) RequestTxStatus(reqID, cost uint64, txHashes []common.Hash) error { + p.Log().Debug("Requesting transaction status", "count", len(txHashes)) + return sendRequest(p.rw, GetTxStatusMsg, reqID, cost, txHashes) } +// SendTxStatus sends a batch of transactions to be added to the remote transaction pool. func (p *peer) SendTxs(reqID, cost uint64, txs types.Transactions) error { p.Log().Debug("Fetching batch of transactions", "count", len(txs)) - return p2p.Send(p.rw, SendTxMsg, txs) + switch p.version { + case lpv1: + return p2p.Send(p.rw, SendTxMsg, txs) // old message format does not include reqID + case lpv2: + return sendRequest(p.rw, SendTxV2Msg, reqID, cost, txs) + default: + panic(nil) + } } type keyValueEntry struct { @@ -289,7 +357,7 @@ func (l keyValueList) decode() keyValueMap { func (m keyValueMap) get(key string, val interface{}) error { enc, ok := m[key] if !ok { - return errResp(ErrHandshakeMissingKey, "%s", key) + return errResp(ErrMissingKey, "%s", key) } if val == nil { return nil @@ -348,6 +416,9 @@ func (p *peer) Handshake(td *big.Int, head common.Hash, headNum uint64, genesis list := server.fcCostStats.getCurrentList() send = send.add("flowControl/MRC", list) p.fcCosts = list.decode() + } else { + p.requestAnnounceType = announceTypeSimple // set to default until "very light" client mode is implemented + send = send.add("announceType", p.requestAnnounceType) } recvList, err := p.sendReceiveHandshake(send) if err != nil { @@ -392,6 +463,9 @@ func (p *peer) Handshake(td *big.Int, head common.Hash, headNum uint64, genesis /*if recv.get("serveStateSince", nil) == nil { return errResp(ErrUselessPeer, "wanted client, got server") }*/ + if recv.get("announceType", &p.announceType) != nil { + p.announceType = announceTypeSimple + } p.fcClient = flowcontrol.NewClientNode(server.fcManager, server.defParams) } else { if recv.get("serveChainSince", nil) != nil { @@ -456,11 +530,15 @@ func newPeerSet() *peerSet { // notify adds a service to be notified about added or removed peers func (ps *peerSet) notify(n peerSetNotify) { ps.lock.Lock() - defer ps.lock.Unlock() - ps.notifyList = append(ps.notifyList, n) + peers := make([]*peer, 0, len(ps.peers)) for _, p := range ps.peers { - go n.registerPeer(p) + peers = append(peers, p) + } + ps.lock.Unlock() + + for _, p := range peers { + n.registerPeer(p) } } @@ -468,8 +546,6 @@ func (ps *peerSet) notify(n peerSetNotify) { // peer is already known. func (ps *peerSet) Register(p *peer) error { ps.lock.Lock() - defer ps.lock.Unlock() - if ps.closed { return errClosed } @@ -478,8 +554,12 @@ func (ps *peerSet) Register(p *peer) error { } ps.peers[p.id] = p p.sendQueue = newExecQueue(100) - for _, n := range ps.notifyList { - go n.registerPeer(p) + peers := make([]peerSetNotify, len(ps.notifyList)) + copy(peers, ps.notifyList) + ps.lock.Unlock() + + for _, n := range peers { + n.registerPeer(p) } return nil } @@ -488,19 +568,22 @@ func (ps *peerSet) Register(p *peer) error { // actions to/from that particular entity. It also initiates disconnection at the networking layer. func (ps *peerSet) Unregister(id string) error { ps.lock.Lock() - defer ps.lock.Unlock() - if p, ok := ps.peers[id]; !ok { + ps.lock.Unlock() return errNotRegistered } else { - for _, n := range ps.notifyList { - go n.unregisterPeer(p) + delete(ps.peers, id) + peers := make([]peerSetNotify, len(ps.notifyList)) + copy(peers, ps.notifyList) + ps.lock.Unlock() + + for _, n := range peers { + n.unregisterPeer(p) } p.sendQueue.quit() p.Peer.Disconnect(p2p.DiscUselessPeer) + return nil } - delete(ps.peers, id) - return nil } // AllPeerIDs returns a list of all registered peer IDs diff --git a/les/protocol.go b/les/protocol.go index 33d930ee0..146b02030 100644 --- a/les/protocol.go +++ b/les/protocol.go @@ -18,24 +18,34 @@ package les import ( + "bytes" + "crypto/ecdsa" + "crypto/elliptic" + "errors" "fmt" "io" "math/big" "github.com/ethereum/go-ethereum/common" + "github.com/ethereum/go-ethereum/crypto" + "github.com/ethereum/go-ethereum/crypto/secp256k1" "github.com/ethereum/go-ethereum/rlp" ) // Constants to match up protocol versions and messages const ( lpv1 = 1 + lpv2 = 2 ) -// Supported versions of the les protocol (first is primary). -var ProtocolVersions = []uint{lpv1} +// Supported versions of the les protocol (first is primary) +var ( + ClientProtocolVersions = []uint{lpv2, lpv1} + ServerProtocolVersions = []uint{lpv2, lpv1} +) // Number of implemented message corresponding to different protocol versions. -var ProtocolLengths = []uint64{15} +var ProtocolLengths = map[uint]uint64{lpv1: 15, lpv2: 22} const ( NetworkId = 1 @@ -53,13 +63,21 @@ const ( BlockBodiesMsg = 0x05 GetReceiptsMsg = 0x06 ReceiptsMsg = 0x07 - GetProofsMsg = 0x08 - ProofsMsg = 0x09 + GetProofsV1Msg = 0x08 + ProofsV1Msg = 0x09 GetCodeMsg = 0x0a CodeMsg = 0x0b SendTxMsg = 0x0c GetHeaderProofsMsg = 0x0d HeaderProofsMsg = 0x0e + // Protocol messages belonging to LPV2 + GetProofsV2Msg = 0x0f + ProofsV2Msg = 0x10 + GetHelperTrieProofsMsg = 0x11 + HelperTrieProofsMsg = 0x12 + SendTxV2Msg = 0x13 + GetTxStatusMsg = 0x14 + TxStatusMsg = 0x15 ) type errCode int @@ -79,7 +97,7 @@ const ( ErrUnexpectedResponse ErrInvalidResponse ErrTooManyTimeouts - ErrHandshakeMissingKey + ErrMissingKey ) func (e errCode) String() string { @@ -101,7 +119,13 @@ var errorToString = map[int]string{ ErrUnexpectedResponse: "Unexpected response", ErrInvalidResponse: "Invalid response", ErrTooManyTimeouts: "Too many request timeouts", - ErrHandshakeMissingKey: "Key missing from handshake message", + ErrMissingKey: "Key missing from list", +} + +type announceBlock struct { + Hash common.Hash // Hash of one particular block being announced + Number uint64 // Number of one particular block being announced + Td *big.Int // Total difficulty of one particular block being announced } // announceData is the network packet for the block announcements. @@ -113,6 +137,32 @@ type announceData struct { Update keyValueList } +// sign adds a signature to the block announcement by the given privKey +func (a *announceData) sign(privKey *ecdsa.PrivateKey) { + rlp, _ := rlp.EncodeToBytes(announceBlock{a.Hash, a.Number, a.Td}) + sig, _ := crypto.Sign(crypto.Keccak256(rlp), privKey) + a.Update = a.Update.add("sign", sig) +} + +// checkSignature verifies if the block announcement has a valid signature by the given pubKey +func (a *announceData) checkSignature(pubKey *ecdsa.PublicKey) error { + var sig []byte + if err := a.Update.decode().get("sign", &sig); err != nil { + return err + } + rlp, _ := rlp.EncodeToBytes(announceBlock{a.Hash, a.Number, a.Td}) + recPubkey, err := secp256k1.RecoverPubkey(crypto.Keccak256(rlp), sig) + if err != nil { + return err + } + pbytes := elliptic.Marshal(pubKey.Curve, pubKey.X, pubKey.Y) + if bytes.Equal(pbytes, recPubkey) { + return nil + } else { + return errors.New("Wrong signature") + } +} + type blockInfo struct { Hash common.Hash // Hash of one particular block being announced Number uint64 // Number of one particular block being announced diff --git a/les/request_test.go b/les/request_test.go index 6b594462d..c13625de8 100644 --- a/les/request_test.go +++ b/les/request_test.go @@ -24,6 +24,7 @@ import ( "github.com/ethereum/go-ethereum/common" "github.com/ethereum/go-ethereum/core" "github.com/ethereum/go-ethereum/crypto" + "github.com/ethereum/go-ethereum/eth" "github.com/ethereum/go-ethereum/ethdb" "github.com/ethereum/go-ethereum/light" ) @@ -38,24 +39,32 @@ type accessTestFn func(db ethdb.Database, bhash common.Hash, number uint64) ligh func TestBlockAccessLes1(t *testing.T) { testAccess(t, 1, tfBlockAccess) } +func TestBlockAccessLes2(t *testing.T) { testAccess(t, 2, tfBlockAccess) } + func tfBlockAccess(db ethdb.Database, bhash common.Hash, number uint64) light.OdrRequest { return &light.BlockRequest{Hash: bhash, Number: number} } func TestReceiptsAccessLes1(t *testing.T) { testAccess(t, 1, tfReceiptsAccess) } +func TestReceiptsAccessLes2(t *testing.T) { testAccess(t, 2, tfReceiptsAccess) } + func tfReceiptsAccess(db ethdb.Database, bhash common.Hash, number uint64) light.OdrRequest { return &light.ReceiptsRequest{Hash: bhash, Number: number} } func TestTrieEntryAccessLes1(t *testing.T) { testAccess(t, 1, tfTrieEntryAccess) } +func TestTrieEntryAccessLes2(t *testing.T) { testAccess(t, 2, tfTrieEntryAccess) } + func tfTrieEntryAccess(db ethdb.Database, bhash common.Hash, number uint64) light.OdrRequest { return &light.TrieRequest{Id: light.StateTrieID(core.GetHeader(db, bhash, core.GetBlockNumber(db, bhash))), Key: testBankSecureTrieKey} } func TestCodeAccessLes1(t *testing.T) { testAccess(t, 1, tfCodeAccess) } +func TestCodeAccessLes2(t *testing.T) { testAccess(t, 2, tfCodeAccess) } + func tfCodeAccess(db ethdb.Database, bhash common.Hash, number uint64) light.OdrRequest { header := core.GetHeader(db, bhash, core.GetBlockNumber(db, bhash)) if header.Number.Uint64() < testContractDeployed { @@ -73,7 +82,7 @@ func testAccess(t *testing.T, protocol int, fn accessTestFn) { rm := newRetrieveManager(peers, dist, nil) db, _ := ethdb.NewMemDatabase() ldb, _ := ethdb.NewMemDatabase() - odr := NewLesOdr(ldb, rm) + odr := NewLesOdr(ldb, light.NewChtIndexer(db, true), light.NewBloomTrieIndexer(db, true), eth.NewBloomIndexer(db, light.BloomTrieFrequency), rm) pm := newTestProtocolManagerMust(t, false, 4, testChainGen, nil, nil, db) lpm := newTestProtocolManagerMust(t, true, 0, nil, peers, odr, ldb) diff --git a/les/retrieve.go b/les/retrieve.go index b060e0b0d..dd15b56ac 100644 --- a/les/retrieve.go +++ b/les/retrieve.go @@ -22,6 +22,7 @@ import ( "context" "crypto/rand" "encoding/binary" + "fmt" "sync" "time" @@ -111,12 +112,14 @@ func newRetrieveManager(peers *peerSet, dist *requestDistributor, serverPool pee // that is delivered through the deliver function and successfully validated by the // validator callback. It returns when a valid answer is delivered or the context is // cancelled. -func (rm *retrieveManager) retrieve(ctx context.Context, reqID uint64, req *distReq, val validatorFunc) error { +func (rm *retrieveManager) retrieve(ctx context.Context, reqID uint64, req *distReq, val validatorFunc, shutdown chan struct{}) error { sentReq := rm.sendReq(reqID, req, val) select { case <-sentReq.stopCh: case <-ctx.Done(): sentReq.stop(ctx.Err()) + case <-shutdown: + sentReq.stop(fmt.Errorf("Client is shutting down")) } return sentReq.getError() } diff --git a/les/server.go b/les/server.go index 8b2730714..d8f93cd87 100644 --- a/les/server.go +++ b/les/server.go @@ -18,10 +18,11 @@ package les import ( + "crypto/ecdsa" "encoding/binary" + "fmt" "math" "sync" - "time" "github.com/ethereum/go-ethereum/common" "github.com/ethereum/go-ethereum/core" @@ -34,7 +35,6 @@ import ( "github.com/ethereum/go-ethereum/p2p" "github.com/ethereum/go-ethereum/p2p/discv5" "github.com/ethereum/go-ethereum/rlp" - "github.com/ethereum/go-ethereum/trie" ) type LesServer struct { @@ -42,23 +42,55 @@ type LesServer struct { fcManager *flowcontrol.ClientManager // nil if our node is client only fcCostStats *requestCostStats defParams *flowcontrol.ServerParams - lesTopic discv5.Topic + lesTopics []discv5.Topic + privateKey *ecdsa.PrivateKey quitSync chan struct{} + + chtIndexer, bloomTrieIndexer *core.ChainIndexer } func NewLesServer(eth *eth.Ethereum, config *eth.Config) (*LesServer, error) { quitSync := make(chan struct{}) - pm, err := NewProtocolManager(eth.BlockChain().Config(), false, config.NetworkId, eth.EventMux(), eth.Engine(), newPeerSet(), eth.BlockChain(), eth.TxPool(), eth.ChainDb(), nil, nil, quitSync, new(sync.WaitGroup)) + pm, err := NewProtocolManager(eth.BlockChain().Config(), false, ServerProtocolVersions, config.NetworkId, eth.EventMux(), eth.Engine(), newPeerSet(), eth.BlockChain(), eth.TxPool(), eth.ChainDb(), nil, nil, quitSync, new(sync.WaitGroup)) if err != nil { return nil, err } - pm.blockLoop() + + lesTopics := make([]discv5.Topic, len(ServerProtocolVersions)) + for i, pv := range ServerProtocolVersions { + lesTopics[i] = lesTopic(eth.BlockChain().Genesis().Hash(), pv) + } srv := &LesServer{ - protocolManager: pm, - quitSync: quitSync, - lesTopic: lesTopic(eth.BlockChain().Genesis().Hash()), + protocolManager: pm, + quitSync: quitSync, + lesTopics: lesTopics, + chtIndexer: light.NewChtIndexer(eth.ChainDb(), false), + bloomTrieIndexer: light.NewBloomTrieIndexer(eth.ChainDb(), false), + } + logger := log.New() + + chtV1SectionCount, _, _ := srv.chtIndexer.Sections() // indexer still uses LES/1 4k section size for backwards server compatibility + chtV2SectionCount := chtV1SectionCount / (light.ChtFrequency / light.ChtV1Frequency) + if chtV2SectionCount != 0 { + // convert to LES/2 section + chtLastSection := chtV2SectionCount - 1 + // convert last LES/2 section index back to LES/1 index for chtIndexer.SectionHead + chtLastSectionV1 := (chtLastSection+1)*(light.ChtFrequency/light.ChtV1Frequency) - 1 + chtSectionHead := srv.chtIndexer.SectionHead(chtLastSectionV1) + chtRoot := light.GetChtV2Root(pm.chainDb, chtLastSection, chtSectionHead) + logger.Info("CHT", "section", chtLastSection, "sectionHead", fmt.Sprintf("%064x", chtSectionHead), "root", fmt.Sprintf("%064x", chtRoot)) + } + + bloomTrieSectionCount, _, _ := srv.bloomTrieIndexer.Sections() + if bloomTrieSectionCount != 0 { + bloomTrieLastSection := bloomTrieSectionCount - 1 + bloomTrieSectionHead := srv.bloomTrieIndexer.SectionHead(bloomTrieLastSection) + bloomTrieRoot := light.GetBloomTrieRoot(pm.chainDb, bloomTrieLastSection, bloomTrieSectionHead) + logger.Info("BloomTrie", "section", bloomTrieLastSection, "sectionHead", fmt.Sprintf("%064x", bloomTrieSectionHead), "root", fmt.Sprintf("%064x", bloomTrieRoot)) } + + srv.chtIndexer.Start(eth.BlockChain()) pm.server = srv srv.defParams = &flowcontrol.ServerParams{ @@ -77,17 +109,28 @@ func (s *LesServer) Protocols() []p2p.Protocol { // Start starts the LES server func (s *LesServer) Start(srvr *p2p.Server) { s.protocolManager.Start() - go func() { - logger := log.New("topic", s.lesTopic) - logger.Info("Starting topic registration") - defer logger.Info("Terminated topic registration") + for _, topic := range s.lesTopics { + topic := topic + go func() { + logger := log.New("topic", topic) + logger.Info("Starting topic registration") + defer logger.Info("Terminated topic registration") + + srvr.DiscV5.RegisterTopic(topic, s.quitSync) + }() + } + s.privateKey = srvr.PrivateKey + s.protocolManager.blockLoop() +} - srvr.DiscV5.RegisterTopic(s.lesTopic, s.quitSync) - }() +func (s *LesServer) SetBloomBitsIndexer(bloomIndexer *core.ChainIndexer) { + bloomIndexer.AddChildIndexer(s.bloomTrieIndexer) } // Stop stops the LES service func (s *LesServer) Stop() { + s.chtIndexer.Close() + // bloom trie indexer is closed by parent bloombits indexer s.fcCostStats.store() s.fcManager.Stop() go func() { @@ -273,10 +316,7 @@ func (pm *ProtocolManager) blockLoop() { pm.wg.Add(1) headCh := make(chan core.ChainHeadEvent, 10) headSub := pm.blockchain.SubscribeChainHeadEvent(headCh) - newCht := make(chan struct{}, 10) - newCht <- struct{}{} go func() { - var mu sync.Mutex var lastHead *types.Header lastBroadcastTd := common.Big0 for { @@ -299,26 +339,37 @@ func (pm *ProtocolManager) blockLoop() { log.Debug("Announcing block to peers", "number", number, "hash", hash, "td", td, "reorg", reorg) announce := announceData{Hash: hash, Number: number, Td: td, ReorgDepth: reorg} + var ( + signed bool + signedAnnounce announceData + ) + for _, p := range peers { - select { - case p.announceChn <- announce: - default: - pm.removePeer(p.id) + switch p.announceType { + + case announceTypeSimple: + select { + case p.announceChn <- announce: + default: + pm.removePeer(p.id) + } + + case announceTypeSigned: + if !signed { + signedAnnounce = announce + signedAnnounce.sign(pm.server.privateKey) + signed = true + } + + select { + case p.announceChn <- signedAnnounce: + default: + pm.removePeer(p.id) + } } } } } - newCht <- struct{}{} - case <-newCht: - go func() { - mu.Lock() - more := makeCht(pm.chainDb) - mu.Unlock() - if more { - time.Sleep(time.Millisecond * 10) - newCht <- struct{}{} - } - }() case <-pm.quitSync: headSub.Unsubscribe() pm.wg.Done() @@ -327,86 +378,3 @@ func (pm *ProtocolManager) blockLoop() { } }() } - -var ( - lastChtKey = []byte("LastChtNumber") // chtNum (uint64 big endian) - chtPrefix = []byte("cht") // chtPrefix + chtNum (uint64 big endian) -> trie root hash -) - -func getChtRoot(db ethdb.Database, num uint64) common.Hash { - var encNumber [8]byte - binary.BigEndian.PutUint64(encNumber[:], num) - data, _ := db.Get(append(chtPrefix, encNumber[:]...)) - return common.BytesToHash(data) -} - -func storeChtRoot(db ethdb.Database, num uint64, root common.Hash) { - var encNumber [8]byte - binary.BigEndian.PutUint64(encNumber[:], num) - db.Put(append(chtPrefix, encNumber[:]...), root[:]) -} - -func makeCht(db ethdb.Database) bool { - headHash := core.GetHeadBlockHash(db) - headNum := core.GetBlockNumber(db, headHash) - - var newChtNum uint64 - if headNum > light.ChtConfirmations { - newChtNum = (headNum - light.ChtConfirmations) / light.ChtFrequency - } - - var lastChtNum uint64 - data, _ := db.Get(lastChtKey) - if len(data) == 8 { - lastChtNum = binary.BigEndian.Uint64(data[:]) - } - if newChtNum <= lastChtNum { - return false - } - - var t *trie.Trie - if lastChtNum > 0 { - var err error - t, err = trie.New(getChtRoot(db, lastChtNum), db) - if err != nil { - lastChtNum = 0 - } - } - if lastChtNum == 0 { - t, _ = trie.New(common.Hash{}, db) - } - - for num := lastChtNum * light.ChtFrequency; num < (lastChtNum+1)*light.ChtFrequency; num++ { - hash := core.GetCanonicalHash(db, num) - if hash == (common.Hash{}) { - panic("Canonical hash not found") - } - td := core.GetTd(db, hash, num) - if td == nil { - panic("TD not found") - } - var encNumber [8]byte - binary.BigEndian.PutUint64(encNumber[:], num) - var node light.ChtNode - node.Hash = hash - node.Td = td - data, _ := rlp.EncodeToBytes(node) - t.Update(encNumber[:], data) - } - - root, err := t.Commit() - if err != nil { - lastChtNum = 0 - } else { - lastChtNum++ - - log.Trace("Generated CHT", "number", lastChtNum, "root", root.Hex()) - - storeChtRoot(db, lastChtNum, root) - var data [8]byte - binary.BigEndian.PutUint64(data[:], lastChtNum) - db.Put(lastChtKey, data[:]) - } - - return newChtNum > lastChtNum -} -- cgit