aboutsummaryrefslogtreecommitdiffstats
path: root/p2p/server.go
diff options
context:
space:
mode:
Diffstat (limited to 'p2p/server.go')
-rw-r--r--p2p/server.go38
1 files changed, 22 insertions, 16 deletions
diff --git a/p2p/server.go b/p2p/server.go
index 35b584a27..194dc3f1c 100644
--- a/p2p/server.go
+++ b/p2p/server.go
@@ -5,7 +5,6 @@ import (
"crypto/ecdsa"
"errors"
"fmt"
- "io"
"net"
"runtime"
"sync"
@@ -83,9 +82,11 @@ type Server struct {
// Hooks for testing. These are useful because we can inhibit
// the whole protocol stack.
- handshakeFunc
+ setupFunc
newPeerHook
+ ourHandshake *protoHandshake
+
lock sync.RWMutex
running bool
listener net.Listener
@@ -99,7 +100,7 @@ type Server struct {
peerConnect chan *discover.Node
}
-type handshakeFunc func(io.ReadWriter, *ecdsa.PrivateKey, *discover.Node) (discover.NodeID, []byte, error)
+type setupFunc func(net.Conn, *ecdsa.PrivateKey, *protoHandshake, *discover.Node) (*conn, error)
type newPeerHook func(*Peer)
// Peers returns all connected peers.
@@ -170,8 +171,8 @@ func (srv *Server) Start() (err error) {
srv.peers = make(map[discover.NodeID]*Peer)
srv.peerConnect = make(chan *discover.Node)
- if srv.handshakeFunc == nil {
- srv.handshakeFunc = encHandshake
+ if srv.setupFunc == nil {
+ srv.setupFunc = setupConn
}
if srv.Blacklist == nil {
srv.Blacklist = NewBlacklist()
@@ -183,11 +184,17 @@ func (srv *Server) Start() (err error) {
}
// dial stuff
- dt, err := discover.ListenUDP(srv.PrivateKey, srv.ListenAddr, srv.NAT)
+ ntab, err := discover.ListenUDP(srv.PrivateKey, srv.ListenAddr, srv.NAT)
if err != nil {
return err
}
- srv.ntab = dt
+ srv.ntab = ntab
+
+ srv.ourHandshake = &protoHandshake{Version: baseProtocolVersion, Name: srv.Name, ID: ntab.Self()}
+ for _, p := range srv.Protocols {
+ srv.ourHandshake.Caps = append(srv.ourHandshake.Caps, p.cap())
+ }
+
if srv.Dialer == nil {
srv.Dialer = &net.Dialer{Timeout: defaultDialTimeout}
}
@@ -347,18 +354,17 @@ func (srv *Server) findPeers() {
}
}
-func (srv *Server) startPeer(conn net.Conn, dest *discover.Node) {
+func (srv *Server) startPeer(fd net.Conn, dest *discover.Node) {
// TODO: handle/store session token
- conn.SetDeadline(time.Now().Add(handshakeTimeout))
- remoteID, _, err := srv.handshakeFunc(conn, srv.PrivateKey, dest)
+ fd.SetDeadline(time.Now().Add(handshakeTimeout))
+ conn, err := srv.setupFunc(fd, srv.PrivateKey, srv.ourHandshake, dest)
if err != nil {
- conn.Close()
- srvlog.Debugf("Encryption Handshake with %v failed: %v", conn.RemoteAddr(), err)
+ fd.Close()
+ srvlog.Debugf("Handshake with %v failed: %v", fd.RemoteAddr(), err)
return
}
- ourID := srv.ntab.Self()
- p := newPeer(conn, srv.Protocols, srv.Name, &ourID, &remoteID)
- if ok, reason := srv.addPeer(remoteID, p); !ok {
+ p := newPeer(conn, srv.Protocols)
+ if ok, reason := srv.addPeer(conn.ID, p); !ok {
srvlog.DebugDetailf("Not adding %v (%v)\n", p, reason)
p.politeDisconnect(reason)
return
@@ -394,7 +400,7 @@ func (srv *Server) addPeer(id discover.NodeID, p *Peer) (bool, DiscReason) {
func (srv *Server) removePeer(p *Peer) {
srv.lock.Lock()
- delete(srv.peers, *p.remoteID)
+ delete(srv.peers, p.ID())
srv.lock.Unlock()
srv.peerWG.Done()
}