diff options
Diffstat (limited to 'eth/protocol.go')
-rw-r--r-- | eth/protocol.go | 100 |
1 files changed, 55 insertions, 45 deletions
diff --git a/eth/protocol.go b/eth/protocol.go index 7c5d09489..b67e5aaea 100644 --- a/eth/protocol.go +++ b/eth/protocol.go @@ -3,7 +3,7 @@ package eth import ( "bytes" "fmt" - "math" + "io" "math/big" "github.com/ethereum/go-ethereum/core/types" @@ -95,14 +95,13 @@ func runEthProtocol(txPool txPool, chainManager chainManager, blockPool blockPoo blockPool: blockPool, rw: rw, peer: peer, - id: (string)(peer.Identity().Pubkey()), + id: fmt.Sprintf("%x", peer.Identity().Pubkey()[:8]), } err = self.handleStatus() if err == nil { for { err = self.handle() if err != nil { - fmt.Println(err) self.blockPool.RemovePeer(self.id) break } @@ -117,7 +116,7 @@ func (self *ethProtocol) handle() error { return err } if msg.Size > ProtocolMaxMsgSize { - return ProtocolError(ErrMsgTooLarge, "%v > %v", msg.Size, ProtocolMaxMsgSize) + return self.protoError(ErrMsgTooLarge, "%v > %v", msg.Size, ProtocolMaxMsgSize) } // make sure that the payload has been fully consumed defer msg.Discard() @@ -125,76 +124,87 @@ func (self *ethProtocol) handle() error { switch msg.Code { case StatusMsg: - return ProtocolError(ErrExtraStatusMsg, "") + return self.protoError(ErrExtraStatusMsg, "") case TxMsg: // TODO: rework using lazy RLP stream var txs []*types.Transaction if err := msg.Decode(&txs); err != nil { - return ProtocolError(ErrDecode, "%v", err) + return self.protoError(ErrDecode, "msg %v: %v", msg, err) } self.txPool.AddTransactions(txs) case GetBlockHashesMsg: var request getBlockHashesMsgData if err := msg.Decode(&request); err != nil { - return ProtocolError(ErrDecode, "%v", err) + return self.protoError(ErrDecode, "->msg %v: %v", msg, err) } hashes := self.chainManager.GetBlockHashesFromHash(request.Hash, request.Amount) return self.rw.EncodeMsg(BlockHashesMsg, ethutil.ByteSliceToInterface(hashes)...) case BlockHashesMsg: // TODO: redo using lazy decode , this way very inefficient on known chains - msgStream := rlp.NewListStream(msg.Payload, uint64(msg.Size)) + msgStream := rlp.NewStream(msg.Payload) var err error + var i int + iter := func() (hash []byte, ok bool) { hash, err = msgStream.Bytes() if err == nil { + i++ ok = true + } else { + if err != io.EOF { + self.protoError(ErrDecode, "msg %v: after %v hashes : %v", msg, i, err) + } } return } + self.blockPool.AddBlockHashes(iter, self.id) - if err != nil && err != rlp.EOL { - return ProtocolError(ErrDecode, "%v", err) - } case GetBlocksMsg: - var blockHashes [][]byte - if err := msg.Decode(&blockHashes); err != nil { - return ProtocolError(ErrDecode, "%v", err) - } - max := int(math.Min(float64(len(blockHashes)), blockHashesBatchSize)) + msgStream := rlp.NewStream(msg.Payload) var blocks []interface{} - for i, hash := range blockHashes { - if i >= max { - break + var i int + for { + i++ + var hash []byte + if err := msgStream.Decode(&hash); err != nil { + if err == io.EOF { + break + } else { + return self.protoError(ErrDecode, "msg %v: %v", msg, err) + } } block := self.chainManager.GetBlock(hash) if block != nil { - blocks = append(blocks, block.RlpData()) + blocks = append(blocks, block) + } + if i == blockHashesBatchSize { + break } } return self.rw.EncodeMsg(BlocksMsg, blocks...) case BlocksMsg: - msgStream := rlp.NewListStream(msg.Payload, uint64(msg.Size)) + msgStream := rlp.NewStream(msg.Payload) for { - var block *types.Block + var block types.Block if err := msgStream.Decode(&block); err != nil { - if err == rlp.EOL { + if err == io.EOF { break } else { - return ProtocolError(ErrDecode, "%v", err) + return self.protoError(ErrDecode, "msg %v: %v", msg, err) } } - self.blockPool.AddBlock(block, self.id) + self.blockPool.AddBlock(&block, self.id) } case NewBlockMsg: var request newBlockMsgData if err := msg.Decode(&request); err != nil { - return ProtocolError(ErrDecode, "%v", err) + return self.protoError(ErrDecode, "msg %v: %v", msg, err) } hash := request.Block.Hash() // to simplify backend interface adding a new block @@ -202,12 +212,12 @@ func (self *ethProtocol) handle() error { // (or selected as new best peer) if self.blockPool.AddPeer(request.TD, hash, self.id, self.requestBlockHashes, self.requestBlocks, self.protoErrorDisconnect) { called := true - iter := func() (hash []byte, ok bool) { + iter := func() ([]byte, bool) { if called { called = false return hash, true } else { - return + return nil, false } } self.blockPool.AddBlockHashes(iter, self.id) @@ -215,14 +225,14 @@ func (self *ethProtocol) handle() error { } default: - return ProtocolError(ErrInvalidMsgCode, "%v", msg.Code) + return self.protoError(ErrInvalidMsgCode, "%v", msg.Code) } return nil } type statusMsgData struct { - ProtocolVersion uint - NetworkId uint + ProtocolVersion uint32 + NetworkId uint32 TD *big.Int CurrentBlock []byte GenesisBlock []byte @@ -253,56 +263,56 @@ func (self *ethProtocol) handleStatus() error { } if msg.Code != StatusMsg { - return ProtocolError(ErrNoStatusMsg, "first msg has code %x (!= %x)", msg.Code, StatusMsg) + return self.protoError(ErrNoStatusMsg, "first msg has code %x (!= %x)", msg.Code, StatusMsg) } if msg.Size > ProtocolMaxMsgSize { - return ProtocolError(ErrMsgTooLarge, "%v > %v", msg.Size, ProtocolMaxMsgSize) + return self.protoError(ErrMsgTooLarge, "%v > %v", msg.Size, ProtocolMaxMsgSize) } var status statusMsgData if err := msg.Decode(&status); err != nil { - return ProtocolError(ErrDecode, "%v", err) + return self.protoError(ErrDecode, "msg %v: %v", msg, err) } _, _, genesisBlock := self.chainManager.Status() if bytes.Compare(status.GenesisBlock, genesisBlock) != 0 { - return ProtocolError(ErrGenesisBlockMismatch, "%x (!= %x)", status.GenesisBlock, genesisBlock) + return self.protoError(ErrGenesisBlockMismatch, "%x (!= %x)", status.GenesisBlock, genesisBlock) } if status.NetworkId != NetworkId { - return ProtocolError(ErrNetworkIdMismatch, "%d (!= %d)", status.NetworkId, NetworkId) + return self.protoError(ErrNetworkIdMismatch, "%d (!= %d)", status.NetworkId, NetworkId) } if ProtocolVersion != status.ProtocolVersion { - return ProtocolError(ErrProtocolVersionMismatch, "%d (!= %d)", status.ProtocolVersion, ProtocolVersion) + return self.protoError(ErrProtocolVersionMismatch, "%d (!= %d)", status.ProtocolVersion, ProtocolVersion) } self.peer.Infof("Peer is [eth] capable (%d/%d). TD=%v H=%x\n", status.ProtocolVersion, status.NetworkId, status.TD, status.CurrentBlock[:4]) - //self.blockPool.AddPeer(status.TD, status.CurrentBlock, self.id, self.requestBlockHashes, self.requestBlocks, self.protoErrorDisconnect) - self.peer.Infoln("AddPeer(IGNORED)") + self.blockPool.AddPeer(status.TD, status.CurrentBlock, self.id, self.requestBlockHashes, self.requestBlocks, self.protoErrorDisconnect) return nil } func (self *ethProtocol) requestBlockHashes(from []byte) error { self.peer.Debugf("fetching hashes (%d) %x...\n", blockHashesBatchSize, from[0:4]) - return self.rw.EncodeMsg(GetBlockHashesMsg, from, blockHashesBatchSize) + return self.rw.EncodeMsg(GetBlockHashesMsg, interface{}(from), uint64(blockHashesBatchSize)) } func (self *ethProtocol) requestBlocks(hashes [][]byte) error { self.peer.Debugf("fetching %v blocks", len(hashes)) - return self.rw.EncodeMsg(GetBlocksMsg, ethutil.ByteSliceToInterface(hashes)) + return self.rw.EncodeMsg(GetBlocksMsg, ethutil.ByteSliceToInterface(hashes)...) } func (self *ethProtocol) protoError(code int, format string, params ...interface{}) (err *protocolError) { err = ProtocolError(code, format, params...) if err.Fatal() { - self.peer.Errorln(err) + self.peer.Errorln("err %v", err) + // disconnect } else { - self.peer.Debugln(err) + self.peer.Debugf("fyi %v", err) } return } @@ -310,10 +320,10 @@ func (self *ethProtocol) protoError(code int, format string, params ...interface func (self *ethProtocol) protoErrorDisconnect(code int, format string, params ...interface{}) { err := ProtocolError(code, format, params...) if err.Fatal() { - self.peer.Errorln(err) + self.peer.Errorln("err %v", err) // disconnect } else { - self.peer.Debugln(err) + self.peer.Debugf("fyi %v", err) } } |