diff options
-rw-r--r-- | eth/handler.go | 4 | ||||
-rw-r--r-- | eth/handler_test.go | 20 |
2 files changed, 23 insertions, 1 deletions
diff --git a/eth/handler.go b/eth/handler.go index d8c5b4b64..d04f79105 100644 --- a/eth/handler.go +++ b/eth/handler.go @@ -372,6 +372,8 @@ func (pm *ProtocolManager) handleMsg(p *peer) error { if err := msg.Decode(&query); err != nil { return errResp(ErrDecode, "%v: %v", msg, err) } + hashMode := query.Origin.Hash != (common.Hash{}) + // Gather headers until the fetch or network limits is reached var ( bytes common.StorageSize @@ -381,7 +383,7 @@ func (pm *ProtocolManager) handleMsg(p *peer) error { for !unknown && len(headers) < int(query.Amount) && bytes < softResponseLimit && len(headers) < downloader.MaxHeaderFetch { // Retrieve the next header satisfying the query var origin *types.Header - if query.Origin.Hash != (common.Hash{}) { + if hashMode { origin = pm.blockchain.GetHeader(query.Origin.Hash) } else { origin = pm.blockchain.GetHeaderByNumber(query.Origin.Number) diff --git a/eth/handler_test.go b/eth/handler_test.go index ab2ce54b1..148d56cc6 100644 --- a/eth/handler_test.go +++ b/eth/handler_test.go @@ -301,6 +301,15 @@ func testGetBlockHeaders(t *testing.T, protocol int) { pm.blockchain.GetBlockByNumber(1).Hash(), }, }, + // Check a corner case where requesting more can iterate past the endpoints + { + &getBlockHeadersData{Origin: hashOrNumber{Number: 2}, Amount: 5, Reverse: true}, + []common.Hash{ + pm.blockchain.GetBlockByNumber(2).Hash(), + pm.blockchain.GetBlockByNumber(1).Hash(), + pm.blockchain.GetBlockByNumber(0).Hash(), + }, + }, // Check that non existing headers aren't returned { &getBlockHeadersData{Origin: hashOrNumber{Hash: unknown}, Amount: 1}, @@ -322,6 +331,17 @@ func testGetBlockHeaders(t *testing.T, protocol int) { if err := p2p.ExpectMsg(peer.app, 0x04, headers); err != nil { t.Errorf("test %d: headers mismatch: %v", i, err) } + // If the test used number origins, repeat with hashes as the too + if tt.query.Origin.Hash == (common.Hash{}) { + if origin := pm.blockchain.GetBlockByNumber(tt.query.Origin.Number); origin != nil { + tt.query.Origin.Hash, tt.query.Origin.Number = origin.Hash(), 0 + + p2p.Send(peer.app, 0x03, tt.query) + if err := p2p.ExpectMsg(peer.app, 0x04, headers); err != nil { + t.Errorf("test %d: headers mismatch: %v", i, err) + } + } + } } } |