diff options
Diffstat (limited to 'p2p/server.go')
-rw-r--r-- | p2p/server.go | 197 |
1 files changed, 116 insertions, 81 deletions
diff --git a/p2p/server.go b/p2p/server.go index f9a38fcc6..40db758e2 100644 --- a/p2p/server.go +++ b/p2p/server.go @@ -18,6 +18,7 @@ package p2p import ( + "bytes" "crypto/ecdsa" "errors" "fmt" @@ -28,10 +29,12 @@ import ( "github.com/ethereum/go-ethereum/common" "github.com/ethereum/go-ethereum/common/mclock" + "github.com/ethereum/go-ethereum/crypto" "github.com/ethereum/go-ethereum/event" "github.com/ethereum/go-ethereum/log" "github.com/ethereum/go-ethereum/p2p/discover" "github.com/ethereum/go-ethereum/p2p/discv5" + "github.com/ethereum/go-ethereum/p2p/enode" "github.com/ethereum/go-ethereum/p2p/nat" "github.com/ethereum/go-ethereum/p2p/netutil" ) @@ -87,7 +90,7 @@ type Config struct { // BootstrapNodes are used to establish connectivity // with the rest of the network. - BootstrapNodes []*discover.Node + BootstrapNodes []*enode.Node // BootstrapNodesV5 are used to establish connectivity // with the rest of the network using the V5 discovery @@ -96,11 +99,11 @@ type Config struct { // Static nodes are used as pre-configured connections which are always // maintained and re-connected on disconnects. - StaticNodes []*discover.Node + StaticNodes []*enode.Node // Trusted nodes are used as pre-configured connections which are always // allowed to connect, even above the peer limit. - TrustedNodes []*discover.Node + TrustedNodes []*enode.Node // Connectivity can be restricted to certain IP networks. // If this option is set to a non-nil value, only hosts which match one of the @@ -168,10 +171,10 @@ type Server struct { peerOpDone chan struct{} quit chan struct{} - addstatic chan *discover.Node - removestatic chan *discover.Node - addtrusted chan *discover.Node - removetrusted chan *discover.Node + addstatic chan *enode.Node + removestatic chan *enode.Node + addtrusted chan *enode.Node + removetrusted chan *enode.Node posthandshake chan *conn addpeer chan *conn delpeer chan peerDrop @@ -180,7 +183,7 @@ type Server struct { log log.Logger } -type peerOpFunc func(map[discover.NodeID]*Peer) +type peerOpFunc func(map[enode.ID]*Peer) type peerDrop struct { *Peer @@ -202,16 +205,16 @@ const ( type conn struct { fd net.Conn transport + node *enode.Node flags connFlag - cont chan error // The run loop uses cont to signal errors to SetupConn. - id discover.NodeID // valid after the encryption handshake - caps []Cap // valid after the protocol handshake - name string // valid after the protocol handshake + cont chan error // The run loop uses cont to signal errors to SetupConn. + caps []Cap // valid after the protocol handshake + name string // valid after the protocol handshake } type transport interface { // The two handshakes. - doEncHandshake(prv *ecdsa.PrivateKey, dialDest *discover.Node) (discover.NodeID, error) + doEncHandshake(prv *ecdsa.PrivateKey, dialDest *ecdsa.PublicKey) (*ecdsa.PublicKey, error) doProtoHandshake(our *protoHandshake) (*protoHandshake, error) // The MsgReadWriter can only be used after the encryption // handshake has completed. The code uses conn.id to track this @@ -225,8 +228,8 @@ type transport interface { func (c *conn) String() string { s := c.flags.String() - if (c.id != discover.NodeID{}) { - s += " " + c.id.String() + if (c.node.ID() != enode.ID{}) { + s += " " + c.node.ID().String() } s += " " + c.fd.RemoteAddr().String() return s @@ -279,7 +282,7 @@ func (srv *Server) Peers() []*Peer { // Note: We'd love to put this function into a variable but // that seems to cause a weird compiler error in some // environments. - case srv.peerOp <- func(peers map[discover.NodeID]*Peer) { + case srv.peerOp <- func(peers map[enode.ID]*Peer) { for _, p := range peers { ps = append(ps, p) } @@ -294,7 +297,7 @@ func (srv *Server) Peers() []*Peer { func (srv *Server) PeerCount() int { var count int select { - case srv.peerOp <- func(ps map[discover.NodeID]*Peer) { count = len(ps) }: + case srv.peerOp <- func(ps map[enode.ID]*Peer) { count = len(ps) }: <-srv.peerOpDone case <-srv.quit: } @@ -304,7 +307,7 @@ func (srv *Server) PeerCount() int { // AddPeer connects to the given node and maintains the connection until the // server is shut down. If the connection fails for any reason, the server will // attempt to reconnect the peer. -func (srv *Server) AddPeer(node *discover.Node) { +func (srv *Server) AddPeer(node *enode.Node) { select { case srv.addstatic <- node: case <-srv.quit: @@ -312,7 +315,7 @@ func (srv *Server) AddPeer(node *discover.Node) { } // RemovePeer disconnects from the given node -func (srv *Server) RemovePeer(node *discover.Node) { +func (srv *Server) RemovePeer(node *enode.Node) { select { case srv.removestatic <- node: case <-srv.quit: @@ -321,7 +324,7 @@ func (srv *Server) RemovePeer(node *discover.Node) { // AddTrustedPeer adds the given node to a reserved whitelist which allows the // node to always connect, even if the slot are full. -func (srv *Server) AddTrustedPeer(node *discover.Node) { +func (srv *Server) AddTrustedPeer(node *enode.Node) { select { case srv.addtrusted <- node: case <-srv.quit: @@ -329,7 +332,7 @@ func (srv *Server) AddTrustedPeer(node *discover.Node) { } // RemoveTrustedPeer removes the given node from the trusted peer set. -func (srv *Server) RemoveTrustedPeer(node *discover.Node) { +func (srv *Server) RemoveTrustedPeer(node *enode.Node) { select { case srv.removetrusted <- node: case <-srv.quit: @@ -342,36 +345,47 @@ func (srv *Server) SubscribeEvents(ch chan *PeerEvent) event.Subscription { } // Self returns the local node's endpoint information. -func (srv *Server) Self() *discover.Node { +func (srv *Server) Self() *enode.Node { srv.lock.Lock() - defer srv.lock.Unlock() + running, listener, ntab := srv.running, srv.listener, srv.ntab + srv.lock.Unlock() - if !srv.running { - return &discover.Node{IP: net.ParseIP("0.0.0.0")} + if !running { + return enode.NewV4(&srv.PrivateKey.PublicKey, net.ParseIP("0.0.0.0"), 0, 0) } - return srv.makeSelf(srv.listener, srv.ntab) + return srv.makeSelf(listener, ntab) } -func (srv *Server) makeSelf(listener net.Listener, ntab discoverTable) *discover.Node { - // If the server's not running, return an empty node. +func (srv *Server) makeSelf(listener net.Listener, ntab discoverTable) *enode.Node { // If the node is running but discovery is off, manually assemble the node infos. if ntab == nil { - // Inbound connections disabled, use zero address. - if listener == nil { - return &discover.Node{IP: net.ParseIP("0.0.0.0"), ID: discover.PubkeyID(&srv.PrivateKey.PublicKey)} - } - // Otherwise inject the listener address too - addr := listener.Addr().(*net.TCPAddr) - return &discover.Node{ - ID: discover.PubkeyID(&srv.PrivateKey.PublicKey), - IP: addr.IP, - TCP: uint16(addr.Port), - } + addr := srv.tcpAddr(listener) + return enode.NewV4(&srv.PrivateKey.PublicKey, addr.IP, addr.Port, 0) } // Otherwise return the discovery node. return ntab.Self() } +func (srv *Server) tcpAddr(listener net.Listener) net.TCPAddr { + addr := net.TCPAddr{IP: net.IP{0, 0, 0, 0}} + if listener == nil { + return addr // Inbound connections disabled, use zero address. + } + // Otherwise inject the listener address too. + if a, ok := listener.Addr().(*net.TCPAddr); ok { + addr = *a + } + if srv.NAT != nil { + if ip, err := srv.NAT.ExternalIP(); err == nil { + addr.IP = ip + } + } + if addr.IP.IsUnspecified() { + addr.IP = net.IP{127, 0, 0, 1} + } + return addr +} + // Stop terminates the server and all active peer connections. // It blocks until all active connections have been closed. func (srv *Server) Stop() { @@ -445,10 +459,10 @@ func (srv *Server) Start() (err error) { srv.addpeer = make(chan *conn) srv.delpeer = make(chan peerDrop) srv.posthandshake = make(chan *conn) - srv.addstatic = make(chan *discover.Node) - srv.removestatic = make(chan *discover.Node) - srv.addtrusted = make(chan *discover.Node) - srv.removetrusted = make(chan *discover.Node) + srv.addstatic = make(chan *enode.Node) + srv.removestatic = make(chan *enode.Node) + srv.addtrusted = make(chan *enode.Node) + srv.removetrusted = make(chan *enode.Node) srv.peerOp = make(chan peerOpFunc) srv.peerOpDone = make(chan struct{}) @@ -525,7 +539,8 @@ func (srv *Server) Start() (err error) { dialer := newDialState(srv.StaticNodes, srv.BootstrapNodes, srv.ntab, dynPeers, srv.NetRestrict) // handshake - srv.ourHandshake = &protoHandshake{Version: baseProtocolVersion, Name: srv.Name, ID: discover.PubkeyID(&srv.PrivateKey.PublicKey)} + pubkey := crypto.FromECDSAPub(&srv.PrivateKey.PublicKey) + srv.ourHandshake = &protoHandshake{Version: baseProtocolVersion, Name: srv.Name, ID: pubkey[1:]} for _, p := range srv.Protocols { srv.ourHandshake.Caps = append(srv.ourHandshake.Caps, p.cap()) } @@ -541,7 +556,6 @@ func (srv *Server) Start() (err error) { srv.loopWG.Add(1) go srv.run(dialer) - srv.running = true return nil } @@ -568,18 +582,18 @@ func (srv *Server) startListening() error { } type dialer interface { - newTasks(running int, peers map[discover.NodeID]*Peer, now time.Time) []task + newTasks(running int, peers map[enode.ID]*Peer, now time.Time) []task taskDone(task, time.Time) - addStatic(*discover.Node) - removeStatic(*discover.Node) + addStatic(*enode.Node) + removeStatic(*enode.Node) } func (srv *Server) run(dialstate dialer) { defer srv.loopWG.Done() var ( - peers = make(map[discover.NodeID]*Peer) + peers = make(map[enode.ID]*Peer) inboundCount = 0 - trusted = make(map[discover.NodeID]bool, len(srv.TrustedNodes)) + trusted = make(map[enode.ID]bool, len(srv.TrustedNodes)) taskdone = make(chan task, maxActiveDialTasks) runningTasks []task queuedTasks []task // tasks that can't run yet @@ -587,7 +601,7 @@ func (srv *Server) run(dialstate dialer) { // Put trusted nodes into a map to speed up checks. // Trusted peers are loaded on startup or added via AddTrustedPeer RPC. for _, n := range srv.TrustedNodes { - trusted[n.ID] = true + trusted[n.ID()] = true } // removes t from runningTasks @@ -640,27 +654,27 @@ running: // stop keeping the node connected. srv.log.Trace("Removing static node", "node", n) dialstate.removeStatic(n) - if p, ok := peers[n.ID]; ok { + if p, ok := peers[n.ID()]; ok { p.Disconnect(DiscRequested) } case n := <-srv.addtrusted: // This channel is used by AddTrustedPeer to add an enode // to the trusted node set. srv.log.Trace("Adding trusted node", "node", n) - trusted[n.ID] = true + trusted[n.ID()] = true // Mark any already-connected peer as trusted - if p, ok := peers[n.ID]; ok { + if p, ok := peers[n.ID()]; ok { p.rw.set(trustedConn, true) } case n := <-srv.removetrusted: // This channel is used by RemoveTrustedPeer to remove an enode // from the trusted node set. srv.log.Trace("Removing trusted node", "node", n) - if _, ok := trusted[n.ID]; ok { - delete(trusted, n.ID) + if _, ok := trusted[n.ID()]; ok { + delete(trusted, n.ID()) } // Unmark any already-connected peer as trusted - if p, ok := peers[n.ID]; ok { + if p, ok := peers[n.ID()]; ok { p.rw.set(trustedConn, false) } case op := <-srv.peerOp: @@ -677,7 +691,7 @@ running: case c := <-srv.posthandshake: // A connection has passed the encryption handshake so // the remote identity is known (but hasn't been verified yet). - if trusted[c.id] { + if trusted[c.node.ID()] { // Ensure that the trusted flag is set before checking against MaxPeers. c.flags |= trustedConn } @@ -702,7 +716,7 @@ running: name := truncateName(c.name) srv.log.Debug("Adding p2p peer", "name", name, "addr", c.fd.RemoteAddr(), "peers", len(peers)+1) go srv.runPeer(p) - peers[c.id] = p + peers[c.node.ID()] = p if p.Inbound() { inboundCount++ } @@ -749,7 +763,7 @@ running: } } -func (srv *Server) protoHandshakeChecks(peers map[discover.NodeID]*Peer, inboundCount int, c *conn) error { +func (srv *Server) protoHandshakeChecks(peers map[enode.ID]*Peer, inboundCount int, c *conn) error { // Drop connections with no matching protocols. if len(srv.Protocols) > 0 && countMatchingProtocols(srv.Protocols, c.caps) == 0 { return DiscUselessPeer @@ -759,15 +773,15 @@ func (srv *Server) protoHandshakeChecks(peers map[discover.NodeID]*Peer, inbound return srv.encHandshakeChecks(peers, inboundCount, c) } -func (srv *Server) encHandshakeChecks(peers map[discover.NodeID]*Peer, inboundCount int, c *conn) error { +func (srv *Server) encHandshakeChecks(peers map[enode.ID]*Peer, inboundCount int, c *conn) error { switch { case !c.is(trustedConn|staticDialedConn) && len(peers) >= srv.MaxPeers: return DiscTooManyPeers case !c.is(trustedConn) && c.is(inboundConn) && inboundCount >= srv.maxInboundConns(): return DiscTooManyPeers - case peers[c.id] != nil: + case peers[c.node.ID()] != nil: return DiscAlreadyConnected - case c.id == srv.Self().ID: + case c.node.ID() == srv.Self().ID(): return DiscSelf default: return nil @@ -777,7 +791,6 @@ func (srv *Server) encHandshakeChecks(peers map[discover.NodeID]*Peer, inboundCo func (srv *Server) maxInboundConns() int { return srv.MaxPeers - srv.maxDialedConns() } - func (srv *Server) maxDialedConns() int { if srv.NoDiscovery || srv.NoDial { return 0 @@ -797,7 +810,7 @@ type tempError interface { // inbound connections. func (srv *Server) listenLoop() { defer srv.loopWG.Done() - srv.log.Info("RLPx listener up", "self", srv.makeSelf(srv.listener, srv.ntab)) + srv.log.Info("RLPx listener up", "self", srv.Self()) tokens := defaultMaxPendingPeers if srv.MaxPendingPeers > 0 { @@ -850,7 +863,7 @@ func (srv *Server) listenLoop() { // SetupConn runs the handshakes and attempts to add the connection // as a peer. It returns when the connection has been added as a peer // or the handshakes have failed. -func (srv *Server) SetupConn(fd net.Conn, flags connFlag, dialDest *discover.Node) error { +func (srv *Server) SetupConn(fd net.Conn, flags connFlag, dialDest *enode.Node) error { self := srv.Self() if self == nil { return errors.New("shutdown") @@ -859,12 +872,12 @@ func (srv *Server) SetupConn(fd net.Conn, flags connFlag, dialDest *discover.Nod err := srv.setupConn(c, flags, dialDest) if err != nil { c.close(err) - srv.log.Trace("Setting up connection failed", "id", c.id, "err", err) + srv.log.Trace("Setting up connection failed", "addr", fd.RemoteAddr(), "err", err) } return err } -func (srv *Server) setupConn(c *conn, flags connFlag, dialDest *discover.Node) error { +func (srv *Server) setupConn(c *conn, flags connFlag, dialDest *enode.Node) error { // Prevent leftover pending conns from entering the handshake. srv.lock.Lock() running := srv.running @@ -872,18 +885,30 @@ func (srv *Server) setupConn(c *conn, flags connFlag, dialDest *discover.Node) e if !running { return errServerStopped } + // If dialing, figure out the remote public key. + var dialPubkey *ecdsa.PublicKey + if dialDest != nil { + dialPubkey = new(ecdsa.PublicKey) + if err := dialDest.Load((*enode.Secp256k1)(dialPubkey)); err != nil { + return fmt.Errorf("dial destination doesn't have a secp256k1 public key") + } + } // Run the encryption handshake. - var err error - if c.id, err = c.doEncHandshake(srv.PrivateKey, dialDest); err != nil { + remotePubkey, err := c.doEncHandshake(srv.PrivateKey, dialPubkey) + if err != nil { srv.log.Trace("Failed RLPx handshake", "addr", c.fd.RemoteAddr(), "conn", c.flags, "err", err) return err } - clog := srv.log.New("id", c.id, "addr", c.fd.RemoteAddr(), "conn", c.flags) - // For dialed connections, check that the remote public key matches. - if dialDest != nil && c.id != dialDest.ID { - clog.Trace("Dialed identity mismatch", "want", c, dialDest.ID) - return DiscUnexpectedIdentity + if dialDest != nil { + // For dialed connections, check that the remote public key matches. + if dialPubkey.X.Cmp(remotePubkey.X) != 0 || dialPubkey.Y.Cmp(remotePubkey.Y) != 0 { + return DiscUnexpectedIdentity + } + c.node = dialDest + } else { + c.node = nodeFromConn(remotePubkey, c.fd) } + clog := srv.log.New("id", c.node.ID(), "addr", c.fd.RemoteAddr(), "conn", c.flags) err = srv.checkpoint(c, srv.posthandshake) if err != nil { clog.Trace("Rejected peer before protocol handshake", "err", err) @@ -895,8 +920,8 @@ func (srv *Server) setupConn(c *conn, flags connFlag, dialDest *discover.Node) e clog.Trace("Failed proto handshake", "err", err) return err } - if phs.ID != c.id { - clog.Trace("Wrong devp2p handshake identity", "err", phs.ID) + if id := c.node.ID(); !bytes.Equal(crypto.Keccak256(phs.ID), id[:]) { + clog.Trace("Wrong devp2p handshake identity", "phsid", fmt.Sprintf("%x", phs.ID)) return DiscUnexpectedIdentity } c.caps, c.name = phs.Caps, phs.Name @@ -911,6 +936,16 @@ func (srv *Server) setupConn(c *conn, flags connFlag, dialDest *discover.Node) e return nil } +func nodeFromConn(pubkey *ecdsa.PublicKey, conn net.Conn) *enode.Node { + var ip net.IP + var port int + if tcp, ok := conn.RemoteAddr().(*net.TCPAddr); ok { + ip = tcp.IP + port = tcp.Port + } + return enode.NewV4(pubkey, ip, port, port) +} + func truncateName(s string) string { if len(s) > 20 { return s[:20] + "..." @@ -985,13 +1020,13 @@ func (srv *Server) NodeInfo() *NodeInfo { info := &NodeInfo{ Name: srv.Name, Enode: node.String(), - ID: node.ID.String(), - IP: node.IP.String(), + ID: node.ID().String(), + IP: node.IP().String(), ListenAddr: srv.ListenAddr, Protocols: make(map[string]interface{}), } - info.Ports.Discovery = int(node.UDP) - info.Ports.Listener = int(node.TCP) + info.Ports.Discovery = node.UDP() + info.Ports.Listener = node.TCP() // Gather all the running protocol infos (only once per protocol type) for _, proto := range srv.Protocols { |