diff options
Diffstat (limited to 'p2p/peer.go')
-rw-r--r-- | p2p/peer.go | 37 |
1 files changed, 21 insertions, 16 deletions
diff --git a/p2p/peer.go b/p2p/peer.go index 1fa8264a3..b61cf96da 100644 --- a/p2p/peer.go +++ b/p2p/peer.go @@ -1,6 +1,7 @@ package p2p import ( + "errors" "fmt" "io" "io/ioutil" @@ -71,7 +72,8 @@ type Peer struct { runlock sync.RWMutex // protects running running map[string]*proto - protocolHandshakeEnabled bool + // disables protocol handshake, for testing + noHandshake bool protoWG sync.WaitGroup protoErr chan error @@ -134,11 +136,11 @@ func (p *Peer) Disconnect(reason DiscReason) { // String implements fmt.Stringer. func (p *Peer) String() string { - return fmt.Sprintf("Peer %.8x %v", p.remoteID, p.RemoteAddr()) + return fmt.Sprintf("Peer %.8x %v", p.remoteID[:], p.RemoteAddr()) } func newPeer(conn net.Conn, protocols []Protocol, ourName string, ourID, remoteID *discover.NodeID) *Peer { - logtag := fmt.Sprintf("Peer %.8x %v", remoteID, conn.RemoteAddr()) + logtag := fmt.Sprintf("Peer %.8x %v", remoteID[:], conn.RemoteAddr()) return &Peer{ Logger: logger.NewLogger(logtag), rw: newFrameRW(conn, msgWriteTimeout), @@ -164,33 +166,35 @@ func (p *Peer) run() DiscReason { var readErr = make(chan error, 1) defer p.closeProtocols() defer close(p.closed) - defer p.rw.Close() - // start the read loop go func() { readErr <- p.readLoop() }() - if p.protocolHandshakeEnabled { + if !p.noHandshake { if err := writeProtocolHandshake(p.rw, p.ourName, *p.ourID, p.protocols); err != nil { p.DebugDetailf("Protocol handshake error: %v\n", err) + p.rw.Close() return DiscProtocolError } } - // wait for an error or disconnect + // Wait for an error or disconnect. var reason DiscReason select { case err := <-readErr: // We rely on protocols to abort if there is a write error. It // might be more robust to handle them here as well. p.DebugDetailf("Read error: %v\n", err) - reason = DiscNetworkError + p.rw.Close() + return DiscNetworkError + case err := <-p.protoErr: reason = discReasonForError(err) case reason = <-p.disc: } - if reason != DiscNetworkError { - p.politeDisconnect(reason) - } + p.politeDisconnect(reason) + + // Wait for readLoop. It will end because conn is now closed. + <-readErr p.Debugf("Disconnected: %v\n", reason) return reason } @@ -198,9 +202,9 @@ func (p *Peer) run() DiscReason { func (p *Peer) politeDisconnect(reason DiscReason) { done := make(chan struct{}) go func() { - // send reason EncodeMsg(p.rw, discMsg, uint(reason)) - // discard any data that might arrive + // Wait for the other side to close the connection. + // Discard any data that they send until then. io.Copy(ioutil.Discard, p.rw) close(done) }() @@ -208,10 +212,11 @@ func (p *Peer) politeDisconnect(reason DiscReason) { case <-done: case <-time.After(disconnectGracePeriod): } + p.rw.Close() } func (p *Peer) readLoop() error { - if p.protocolHandshakeEnabled { + if !p.noHandshake { if err := readProtocolHandshake(p, p.rw); err != nil { return err } @@ -264,7 +269,7 @@ func readProtocolHandshake(p *Peer, rw MsgReadWriter) error { return newPeerError(errProtocolBreach, "expected handshake, got %x", msg.Code) } if msg.Size > baseProtocolMaxMsgSize { - return newPeerError(errMisc, "message too big") + return newPeerError(errInvalidMsg, "message too big") } var hs handshake if err := msg.Decode(&hs); err != nil { @@ -326,7 +331,7 @@ func (p *Peer) startProto(offset uint64, impl Protocol) *proto { err := impl.Run(p, rw) if err == nil { p.DebugDetailf("Protocol %s/%d returned\n", impl.Name, impl.Version) - err = newPeerError(errMisc, "protocol returned") + err = errors.New("protocol returned") } else { p.DebugDetailf("Protocol %s/%d error: %v\n", impl.Name, impl.Version, err) } |