aboutsummaryrefslogtreecommitdiffstats
path: root/p2p/server.go
diff options
context:
space:
mode:
Diffstat (limited to 'p2p/server.go')
-rw-r--r--p2p/server.go666
1 files changed, 361 insertions, 305 deletions
diff --git a/p2p/server.go b/p2p/server.go
index 529fedbca..27e617610 100644
--- a/p2p/server.go
+++ b/p2p/server.go
@@ -2,7 +2,6 @@ package p2p
import (
"crypto/ecdsa"
- "crypto/rand"
"errors"
"fmt"
"net"
@@ -24,11 +23,8 @@ const (
maxAcceptConns = 50
// Maximum number of concurrently dialing outbound connections.
- maxDialingConns = 10
+ maxActiveDialTasks = 16
- // total timeout for encryption handshake and protocol
- // handshake in both directions.
- handshakeTimeout = 5 * time.Second
// maximum time allowed for reading a complete message.
// this is effectively the amount of time a connection can be idle.
frameReadTimeout = 1 * time.Minute
@@ -36,6 +32,8 @@ const (
frameWriteTimeout = 5 * time.Second
)
+var errServerStopped = errors.New("server stopped")
+
var srvjslog = logger.NewJsonLogger()
// Server manages all peer connections.
@@ -103,68 +101,173 @@ type Server struct {
// Hooks for testing. These are useful because we can inhibit
// the whole protocol stack.
- setupFunc
- newPeerHook
+ newTransport func(net.Conn) transport
+ newPeerHook func(*Peer)
+
+ lock sync.Mutex // protects running
+ running bool
+ ntab discoverTable
+ listener net.Listener
ourHandshake *protoHandshake
- lock sync.RWMutex // protects running, peers and the trust fields
- running bool
- peers map[discover.NodeID]*Peer
- staticNodes map[discover.NodeID]*discover.Node // Map of currently maintained static remote nodes
- staticDial chan *discover.Node // Dial request channel reserved for the static nodes
- staticCycle time.Duration // Overrides staticPeerCheckInterval, used for testing
- trustedNodes map[discover.NodeID]bool // Set of currently trusted remote nodes
+ // These are for Peers, PeerCount (and nothing else).
+ peerOp chan peerOpFunc
+ peerOpDone chan struct{}
+
+ quit chan struct{}
+ addstatic chan *discover.Node
+ posthandshake chan *conn
+ addpeer chan *conn
+ delpeer chan *Peer
+ loopWG sync.WaitGroup // loop, listenLoop
+}
+
+type peerOpFunc func(map[discover.NodeID]*Peer)
+
+type connFlag int
- ntab *discover.Table
- listener net.Listener
+const (
+ dynDialedConn connFlag = 1 << iota
+ staticDialedConn
+ inboundConn
+ trustedConn
+)
+
+// conn wraps a network connection with information gathered
+// during the two handshakes.
+type conn struct {
+ fd net.Conn
+ transport
+ flags connFlag
+ cont chan error // The run loop uses cont to signal errors to setupConn.
+ id discover.NodeID // valid after the encryption handshake
+ caps []Cap // valid after the protocol handshake
+ name string // valid after the protocol handshake
+}
- quit chan struct{}
- loopWG sync.WaitGroup // {dial,listen,nat}Loop
- peerWG sync.WaitGroup // active peer goroutines
+type transport interface {
+ // The two handshakes.
+ doEncHandshake(prv *ecdsa.PrivateKey, dialDest *discover.Node) (discover.NodeID, error)
+ doProtoHandshake(our *protoHandshake) (*protoHandshake, error)
+ // The MsgReadWriter can only be used after the encryption
+ // handshake has completed. The code uses conn.id to track this
+ // by setting it to a non-nil value after the encryption handshake.
+ MsgReadWriter
+ // transports must provide Close because we use MsgPipe in some of
+ // the tests. Closing the actual network connection doesn't do
+ // anything in those tests because NsgPipe doesn't use it.
+ close(err error)
}
-type setupFunc func(net.Conn, *ecdsa.PrivateKey, *protoHandshake, *discover.Node, func(discover.NodeID) bool) (*conn, error)
-type newPeerHook func(*Peer)
+func (c *conn) String() string {
+ s := c.flags.String() + " conn"
+ if (c.id != discover.NodeID{}) {
+ s += fmt.Sprintf(" %x", c.id[:8])
+ }
+ s += " " + c.fd.RemoteAddr().String()
+ return s
+}
+
+func (f connFlag) String() string {
+ s := ""
+ if f&trustedConn != 0 {
+ s += " trusted"
+ }
+ if f&dynDialedConn != 0 {
+ s += " dyn dial"
+ }
+ if f&staticDialedConn != 0 {
+ s += " static dial"
+ }
+ if f&inboundConn != 0 {
+ s += " inbound"
+ }
+ if s != "" {
+ s = s[1:]
+ }
+ return s
+}
+
+func (c *conn) is(f connFlag) bool {
+ return c.flags&f != 0
+}
// Peers returns all connected peers.
-func (srv *Server) Peers() (peers []*Peer) {
- srv.lock.RLock()
- defer srv.lock.RUnlock()
- for _, peer := range srv.peers {
- if peer != nil {
- peers = append(peers, peer)
+func (srv *Server) Peers() []*Peer {
+ var ps []*Peer
+ select {
+ // Note: We'd love to put this function into a variable but
+ // that seems to cause a weird compiler error in some
+ // environments.
+ case srv.peerOp <- func(peers map[discover.NodeID]*Peer) {
+ for _, p := range peers {
+ ps = append(ps, p)
}
+ }:
+ <-srv.peerOpDone
+ case <-srv.quit:
}
- return
+ return ps
}
// PeerCount returns the number of connected peers.
func (srv *Server) PeerCount() int {
- srv.lock.RLock()
- n := len(srv.peers)
- srv.lock.RUnlock()
- return n
+ var count int
+ select {
+ case srv.peerOp <- func(ps map[discover.NodeID]*Peer) { count = len(ps) }:
+ <-srv.peerOpDone
+ case <-srv.quit:
+ }
+ return count
}
// AddPeer connects to the given node and maintains the connection until the
// server is shut down. If the connection fails for any reason, the server will
// attempt to reconnect the peer.
func (srv *Server) AddPeer(node *discover.Node) {
+ select {
+ case srv.addstatic <- node:
+ case <-srv.quit:
+ }
+}
+
+// Self returns the local node's endpoint information.
+func (srv *Server) Self() *discover.Node {
srv.lock.Lock()
defer srv.lock.Unlock()
+ if !srv.running {
+ return &discover.Node{IP: net.ParseIP("0.0.0.0")}
+ }
+ return srv.ntab.Self()
+}
- srv.staticNodes[node.ID] = node
+// Stop terminates the server and all active peer connections.
+// It blocks until all active connections have been closed.
+func (srv *Server) Stop() {
+ srv.lock.Lock()
+ defer srv.lock.Unlock()
+ if !srv.running {
+ return
+ }
+ srv.running = false
+ if srv.listener != nil {
+ // this unblocks listener Accept
+ srv.listener.Close()
+ }
+ close(srv.quit)
+ srv.loopWG.Wait()
}
// Start starts running the server.
-// Servers can be re-used and started again after stopping.
+// Servers can not be re-used after stopping.
func (srv *Server) Start() (err error) {
srv.lock.Lock()
defer srv.lock.Unlock()
if srv.running {
return errors.New("server already running")
}
+ srv.running = true
glog.V(logger.Info).Infoln("Starting Server")
// static fields
@@ -174,23 +277,19 @@ func (srv *Server) Start() (err error) {
if srv.MaxPeers <= 0 {
return fmt.Errorf("Server.MaxPeers must be > 0")
}
- srv.quit = make(chan struct{})
- srv.peers = make(map[discover.NodeID]*Peer)
-
- // Create the current trust maps, and the associated dialing channel
- srv.trustedNodes = make(map[discover.NodeID]bool)
- for _, node := range srv.TrustedNodes {
- srv.trustedNodes[node.ID] = true
- }
- srv.staticNodes = make(map[discover.NodeID]*discover.Node)
- for _, node := range srv.StaticNodes {
- srv.staticNodes[node.ID] = node
+ if srv.newTransport == nil {
+ srv.newTransport = newRLPX
}
- srv.staticDial = make(chan *discover.Node)
-
- if srv.setupFunc == nil {
- srv.setupFunc = setupConn
+ if srv.Dialer == nil {
+ srv.Dialer = &net.Dialer{Timeout: defaultDialTimeout}
}
+ srv.quit = make(chan struct{})
+ srv.addpeer = make(chan *conn)
+ srv.delpeer = make(chan *Peer)
+ srv.posthandshake = make(chan *conn)
+ srv.addstatic = make(chan *discover.Node)
+ srv.peerOp = make(chan peerOpFunc)
+ srv.peerOpDone = make(chan struct{})
// node table
ntab, err := discover.ListenUDP(srv.PrivateKey, srv.ListenAddr, srv.NAT, srv.NodeDatabase)
@@ -198,37 +297,31 @@ func (srv *Server) Start() (err error) {
return err
}
srv.ntab = ntab
+ dialer := newDialState(srv.StaticNodes, srv.ntab, srv.MaxPeers/2)
// handshake
srv.ourHandshake = &protoHandshake{Version: baseProtocolVersion, Name: srv.Name, ID: ntab.Self().ID}
for _, p := range srv.Protocols {
srv.ourHandshake.Caps = append(srv.ourHandshake.Caps, p.cap())
}
-
// listen/dial
if srv.ListenAddr != "" {
if err := srv.startListening(); err != nil {
return err
}
}
- if srv.Dialer == nil {
- srv.Dialer = &net.Dialer{Timeout: defaultDialTimeout}
- }
- if !srv.NoDial {
- srv.loopWG.Add(1)
- go srv.dialLoop()
- }
if srv.NoDial && srv.ListenAddr == "" {
glog.V(logger.Warn).Infoln("I will be kind-of useless, neither dialing nor listening.")
}
- // maintain the static peers
- go srv.staticNodesLoop()
+ srv.loopWG.Add(1)
+ go srv.run(dialer)
srv.running = true
return nil
}
func (srv *Server) startListening() error {
+ // Launch the TCP listener.
listener, err := net.Listen("tcp", srv.ListenAddr)
if err != nil {
return err
@@ -238,6 +331,7 @@ func (srv *Server) startListening() error {
srv.listener = listener
srv.loopWG.Add(1)
go srv.listenLoop()
+ // Map the TCP listening port if NAT is configured.
if !laddr.IP.IsLoopback() && srv.NAT != nil {
srv.loopWG.Add(1)
go func() {
@@ -248,50 +342,164 @@ func (srv *Server) startListening() error {
return nil
}
-// Stop terminates the server and all active peer connections.
-// It blocks until all active connections have been closed.
-func (srv *Server) Stop() {
- srv.lock.Lock()
- if !srv.running {
- srv.lock.Unlock()
- return
+type dialer interface {
+ newTasks(running int, peers map[discover.NodeID]*Peer, now time.Time) []task
+ taskDone(task, time.Time)
+ addStatic(*discover.Node)
+}
+
+func (srv *Server) run(dialstate dialer) {
+ defer srv.loopWG.Done()
+ var (
+ peers = make(map[discover.NodeID]*Peer)
+ trusted = make(map[discover.NodeID]bool, len(srv.TrustedNodes))
+
+ tasks []task
+ pendingTasks []task
+ taskdone = make(chan task, maxActiveDialTasks)
+ )
+ // Put trusted nodes into a map to speed up checks.
+ // Trusted peers are loaded on startup and cannot be
+ // modified while the server is running.
+ for _, n := range srv.TrustedNodes {
+ trusted[n.ID] = true
+ }
+
+ // Some task list helpers.
+ delTask := func(t task) {
+ for i := range tasks {
+ if tasks[i] == t {
+ tasks = append(tasks[:i], tasks[i+1:]...)
+ break
+ }
+ }
}
- srv.running = false
- srv.lock.Unlock()
+ scheduleTasks := func(new []task) {
+ pt := append(pendingTasks, new...)
+ start := maxActiveDialTasks - len(tasks)
+ if len(pt) < start {
+ start = len(pt)
+ }
+ if start > 0 {
+ tasks = append(tasks, pt[:start]...)
+ for _, t := range pt[:start] {
+ t := t
+ glog.V(logger.Detail).Infoln("new task:", t)
+ go func() { t.Do(srv); taskdone <- t }()
+ }
+ copy(pt, pt[start:])
+ pendingTasks = pt[:len(pt)-start]
+ }
+ }
+
+running:
+ for {
+ // Query the dialer for new tasks and launch them.
+ now := time.Now()
+ nt := dialstate.newTasks(len(pendingTasks)+len(tasks), peers, now)
+ scheduleTasks(nt)
- glog.V(logger.Info).Infoln("Stopping Server")
+ select {
+ case <-srv.quit:
+ // The server was stopped. Run the cleanup logic.
+ glog.V(logger.Detail).Infoln("<-quit: spinning down")
+ break running
+ case n := <-srv.addstatic:
+ // This channel is used by AddPeer to add to the
+ // ephemeral static peer list. Add it to the dialer,
+ // it will keep the node connected.
+ glog.V(logger.Detail).Infoln("<-addstatic:", n)
+ dialstate.addStatic(n)
+ case op := <-srv.peerOp:
+ // This channel is used by Peers and PeerCount.
+ op(peers)
+ srv.peerOpDone <- struct{}{}
+ case t := <-taskdone:
+ // A task got done. Tell dialstate about it so it
+ // can update its state and remove it from the active
+ // tasks list.
+ glog.V(logger.Detail).Infoln("<-taskdone:", t)
+ dialstate.taskDone(t, now)
+ delTask(t)
+ case c := <-srv.posthandshake:
+ // A connection has passed the encryption handshake so
+ // the remote identity is known (but hasn't been verified yet).
+ if trusted[c.id] {
+ // Ensure that the trusted flag is set before checking against MaxPeers.
+ c.flags |= trustedConn
+ }
+ glog.V(logger.Detail).Infoln("<-posthandshake:", c)
+ // TODO: track in-progress inbound node IDs (pre-Peer) to avoid dialing them.
+ c.cont <- srv.encHandshakeChecks(peers, c)
+ case c := <-srv.addpeer:
+ // At this point the connection is past the protocol handshake.
+ // Its capabilities are known and the remote identity is verified.
+ glog.V(logger.Detail).Infoln("<-addpeer:", c)
+ err := srv.protoHandshakeChecks(peers, c)
+ if err != nil {
+ glog.V(logger.Detail).Infof("Not adding %v as peer: %v", c, err)
+ } else {
+ // The handshakes are done and it passed all checks.
+ p := newPeer(c, srv.Protocols)
+ peers[c.id] = p
+ go srv.runPeer(p)
+ }
+ // The dialer logic relies on the assumption that
+ // dial tasks complete after the peer has been added or
+ // discarded. Unblock the task last.
+ c.cont <- err
+ case p := <-srv.delpeer:
+ // A peer disconnected.
+ glog.V(logger.Detail).Infoln("<-delpeer:", p)
+ delete(peers, p.ID())
+ }
+ }
+
+ // Terminate discovery. If there is a running lookup it will terminate soon.
srv.ntab.Close()
- if srv.listener != nil {
- // this unblocks listener Accept
- srv.listener.Close()
+ // Disconnect all peers.
+ for _, p := range peers {
+ p.Disconnect(DiscQuitting)
+ }
+ // Wait for peers to shut down. Pending connections and tasks are
+ // not handled here and will terminate soon-ish because srv.quit
+ // is closed.
+ glog.V(logger.Detail).Infof("ignoring %d pending tasks at spindown", len(tasks))
+ for len(peers) > 0 {
+ p := <-srv.delpeer
+ glog.V(logger.Detail).Infoln("<-delpeer (spindown):", p)
+ delete(peers, p.ID())
}
- close(srv.quit)
- srv.loopWG.Wait()
+}
- // No new peers can be added at this point because dialLoop and
- // listenLoop are down. It is safe to call peerWG.Wait because
- // peerWG.Add is not called outside of those loops.
- srv.lock.Lock()
- for _, peer := range srv.peers {
- peer.Disconnect(DiscQuitting)
+func (srv *Server) protoHandshakeChecks(peers map[discover.NodeID]*Peer, c *conn) error {
+ // Drop connections with no matching protocols.
+ if len(srv.Protocols) > 0 && countMatchingProtocols(srv.Protocols, c.caps) == 0 {
+ return DiscUselessPeer
}
- srv.lock.Unlock()
- srv.peerWG.Wait()
+ // Repeat the encryption handshake checks because the
+ // peer set might have changed between the handshakes.
+ return srv.encHandshakeChecks(peers, c)
}
-// Self returns the local node's endpoint information.
-func (srv *Server) Self() *discover.Node {
- srv.lock.RLock()
- defer srv.lock.RUnlock()
- if !srv.running {
- return &discover.Node{IP: net.ParseIP("0.0.0.0")}
+func (srv *Server) encHandshakeChecks(peers map[discover.NodeID]*Peer, c *conn) error {
+ switch {
+ case !c.is(trustedConn|staticDialedConn) && len(peers) >= srv.MaxPeers:
+ return DiscTooManyPeers
+ case peers[c.id] != nil:
+ return DiscAlreadyConnected
+ case c.id == srv.ntab.Self().ID:
+ return DiscSelf
+ default:
+ return nil
}
- return srv.ntab.Self()
}
-// main loop for adding connections via listening
+// listenLoop runs in its own goroutine and accepts
+// inbound connections.
func (srv *Server) listenLoop() {
defer srv.loopWG.Done()
+ glog.V(logger.Info).Infoln("Listening on", srv.listener.Addr())
// This channel acts as a semaphore limiting
// active inbound connections that are lingering pre-handshake.
@@ -305,204 +513,92 @@ func (srv *Server) listenLoop() {
slots <- struct{}{}
}
- glog.V(logger.Info).Infoln("Listening on", srv.listener.Addr())
for {
<-slots
- conn, err := srv.listener.Accept()
+ fd, err := srv.listener.Accept()
if err != nil {
return
}
- glog.V(logger.Debug).Infof("Accepted conn %v\n", conn.RemoteAddr())
- srv.peerWG.Add(1)
+ glog.V(logger.Debug).Infof("Accepted conn %v\n", fd.RemoteAddr())
go func() {
- srv.startPeer(conn, nil)
+ srv.setupConn(fd, inboundConn, nil)
slots <- struct{}{}
}()
}
}
-// staticNodesLoop is responsible for periodically checking that static
-// connections are actually live, and requests dialing if not.
-func (srv *Server) staticNodesLoop() {
- // Create a default maintenance ticker, but override it requested
- cycle := staticPeerCheckInterval
- if srv.staticCycle != 0 {
- cycle = srv.staticCycle
- }
- tick := time.NewTicker(cycle)
-
- for {
- select {
- case <-srv.quit:
- return
-
- case <-tick.C:
- // Collect all the non-connected static nodes
- needed := []*discover.Node{}
- srv.lock.RLock()
- for id, node := range srv.staticNodes {
- if _, ok := srv.peers[id]; !ok {
- needed = append(needed, node)
- }
- }
- srv.lock.RUnlock()
-
- // Try to dial each of them (don't hang if server terminates)
- for _, node := range needed {
- glog.V(logger.Debug).Infof("Dialing static peer %v", node)
- select {
- case srv.staticDial <- node:
- case <-srv.quit:
- return
- }
- }
- }
- }
-}
-
-func (srv *Server) dialLoop() {
- var (
- dialed = make(chan *discover.Node)
- dialing = make(map[discover.NodeID]bool)
- findresults = make(chan []*discover.Node)
- refresh = time.NewTimer(0)
- )
- defer srv.loopWG.Done()
- defer refresh.Stop()
-
- // Limit the number of concurrent dials
- tokens := maxDialingConns
- if srv.MaxPendingPeers > 0 {
- tokens = srv.MaxPendingPeers
- }
- slots := make(chan struct{}, tokens)
- for i := 0; i < tokens; i++ {
- slots <- struct{}{}
+// setupConn runs the handshakes and attempts to add the connection
+// as a peer. It returns when the connection has been added as a peer
+// or the handshakes have failed.
+func (srv *Server) setupConn(fd net.Conn, flags connFlag, dialDest *discover.Node) {
+ // Prevent leftover pending conns from entering the handshake.
+ srv.lock.Lock()
+ running := srv.running
+ srv.lock.Unlock()
+ c := &conn{fd: fd, transport: srv.newTransport(fd), flags: flags, cont: make(chan error)}
+ if !running {
+ c.close(errServerStopped)
+ return
}
- dial := func(dest *discover.Node) {
- // Don't dial nodes that would fail the checks in addPeer.
- // This is important because the connection handshake is a lot
- // of work and we'd rather avoid doing that work for peers
- // that can't be added.
- srv.lock.RLock()
- ok, _ := srv.checkPeer(dest.ID)
- srv.lock.RUnlock()
- if !ok || dialing[dest.ID] {
- return
- }
- // Request a dial slot to prevent CPU exhaustion
- <-slots
-
- dialing[dest.ID] = true
- srv.peerWG.Add(1)
- go func() {
- srv.dialNode(dest)
- slots <- struct{}{}
- dialed <- dest
- }()
+ // Run the encryption handshake.
+ var err error
+ if c.id, err = c.doEncHandshake(srv.PrivateKey, dialDest); err != nil {
+ glog.V(logger.Debug).Infof("%v faild enc handshake: %v", c, err)
+ c.close(err)
+ return
}
-
- srv.ntab.Bootstrap(srv.BootstrapNodes)
- for {
- select {
- case <-refresh.C:
- // Grab some nodes to connect to if we're not at capacity.
- srv.lock.RLock()
- needpeers := len(srv.peers) < srv.MaxPeers/2
- srv.lock.RUnlock()
- if needpeers {
- go func() {
- var target discover.NodeID
- rand.Read(target[:])
- findresults <- srv.ntab.Lookup(target)
- }()
- } else {
- // Make sure we check again if the peer count falls
- // below MaxPeers.
- refresh.Reset(refreshPeersInterval)
- }
- case dest := <-srv.staticDial:
- dial(dest)
- case dests := <-findresults:
- for _, dest := range dests {
- dial(dest)
- }
- refresh.Reset(refreshPeersInterval)
- case dest := <-dialed:
- delete(dialing, dest.ID)
- if len(dialing) == 0 {
- // Check again immediately after dialing all current candidates.
- refresh.Reset(0)
- }
- case <-srv.quit:
- // TODO: maybe wait for active dials
- return
- }
+ // For dialed connections, check that the remote public key matches.
+ if dialDest != nil && c.id != dialDest.ID {
+ c.close(DiscUnexpectedIdentity)
+ glog.V(logger.Debug).Infof("%v dialed identity mismatch, want %x", c, dialDest.ID[:8])
+ return
}
-}
-
-func (srv *Server) dialNode(dest *discover.Node) {
- addr := &net.TCPAddr{IP: dest.IP, Port: int(dest.TCP)}
- glog.V(logger.Debug).Infof("Dialing %v\n", dest)
- conn, err := srv.Dialer.Dial("tcp", addr.String())
- if err != nil {
- // dialLoop adds to the wait group counter when launching
- // dialNode, so we need to count it down again. startPeer also
- // does that when an error occurs.
- srv.peerWG.Done()
- glog.V(logger.Detail).Infof("dial error: %v", err)
+ if err := srv.checkpoint(c, srv.posthandshake); err != nil {
+ glog.V(logger.Debug).Infof("%v failed checkpoint posthandshake: %v", c, err)
+ c.close(err)
return
}
- srv.startPeer(conn, dest)
-}
-
-func (srv *Server) startPeer(fd net.Conn, dest *discover.Node) {
- // TODO: handle/store session token
-
- // Run setupFunc, which should create an authenticated connection
- // and run the capability exchange. Note that any early error
- // returns during that exchange need to call peerWG.Done because
- // the callers of startPeer added the peer to the wait group already.
- fd.SetDeadline(time.Now().Add(handshakeTimeout))
-
- conn, err := srv.setupFunc(fd, srv.PrivateKey, srv.ourHandshake, dest, srv.keepconn)
+ // Run the protocol handshake
+ phs, err := c.doProtoHandshake(srv.ourHandshake)
if err != nil {
- fd.Close()
- glog.V(logger.Debug).Infof("Handshake with %v failed: %v", fd.RemoteAddr(), err)
- srv.peerWG.Done()
+ glog.V(logger.Debug).Infof("%v failed proto handshake: %v", c, err)
+ c.close(err)
return
}
- conn.MsgReadWriter = &netWrapper{
- wrapped: conn.MsgReadWriter,
- conn: fd, rtimeout: frameReadTimeout, wtimeout: frameWriteTimeout,
+ if phs.ID != c.id {
+ glog.V(logger.Debug).Infof("%v wrong proto handshake identity: %x", c, phs.ID[:8])
+ c.close(DiscUnexpectedIdentity)
+ return
}
- p := newPeer(fd, conn, srv.Protocols)
- if ok, reason := srv.addPeer(conn, p); !ok {
- glog.V(logger.Detail).Infof("Not adding %v (%v)\n", p, reason)
- p.politeDisconnect(reason)
- srv.peerWG.Done()
+ c.caps, c.name = phs.Caps, phs.Name
+ if err := srv.checkpoint(c, srv.addpeer); err != nil {
+ glog.V(logger.Debug).Infof("%v failed checkpoint addpeer: %v", c, err)
+ c.close(err)
return
}
- // The handshakes are done and it passed all checks.
- // Spawn the Peer loops.
- go srv.runPeer(p)
+ // If the checks completed successfully, runPeer has now been
+ // launched by run.
}
-// preflight checks whether a connection should be kept. it runs
-// after the encryption handshake, as soon as the remote identity is
-// known.
-func (srv *Server) keepconn(id discover.NodeID) bool {
- srv.lock.RLock()
- defer srv.lock.RUnlock()
- if _, ok := srv.staticNodes[id]; ok {
- return true // static nodes are always allowed
+// checkpoint sends the conn to run, which performs the
+// post-handshake checks for the stage (posthandshake, addpeer).
+func (srv *Server) checkpoint(c *conn, stage chan<- *conn) error {
+ select {
+ case stage <- c:
+ case <-srv.quit:
+ return errServerStopped
}
- if _, ok := srv.trustedNodes[id]; ok {
- return true // trusted nodes are always allowed
+ select {
+ case err := <-c.cont:
+ return err
+ case <-srv.quit:
+ return errServerStopped
}
- return len(srv.peers) < srv.MaxPeers
}
+// runPeer runs in its own goroutine for each peer.
+// it waits until the Peer logic returns and removes
+// the peer.
func (srv *Server) runPeer(p *Peer) {
glog.V(logger.Debug).Infof("Added %v\n", p)
srvjslog.LogJson(&logger.P2PConnected{
@@ -511,58 +607,18 @@ func (srv *Server) runPeer(p *Peer) {
RemoteVersionString: p.Name(),
NumConnections: srv.PeerCount(),
})
+
if srv.newPeerHook != nil {
srv.newPeerHook(p)
}
discreason := p.run()
- srv.removePeer(p)
+ // Note: run waits for existing peers to be sent on srv.delpeer
+ // before returning, so this send should not select on srv.quit.
+ srv.delpeer <- p
+
glog.V(logger.Debug).Infof("Removed %v (%v)\n", p, discreason)
srvjslog.LogJson(&logger.P2PDisconnected{
RemoteId: p.ID().String(),
NumConnections: srv.PeerCount(),
})
}
-
-func (srv *Server) addPeer(conn *conn, p *Peer) (bool, DiscReason) {
- // drop connections with no matching protocols.
- if len(srv.Protocols) > 0 && countMatchingProtocols(srv.Protocols, conn.protoHandshake.Caps) == 0 {
- return false, DiscUselessPeer
- }
- // add the peer if it passes the other checks.
- srv.lock.Lock()
- defer srv.lock.Unlock()
- if ok, reason := srv.checkPeer(conn.ID); !ok {
- return false, reason
- }
- srv.peers[conn.ID] = p
- return true, 0
-}
-
-// checkPeer verifies whether a peer looks promising and should be allowed/kept
-// in the pool, or if it's of no use.
-func (srv *Server) checkPeer(id discover.NodeID) (bool, DiscReason) {
- // First up, figure out if the peer is static or trusted
- _, static := srv.staticNodes[id]
- trusted := srv.trustedNodes[id]
-
- // Make sure the peer passes all required checks
- switch {
- case !srv.running:
- return false, DiscQuitting
- case !static && !trusted && len(srv.peers) >= srv.MaxPeers:
- return false, DiscTooManyPeers
- case srv.peers[id] != nil:
- return false, DiscAlreadyConnected
- case id == srv.ntab.Self().ID:
- return false, DiscSelf
- default:
- return true, 0
- }
-}
-
-func (srv *Server) removePeer(p *Peer) {
- srv.lock.Lock()
- delete(srv.peers, p.ID())
- srv.lock.Unlock()
- srv.peerWG.Done()
-}