diff options
Diffstat (limited to 'p2p/peer.go')
-rw-r--r-- | p2p/peer.go | 472 |
1 files changed, 207 insertions, 265 deletions
diff --git a/p2p/peer.go b/p2p/peer.go index 2380a3285..fd5bec7d5 100644 --- a/p2p/peer.go +++ b/p2p/peer.go @@ -1,8 +1,7 @@ package p2p import ( - "bufio" - "bytes" + "errors" "fmt" "io" "io/ioutil" @@ -11,159 +10,109 @@ import ( "sync" "time" - "github.com/ethereum/go-ethereum/event" "github.com/ethereum/go-ethereum/logger" + "github.com/ethereum/go-ethereum/p2p/discover" + "github.com/ethereum/go-ethereum/rlp" ) -// peerAddr is the structure of a peer list element. -// It is also a valid net.Addr. -type peerAddr struct { - IP net.IP - Port uint64 - Pubkey []byte // optional -} - -func newPeerAddr(addr net.Addr, pubkey []byte) *peerAddr { - n := addr.Network() - if n != "tcp" && n != "tcp4" && n != "tcp6" { - // for testing with non-TCP - return &peerAddr{net.ParseIP("127.0.0.1"), 30303, pubkey} - } - ta := addr.(*net.TCPAddr) - return &peerAddr{ta.IP, uint64(ta.Port), pubkey} -} +const ( + baseProtocolVersion = 3 + baseProtocolLength = uint64(16) + baseProtocolMaxMsgSize = 10 * 1024 * 1024 -func (d peerAddr) Network() string { - if d.IP.To4() != nil { - return "tcp4" - } else { - return "tcp6" - } -} + disconnectGracePeriod = 2 * time.Second +) -func (d peerAddr) String() string { - return fmt.Sprintf("%v:%d", d.IP, d.Port) -} +const ( + // devp2p message codes + handshakeMsg = 0x00 + discMsg = 0x01 + pingMsg = 0x02 + pongMsg = 0x03 + getPeersMsg = 0x04 + peersMsg = 0x05 +) -func (d *peerAddr) RlpData() interface{} { - return []interface{}{string(d.IP), d.Port, d.Pubkey} +// handshake is the RLP structure of the protocol handshake. +type handshake struct { + Version uint64 + Name string + Caps []Cap + ListenPort uint64 + NodeID discover.NodeID } -// Peer represents a remote peer. +// Peer represents a connected remote node. type Peer struct { // Peers have all the log methods. // Use them to display messages related to the peer. *logger.Logger - infolock sync.Mutex - identity ClientIdentity - caps []Cap - listenAddr *peerAddr // what remote peer is listening on - dialAddr *peerAddr // non-nil if dialing + infoMu sync.Mutex + name string + caps []Cap + + ourID, remoteID *discover.NodeID + ourName string - // The mutex protects the connection - // so only one protocol can write at a time. - writeMu sync.Mutex - conn net.Conn - bufconn *bufio.ReadWriter + rw *frameRW // These fields maintain the running protocols. - protocols []Protocol - runBaseProtocol bool // for testing + protocols []Protocol + runlock sync.RWMutex // protects running + running map[string]*proto - runlock sync.RWMutex // protects running - running map[string]*proto + // disables protocol handshake, for testing + noHandshake bool protoWG sync.WaitGroup protoErr chan error closed chan struct{} disc chan DiscReason - - activity event.TypeMux // for activity events - - slot int // index into Server peer list - - // These fields are kept so base protocol can access them. - // TODO: this should be one or more interfaces - ourID ClientIdentity // client id of the Server - ourListenAddr *peerAddr // listen addr of Server, nil if not listening - newPeerAddr chan<- *peerAddr // tell server about received peers - otherPeers func() []*Peer // should return the list of all peers - pubkeyHook func(*peerAddr) error // called at end of handshake to validate pubkey } // NewPeer returns a peer for testing purposes. -func NewPeer(id ClientIdentity, caps []Cap) *Peer { +func NewPeer(id discover.NodeID, name string, caps []Cap) *Peer { conn, _ := net.Pipe() - peer := newPeer(conn, nil, nil) - peer.setHandshakeInfo(id, nil, caps) - close(peer.closed) + peer := newPeer(conn, nil, "", nil, &id) + peer.setHandshakeInfo(name, caps) + close(peer.closed) // ensures Disconnect doesn't block return peer } -func newServerPeer(server *Server, conn net.Conn, dialAddr *peerAddr) *Peer { - p := newPeer(conn, server.Protocols, dialAddr) - p.ourID = server.Identity - p.newPeerAddr = server.peerConnect - p.otherPeers = server.Peers - p.pubkeyHook = server.verifyPeer - p.runBaseProtocol = true - - // laddr can be updated concurrently by NAT traversal. - // newServerPeer must be called with the server lock held. - if server.laddr != nil { - p.ourListenAddr = newPeerAddr(server.laddr, server.Identity.Pubkey()) - } - return p -} - -func newPeer(conn net.Conn, protocols []Protocol, dialAddr *peerAddr) *Peer { - p := &Peer{ - Logger: logger.NewLogger("P2P " + conn.RemoteAddr().String()), - conn: conn, - dialAddr: dialAddr, - bufconn: bufio.NewReadWriter(bufio.NewReader(conn), bufio.NewWriter(conn)), - protocols: protocols, - running: make(map[string]*proto), - disc: make(chan DiscReason), - protoErr: make(chan error), - closed: make(chan struct{}), - } - return p +// ID returns the node's public key. +func (p *Peer) ID() discover.NodeID { + return *p.remoteID } -// Identity returns the client identity of the remote peer. The -// identity can be nil if the peer has not yet completed the -// handshake. -func (p *Peer) Identity() ClientIdentity { - p.infolock.Lock() - defer p.infolock.Unlock() - return p.identity +// Name returns the node name that the remote node advertised. +func (p *Peer) Name() string { + // this needs a lock because the information is part of the + // protocol handshake. + p.infoMu.Lock() + name := p.name + p.infoMu.Unlock() + return name } // Caps returns the capabilities (supported subprotocols) of the remote peer. func (p *Peer) Caps() []Cap { - p.infolock.Lock() - defer p.infolock.Unlock() - return p.caps -} - -func (p *Peer) setHandshakeInfo(id ClientIdentity, laddr *peerAddr, caps []Cap) { - p.infolock.Lock() - p.identity = id - p.listenAddr = laddr - p.caps = caps - p.infolock.Unlock() + // this needs a lock because the information is part of the + // protocol handshake. + p.infoMu.Lock() + caps := p.caps + p.infoMu.Unlock() + return caps } // RemoteAddr returns the remote address of the network connection. func (p *Peer) RemoteAddr() net.Addr { - return p.conn.RemoteAddr() + return p.rw.RemoteAddr() } // LocalAddr returns the local address of the network connection. func (p *Peer) LocalAddr() net.Addr { - return p.conn.LocalAddr() + return p.rw.LocalAddr() } // Disconnect terminates the peer connection with the given reason. @@ -177,149 +126,177 @@ func (p *Peer) Disconnect(reason DiscReason) { // String implements fmt.Stringer. func (p *Peer) String() string { - kind := "inbound" - p.infolock.Lock() - if p.dialAddr != nil { - kind = "outbound" - } - p.infolock.Unlock() - return fmt.Sprintf("Peer(%p %v %s)", p, p.conn.RemoteAddr(), kind) + return fmt.Sprintf("Peer %.8x %v", p.remoteID[:], p.RemoteAddr()) } -const ( - // maximum amount of time allowed for reading a message - msgReadTimeout = 5 * time.Second - // maximum amount of time allowed for writing a message - msgWriteTimeout = 5 * time.Second - // messages smaller than this many bytes will be read at - // once before passing them to a protocol. - wholePayloadSize = 64 * 1024 -) +func newPeer(conn net.Conn, protocols []Protocol, ourName string, ourID, remoteID *discover.NodeID) *Peer { + logtag := fmt.Sprintf("Peer %.8x %v", remoteID[:], conn.RemoteAddr()) + return &Peer{ + Logger: logger.NewLogger(logtag), + rw: newFrameRW(conn, msgWriteTimeout), + ourID: ourID, + ourName: ourName, + remoteID: remoteID, + protocols: protocols, + running: make(map[string]*proto), + disc: make(chan DiscReason), + protoErr: make(chan error), + closed: make(chan struct{}), + } +} -var ( - inactivityTimeout = 2 * time.Second - disconnectGracePeriod = 2 * time.Second -) +func (p *Peer) setHandshakeInfo(name string, caps []Cap) { + p.infoMu.Lock() + p.name = name + p.caps = caps + p.infoMu.Unlock() +} -func (p *Peer) loop() (reason DiscReason, err error) { - defer p.activity.Stop() +func (p *Peer) run() DiscReason { + var readErr = make(chan error, 1) defer p.closeProtocols() defer close(p.closed) - defer p.conn.Close() - - // read loop - readMsg := make(chan Msg) - readErr := make(chan error) - readNext := make(chan bool, 1) - protoDone := make(chan struct{}, 1) - go p.readLoop(readMsg, readErr, readNext) - readNext <- true - - if p.runBaseProtocol { - p.startBaseProtocol() - } -loop: - for { - select { - case msg := <-readMsg: - // a new message has arrived. - var wait bool - if wait, err = p.dispatch(msg, protoDone); err != nil { - p.Errorf("msg dispatch error: %v\n", err) - reason = discReasonForError(err) - break loop - } - if !wait { - // Msg has already been read completely, continue with next message. - readNext <- true - } - p.activity.Post(time.Now()) - case <-protoDone: - // protocol has consumed the message payload, - // we can continue reading from the socket. - readNext <- true - - case err := <-readErr: - // read failed. there is no need to run the - // polite disconnect sequence because the connection - // is probably dead anyway. - // TODO: handle write errors as well - return DiscNetworkError, err - case err = <-p.protoErr: - reason = discReasonForError(err) - break loop - case reason = <-p.disc: - break loop + go func() { readErr <- p.readLoop() }() + + if !p.noHandshake { + if err := writeProtocolHandshake(p.rw, p.ourName, *p.ourID, p.protocols); err != nil { + p.DebugDetailf("Protocol handshake error: %v\n", err) + p.rw.Close() + return DiscProtocolError } } - // wait for read loop to return. - close(readNext) + // Wait for an error or disconnect. + var reason DiscReason + select { + case err := <-readErr: + // We rely on protocols to abort if there is a write error. It + // might be more robust to handle them here as well. + p.DebugDetailf("Read error: %v\n", err) + p.rw.Close() + return DiscNetworkError + + case err := <-p.protoErr: + reason = discReasonForError(err) + case reason = <-p.disc: + } + p.politeDisconnect(reason) + + // Wait for readLoop. It will end because conn is now closed. <-readErr - // tell the remote end to disconnect + p.Debugf("Disconnected: %v\n", reason) + return reason +} + +func (p *Peer) politeDisconnect(reason DiscReason) { done := make(chan struct{}) go func() { - p.conn.SetDeadline(time.Now().Add(disconnectGracePeriod)) - p.writeMsg(NewMsg(discMsg, reason), disconnectGracePeriod) - io.Copy(ioutil.Discard, p.conn) + EncodeMsg(p.rw, discMsg, uint(reason)) + // Wait for the other side to close the connection. + // Discard any data that they send until then. + io.Copy(ioutil.Discard, p.rw) close(done) }() select { case <-done: case <-time.After(disconnectGracePeriod): } - return reason, err + p.rw.Close() } -func (p *Peer) readLoop(msgc chan<- Msg, errc chan<- error, unblock <-chan bool) { - for _ = range unblock { - p.conn.SetReadDeadline(time.Now().Add(msgReadTimeout)) - if msg, err := readMsg(p.bufconn); err != nil { - errc <- err - } else { - msgc <- msg +func (p *Peer) readLoop() error { + if !p.noHandshake { + if err := readProtocolHandshake(p, p.rw); err != nil { + return err } } - close(errc) -} - -func (p *Peer) dispatch(msg Msg, protoDone chan struct{}) (wait bool, err error) { - proto, err := p.getProto(msg.Code) - if err != nil { - return false, err + for { + msg, err := p.rw.ReadMsg() + if err != nil { + return err + } + if err = p.handle(msg); err != nil { + return err + } } - if msg.Size <= wholePayloadSize { - // optimization: msg is small enough, read all - // of it and move on to the next message - buf, err := ioutil.ReadAll(msg.Payload) + return nil +} + +func (p *Peer) handle(msg Msg) error { + switch { + case msg.Code == pingMsg: + msg.Discard() + go EncodeMsg(p.rw, pongMsg) + case msg.Code == discMsg: + var reason DiscReason + // no need to discard or for error checking, we'll close the + // connection after this. + rlp.Decode(msg.Payload, &reason) + p.Disconnect(DiscRequested) + return discRequestedError(reason) + case msg.Code < baseProtocolLength: + // ignore other base protocol messages + return msg.Discard() + default: + // it's a subprotocol message + proto, err := p.getProto(msg.Code) if err != nil { - return false, err + return fmt.Errorf("msg code out of range: %v", msg.Code) } - msg.Payload = bytes.NewReader(buf) - proto.in <- msg - } else { - wait = true - pr := &eofSignal{msg.Payload, int64(msg.Size), protoDone} - msg.Payload = pr proto.in <- msg } - return wait, nil + return nil } -func (p *Peer) startBaseProtocol() { - p.runlock.Lock() - defer p.runlock.Unlock() - p.running[""] = p.startProto(0, Protocol{ - Length: baseProtocolLength, - Run: runBaseProtocol, - }) +func readProtocolHandshake(p *Peer, rw MsgReadWriter) error { + // read and handle remote handshake + msg, err := rw.ReadMsg() + if err != nil { + return err + } + if msg.Code == discMsg { + // disconnect before protocol handshake is valid according to the + // spec and we send it ourself if Server.addPeer fails. + var reason DiscReason + rlp.Decode(msg.Payload, &reason) + return discRequestedError(reason) + } + if msg.Code != handshakeMsg { + return newPeerError(errProtocolBreach, "expected handshake, got %x", msg.Code) + } + if msg.Size > baseProtocolMaxMsgSize { + return newPeerError(errInvalidMsg, "message too big") + } + var hs handshake + if err := msg.Decode(&hs); err != nil { + return err + } + // validate handshake info + if hs.Version != baseProtocolVersion { + return newPeerError(errP2PVersionMismatch, "required version %d, received %d\n", + baseProtocolVersion, hs.Version) + } + if hs.NodeID == *p.remoteID { + return newPeerError(errPubkeyForbidden, "node ID mismatch") + } + // TODO: remove Caps with empty name + p.setHandshakeInfo(hs.Name, hs.Caps) + p.startSubprotocols(hs.Caps) + return nil +} + +func writeProtocolHandshake(w MsgWriter, name string, id discover.NodeID, ps []Protocol) error { + var caps []interface{} + for _, proto := range ps { + caps = append(caps, proto.cap()) + } + return EncodeMsg(w, handshakeMsg, baseProtocolVersion, name, caps, 0, id) } // startProtocols starts matching named subprotocols. func (p *Peer) startSubprotocols(caps []Cap) { sort.Sort(capsByName(caps)) - p.runlock.Lock() defer p.runlock.Unlock() offset := baseProtocolLength @@ -338,20 +315,22 @@ outer: } func (p *Peer) startProto(offset uint64, impl Protocol) *proto { + p.DebugDetailf("Starting protocol %s/%d\n", impl.Name, impl.Version) rw := &proto{ + name: impl.Name, in: make(chan Msg), offset: offset, maxcode: impl.Length, - peer: p, + w: p.rw, } p.protoWG.Add(1) go func() { err := impl.Run(p, rw) if err == nil { - p.Infof("protocol %q returned", impl.Name) - err = newPeerError(errMisc, "protocol returned") + p.DebugDetailf("Protocol %s/%d returned\n", impl.Name, impl.Version) + err = errors.New("protocol returned") } else { - p.Errorf("protocol %q error: %v\n", impl.Name, err) + p.DebugDetailf("Protocol %s/%d error: %v\n", impl.Name, impl.Version, err) } select { case p.protoErr <- err: @@ -385,6 +364,7 @@ func (p *Peer) closeProtocols() { } // writeProtoMsg sends the given message on behalf of the given named protocol. +// this exists because of Server.Broadcast. func (p *Peer) writeProtoMsg(protoName string, msg Msg) error { p.runlock.RLock() proto, ok := p.running[protoName] @@ -396,25 +376,14 @@ func (p *Peer) writeProtoMsg(protoName string, msg Msg) error { return newPeerError(errInvalidMsgCode, "code %x is out of range for protocol %q", msg.Code, protoName) } msg.Code += proto.offset - return p.writeMsg(msg, msgWriteTimeout) -} - -// writeMsg writes a message to the connection. -func (p *Peer) writeMsg(msg Msg, timeout time.Duration) error { - p.writeMu.Lock() - defer p.writeMu.Unlock() - p.conn.SetWriteDeadline(time.Now().Add(timeout)) - if err := writeMsg(p.bufconn, msg); err != nil { - return newPeerError(errWrite, "%v", err) - } - return p.bufconn.Flush() + return p.rw.WriteMsg(msg) } type proto struct { name string in chan Msg maxcode, offset uint64 - peer *Peer + w MsgWriter } func (rw *proto) WriteMsg(msg Msg) error { @@ -422,11 +391,7 @@ func (rw *proto) WriteMsg(msg Msg) error { return newPeerError(errInvalidMsgCode, "not handled") } msg.Code += rw.offset - return rw.peer.writeMsg(msg, msgWriteTimeout) -} - -func (rw *proto) EncodeMsg(code uint64, data ...interface{}) error { - return rw.WriteMsg(NewMsg(code, data...)) + return rw.w.WriteMsg(msg) } func (rw *proto) ReadMsg() (Msg, error) { @@ -437,26 +402,3 @@ func (rw *proto) ReadMsg() (Msg, error) { msg.Code -= rw.offset return msg, nil } - -// eofSignal wraps a reader with eof signaling. the eof channel is -// closed when the wrapped reader returns an error or when count bytes -// have been read. -// -type eofSignal struct { - wrapped io.Reader - count int64 - eof chan<- struct{} -} - -// note: when using eofSignal to detect whether a message payload -// has been read, Read might not be called for zero sized messages. - -func (r *eofSignal) Read(buf []byte) (int, error) { - n, err := r.wrapped.Read(buf) - r.count -= int64(n) - if (err != nil || r.count <= 0) && r.eof != nil { - r.eof <- struct{}{} // tell Peer that msg has been consumed - r.eof = nil - } - return n, err -} |