diff options
Diffstat (limited to 'p2p/peer.go')
-rw-r--r-- | p2p/peer.go | 131 |
1 files changed, 69 insertions, 62 deletions
diff --git a/p2p/peer.go b/p2p/peer.go index 6b97ea58d..a82ee4bca 100644 --- a/p2p/peer.go +++ b/p2p/peer.go @@ -44,7 +44,7 @@ type Peer struct { rw *conn running map[string]*protoRW - protoWG sync.WaitGroup + wg sync.WaitGroup protoErr chan error closed chan struct{} disc chan DiscReason @@ -102,58 +102,50 @@ func (p *Peer) String() string { func newPeer(fd net.Conn, conn *conn, protocols []Protocol) *Peer { logtag := fmt.Sprintf("Peer %.8x %v", conn.ID[:], fd.RemoteAddr()) + protomap := matchProtocols(protocols, conn.Caps, conn) p := &Peer{ Logger: logger.NewLogger(logtag), conn: fd, rw: conn, - running: matchProtocols(protocols, conn.Caps, conn), + running: protomap, disc: make(chan DiscReason), - protoErr: make(chan error), + protoErr: make(chan error, len(protomap)+1), // protocols + pingLoop closed: make(chan struct{}), } return p } func (p *Peer) run() DiscReason { - var readErr = make(chan error, 1) - defer p.closeProtocols() - defer close(p.closed) + readErr := make(chan error, 1) + p.wg.Add(2) + go p.readLoop(readErr) + go p.pingLoop() p.startProtocols() - go func() { readErr <- p.readLoop() }() - - ping := time.NewTicker(pingInterval) - defer ping.Stop() // Wait for an error or disconnect. var reason DiscReason -loop: - for { - select { - case <-ping.C: - go func() { - if err := SendItems(p.rw, pingMsg); err != nil { - p.protoErr <- err - return - } - }() - case err := <-readErr: - // We rely on protocols to abort if there is a write error. It - // might be more robust to handle them here as well. - p.DebugDetailf("Read error: %v\n", err) - p.conn.Close() - return DiscNetworkError - case err := <-p.protoErr: - reason = discReasonForError(err) - break loop - case reason = <-p.disc: - break loop + select { + case err := <-readErr: + if r, ok := err.(DiscReason); ok { + reason = r + break } + // Note: We rely on protocols to abort if there is a write + // error. It might be more robust to handle them here as well. + p.DebugDetailf("Read error: %v\n", err) + p.conn.Close() + reason = DiscNetworkError + case err := <-p.protoErr: + reason = discReasonForError(err) + case reason = <-p.disc: } - p.politeDisconnect(reason) - // Wait for readLoop. It will end because conn is now closed. - <-readErr + close(p.closed) + p.wg.Wait() + if reason != DiscNetworkError { + p.politeDisconnect(reason) + } p.Debugf("Disconnected: %v\n", reason) return reason } @@ -174,18 +166,37 @@ func (p *Peer) politeDisconnect(reason DiscReason) { p.conn.Close() } -func (p *Peer) readLoop() error { +func (p *Peer) pingLoop() { + ping := time.NewTicker(pingInterval) + defer p.wg.Done() + defer ping.Stop() + for { + select { + case <-ping.C: + if err := SendItems(p.rw, pingMsg); err != nil { + p.protoErr <- err + return + } + case <-p.closed: + return + } + } +} + +func (p *Peer) readLoop(errc chan<- error) { + defer p.wg.Done() for { p.conn.SetDeadline(time.Now().Add(frameReadTimeout)) msg, err := p.rw.ReadMsg() if err != nil { - return err + errc <- err + return } if err = p.handle(msg); err != nil { - return err + errc <- err + return } } - return nil } func (p *Peer) handle(msg Msg) error { @@ -195,12 +206,11 @@ func (p *Peer) handle(msg Msg) error { go SendItems(p.rw, pongMsg) case msg.Code == discMsg: var reason [1]DiscReason - // no need to discard or for error checking, we'll close the - // connection after this. + // This is the last message. We don't need to discard or + // check errors because, the connection will be closed after it. rlp.Decode(msg.Payload, &reason) p.Debugf("Disconnect requested: %v\n", reason[0]) - p.Disconnect(DiscRequested) - return discRequestedError(reason[0]) + return DiscRequested case msg.Code < baseProtocolLength: // ignore other base protocol messages return msg.Discard() @@ -210,7 +220,12 @@ func (p *Peer) handle(msg Msg) error { if err != nil { return fmt.Errorf("msg code out of range: %v", msg.Code) } - proto.in <- msg + select { + case proto.in <- msg: + return nil + case <-p.closed: + return io.EOF + } } return nil } @@ -234,10 +249,11 @@ outer: } func (p *Peer) startProtocols() { + p.wg.Add(len(p.running)) for _, proto := range p.running { proto := proto + proto.closed = p.closed p.DebugDetailf("Starting protocol %s/%d\n", proto.Name, proto.Version) - p.protoWG.Add(1) go func() { err := proto.Run(p, proto) if err == nil { @@ -246,11 +262,8 @@ func (p *Peer) startProtocols() { } else { p.DebugDetailf("Protocol %s/%d error: %v\n", proto.Name, proto.Version, err) } - select { - case p.protoErr <- err: - case <-p.closed: - } - p.protoWG.Done() + p.protoErr <- err + p.wg.Done() }() } } @@ -266,13 +279,6 @@ func (p *Peer) getProto(code uint64) (*protoRW, error) { return nil, newPeerError(errInvalidMsgCode, "%d", code) } -func (p *Peer) closeProtocols() { - for _, p := range p.running { - close(p.in) - } - p.protoWG.Wait() -} - // writeProtoMsg sends the given message on behalf of the given named protocol. // this exists because of Server.Broadcast. func (p *Peer) writeProtoMsg(protoName string, msg Msg) error { @@ -289,8 +295,8 @@ func (p *Peer) writeProtoMsg(protoName string, msg Msg) error { type protoRW struct { Protocol - in chan Msg + closed <-chan struct{} offset uint64 w MsgWriter } @@ -304,10 +310,11 @@ func (rw *protoRW) WriteMsg(msg Msg) error { } func (rw *protoRW) ReadMsg() (Msg, error) { - msg, ok := <-rw.in - if !ok { - return msg, io.EOF + select { + case msg := <-rw.in: + msg.Code -= rw.offset + return msg, nil + case <-rw.closed: + return Msg{}, io.EOF } - msg.Code -= rw.offset - return msg, nil } |