From 1440f9a37a8baf67b989ddf0b8cc30c9a1970e14 Mon Sep 17 00:00:00 2001 From: Felix Lange Date: Sat, 16 May 2015 00:38:28 +0200 Subject: p2p: new dialer, peer management without locks The most visible change is event-based dialing, which should be an improvement over the timer-based system that we have at the moment. The dialer gets a chance to compute new tasks whenever peers change or dials complete. This is better than checking peers on a timer because dials happen faster. The dialer can now make more precise decisions about whom to dial based on the peer set and we can test those decisions without actually opening any sockets. Peer management is easier to test because the tests can inject connections at checkpoints (after enc handshake, after protocol handshake). Most of the handshake stuff is now part of the RLPx code. It could be exported or move to its own package because it is no longer entangled with Server logic. --- p2p/dial.go | 276 +++++++++++++++++++++ p2p/dial_test.go | 482 ++++++++++++++++++++++++++++++++++++ p2p/handshake.go | 448 --------------------------------- p2p/handshake_test.go | 172 ------------- p2p/peer.go | 55 +++-- p2p/peer_error.go | 56 +---- p2p/peer_test.go | 20 +- p2p/rlpx.go | 444 ++++++++++++++++++++++++++++++++- p2p/rlpx_test.go | 244 +++++++++++++++++- p2p/server.go | 666 +++++++++++++++++++++++++++----------------------- p2p/server_test.go | 584 ++++++++++++++++++++----------------------- 11 files changed, 2118 insertions(+), 1329 deletions(-) create mode 100644 p2p/dial.go create mode 100644 p2p/dial_test.go delete mode 100644 p2p/handshake.go delete mode 100644 p2p/handshake_test.go (limited to 'p2p') diff --git a/p2p/dial.go b/p2p/dial.go new file mode 100644 index 000000000..71065c5ee --- /dev/null +++ b/p2p/dial.go @@ -0,0 +1,276 @@ +package p2p + +import ( + "container/heap" + "crypto/rand" + "fmt" + "net" + "time" + + "github.com/ethereum/go-ethereum/logger" + "github.com/ethereum/go-ethereum/logger/glog" + "github.com/ethereum/go-ethereum/p2p/discover" +) + +const ( + // This is the amount of time spent waiting in between + // redialing a certain node. + dialHistoryExpiration = 30 * time.Second + + // Discovery lookup tasks will wait for this long when + // no results are returned. This can happen if the table + // becomes empty (i.e. not often). + emptyLookupDelay = 10 * time.Second +) + +// dialstate schedules dials and discovery lookups. +// it get's a chance to compute new tasks on every iteration +// of the main loop in Server.run. +type dialstate struct { + maxDynDials int + ntab discoverTable + + lookupRunning bool + bootstrapped bool + + dialing map[discover.NodeID]connFlag + lookupBuf []*discover.Node // current discovery lookup results + randomNodes []*discover.Node // filled from Table + static map[discover.NodeID]*discover.Node + hist *dialHistory +} + +type discoverTable interface { + Self() *discover.Node + Close() + Bootstrap([]*discover.Node) + Lookup(target discover.NodeID) []*discover.Node + ReadRandomNodes([]*discover.Node) int +} + +// the dial history remembers recent dials. +type dialHistory []pastDial + +// pastDial is an entry in the dial history. +type pastDial struct { + id discover.NodeID + exp time.Time +} + +type task interface { + Do(*Server) +} + +// A dialTask is generated for each node that is dialed. +type dialTask struct { + flags connFlag + dest *discover.Node +} + +// discoverTask runs discovery table operations. +// Only one discoverTask is active at any time. +// +// If bootstrap is true, the task runs Table.Bootstrap, +// otherwise it performs a random lookup and leaves the +// results in the task. +type discoverTask struct { + bootstrap bool + results []*discover.Node +} + +// A waitExpireTask is generated if there are no other tasks +// to keep the loop in Server.run ticking. +type waitExpireTask struct { + time.Duration +} + +func newDialState(static []*discover.Node, ntab discoverTable, maxdyn int) *dialstate { + s := &dialstate{ + maxDynDials: maxdyn, + ntab: ntab, + static: make(map[discover.NodeID]*discover.Node), + dialing: make(map[discover.NodeID]connFlag), + randomNodes: make([]*discover.Node, maxdyn/2), + hist: new(dialHistory), + } + for _, n := range static { + s.static[n.ID] = n + } + return s +} + +func (s *dialstate) addStatic(n *discover.Node) { + s.static[n.ID] = n +} + +func (s *dialstate) newTasks(nRunning int, peers map[discover.NodeID]*Peer, now time.Time) []task { + var newtasks []task + addDial := func(flag connFlag, n *discover.Node) bool { + _, dialing := s.dialing[n.ID] + if dialing || peers[n.ID] != nil || s.hist.contains(n.ID) { + return false + } + s.dialing[n.ID] = flag + newtasks = append(newtasks, &dialTask{flags: flag, dest: n}) + return true + } + + // Compute number of dynamic dials necessary at this point. + needDynDials := s.maxDynDials + for _, p := range peers { + if p.rw.is(dynDialedConn) { + needDynDials-- + } + } + for _, flag := range s.dialing { + if flag&dynDialedConn != 0 { + needDynDials-- + } + } + + // Expire the dial history on every invocation. + s.hist.expire(now) + + // Create dials for static nodes if they are not connected. + for _, n := range s.static { + addDial(staticDialedConn, n) + } + + // Use random nodes from the table for half of the necessary + // dynamic dials. + randomCandidates := needDynDials / 2 + if randomCandidates > 0 && s.bootstrapped { + n := s.ntab.ReadRandomNodes(s.randomNodes) + for i := 0; i < randomCandidates && i < n; i++ { + if addDial(dynDialedConn, s.randomNodes[i]) { + needDynDials-- + } + } + } + // Create dynamic dials from random lookup results, removing tried + // items from the result buffer. + i := 0 + for ; i < len(s.lookupBuf) && needDynDials > 0; i++ { + if addDial(dynDialedConn, s.lookupBuf[i]) { + needDynDials-- + } + } + s.lookupBuf = s.lookupBuf[:copy(s.lookupBuf, s.lookupBuf[i:])] + // Launch a discovery lookup if more candidates are needed. The + // first discoverTask bootstraps the table and won't return any + // results. + if len(s.lookupBuf) < needDynDials && !s.lookupRunning { + s.lookupRunning = true + newtasks = append(newtasks, &discoverTask{bootstrap: !s.bootstrapped}) + } + + // Launch a timer to wait for the next node to expire if all + // candidates have been tried and no task is currently active. + // This should prevent cases where the dialer logic is not ticked + // because there are no pending events. + if nRunning == 0 && len(newtasks) == 0 && s.hist.Len() > 0 { + t := &waitExpireTask{s.hist.min().exp.Sub(now)} + newtasks = append(newtasks, t) + } + return newtasks +} + +func (s *dialstate) taskDone(t task, now time.Time) { + switch t := t.(type) { + case *dialTask: + s.hist.add(t.dest.ID, now.Add(dialHistoryExpiration)) + delete(s.dialing, t.dest.ID) + case *discoverTask: + if t.bootstrap { + s.bootstrapped = true + } + s.lookupRunning = false + s.lookupBuf = append(s.lookupBuf, t.results...) + } +} + +func (t *dialTask) Do(srv *Server) { + addr := &net.TCPAddr{IP: t.dest.IP, Port: int(t.dest.TCP)} + glog.V(logger.Debug).Infof("dialing %v\n", t.dest) + fd, err := srv.Dialer.Dial("tcp", addr.String()) + if err != nil { + glog.V(logger.Detail).Infof("dial error: %v", err) + return + } + srv.setupConn(fd, t.flags, t.dest) +} +func (t *dialTask) String() string { + return fmt.Sprintf("%v %x %v:%d", t.flags, t.dest.ID[:8], t.dest.IP, t.dest.TCP) +} + +func (t *discoverTask) Do(srv *Server) { + if t.bootstrap { + srv.ntab.Bootstrap(srv.BootstrapNodes) + } else { + var target discover.NodeID + rand.Read(target[:]) + t.results = srv.ntab.Lookup(target) + // newTasks generates a lookup task whenever dynamic dials are + // necessary. Lookups need to take some time, otherwise the + // event loop spins too fast. An empty result can only be + // returned if the table is empty. + if len(t.results) == 0 { + time.Sleep(emptyLookupDelay) + } + } +} + +func (t *discoverTask) String() (s string) { + if t.bootstrap { + s = "discovery bootstrap" + } else { + s = "discovery lookup" + } + if len(t.results) > 0 { + s += fmt.Sprintf(" (%d results)", len(t.results)) + } + return s +} + +func (t waitExpireTask) Do(*Server) { + time.Sleep(t.Duration) +} +func (t waitExpireTask) String() string { + return fmt.Sprintf("wait for dial hist expire (%v)", t.Duration) +} + +// Use only these methods to access or modify dialHistory. +func (h dialHistory) min() pastDial { + return h[0] +} +func (h *dialHistory) add(id discover.NodeID, exp time.Time) { + heap.Push(h, pastDial{id, exp}) +} +func (h dialHistory) contains(id discover.NodeID) bool { + for _, v := range h { + if v.id == id { + return true + } + } + return false +} +func (h *dialHistory) expire(now time.Time) { + for h.Len() > 0 && h.min().exp.Before(now) { + heap.Pop(h) + } +} + +// heap.Interface boilerplate +func (h dialHistory) Len() int { return len(h) } +func (h dialHistory) Less(i, j int) bool { return h[i].exp.Before(h[j].exp) } +func (h dialHistory) Swap(i, j int) { h[i], h[j] = h[j], h[i] } +func (h *dialHistory) Push(x interface{}) { + *h = append(*h, x.(pastDial)) +} +func (h *dialHistory) Pop() interface{} { + old := *h + n := len(old) + x := old[n-1] + *h = old[0 : n-1] + return x +} diff --git a/p2p/dial_test.go b/p2p/dial_test.go new file mode 100644 index 000000000..78568c5ed --- /dev/null +++ b/p2p/dial_test.go @@ -0,0 +1,482 @@ +package p2p + +import ( + "encoding/binary" + "reflect" + "testing" + "time" + + "github.com/davecgh/go-spew/spew" + "github.com/ethereum/go-ethereum/p2p/discover" +) + +func init() { + spew.Config.Indent = "\t" +} + +type dialtest struct { + init *dialstate // state before and after the test. + rounds []round +} + +type round struct { + peers []*Peer // current peer set + done []task // tasks that got done this round + new []task // the result must match this one +} + +func runDialTest(t *testing.T, test dialtest) { + var ( + vtime time.Time + running int + ) + pm := func(ps []*Peer) map[discover.NodeID]*Peer { + m := make(map[discover.NodeID]*Peer) + for _, p := range ps { + m[p.rw.id] = p + } + return m + } + for i, round := range test.rounds { + for _, task := range round.done { + running-- + if running < 0 { + panic("running task counter underflow") + } + test.init.taskDone(task, vtime) + } + + new := test.init.newTasks(running, pm(round.peers), vtime) + if !sametasks(new, round.new) { + t.Errorf("round %d: new tasks mismatch:\ngot %v\nwant %v\nstate: %v\nrunning: %v\n", + i, spew.Sdump(new), spew.Sdump(round.new), spew.Sdump(test.init), spew.Sdump(running)) + } + + // Time advances by 16 seconds on every round. + vtime = vtime.Add(16 * time.Second) + running += len(new) + } +} + +type fakeTable []*discover.Node + +func (t fakeTable) Self() *discover.Node { return new(discover.Node) } +func (t fakeTable) Close() {} +func (t fakeTable) Bootstrap([]*discover.Node) {} +func (t fakeTable) Lookup(target discover.NodeID) []*discover.Node { + return nil +} +func (t fakeTable) ReadRandomNodes(buf []*discover.Node) int { + return copy(buf, t) +} + +// This test checks that dynamic dials are launched from discovery results. +func TestDialStateDynDial(t *testing.T) { + runDialTest(t, dialtest{ + init: newDialState(nil, fakeTable{}, 5), + rounds: []round{ + // A discovery query is launched. + { + peers: []*Peer{ + {rw: &conn{flags: staticDialedConn, id: uintID(0)}}, + {rw: &conn{flags: dynDialedConn, id: uintID(1)}}, + {rw: &conn{flags: dynDialedConn, id: uintID(2)}}, + }, + new: []task{&discoverTask{bootstrap: true}}, + }, + // Dynamic dials are launched when it completes. + { + peers: []*Peer{ + {rw: &conn{flags: staticDialedConn, id: uintID(0)}}, + {rw: &conn{flags: dynDialedConn, id: uintID(1)}}, + {rw: &conn{flags: dynDialedConn, id: uintID(2)}}, + }, + done: []task{ + &discoverTask{bootstrap: true, results: []*discover.Node{ + {ID: uintID(2)}, // this one is already connected and not dialed. + {ID: uintID(3)}, + {ID: uintID(4)}, + {ID: uintID(5)}, + {ID: uintID(6)}, // these are not tried because max dyn dials is 5 + {ID: uintID(7)}, // ... + }}, + }, + new: []task{ + &dialTask{dynDialedConn, &discover.Node{ID: uintID(3)}}, + &dialTask{dynDialedConn, &discover.Node{ID: uintID(4)}}, + &dialTask{dynDialedConn, &discover.Node{ID: uintID(5)}}, + }, + }, + // Some of the dials complete but no new ones are launched yet because + // the sum of active dial count and dynamic peer count is == maxDynDials. + { + peers: []*Peer{ + {rw: &conn{flags: staticDialedConn, id: uintID(0)}}, + {rw: &conn{flags: dynDialedConn, id: uintID(1)}}, + {rw: &conn{flags: dynDialedConn, id: uintID(2)}}, + {rw: &conn{flags: dynDialedConn, id: uintID(3)}}, + {rw: &conn{flags: dynDialedConn, id: uintID(4)}}, + }, + done: []task{ + &dialTask{dynDialedConn, &discover.Node{ID: uintID(3)}}, + &dialTask{dynDialedConn, &discover.Node{ID: uintID(4)}}, + }, + }, + // No new dial tasks are launched in the this round because + // maxDynDials has been reached. + { + peers: []*Peer{ + {rw: &conn{flags: staticDialedConn, id: uintID(0)}}, + {rw: &conn{flags: dynDialedConn, id: uintID(1)}}, + {rw: &conn{flags: dynDialedConn, id: uintID(2)}}, + {rw: &conn{flags: dynDialedConn, id: uintID(3)}}, + {rw: &conn{flags: dynDialedConn, id: uintID(4)}}, + {rw: &conn{flags: dynDialedConn, id: uintID(5)}}, + }, + done: []task{ + &dialTask{dynDialedConn, &discover.Node{ID: uintID(5)}}, + }, + new: []task{ + &waitExpireTask{Duration: 14 * time.Second}, + }, + }, + // In this round, the peer with id 2 drops off. The query + // results from last discovery lookup are reused. + { + peers: []*Peer{ + {rw: &conn{flags: staticDialedConn, id: uintID(0)}}, + {rw: &conn{flags: dynDialedConn, id: uintID(1)}}, + {rw: &conn{flags: dynDialedConn, id: uintID(3)}}, + {rw: &conn{flags: dynDialedConn, id: uintID(4)}}, + {rw: &conn{flags: dynDialedConn, id: uintID(5)}}, + }, + new: []task{ + &dialTask{dynDialedConn, &discover.Node{ID: uintID(6)}}, + }, + }, + // More peers (3,4) drop off and dial for ID 6 completes. + // The last query result from the discovery lookup is reused + // and a new one is spawned because more candidates are needed. + { + peers: []*Peer{ + {rw: &conn{flags: staticDialedConn, id: uintID(0)}}, + {rw: &conn{flags: dynDialedConn, id: uintID(1)}}, + {rw: &conn{flags: dynDialedConn, id: uintID(5)}}, + }, + done: []task{ + &dialTask{dynDialedConn, &discover.Node{ID: uintID(6)}}, + }, + new: []task{ + &dialTask{dynDialedConn, &discover.Node{ID: uintID(7)}}, + &discoverTask{}, + }, + }, + // Peer 7 is connected, but there still aren't enough dynamic peers + // (4 out of 5). However, a discovery is already running, so ensure + // no new is started. + { + peers: []*Peer{ + {rw: &conn{flags: staticDialedConn, id: uintID(0)}}, + {rw: &conn{flags: dynDialedConn, id: uintID(1)}}, + {rw: &conn{flags: dynDialedConn, id: uintID(5)}}, + {rw: &conn{flags: dynDialedConn, id: uintID(7)}}, + }, + done: []task{ + &dialTask{dynDialedConn, &discover.Node{ID: uintID(7)}}, + }, + }, + // Finish the running node discovery with an empty set. A new lookup + // should be immediately requested. + { + peers: []*Peer{ + {rw: &conn{flags: staticDialedConn, id: uintID(0)}}, + {rw: &conn{flags: dynDialedConn, id: uintID(1)}}, + {rw: &conn{flags: dynDialedConn, id: uintID(5)}}, + {rw: &conn{flags: dynDialedConn, id: uintID(7)}}, + }, + done: []task{ + &discoverTask{}, + }, + new: []task{ + &discoverTask{}, + }, + }, + }, + }) +} + +func TestDialStateDynDialFromTable(t *testing.T) { + // This table always returns the same random nodes + // in the order given below. + table := fakeTable{ + {ID: uintID(1)}, + {ID: uintID(2)}, + {ID: uintID(3)}, + {ID: uintID(4)}, + {ID: uintID(5)}, + {ID: uintID(6)}, + {ID: uintID(7)}, + {ID: uintID(8)}, + } + + runDialTest(t, dialtest{ + init: newDialState(nil, table, 10), + rounds: []round{ + // Discovery bootstrap is launched. + { + new: []task{&discoverTask{bootstrap: true}}, + }, + // 5 out of 8 of the nodes returned by ReadRandomNodes are dialed. + { + done: []task{ + &discoverTask{bootstrap: true}, + }, + new: []task{ + &dialTask{dynDialedConn, &discover.Node{ID: uintID(1)}}, + &dialTask{dynDialedConn, &discover.Node{ID: uintID(2)}}, + &dialTask{dynDialedConn, &discover.Node{ID: uintID(3)}}, + &dialTask{dynDialedConn, &discover.Node{ID: uintID(4)}}, + &dialTask{dynDialedConn, &discover.Node{ID: uintID(5)}}, + &discoverTask{bootstrap: false}, + }, + }, + // Dialing nodes 1,2 succeeds. Dials from the lookup are launched. + { + peers: []*Peer{ + {rw: &conn{flags: dynDialedConn, id: uintID(1)}}, + {rw: &conn{flags: dynDialedConn, id: uintID(2)}}, + }, + done: []task{ + &dialTask{dynDialedConn, &discover.Node{ID: uintID(1)}}, + &dialTask{dynDialedConn, &discover.Node{ID: uintID(2)}}, + &discoverTask{results: []*discover.Node{ + {ID: uintID(10)}, + {ID: uintID(11)}, + {ID: uintID(12)}, + }}, + }, + new: []task{ + &dialTask{dynDialedConn, &discover.Node{ID: uintID(10)}}, + &dialTask{dynDialedConn, &discover.Node{ID: uintID(11)}}, + &dialTask{dynDialedConn, &discover.Node{ID: uintID(12)}}, + &discoverTask{bootstrap: false}, + }, + }, + // Dialing nodes 3,4,5 fails. The dials from the lookup succeed. + { + peers: []*Peer{ + {rw: &conn{flags: dynDialedConn, id: uintID(1)}}, + {rw: &conn{flags: dynDialedConn, id: uintID(2)}}, + {rw: &conn{flags: dynDialedConn, id: uintID(10)}}, + {rw: &conn{flags: dynDialedConn, id: uintID(11)}}, + {rw: &conn{flags: dynDialedConn, id: uintID(12)}}, + }, + done: []task{ + &dialTask{dynDialedConn, &discover.Node{ID: uintID(3)}}, + &dialTask{dynDialedConn, &discover.Node{ID: uintID(4)}}, + &dialTask{dynDialedConn, &discover.Node{ID: uintID(5)}}, + &dialTask{dynDialedConn, &discover.Node{ID: uintID(10)}}, + &dialTask{dynDialedConn, &discover.Node{ID: uintID(11)}}, + &dialTask{dynDialedConn, &discover.Node{ID: uintID(12)}}, + }, + }, + // Waiting for expiry. No waitExpireTask is launched because the + // discovery query is still running. + { + peers: []*Peer{ + {rw: &conn{flags: dynDialedConn, id: uintID(1)}}, + {rw: &conn{flags: dynDialedConn, id: uintID(2)}}, + {rw: &conn{flags: dynDialedConn, id: uintID(10)}}, + {rw: &conn{flags: dynDialedConn, id: uintID(11)}}, + {rw: &conn{flags: dynDialedConn, id: uintID(12)}}, + }, + }, + // Nodes 3,4 are not tried again because only the first two + // returned random nodes (nodes 1,2) are tried and they're + // already connected. + { + peers: []*Peer{ + {rw: &conn{flags: dynDialedConn, id: uintID(1)}}, + {rw: &conn{flags: dynDialedConn, id: uintID(2)}}, + {rw: &conn{flags: dynDialedConn, id: uintID(10)}}, + {rw: &conn{flags: dynDialedConn, id: uintID(11)}}, + {rw: &conn{flags: dynDialedConn, id: uintID(12)}}, + }, + }, + }, + }) +} + +// This test checks that static dials are launched. +func TestDialStateStaticDial(t *testing.T) { + wantStatic := []*discover.Node{ + {ID: uintID(1)}, + {ID: uintID(2)}, + {ID: uintID(3)}, + {ID: uintID(4)}, + {ID: uintID(5)}, + } + + runDialTest(t, dialtest{ + init: newDialState(wantStatic, fakeTable{}, 0), + rounds: []round{ + // Static dials are launched for the nodes that + // aren't yet connected. + { + peers: []*Peer{ + {rw: &conn{flags: dynDialedConn, id: uintID(1)}}, + {rw: &conn{flags: dynDialedConn, id: uintID(2)}}, + }, + new: []task{ + &dialTask{staticDialedConn, &discover.Node{ID: uintID(3)}}, + &dialTask{staticDialedConn, &discover.Node{ID: uintID(4)}}, + &dialTask{staticDialedConn, &discover.Node{ID: uintID(5)}}, + }, + }, + // No new tasks are launched in this round because all static + // nodes are either connected or still being dialed. + { + peers: []*Peer{ + {rw: &conn{flags: dynDialedConn, id: uintID(1)}}, + {rw: &conn{flags: dynDialedConn, id: uintID(2)}}, + {rw: &conn{flags: staticDialedConn, id: uintID(3)}}, + }, + done: []task{ + &dialTask{staticDialedConn, &discover.Node{ID: uintID(3)}}, + }, + }, + // No new dial tasks are launched because all static + // nodes are now connected. + { + peers: []*Peer{ + {rw: &conn{flags: dynDialedConn, id: uintID(1)}}, + {rw: &conn{flags: dynDialedConn, id: uintID(2)}}, + {rw: &conn{flags: staticDialedConn, id: uintID(3)}}, + {rw: &conn{flags: staticDialedConn, id: uintID(4)}}, + {rw: &conn{flags: staticDialedConn, id: uintID(5)}}, + }, + done: []task{ + &dialTask{staticDialedConn, &discover.Node{ID: uintID(4)}}, + &dialTask{staticDialedConn, &discover.Node{ID: uintID(5)}}, + }, + new: []task{ + &waitExpireTask{Duration: 14 * time.Second}, + }, + }, + // Wait a round for dial history to expire, no new tasks should spawn. + { + peers: []*Peer{ + {rw: &conn{flags: dynDialedConn, id: uintID(1)}}, + {rw: &conn{flags: dynDialedConn, id: uintID(2)}}, + {rw: &conn{flags: staticDialedConn, id: uintID(3)}}, + {rw: &conn{flags: staticDialedConn, id: uintID(4)}}, + {rw: &conn{flags: staticDialedConn, id: uintID(5)}}, + }, + }, + // If a static node is dropped, it should be immediately redialed, + // irrespective whether it was originally static or dynamic. + { + peers: []*Peer{ + {rw: &conn{flags: dynDialedConn, id: uintID(1)}}, + {rw: &conn{flags: staticDialedConn, id: uintID(3)}}, + {rw: &conn{flags: staticDialedConn, id: uintID(5)}}, + }, + new: []task{ + &dialTask{staticDialedConn, &discover.Node{ID: uintID(2)}}, + &dialTask{staticDialedConn, &discover.Node{ID: uintID(4)}}, + }, + }, + }, + }) +} + +// This test checks that past dials are not retried for some time. +func TestDialStateCache(t *testing.T) { + wantStatic := []*discover.Node{ + {ID: uintID(1)}, + {ID: uintID(2)}, + {ID: uintID(3)}, + } + + runDialTest(t, dialtest{ + init: newDialState(wantStatic, fakeTable{}, 0), + rounds: []round{ + // Static dials are launched for the nodes that + // aren't yet connected. + { + peers: nil, + new: []task{ + &dialTask{staticDialedConn, &discover.Node{ID: uintID(1)}}, + &dialTask{staticDialedConn, &discover.Node{ID: uintID(2)}}, + &dialTask{staticDialedConn, &discover.Node{ID: uintID(3)}}, + }, + }, + // No new tasks are launched in this round because all static + // nodes are either connected or still being dialed. + { + peers: []*Peer{ + {rw: &conn{flags: staticDialedConn, id: uintID(1)}}, + {rw: &conn{flags: staticDialedConn, id: uintID(2)}}, + }, + done: []task{ + &dialTask{staticDialedConn, &discover.Node{ID: uintID(1)}}, + &dialTask{staticDialedConn, &discover.Node{ID: uintID(2)}}, + }, + }, + // A salvage task is launched to wait for node 3's history + // entry to expire. + { + peers: []*Peer{ + {rw: &conn{flags: dynDialedConn, id: uintID(1)}}, + {rw: &conn{flags: dynDialedConn, id: uintID(2)}}, + }, + done: []task{ + &dialTask{staticDialedConn, &discover.Node{ID: uintID(3)}}, + }, + new: []task{ + &waitExpireTask{Duration: 14 * time.Second}, + }, + }, + // Still waiting for node 3's entry to expire in the cache. + { + peers: []*Peer{ + {rw: &conn{flags: dynDialedConn, id: uintID(1)}}, + {rw: &conn{flags: dynDialedConn, id: uintID(2)}}, + }, + }, + // The cache entry for node 3 has expired and is retried. + { + peers: []*Peer{ + {rw: &conn{flags: dynDialedConn, id: uintID(1)}}, + {rw: &conn{flags: dynDialedConn, id: uintID(2)}}, + }, + new: []task{ + &dialTask{staticDialedConn, &discover.Node{ID: uintID(3)}}, + }, + }, + }, + }) +} + +// compares task lists but doesn't care about the order. +func sametasks(a, b []task) bool { + if len(a) != len(b) { + return false + } +next: + for _, ta := range a { + for _, tb := range b { + if reflect.DeepEqual(ta, tb) { + continue next + } + } + return false + } + return true +} + +func uintID(i uint32) discover.NodeID { + var id discover.NodeID + binary.BigEndian.PutUint32(id[:], i) + return id +} diff --git a/p2p/handshake.go b/p2p/handshake.go deleted file mode 100644 index 4cdcee6d4..000000000 --- a/p2p/handshake.go +++ /dev/null @@ -1,448 +0,0 @@ -package p2p - -import ( - "crypto/ecdsa" - "crypto/elliptic" - "crypto/rand" - "errors" - "fmt" - "hash" - "io" - "net" - - "github.com/ethereum/go-ethereum/crypto" - "github.com/ethereum/go-ethereum/crypto/ecies" - "github.com/ethereum/go-ethereum/crypto/secp256k1" - "github.com/ethereum/go-ethereum/crypto/sha3" - "github.com/ethereum/go-ethereum/p2p/discover" - "github.com/ethereum/go-ethereum/rlp" -) - -const ( - sskLen = 16 // ecies.MaxSharedKeyLength(pubKey) / 2 - sigLen = 65 // elliptic S256 - pubLen = 64 // 512 bit pubkey in uncompressed representation without format byte - shaLen = 32 // hash length (for nonce etc) - - authMsgLen = sigLen + shaLen + pubLen + shaLen + 1 - authRespLen = pubLen + shaLen + 1 - - eciesBytes = 65 + 16 + 32 - encAuthMsgLen = authMsgLen + eciesBytes // size of the final ECIES payload sent as initiator's handshake - encAuthRespLen = authRespLen + eciesBytes // size of the final ECIES payload sent as receiver's handshake -) - -// conn represents a remote connection after encryption handshake -// and protocol handshake have completed. -// -// The MsgReadWriter is usually layered as follows: -// -// netWrapper (I/O timeouts, thread-safe ReadMsg, WriteMsg) -// rlpxFrameRW (message encoding, encryption, authentication) -// bufio.ReadWriter (buffering) -// net.Conn (network I/O) -// -type conn struct { - MsgReadWriter - *protoHandshake -} - -// secrets represents the connection secrets -// which are negotiated during the encryption handshake. -type secrets struct { - RemoteID discover.NodeID - AES, MAC []byte - EgressMAC, IngressMAC hash.Hash - Token []byte -} - -// protoHandshake is the RLP structure of the protocol handshake. -type protoHandshake struct { - Version uint64 - Name string - Caps []Cap - ListenPort uint64 - ID discover.NodeID -} - -// setupConn starts a protocol session on the given connection. It -// runs the encryption handshake and the protocol handshake. If dial -// is non-nil, the connection the local node is the initiator. If -// keepconn returns false, the connection will be disconnected with -// DiscTooManyPeers after the key exchange. -func setupConn(fd net.Conn, prv *ecdsa.PrivateKey, our *protoHandshake, dial *discover.Node, keepconn func(discover.NodeID) bool) (*conn, error) { - if dial == nil { - return setupInboundConn(fd, prv, our, keepconn) - } else { - return setupOutboundConn(fd, prv, our, dial, keepconn) - } -} - -func setupInboundConn(fd net.Conn, prv *ecdsa.PrivateKey, our *protoHandshake, keepconn func(discover.NodeID) bool) (*conn, error) { - secrets, err := receiverEncHandshake(fd, prv, nil) - if err != nil { - return nil, fmt.Errorf("encryption handshake failed: %v", err) - } - rw := newRlpxFrameRW(fd, secrets) - if !keepconn(secrets.RemoteID) { - SendItems(rw, discMsg, DiscTooManyPeers) - return nil, errors.New("we have too many peers") - } - // Run the protocol handshake using authenticated messages. - rhs, err := readProtocolHandshake(rw, secrets.RemoteID, our) - if err != nil { - return nil, err - } - if err := Send(rw, handshakeMsg, our); err != nil { - return nil, fmt.Errorf("protocol handshake write error: %v", err) - } - return &conn{rw, rhs}, nil -} - -func setupOutboundConn(fd net.Conn, prv *ecdsa.PrivateKey, our *protoHandshake, dial *discover.Node, keepconn func(discover.NodeID) bool) (*conn, error) { - secrets, err := initiatorEncHandshake(fd, prv, dial.ID, nil) - if err != nil { - return nil, fmt.Errorf("encryption handshake failed: %v", err) - } - rw := newRlpxFrameRW(fd, secrets) - if !keepconn(secrets.RemoteID) { - SendItems(rw, discMsg, DiscTooManyPeers) - return nil, errors.New("we have too many peers") - } - // Run the protocol handshake using authenticated messages. - // - // Note that even though writing the handshake is first, we prefer - // returning the handshake read error. If the remote side - // disconnects us early with a valid reason, we should return it - // as the error so it can be tracked elsewhere. - werr := make(chan error, 1) - go func() { werr <- Send(rw, handshakeMsg, our) }() - rhs, err := readProtocolHandshake(rw, secrets.RemoteID, our) - if err != nil { - return nil, err - } - if err := <-werr; err != nil { - return nil, fmt.Errorf("protocol handshake write error: %v", err) - } - if rhs.ID != dial.ID { - return nil, errors.New("dialed node id mismatch") - } - return &conn{rw, rhs}, nil -} - -// encHandshake contains the state of the encryption handshake. -type encHandshake struct { - initiator bool - remoteID discover.NodeID - - remotePub *ecies.PublicKey // remote-pubk - initNonce, respNonce []byte // nonce - randomPrivKey *ecies.PrivateKey // ecdhe-random - remoteRandomPub *ecies.PublicKey // ecdhe-random-pubk -} - -// secrets is called after the handshake is completed. -// It extracts the connection secrets from the handshake values. -func (h *encHandshake) secrets(auth, authResp []byte) (secrets, error) { - ecdheSecret, err := h.randomPrivKey.GenerateShared(h.remoteRandomPub, sskLen, sskLen) - if err != nil { - return secrets{}, err - } - - // derive base secrets from ephemeral key agreement - sharedSecret := crypto.Sha3(ecdheSecret, crypto.Sha3(h.respNonce, h.initNonce)) - aesSecret := crypto.Sha3(ecdheSecret, sharedSecret) - s := secrets{ - RemoteID: h.remoteID, - AES: aesSecret, - MAC: crypto.Sha3(ecdheSecret, aesSecret), - Token: crypto.Sha3(sharedSecret), - } - - // setup sha3 instances for the MACs - mac1 := sha3.NewKeccak256() - mac1.Write(xor(s.MAC, h.respNonce)) - mac1.Write(auth) - mac2 := sha3.NewKeccak256() - mac2.Write(xor(s.MAC, h.initNonce)) - mac2.Write(authResp) - if h.initiator { - s.EgressMAC, s.IngressMAC = mac1, mac2 - } else { - s.EgressMAC, s.IngressMAC = mac2, mac1 - } - - return s, nil -} - -func (h *encHandshake) ecdhShared(prv *ecdsa.PrivateKey) ([]byte, error) { - return ecies.ImportECDSA(prv).GenerateShared(h.remotePub, sskLen, sskLen) -} - -// initiatorEncHandshake negotiates a session token on conn. -// it should be called on the dialing side of the connection. -// -// prv is the local client's private key. -// token is the token from a previous session with this node. -func initiatorEncHandshake(conn io.ReadWriter, prv *ecdsa.PrivateKey, remoteID discover.NodeID, token []byte) (s secrets, err error) { - h, err := newInitiatorHandshake(remoteID) - if err != nil { - return s, err - } - auth, err := h.authMsg(prv, token) - if err != nil { - return s, err - } - if _, err = conn.Write(auth); err != nil { - return s, err - } - - response := make([]byte, encAuthRespLen) - if _, err = io.ReadFull(conn, response); err != nil { - return s, err - } - if err := h.decodeAuthResp(response, prv); err != nil { - return s, err - } - return h.secrets(auth, response) -} - -func newInitiatorHandshake(remoteID discover.NodeID) (*encHandshake, error) { - // generate random initiator nonce - n := make([]byte, shaLen) - if _, err := rand.Read(n); err != nil { - return nil, err - } - // generate random keypair to use for signing - randpriv, err := ecies.GenerateKey(rand.Reader, crypto.S256(), nil) - if err != nil { - return nil, err - } - rpub, err := remoteID.Pubkey() - if err != nil { - return nil, fmt.Errorf("bad remoteID: %v", err) - } - h := &encHandshake{ - initiator: true, - remoteID: remoteID, - remotePub: ecies.ImportECDSAPublic(rpub), - initNonce: n, - randomPrivKey: randpriv, - } - return h, nil -} - -// authMsg creates an encrypted initiator handshake message. -func (h *encHandshake) authMsg(prv *ecdsa.PrivateKey, token []byte) ([]byte, error) { - var tokenFlag byte - if token == nil { - // no session token found means we need to generate shared secret. - // ecies shared secret is used as initial session token for new peers - // generate shared key from prv and remote pubkey - var err error - if token, err = h.ecdhShared(prv); err != nil { - return nil, err - } - } else { - // for known peers, we use stored token from the previous session - tokenFlag = 0x01 - } - - // sign known message: - // ecdh-shared-secret^nonce for new peers - // token^nonce for old peers - signed := xor(token, h.initNonce) - signature, err := crypto.Sign(signed, h.randomPrivKey.ExportECDSA()) - if err != nil { - return nil, err - } - - // encode auth message - // signature || sha3(ecdhe-random-pubk) || pubk || nonce || token-flag - msg := make([]byte, authMsgLen) - n := copy(msg, signature) - n += copy(msg[n:], crypto.Sha3(exportPubkey(&h.randomPrivKey.PublicKey))) - n += copy(msg[n:], crypto.FromECDSAPub(&prv.PublicKey)[1:]) - n += copy(msg[n:], h.initNonce) - msg[n] = tokenFlag - - // encrypt auth message using remote-pubk - return ecies.Encrypt(rand.Reader, h.remotePub, msg, nil, nil) -} - -// decodeAuthResp decode an encrypted authentication response message. -func (h *encHandshake) decodeAuthResp(auth []byte, prv *ecdsa.PrivateKey) error { - msg, err := crypto.Decrypt(prv, auth) - if err != nil { - return fmt.Errorf("could not decrypt auth response (%v)", err) - } - h.respNonce = msg[pubLen : pubLen+shaLen] - h.remoteRandomPub, err = importPublicKey(msg[:pubLen]) - if err != nil { - return err - } - // ignore token flag for now - return nil -} - -// receiverEncHandshake negotiates a session token on conn. -// it should be called on the listening side of the connection. -// -// prv is the local client's private key. -// token is the token from a previous session with this node. -func receiverEncHandshake(conn io.ReadWriter, prv *ecdsa.PrivateKey, token []byte) (s secrets, err error) { - // read remote auth sent by initiator. - auth := make([]byte, encAuthMsgLen) - if _, err := io.ReadFull(conn, auth); err != nil { - return s, err - } - h, err := decodeAuthMsg(prv, token, auth) - if err != nil { - return s, err - } - - // send auth response - resp, err := h.authResp(prv, token) - if err != nil { - return s, err - } - if _, err = conn.Write(resp); err != nil { - return s, err - } - - return h.secrets(auth, resp) -} - -func decodeAuthMsg(prv *ecdsa.PrivateKey, token []byte, auth []byte) (*encHandshake, error) { - var err error - h := new(encHandshake) - // generate random keypair for session - h.randomPrivKey, err = ecies.GenerateKey(rand.Reader, crypto.S256(), nil) - if err != nil { - return nil, err - } - // generate random nonce - h.respNonce = make([]byte, shaLen) - if _, err = rand.Read(h.respNonce); err != nil { - return nil, err - } - - msg, err := crypto.Decrypt(prv, auth) - if err != nil { - return nil, fmt.Errorf("could not decrypt auth message (%v)", err) - } - - // decode message parameters - // signature || sha3(ecdhe-random-pubk) || pubk || nonce || token-flag - h.initNonce = msg[authMsgLen-shaLen-1 : authMsgLen-1] - copy(h.remoteID[:], msg[sigLen+shaLen:sigLen+shaLen+pubLen]) - rpub, err := h.remoteID.Pubkey() - if err != nil { - return nil, fmt.Errorf("bad remoteID: %#v", err) - } - h.remotePub = ecies.ImportECDSAPublic(rpub) - - // recover remote random pubkey from signed message. - if token == nil { - // TODO: it is an error if the initiator has a token and we don't. check that. - - // no session token means we need to generate shared secret. - // ecies shared secret is used as initial session token for new peers. - // generate shared key from prv and remote pubkey. - if token, err = h.ecdhShared(prv); err != nil { - return nil, err - } - } - signedMsg := xor(token, h.initNonce) - remoteRandomPub, err := secp256k1.RecoverPubkey(signedMsg, msg[:sigLen]) - if err != nil { - return nil, err - } - h.remoteRandomPub, _ = importPublicKey(remoteRandomPub) - return h, nil -} - -// authResp generates the encrypted authentication response message. -func (h *encHandshake) authResp(prv *ecdsa.PrivateKey, token []byte) ([]byte, error) { - // responder auth message - // E(remote-pubk, ecdhe-random-pubk || nonce || 0x0) - resp := make([]byte, authRespLen) - n := copy(resp, exportPubkey(&h.randomPrivKey.PublicKey)) - n += copy(resp[n:], h.respNonce) - if token == nil { - resp[n] = 0 - } else { - resp[n] = 1 - } - // encrypt using remote-pubk - return ecies.Encrypt(rand.Reader, h.remotePub, resp, nil, nil) -} - -// importPublicKey unmarshals 512 bit public keys. -func importPublicKey(pubKey []byte) (*ecies.PublicKey, error) { - var pubKey65 []byte - switch len(pubKey) { - case 64: - // add 'uncompressed key' flag - pubKey65 = append([]byte{0x04}, pubKey...) - case 65: - pubKey65 = pubKey - default: - return nil, fmt.Errorf("invalid public key length %v (expect 64/65)", len(pubKey)) - } - // TODO: fewer pointless conversions - return ecies.ImportECDSAPublic(crypto.ToECDSAPub(pubKey65)), nil -} - -func exportPubkey(pub *ecies.PublicKey) []byte { - if pub == nil { - panic("nil pubkey") - } - return elliptic.Marshal(pub.Curve, pub.X, pub.Y)[1:] -} - -func xor(one, other []byte) (xor []byte) { - xor = make([]byte, len(one)) - for i := 0; i < len(one); i++ { - xor[i] = one[i] ^ other[i] - } - return xor -} - -func readProtocolHandshake(rw MsgReadWriter, wantID discover.NodeID, our *protoHandshake) (*protoHandshake, error) { - msg, err := rw.ReadMsg() - if err != nil { - return nil, err - } - if msg.Code == discMsg { - // disconnect before protocol handshake is valid according to the - // spec and we send it ourself if Server.addPeer fails. - var reason [1]DiscReason - rlp.Decode(msg.Payload, &reason) - return nil, reason[0] - } - if msg.Code != handshakeMsg { - return nil, fmt.Errorf("expected handshake, got %x", msg.Code) - } - if msg.Size > baseProtocolMaxMsgSize { - return nil, fmt.Errorf("message too big (%d > %d)", msg.Size, baseProtocolMaxMsgSize) - } - var hs protoHandshake - if err := msg.Decode(&hs); err != nil { - return nil, err - } - // validate handshake info - if hs.Version != our.Version { - SendItems(rw, discMsg, DiscIncompatibleVersion) - return nil, fmt.Errorf("required version %d, received %d\n", baseProtocolVersion, hs.Version) - } - if (hs.ID == discover.NodeID{}) { - SendItems(rw, discMsg, DiscInvalidIdentity) - return nil, errors.New("invalid public key in handshake") - } - if hs.ID != wantID { - SendItems(rw, discMsg, DiscUnexpectedIdentity) - return nil, errors.New("handshake node ID does not match encryption handshake") - } - return &hs, nil -} diff --git a/p2p/handshake_test.go b/p2p/handshake_test.go deleted file mode 100644 index ab75921a3..000000000 --- a/p2p/handshake_test.go +++ /dev/null @@ -1,172 +0,0 @@ -package p2p - -import ( - "bytes" - "crypto/rand" - "fmt" - "net" - "reflect" - "testing" - "time" - - "github.com/ethereum/go-ethereum/crypto" - "github.com/ethereum/go-ethereum/crypto/ecies" - "github.com/ethereum/go-ethereum/p2p/discover" -) - -func TestSharedSecret(t *testing.T) { - prv0, _ := crypto.GenerateKey() // = ecdsa.GenerateKey(crypto.S256(), rand.Reader) - pub0 := &prv0.PublicKey - prv1, _ := crypto.GenerateKey() - pub1 := &prv1.PublicKey - - ss0, err := ecies.ImportECDSA(prv0).GenerateShared(ecies.ImportECDSAPublic(pub1), sskLen, sskLen) - if err != nil { - return - } - ss1, err := ecies.ImportECDSA(prv1).GenerateShared(ecies.ImportECDSAPublic(pub0), sskLen, sskLen) - if err != nil { - return - } - t.Logf("Secret:\n%v %x\n%v %x", len(ss0), ss0, len(ss0), ss1) - if !bytes.Equal(ss0, ss1) { - t.Errorf("dont match :(") - } -} - -func TestEncHandshake(t *testing.T) { - for i := 0; i < 20; i++ { - start := time.Now() - if err := testEncHandshake(nil); err != nil { - t.Fatalf("i=%d %v", i, err) - } - t.Logf("(without token) %d %v\n", i+1, time.Since(start)) - } - - for i := 0; i < 20; i++ { - tok := make([]byte, shaLen) - rand.Reader.Read(tok) - start := time.Now() - if err := testEncHandshake(tok); err != nil { - t.Fatalf("i=%d %v", i, err) - } - t.Logf("(with token) %d %v\n", i+1, time.Since(start)) - } -} - -func testEncHandshake(token []byte) error { - type result struct { - side string - s secrets - err error - } - var ( - prv0, _ = crypto.GenerateKey() - prv1, _ = crypto.GenerateKey() - rw0, rw1 = net.Pipe() - output = make(chan result) - ) - - go func() { - r := result{side: "initiator"} - defer func() { output <- r }() - - pub1s := discover.PubkeyID(&prv1.PublicKey) - r.s, r.err = initiatorEncHandshake(rw0, prv0, pub1s, token) - if r.err != nil { - return - } - id1 := discover.PubkeyID(&prv1.PublicKey) - if r.s.RemoteID != id1 { - r.err = fmt.Errorf("remote ID mismatch: got %v, want: %v", r.s.RemoteID, id1) - } - }() - go func() { - r := result{side: "receiver"} - defer func() { output <- r }() - - r.s, r.err = receiverEncHandshake(rw1, prv1, token) - if r.err != nil { - return - } - id0 := discover.PubkeyID(&prv0.PublicKey) - if r.s.RemoteID != id0 { - r.err = fmt.Errorf("remote ID mismatch: got %v, want: %v", r.s.RemoteID, id0) - } - }() - - // wait for results from both sides - r1, r2 := <-output, <-output - - if r1.err != nil { - return fmt.Errorf("%s side error: %v", r1.side, r1.err) - } - if r2.err != nil { - return fmt.Errorf("%s side error: %v", r2.side, r2.err) - } - - // don't compare remote node IDs - r1.s.RemoteID, r2.s.RemoteID = discover.NodeID{}, discover.NodeID{} - // flip MACs on one of them so they compare equal - r1.s.EgressMAC, r1.s.IngressMAC = r1.s.IngressMAC, r1.s.EgressMAC - if !reflect.DeepEqual(r1.s, r2.s) { - return fmt.Errorf("secrets mismatch:\n t1: %#v\n t2: %#v", r1.s, r2.s) - } - return nil -} - -func TestSetupConn(t *testing.T) { - prv0, _ := crypto.GenerateKey() - prv1, _ := crypto.GenerateKey() - node0 := &discover.Node{ - ID: discover.PubkeyID(&prv0.PublicKey), - IP: net.IP{1, 2, 3, 4}, - TCP: 33, - } - node1 := &discover.Node{ - ID: discover.PubkeyID(&prv1.PublicKey), - IP: net.IP{5, 6, 7, 8}, - TCP: 44, - } - hs0 := &protoHandshake{ - Version: baseProtocolVersion, - ID: node0.ID, - Caps: []Cap{{"a", 0}, {"b", 2}}, - } - hs1 := &protoHandshake{ - Version: baseProtocolVersion, - ID: node1.ID, - Caps: []Cap{{"c", 1}, {"d", 3}}, - } - fd0, fd1 := net.Pipe() - - done := make(chan struct{}) - keepalways := func(discover.NodeID) bool { return true } - go func() { - defer close(done) - conn0, err := setupConn(fd0, prv0, hs0, node1, keepalways) - if err != nil { - t.Errorf("outbound side error: %v", err) - return - } - if conn0.ID != node1.ID { - t.Errorf("outbound conn id mismatch: got %v, want %v", conn0.ID, node1.ID) - } - if !reflect.DeepEqual(conn0.Caps, hs1.Caps) { - t.Errorf("outbound caps mismatch: got %v, want %v", conn0.Caps, hs1.Caps) - } - }() - - conn1, err := setupConn(fd1, prv1, hs1, nil, keepalways) - if err != nil { - t.Fatalf("inbound side error: %v", err) - } - if conn1.ID != node0.ID { - t.Errorf("inbound conn id mismatch: got %v, want %v", conn1.ID, node0.ID) - } - if !reflect.DeepEqual(conn1.Caps, hs0.Caps) { - t.Errorf("inbound caps mismatch: got %v, want %v", conn1.Caps, hs0.Caps) - } - - <-done -} diff --git a/p2p/peer.go b/p2p/peer.go index 87a91d406..cbe5ccc84 100644 --- a/p2p/peer.go +++ b/p2p/peer.go @@ -33,9 +33,17 @@ const ( peersMsg = 0x05 ) +// protoHandshake is the RLP structure of the protocol handshake. +type protoHandshake struct { + Version uint64 + Name string + Caps []Cap + ListenPort uint64 + ID discover.NodeID +} + // Peer represents a connected remote node. type Peer struct { - conn net.Conn rw *conn running map[string]*protoRW @@ -48,37 +56,36 @@ type Peer struct { // NewPeer returns a peer for testing purposes. func NewPeer(id discover.NodeID, name string, caps []Cap) *Peer { pipe, _ := net.Pipe() - msgpipe, _ := MsgPipe() - conn := &conn{msgpipe, &protoHandshake{ID: id, Name: name, Caps: caps}} - peer := newPeer(pipe, conn, nil) + conn := &conn{fd: pipe, transport: nil, id: id, caps: caps, name: name} + peer := newPeer(conn, nil) close(peer.closed) // ensures Disconnect doesn't block return peer } // ID returns the node's public key. func (p *Peer) ID() discover.NodeID { - return p.rw.ID + return p.rw.id } // Name returns the node name that the remote node advertised. func (p *Peer) Name() string { - return p.rw.Name + return p.rw.name } // Caps returns the capabilities (supported subprotocols) of the remote peer. func (p *Peer) Caps() []Cap { // TODO: maybe return copy - return p.rw.Caps + return p.rw.caps } // RemoteAddr returns the remote address of the network connection. func (p *Peer) RemoteAddr() net.Addr { - return p.conn.RemoteAddr() + return p.rw.fd.RemoteAddr() } // LocalAddr returns the local address of the network connection. func (p *Peer) LocalAddr() net.Addr { - return p.conn.LocalAddr() + return p.rw.fd.LocalAddr() } // Disconnect terminates the peer connection with the given reason. @@ -92,13 +99,12 @@ func (p *Peer) Disconnect(reason DiscReason) { // String implements fmt.Stringer. func (p *Peer) String() string { - return fmt.Sprintf("Peer %.8x %v", p.rw.ID[:], p.RemoteAddr()) + return fmt.Sprintf("Peer %x %v", p.rw.id[:8], p.RemoteAddr()) } -func newPeer(fd net.Conn, conn *conn, protocols []Protocol) *Peer { - protomap := matchProtocols(protocols, conn.Caps, conn) +func newPeer(conn *conn, protocols []Protocol) *Peer { + protomap := matchProtocols(protocols, conn.caps, conn) p := &Peer{ - conn: fd, rw: conn, running: protomap, disc: make(chan DiscReason), @@ -117,7 +123,10 @@ func (p *Peer) run() DiscReason { p.startProtocols() // Wait for an error or disconnect. - var reason DiscReason + var ( + reason DiscReason + requested bool + ) select { case err := <-readErr: if r, ok := err.(DiscReason); ok { @@ -131,21 +140,17 @@ func (p *Peer) run() DiscReason { case err := <-p.protoErr: reason = discReasonForError(err) case reason = <-p.disc: - p.politeDisconnect(reason) - reason = DiscRequested + requested = true } - close(p.closed) + p.rw.close(reason) p.wg.Wait() - glog.V(logger.Debug).Infof("%v: Disconnected: %v\n", p, reason) - return reason -} -func (p *Peer) politeDisconnect(reason DiscReason) { - if reason != DiscNetworkError { - SendItems(p.rw, discMsg, uint(reason)) + if requested { + reason = DiscRequested } - p.conn.Close() + glog.V(logger.Debug).Infof("%v: Disconnected: %v\n", p, reason) + return reason } func (p *Peer) pingLoop() { @@ -254,7 +259,7 @@ func (p *Peer) startProtocols() { glog.V(logger.Detail).Infof("%v: Protocol %s/%d returned\n", p, proto.Name, proto.Version) err = errors.New("protocol returned") } else if err != io.EOF { - glog.V(logger.Detail).Infof("%v: Protocol %s/%d error: \n", p, proto.Name, proto.Version, err) + glog.V(logger.Detail).Infof("%v: Protocol %s/%d error: %v\n", p, proto.Name, proto.Version, err) } p.protoErr <- err p.wg.Done() diff --git a/p2p/peer_error.go b/p2p/peer_error.go index a912f6064..6938a9801 100644 --- a/p2p/peer_error.go +++ b/p2p/peer_error.go @@ -5,39 +5,17 @@ import ( ) const ( - errMagicTokenMismatch = iota - errRead - errWrite - errMisc - errInvalidMsgCode + errInvalidMsgCode = iota errInvalidMsg - errP2PVersionMismatch - errPubkeyInvalid - errPubkeyForbidden - errProtocolBreach - errPingTimeout - errInvalidNetworkId - errInvalidProtocolVersion ) var errorToString = map[int]string{ - errMagicTokenMismatch: "magic token mismatch", - errRead: "read error", - errWrite: "write error", - errMisc: "misc error", - errInvalidMsgCode: "invalid message code", - errInvalidMsg: "invalid message", - errP2PVersionMismatch: "P2P Version Mismatch", - errPubkeyInvalid: "public key invalid", - errPubkeyForbidden: "public key forbidden", - errProtocolBreach: "protocol Breach", - errPingTimeout: "ping timeout", - errInvalidNetworkId: "invalid network id", - errInvalidProtocolVersion: "invalid protocol version", + errInvalidMsgCode: "invalid message code", + errInvalidMsg: "invalid message", } type peerError struct { - Code int + code int message string } @@ -107,23 +85,13 @@ func discReasonForError(err error) DiscReason { return reason } peerError, ok := err.(*peerError) - if !ok { - return DiscSubprotocolError - } - switch peerError.Code { - case errP2PVersionMismatch: - return DiscIncompatibleVersion - case errPubkeyInvalid: - return DiscInvalidIdentity - case errPubkeyForbidden: - return DiscUselessPeer - case errInvalidMsgCode, errMagicTokenMismatch, errProtocolBreach: - return DiscProtocolError - case errPingTimeout: - return DiscReadTimeout - case errRead, errWrite: - return DiscNetworkError - default: - return DiscSubprotocolError + if ok { + switch peerError.code { + case errInvalidMsgCode, errInvalidMsg: + return DiscProtocolError + default: + return DiscSubprotocolError + } } + return DiscSubprotocolError } diff --git a/p2p/peer_test.go b/p2p/peer_test.go index 0ac032ab7..7b772e198 100644 --- a/p2p/peer_test.go +++ b/p2p/peer_test.go @@ -28,24 +28,20 @@ var discard = Protocol{ } func testPeer(protos []Protocol) (func(), *conn, *Peer, <-chan DiscReason) { - fd1, _ := net.Pipe() - hs1 := &protoHandshake{ID: randomID(), Version: baseProtocolVersion} - hs2 := &protoHandshake{ID: randomID(), Version: baseProtocolVersion} + fd1, fd2 := net.Pipe() + c1 := &conn{fd: fd1, transport: newTestTransport(randomID(), fd1)} + c2 := &conn{fd: fd2, transport: newTestTransport(randomID(), fd2)} for _, p := range protos { - hs1.Caps = append(hs1.Caps, p.cap()) - hs2.Caps = append(hs2.Caps, p.cap()) + c1.caps = append(c1.caps, p.cap()) + c2.caps = append(c2.caps, p.cap()) } - p1, p2 := MsgPipe() - peer := newPeer(fd1, &conn{p1, hs1}, protos) + peer := newPeer(c1, protos) errc := make(chan DiscReason, 1) go func() { errc <- peer.run() }() - closer := func() { - p1.Close() - fd1.Close() - } - return closer, &conn{p2, hs2}, peer, errc + closer := func() { c2.close(errors.New("close func called")) } + return closer, c2, peer, errc } func TestPeerProtoReadMsg(t *testing.T) { diff --git a/p2p/rlpx.go b/p2p/rlpx.go index 6b533e275..e1cb13aae 100644 --- a/p2p/rlpx.go +++ b/p2p/rlpx.go @@ -4,23 +4,459 @@ import ( "bytes" "crypto/aes" "crypto/cipher" + "crypto/ecdsa" + "crypto/elliptic" "crypto/hmac" + "crypto/rand" "errors" + "fmt" "hash" "io" + "net" + "sync" + "time" + "github.com/ethereum/go-ethereum/crypto" + "github.com/ethereum/go-ethereum/crypto/ecies" + "github.com/ethereum/go-ethereum/crypto/secp256k1" + "github.com/ethereum/go-ethereum/crypto/sha3" + "github.com/ethereum/go-ethereum/p2p/discover" "github.com/ethereum/go-ethereum/rlp" ) +const ( + maxUint24 = ^uint32(0) >> 8 + + sskLen = 16 // ecies.MaxSharedKeyLength(pubKey) / 2 + sigLen = 65 // elliptic S256 + pubLen = 64 // 512 bit pubkey in uncompressed representation without format byte + shaLen = 32 // hash length (for nonce etc) + + authMsgLen = sigLen + shaLen + pubLen + shaLen + 1 + authRespLen = pubLen + shaLen + 1 + + eciesBytes = 65 + 16 + 32 + encAuthMsgLen = authMsgLen + eciesBytes // size of the final ECIES payload sent as initiator's handshake + encAuthRespLen = authRespLen + eciesBytes // size of the final ECIES payload sent as receiver's handshake + + // total timeout for encryption handshake and protocol + // handshake in both directions. + handshakeTimeout = 5 * time.Second + + // This is the timeout for sending the disconnect reason. + // This is shorter than the usual timeout because we don't want + // to wait if the connection is known to be bad anyway. + discWriteTimeout = 1 * time.Second +) + +// rlpx is the transport protocol used by actual (non-test) connections. +// It wraps the frame encoder with locks and read/write deadlines. +type rlpx struct { + fd net.Conn + + rmu, wmu sync.Mutex + rw *rlpxFrameRW +} + +func newRLPX(fd net.Conn) transport { + fd.SetDeadline(time.Now().Add(handshakeTimeout)) + return &rlpx{fd: fd} +} + +func (t *rlpx) ReadMsg() (Msg, error) { + t.rmu.Lock() + defer t.rmu.Unlock() + t.fd.SetReadDeadline(time.Now().Add(frameReadTimeout)) + return t.rw.ReadMsg() +} + +func (t *rlpx) WriteMsg(msg Msg) error { + t.wmu.Lock() + defer t.wmu.Unlock() + t.fd.SetWriteDeadline(time.Now().Add(frameWriteTimeout)) + return t.rw.WriteMsg(msg) +} + +func (t *rlpx) close(err error) { + t.wmu.Lock() + defer t.wmu.Unlock() + // Tell the remote end why we're disconnecting if possible. + if t.rw != nil { + if r, ok := err.(DiscReason); ok && r != DiscNetworkError { + t.fd.SetWriteDeadline(time.Now().Add(discWriteTimeout)) + SendItems(t.rw, discMsg, r) + } + } + t.fd.Close() +} + +// doEncHandshake runs the protocol handshake using authenticated +// messages. the protocol handshake is the first authenticated message +// and also verifies whether the encryption handshake 'worked' and the +// remote side actually provided the right public key. +func (t *rlpx) doProtoHandshake(our *protoHandshake) (their *protoHandshake, err error) { + // Writing our handshake happens concurrently, we prefer + // returning the handshake read error. If the remote side + // disconnects us early with a valid reason, we should return it + // as the error so it can be tracked elsewhere. + werr := make(chan error, 1) + go func() { werr <- Send(t.rw, handshakeMsg, our) }() + if their, err = readProtocolHandshake(t.rw, our); err != nil { + return nil, err + } + if err := <-werr; err != nil { + return nil, fmt.Errorf("write error: %v", err) + } + return their, nil +} + +func readProtocolHandshake(rw MsgReader, our *protoHandshake) (*protoHandshake, error) { + msg, err := rw.ReadMsg() + if err != nil { + return nil, err + } + if msg.Size > baseProtocolMaxMsgSize { + return nil, fmt.Errorf("message too big") + } + if msg.Code == discMsg { + // Disconnect before protocol handshake is valid according to the + // spec and we send it ourself if the posthanshake checks fail. + // We can't return the reason directly, though, because it is echoed + // back otherwise. Wrap it in a string instead. + var reason [1]DiscReason + rlp.Decode(msg.Payload, &reason) + return nil, reason[0] + } + if msg.Code != handshakeMsg { + return nil, fmt.Errorf("expected handshake, got %x", msg.Code) + } + var hs protoHandshake + if err := msg.Decode(&hs); err != nil { + return nil, err + } + // validate handshake info + if hs.Version != our.Version { + return nil, DiscIncompatibleVersion + } + if (hs.ID == discover.NodeID{}) { + return nil, DiscInvalidIdentity + } + return &hs, nil +} + +func (t *rlpx) doEncHandshake(prv *ecdsa.PrivateKey, dial *discover.Node) (discover.NodeID, error) { + var ( + sec secrets + err error + ) + if dial == nil { + sec, err = receiverEncHandshake(t.fd, prv, nil) + } else { + sec, err = initiatorEncHandshake(t.fd, prv, dial.ID, nil) + } + if err != nil { + return discover.NodeID{}, err + } + t.wmu.Lock() + t.rw = newRLPXFrameRW(t.fd, sec) + t.wmu.Unlock() + return sec.RemoteID, nil +} + +// encHandshake contains the state of the encryption handshake. +type encHandshake struct { + initiator bool + remoteID discover.NodeID + + remotePub *ecies.PublicKey // remote-pubk + initNonce, respNonce []byte // nonce + randomPrivKey *ecies.PrivateKey // ecdhe-random + remoteRandomPub *ecies.PublicKey // ecdhe-random-pubk +} + +// secrets represents the connection secrets +// which are negotiated during the encryption handshake. +type secrets struct { + RemoteID discover.NodeID + AES, MAC []byte + EgressMAC, IngressMAC hash.Hash + Token []byte +} + +// secrets is called after the handshake is completed. +// It extracts the connection secrets from the handshake values. +func (h *encHandshake) secrets(auth, authResp []byte) (secrets, error) { + ecdheSecret, err := h.randomPrivKey.GenerateShared(h.remoteRandomPub, sskLen, sskLen) + if err != nil { + return secrets{}, err + } + + // derive base secrets from ephemeral key agreement + sharedSecret := crypto.Sha3(ecdheSecret, crypto.Sha3(h.respNonce, h.initNonce)) + aesSecret := crypto.Sha3(ecdheSecret, sharedSecret) + s := secrets{ + RemoteID: h.remoteID, + AES: aesSecret, + MAC: crypto.Sha3(ecdheSecret, aesSecret), + Token: crypto.Sha3(sharedSecret), + } + + // setup sha3 instances for the MACs + mac1 := sha3.NewKeccak256() + mac1.Write(xor(s.MAC, h.respNonce)) + mac1.Write(auth) + mac2 := sha3.NewKeccak256() + mac2.Write(xor(s.MAC, h.initNonce)) + mac2.Write(authResp) + if h.initiator { + s.EgressMAC, s.IngressMAC = mac1, mac2 + } else { + s.EgressMAC, s.IngressMAC = mac2, mac1 + } + + return s, nil +} + +func (h *encHandshake) ecdhShared(prv *ecdsa.PrivateKey) ([]byte, error) { + return ecies.ImportECDSA(prv).GenerateShared(h.remotePub, sskLen, sskLen) +} + +// initiatorEncHandshake negotiates a session token on conn. +// it should be called on the dialing side of the connection. +// +// prv is the local client's private key. +// token is the token from a previous session with this node. +func initiatorEncHandshake(conn io.ReadWriter, prv *ecdsa.PrivateKey, remoteID discover.NodeID, token []byte) (s secrets, err error) { + h, err := newInitiatorHandshake(remoteID) + if err != nil { + return s, err + } + auth, err := h.authMsg(prv, token) + if err != nil { + return s, err + } + if _, err = conn.Write(auth); err != nil { + return s, err + } + + response := make([]byte, encAuthRespLen) + if _, err = io.ReadFull(conn, response); err != nil { + return s, err + } + if err := h.decodeAuthResp(response, prv); err != nil { + return s, err + } + return h.secrets(auth, response) +} + +func newInitiatorHandshake(remoteID discover.NodeID) (*encHandshake, error) { + // generate random initiator nonce + n := make([]byte, shaLen) + if _, err := rand.Read(n); err != nil { + return nil, err + } + // generate random keypair to use for signing + randpriv, err := ecies.GenerateKey(rand.Reader, crypto.S256(), nil) + if err != nil { + return nil, err + } + rpub, err := remoteID.Pubkey() + if err != nil { + return nil, fmt.Errorf("bad remoteID: %v", err) + } + h := &encHandshake{ + initiator: true, + remoteID: remoteID, + remotePub: ecies.ImportECDSAPublic(rpub), + initNonce: n, + randomPrivKey: randpriv, + } + return h, nil +} + +// authMsg creates an encrypted initiator handshake message. +func (h *encHandshake) authMsg(prv *ecdsa.PrivateKey, token []byte) ([]byte, error) { + var tokenFlag byte + if token == nil { + // no session token found means we need to generate shared secret. + // ecies shared secret is used as initial session token for new peers + // generate shared key from prv and remote pubkey + var err error + if token, err = h.ecdhShared(prv); err != nil { + return nil, err + } + } else { + // for known peers, we use stored token from the previous session + tokenFlag = 0x01 + } + + // sign known message: + // ecdh-shared-secret^nonce for new peers + // token^nonce for old peers + signed := xor(token, h.initNonce) + signature, err := crypto.Sign(signed, h.randomPrivKey.ExportECDSA()) + if err != nil { + return nil, err + } + + // encode auth message + // signature || sha3(ecdhe-random-pubk) || pubk || nonce || token-flag + msg := make([]byte, authMsgLen) + n := copy(msg, signature) + n += copy(msg[n:], crypto.Sha3(exportPubkey(&h.randomPrivKey.PublicKey))) + n += copy(msg[n:], crypto.FromECDSAPub(&prv.PublicKey)[1:]) + n += copy(msg[n:], h.initNonce) + msg[n] = tokenFlag + + // encrypt auth message using remote-pubk + return ecies.Encrypt(rand.Reader, h.remotePub, msg, nil, nil) +} + +// decodeAuthResp decode an encrypted authentication response message. +func (h *encHandshake) decodeAuthResp(auth []byte, prv *ecdsa.PrivateKey) error { + msg, err := crypto.Decrypt(prv, auth) + if err != nil { + return fmt.Errorf("could not decrypt auth response (%v)", err) + } + h.respNonce = msg[pubLen : pubLen+shaLen] + h.remoteRandomPub, err = importPublicKey(msg[:pubLen]) + if err != nil { + return err + } + // ignore token flag for now + return nil +} + +// receiverEncHandshake negotiates a session token on conn. +// it should be called on the listening side of the connection. +// +// prv is the local client's private key. +// token is the token from a previous session with this node. +func receiverEncHandshake(conn io.ReadWriter, prv *ecdsa.PrivateKey, token []byte) (s secrets, err error) { + // read remote auth sent by initiator. + auth := make([]byte, encAuthMsgLen) + if _, err := io.ReadFull(conn, auth); err != nil { + return s, err + } + h, err := decodeAuthMsg(prv, token, auth) + if err != nil { + return s, err + } + + // send auth response + resp, err := h.authResp(prv, token) + if err != nil { + return s, err + } + if _, err = conn.Write(resp); err != nil { + return s, err + } + + return h.secrets(auth, resp) +} + +func decodeAuthMsg(prv *ecdsa.PrivateKey, token []byte, auth []byte) (*encHandshake, error) { + var err error + h := new(encHandshake) + // generate random keypair for session + h.randomPrivKey, err = ecies.GenerateKey(rand.Reader, crypto.S256(), nil) + if err != nil { + return nil, err + } + // generate random nonce + h.respNonce = make([]byte, shaLen) + if _, err = rand.Read(h.respNonce); err != nil { + return nil, err + } + + msg, err := crypto.Decrypt(prv, auth) + if err != nil { + return nil, fmt.Errorf("could not decrypt auth message (%v)", err) + } + + // decode message parameters + // signature || sha3(ecdhe-random-pubk) || pubk || nonce || token-flag + h.initNonce = msg[authMsgLen-shaLen-1 : authMsgLen-1] + copy(h.remoteID[:], msg[sigLen+shaLen:sigLen+shaLen+pubLen]) + rpub, err := h.remoteID.Pubkey() + if err != nil { + return nil, fmt.Errorf("bad remoteID: %#v", err) + } + h.remotePub = ecies.ImportECDSAPublic(rpub) + + // recover remote random pubkey from signed message. + if token == nil { + // TODO: it is an error if the initiator has a token and we don't. check that. + + // no session token means we need to generate shared secret. + // ecies shared secret is used as initial session token for new peers. + // generate shared key from prv and remote pubkey. + if token, err = h.ecdhShared(prv); err != nil { + return nil, err + } + } + signedMsg := xor(token, h.initNonce) + remoteRandomPub, err := secp256k1.RecoverPubkey(signedMsg, msg[:sigLen]) + if err != nil { + return nil, err + } + h.remoteRandomPub, _ = importPublicKey(remoteRandomPub) + return h, nil +} + +// authResp generates the encrypted authentication response message. +func (h *encHandshake) authResp(prv *ecdsa.PrivateKey, token []byte) ([]byte, error) { + // responder auth message + // E(remote-pubk, ecdhe-random-pubk || nonce || 0x0) + resp := make([]byte, authRespLen) + n := copy(resp, exportPubkey(&h.randomPrivKey.PublicKey)) + n += copy(resp[n:], h.respNonce) + if token == nil { + resp[n] = 0 + } else { + resp[n] = 1 + } + // encrypt using remote-pubk + return ecies.Encrypt(rand.Reader, h.remotePub, resp, nil, nil) +} + +// importPublicKey unmarshals 512 bit public keys. +func importPublicKey(pubKey []byte) (*ecies.PublicKey, error) { + var pubKey65 []byte + switch len(pubKey) { + case 64: + // add 'uncompressed key' flag + pubKey65 = append([]byte{0x04}, pubKey...) + case 65: + pubKey65 = pubKey + default: + return nil, fmt.Errorf("invalid public key length %v (expect 64/65)", len(pubKey)) + } + // TODO: fewer pointless conversions + return ecies.ImportECDSAPublic(crypto.ToECDSAPub(pubKey65)), nil +} + +func exportPubkey(pub *ecies.PublicKey) []byte { + if pub == nil { + panic("nil pubkey") + } + return elliptic.Marshal(pub.Curve, pub.X, pub.Y)[1:] +} + +func xor(one, other []byte) (xor []byte) { + xor = make([]byte, len(one)) + for i := 0; i < len(one); i++ { + xor[i] = one[i] ^ other[i] + } + return xor +} + var ( // this is used in place of actual frame header data. // TODO: replace this when Msg contains the protocol type code. zeroHeader = []byte{0xC2, 0x80, 0x80} - // sixteen zero bytes zero16 = make([]byte, 16) - - maxUint24 = ^uint32(0) >> 8 ) // rlpxFrameRW implements a simplified version of RLPx framing. @@ -38,7 +474,7 @@ type rlpxFrameRW struct { ingressMAC hash.Hash } -func newRlpxFrameRW(conn io.ReadWriter, s secrets) *rlpxFrameRW { +func newRLPXFrameRW(conn io.ReadWriter, s secrets) *rlpxFrameRW { macc, err := aes.NewCipher(s.MAC) if err != nil { panic("invalid MAC secret: " + err.Error()) diff --git a/p2p/rlpx_test.go b/p2p/rlpx_test.go index d98f1c2cd..44be46a99 100644 --- a/p2p/rlpx_test.go +++ b/p2p/rlpx_test.go @@ -3,19 +3,253 @@ package p2p import ( "bytes" "crypto/rand" + "errors" + "fmt" "io/ioutil" + "net" + "reflect" "strings" + "sync" "testing" + "time" + "github.com/davecgh/go-spew/spew" "github.com/ethereum/go-ethereum/crypto" + "github.com/ethereum/go-ethereum/crypto/ecies" "github.com/ethereum/go-ethereum/crypto/sha3" + "github.com/ethereum/go-ethereum/p2p/discover" "github.com/ethereum/go-ethereum/rlp" ) -func TestRlpxFrameFake(t *testing.T) { +func TestSharedSecret(t *testing.T) { + prv0, _ := crypto.GenerateKey() // = ecdsa.GenerateKey(crypto.S256(), rand.Reader) + pub0 := &prv0.PublicKey + prv1, _ := crypto.GenerateKey() + pub1 := &prv1.PublicKey + + ss0, err := ecies.ImportECDSA(prv0).GenerateShared(ecies.ImportECDSAPublic(pub1), sskLen, sskLen) + if err != nil { + return + } + ss1, err := ecies.ImportECDSA(prv1).GenerateShared(ecies.ImportECDSAPublic(pub0), sskLen, sskLen) + if err != nil { + return + } + t.Logf("Secret:\n%v %x\n%v %x", len(ss0), ss0, len(ss0), ss1) + if !bytes.Equal(ss0, ss1) { + t.Errorf("dont match :(") + } +} + +func TestEncHandshake(t *testing.T) { + for i := 0; i < 10; i++ { + start := time.Now() + if err := testEncHandshake(nil); err != nil { + t.Fatalf("i=%d %v", i, err) + } + t.Logf("(without token) %d %v\n", i+1, time.Since(start)) + } + for i := 0; i < 10; i++ { + tok := make([]byte, shaLen) + rand.Reader.Read(tok) + start := time.Now() + if err := testEncHandshake(tok); err != nil { + t.Fatalf("i=%d %v", i, err) + } + t.Logf("(with token) %d %v\n", i+1, time.Since(start)) + } +} + +func testEncHandshake(token []byte) error { + type result struct { + side string + id discover.NodeID + err error + } + var ( + prv0, _ = crypto.GenerateKey() + prv1, _ = crypto.GenerateKey() + fd0, fd1 = net.Pipe() + c0, c1 = newRLPX(fd0).(*rlpx), newRLPX(fd1).(*rlpx) + output = make(chan result) + ) + + go func() { + r := result{side: "initiator"} + defer func() { output <- r }() + + dest := &discover.Node{ID: discover.PubkeyID(&prv1.PublicKey)} + r.id, r.err = c0.doEncHandshake(prv0, dest) + if r.err != nil { + return + } + id1 := discover.PubkeyID(&prv1.PublicKey) + if r.id != id1 { + r.err = fmt.Errorf("remote ID mismatch: got %v, want: %v", r.id, id1) + } + }() + go func() { + r := result{side: "receiver"} + defer func() { output <- r }() + + r.id, r.err = c1.doEncHandshake(prv1, nil) + if r.err != nil { + return + } + id0 := discover.PubkeyID(&prv0.PublicKey) + if r.id != id0 { + r.err = fmt.Errorf("remote ID mismatch: got %v, want: %v", r.id, id0) + } + }() + + // wait for results from both sides + r1, r2 := <-output, <-output + if r1.err != nil { + return fmt.Errorf("%s side error: %v", r1.side, r1.err) + } + if r2.err != nil { + return fmt.Errorf("%s side error: %v", r2.side, r2.err) + } + + // compare derived secrets + if !reflect.DeepEqual(c0.rw.egressMAC, c1.rw.ingressMAC) { + return fmt.Errorf("egress mac mismatch:\n c0.rw: %#v\n c1.rw: %#v", c0.rw.egressMAC, c1.rw.ingressMAC) + } + if !reflect.DeepEqual(c0.rw.ingressMAC, c1.rw.egressMAC) { + return fmt.Errorf("ingress mac mismatch:\n c0.rw: %#v\n c1.rw: %#v", c0.rw.ingressMAC, c1.rw.egressMAC) + } + if !reflect.DeepEqual(c0.rw.enc, c1.rw.enc) { + return fmt.Errorf("enc cipher mismatch:\n c0.rw: %#v\n c1.rw: %#v", c0.rw.enc, c1.rw.enc) + } + if !reflect.DeepEqual(c0.rw.dec, c1.rw.dec) { + return fmt.Errorf("dec cipher mismatch:\n c0.rw: %#v\n c1.rw: %#v", c0.rw.dec, c1.rw.dec) + } + return nil +} + +func TestProtocolHandshake(t *testing.T) { + var ( + prv0, _ = crypto.GenerateKey() + node0 = &discover.Node{ID: discover.PubkeyID(&prv0.PublicKey), IP: net.IP{1, 2, 3, 4}, TCP: 33} + hs0 = &protoHandshake{Version: 3, ID: node0.ID, Caps: []Cap{{"a", 0}, {"b", 2}}} + + prv1, _ = crypto.GenerateKey() + node1 = &discover.Node{ID: discover.PubkeyID(&prv1.PublicKey), IP: net.IP{5, 6, 7, 8}, TCP: 44} + hs1 = &protoHandshake{Version: 3, ID: node1.ID, Caps: []Cap{{"c", 1}, {"d", 3}}} + + fd0, fd1 = net.Pipe() + wg sync.WaitGroup + ) + + wg.Add(2) + go func() { + defer wg.Done() + rlpx := newRLPX(fd0) + remid, err := rlpx.doEncHandshake(prv0, node1) + if err != nil { + t.Errorf("dial side enc handshake failed: %v", err) + return + } + if remid != node1.ID { + t.Errorf("dial side remote id mismatch: got %v, want %v", remid, node1.ID) + return + } + + phs, err := rlpx.doProtoHandshake(hs0) + if err != nil { + t.Errorf("dial side proto handshake error: %v", err) + return + } + if !reflect.DeepEqual(phs, hs1) { + t.Errorf("dial side proto handshake mismatch:\ngot: %s\nwant: %s\n", spew.Sdump(phs), spew.Sdump(hs1)) + return + } + rlpx.close(DiscQuitting) + }() + go func() { + defer wg.Done() + rlpx := newRLPX(fd1) + remid, err := rlpx.doEncHandshake(prv1, nil) + if err != nil { + t.Errorf("listen side enc handshake failed: %v", err) + return + } + if remid != node0.ID { + t.Errorf("listen side remote id mismatch: got %v, want %v", remid, node0.ID) + return + } + + phs, err := rlpx.doProtoHandshake(hs1) + if err != nil { + t.Errorf("listen side proto handshake error: %v", err) + return + } + if !reflect.DeepEqual(phs, hs0) { + t.Errorf("listen side proto handshake mismatch:\ngot: %s\nwant: %s\n", spew.Sdump(phs), spew.Sdump(hs0)) + return + } + + if err := ExpectMsg(rlpx, discMsg, []DiscReason{DiscQuitting}); err != nil { + t.Errorf("error receiving disconnect: %v", err) + } + }() + wg.Wait() +} + +func TestProtocolHandshakeErrors(t *testing.T) { + our := &protoHandshake{Version: 3, Caps: []Cap{{"foo", 2}, {"bar", 3}}, Name: "quux"} + id := randomID() + tests := []struct { + code uint64 + msg interface{} + err error + }{ + { + code: discMsg, + msg: []DiscReason{DiscQuitting}, + err: DiscQuitting, + }, + { + code: 0x989898, + msg: []byte{1}, + err: errors.New("expected handshake, got 989898"), + }, + { + code: handshakeMsg, + msg: make([]byte, baseProtocolMaxMsgSize+2), + err: errors.New("message too big"), + }, + { + code: handshakeMsg, + msg: []byte{1, 2, 3}, + err: newPeerError(errInvalidMsg, "(code 0) (size 4) rlp: expected input list for p2p.protoHandshake"), + }, + { + code: handshakeMsg, + msg: &protoHandshake{Version: 9944, ID: id}, + err: DiscIncompatibleVersion, + }, + { + code: handshakeMsg, + msg: &protoHandshake{Version: 3}, + err: DiscInvalidIdentity, + }, + } + + for i, test := range tests { + p1, p2 := MsgPipe() + go Send(p1, test.code, test.msg) + _, err := readProtocolHandshake(p2, our) + if !reflect.DeepEqual(err, test.err) { + t.Errorf("test %d: error mismatch: got %q, want %q", i, err, test.err) + } + } +} + +func TestRLPXFrameFake(t *testing.T) { buf := new(bytes.Buffer) hash := fakeHash([]byte{1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1}) - rw := newRlpxFrameRW(buf, secrets{ + rw := newRLPXFrameRW(buf, secrets{ AES: crypto.Sha3(), MAC: crypto.Sha3(), IngressMAC: hash, @@ -66,7 +300,7 @@ func (fakeHash) BlockSize() int { return 0 } func (h fakeHash) Size() int { return len(h) } func (h fakeHash) Sum(b []byte) []byte { return append(b, h...) } -func TestRlpxFrameRW(t *testing.T) { +func TestRLPXFrameRW(t *testing.T) { var ( aesSecret = make([]byte, 16) macSecret = make([]byte, 16) @@ -86,7 +320,7 @@ func TestRlpxFrameRW(t *testing.T) { } s1.EgressMAC.Write(egressMACinit) s1.IngressMAC.Write(ingressMACinit) - rw1 := newRlpxFrameRW(conn, s1) + rw1 := newRLPXFrameRW(conn, s1) s2 := secrets{ AES: aesSecret, @@ -96,7 +330,7 @@ func TestRlpxFrameRW(t *testing.T) { } s2.EgressMAC.Write(ingressMACinit) s2.IngressMAC.Write(egressMACinit) - rw2 := newRlpxFrameRW(conn, s2) + rw2 := newRLPXFrameRW(conn, s2) // send some messages for i := 0; i < 10; i++ { diff --git a/p2p/server.go b/p2p/server.go index 529fedbca..27e617610 100644 --- a/p2p/server.go +++ b/p2p/server.go @@ -2,7 +2,6 @@ package p2p import ( "crypto/ecdsa" - "crypto/rand" "errors" "fmt" "net" @@ -24,11 +23,8 @@ const ( maxAcceptConns = 50 // Maximum number of concurrently dialing outbound connections. - maxDialingConns = 10 + maxActiveDialTasks = 16 - // total timeout for encryption handshake and protocol - // handshake in both directions. - handshakeTimeout = 5 * time.Second // maximum time allowed for reading a complete message. // this is effectively the amount of time a connection can be idle. frameReadTimeout = 1 * time.Minute @@ -36,6 +32,8 @@ const ( frameWriteTimeout = 5 * time.Second ) +var errServerStopped = errors.New("server stopped") + var srvjslog = logger.NewJsonLogger() // Server manages all peer connections. @@ -103,68 +101,173 @@ type Server struct { // Hooks for testing. These are useful because we can inhibit // the whole protocol stack. - setupFunc - newPeerHook + newTransport func(net.Conn) transport + newPeerHook func(*Peer) + + lock sync.Mutex // protects running + running bool + ntab discoverTable + listener net.Listener ourHandshake *protoHandshake - lock sync.RWMutex // protects running, peers and the trust fields - running bool - peers map[discover.NodeID]*Peer - staticNodes map[discover.NodeID]*discover.Node // Map of currently maintained static remote nodes - staticDial chan *discover.Node // Dial request channel reserved for the static nodes - staticCycle time.Duration // Overrides staticPeerCheckInterval, used for testing - trustedNodes map[discover.NodeID]bool // Set of currently trusted remote nodes + // These are for Peers, PeerCount (and nothing else). + peerOp chan peerOpFunc + peerOpDone chan struct{} + + quit chan struct{} + addstatic chan *discover.Node + posthandshake chan *conn + addpeer chan *conn + delpeer chan *Peer + loopWG sync.WaitGroup // loop, listenLoop +} + +type peerOpFunc func(map[discover.NodeID]*Peer) + +type connFlag int - ntab *discover.Table - listener net.Listener +const ( + dynDialedConn connFlag = 1 << iota + staticDialedConn + inboundConn + trustedConn +) + +// conn wraps a network connection with information gathered +// during the two handshakes. +type conn struct { + fd net.Conn + transport + flags connFlag + cont chan error // The run loop uses cont to signal errors to setupConn. + id discover.NodeID // valid after the encryption handshake + caps []Cap // valid after the protocol handshake + name string // valid after the protocol handshake +} - quit chan struct{} - loopWG sync.WaitGroup // {dial,listen,nat}Loop - peerWG sync.WaitGroup // active peer goroutines +type transport interface { + // The two handshakes. + doEncHandshake(prv *ecdsa.PrivateKey, dialDest *discover.Node) (discover.NodeID, error) + doProtoHandshake(our *protoHandshake) (*protoHandshake, error) + // The MsgReadWriter can only be used after the encryption + // handshake has completed. The code uses conn.id to track this + // by setting it to a non-nil value after the encryption handshake. + MsgReadWriter + // transports must provide Close because we use MsgPipe in some of + // the tests. Closing the actual network connection doesn't do + // anything in those tests because NsgPipe doesn't use it. + close(err error) } -type setupFunc func(net.Conn, *ecdsa.PrivateKey, *protoHandshake, *discover.Node, func(discover.NodeID) bool) (*conn, error) -type newPeerHook func(*Peer) +func (c *conn) String() string { + s := c.flags.String() + " conn" + if (c.id != discover.NodeID{}) { + s += fmt.Sprintf(" %x", c.id[:8]) + } + s += " " + c.fd.RemoteAddr().String() + return s +} + +func (f connFlag) String() string { + s := "" + if f&trustedConn != 0 { + s += " trusted" + } + if f&dynDialedConn != 0 { + s += " dyn dial" + } + if f&staticDialedConn != 0 { + s += " static dial" + } + if f&inboundConn != 0 { + s += " inbound" + } + if s != "" { + s = s[1:] + } + return s +} + +func (c *conn) is(f connFlag) bool { + return c.flags&f != 0 +} // Peers returns all connected peers. -func (srv *Server) Peers() (peers []*Peer) { - srv.lock.RLock() - defer srv.lock.RUnlock() - for _, peer := range srv.peers { - if peer != nil { - peers = append(peers, peer) +func (srv *Server) Peers() []*Peer { + var ps []*Peer + select { + // Note: We'd love to put this function into a variable but + // that seems to cause a weird compiler error in some + // environments. + case srv.peerOp <- func(peers map[discover.NodeID]*Peer) { + for _, p := range peers { + ps = append(ps, p) } + }: + <-srv.peerOpDone + case <-srv.quit: } - return + return ps } // PeerCount returns the number of connected peers. func (srv *Server) PeerCount() int { - srv.lock.RLock() - n := len(srv.peers) - srv.lock.RUnlock() - return n + var count int + select { + case srv.peerOp <- func(ps map[discover.NodeID]*Peer) { count = len(ps) }: + <-srv.peerOpDone + case <-srv.quit: + } + return count } // AddPeer connects to the given node and maintains the connection until the // server is shut down. If the connection fails for any reason, the server will // attempt to reconnect the peer. func (srv *Server) AddPeer(node *discover.Node) { + select { + case srv.addstatic <- node: + case <-srv.quit: + } +} + +// Self returns the local node's endpoint information. +func (srv *Server) Self() *discover.Node { srv.lock.Lock() defer srv.lock.Unlock() + if !srv.running { + return &discover.Node{IP: net.ParseIP("0.0.0.0")} + } + return srv.ntab.Self() +} - srv.staticNodes[node.ID] = node +// Stop terminates the server and all active peer connections. +// It blocks until all active connections have been closed. +func (srv *Server) Stop() { + srv.lock.Lock() + defer srv.lock.Unlock() + if !srv.running { + return + } + srv.running = false + if srv.listener != nil { + // this unblocks listener Accept + srv.listener.Close() + } + close(srv.quit) + srv.loopWG.Wait() } // Start starts running the server. -// Servers can be re-used and started again after stopping. +// Servers can not be re-used after stopping. func (srv *Server) Start() (err error) { srv.lock.Lock() defer srv.lock.Unlock() if srv.running { return errors.New("server already running") } + srv.running = true glog.V(logger.Info).Infoln("Starting Server") // static fields @@ -174,23 +277,19 @@ func (srv *Server) Start() (err error) { if srv.MaxPeers <= 0 { return fmt.Errorf("Server.MaxPeers must be > 0") } - srv.quit = make(chan struct{}) - srv.peers = make(map[discover.NodeID]*Peer) - - // Create the current trust maps, and the associated dialing channel - srv.trustedNodes = make(map[discover.NodeID]bool) - for _, node := range srv.TrustedNodes { - srv.trustedNodes[node.ID] = true - } - srv.staticNodes = make(map[discover.NodeID]*discover.Node) - for _, node := range srv.StaticNodes { - srv.staticNodes[node.ID] = node + if srv.newTransport == nil { + srv.newTransport = newRLPX } - srv.staticDial = make(chan *discover.Node) - - if srv.setupFunc == nil { - srv.setupFunc = setupConn + if srv.Dialer == nil { + srv.Dialer = &net.Dialer{Timeout: defaultDialTimeout} } + srv.quit = make(chan struct{}) + srv.addpeer = make(chan *conn) + srv.delpeer = make(chan *Peer) + srv.posthandshake = make(chan *conn) + srv.addstatic = make(chan *discover.Node) + srv.peerOp = make(chan peerOpFunc) + srv.peerOpDone = make(chan struct{}) // node table ntab, err := discover.ListenUDP(srv.PrivateKey, srv.ListenAddr, srv.NAT, srv.NodeDatabase) @@ -198,37 +297,31 @@ func (srv *Server) Start() (err error) { return err } srv.ntab = ntab + dialer := newDialState(srv.StaticNodes, srv.ntab, srv.MaxPeers/2) // handshake srv.ourHandshake = &protoHandshake{Version: baseProtocolVersion, Name: srv.Name, ID: ntab.Self().ID} for _, p := range srv.Protocols { srv.ourHandshake.Caps = append(srv.ourHandshake.Caps, p.cap()) } - // listen/dial if srv.ListenAddr != "" { if err := srv.startListening(); err != nil { return err } } - if srv.Dialer == nil { - srv.Dialer = &net.Dialer{Timeout: defaultDialTimeout} - } - if !srv.NoDial { - srv.loopWG.Add(1) - go srv.dialLoop() - } if srv.NoDial && srv.ListenAddr == "" { glog.V(logger.Warn).Infoln("I will be kind-of useless, neither dialing nor listening.") } - // maintain the static peers - go srv.staticNodesLoop() + srv.loopWG.Add(1) + go srv.run(dialer) srv.running = true return nil } func (srv *Server) startListening() error { + // Launch the TCP listener. listener, err := net.Listen("tcp", srv.ListenAddr) if err != nil { return err @@ -238,6 +331,7 @@ func (srv *Server) startListening() error { srv.listener = listener srv.loopWG.Add(1) go srv.listenLoop() + // Map the TCP listening port if NAT is configured. if !laddr.IP.IsLoopback() && srv.NAT != nil { srv.loopWG.Add(1) go func() { @@ -248,50 +342,164 @@ func (srv *Server) startListening() error { return nil } -// Stop terminates the server and all active peer connections. -// It blocks until all active connections have been closed. -func (srv *Server) Stop() { - srv.lock.Lock() - if !srv.running { - srv.lock.Unlock() - return +type dialer interface { + newTasks(running int, peers map[discover.NodeID]*Peer, now time.Time) []task + taskDone(task, time.Time) + addStatic(*discover.Node) +} + +func (srv *Server) run(dialstate dialer) { + defer srv.loopWG.Done() + var ( + peers = make(map[discover.NodeID]*Peer) + trusted = make(map[discover.NodeID]bool, len(srv.TrustedNodes)) + + tasks []task + pendingTasks []task + taskdone = make(chan task, maxActiveDialTasks) + ) + // Put trusted nodes into a map to speed up checks. + // Trusted peers are loaded on startup and cannot be + // modified while the server is running. + for _, n := range srv.TrustedNodes { + trusted[n.ID] = true + } + + // Some task list helpers. + delTask := func(t task) { + for i := range tasks { + if tasks[i] == t { + tasks = append(tasks[:i], tasks[i+1:]...) + break + } + } } - srv.running = false - srv.lock.Unlock() + scheduleTasks := func(new []task) { + pt := append(pendingTasks, new...) + start := maxActiveDialTasks - len(tasks) + if len(pt) < start { + start = len(pt) + } + if start > 0 { + tasks = append(tasks, pt[:start]...) + for _, t := range pt[:start] { + t := t + glog.V(logger.Detail).Infoln("new task:", t) + go func() { t.Do(srv); taskdone <- t }() + } + copy(pt, pt[start:]) + pendingTasks = pt[:len(pt)-start] + } + } + +running: + for { + // Query the dialer for new tasks and launch them. + now := time.Now() + nt := dialstate.newTasks(len(pendingTasks)+len(tasks), peers, now) + scheduleTasks(nt) - glog.V(logger.Info).Infoln("Stopping Server") + select { + case <-srv.quit: + // The server was stopped. Run the cleanup logic. + glog.V(logger.Detail).Infoln("<-quit: spinning down") + break running + case n := <-srv.addstatic: + // This channel is used by AddPeer to add to the + // ephemeral static peer list. Add it to the dialer, + // it will keep the node connected. + glog.V(logger.Detail).Infoln("<-addstatic:", n) + dialstate.addStatic(n) + case op := <-srv.peerOp: + // This channel is used by Peers and PeerCount. + op(peers) + srv.peerOpDone <- struct{}{} + case t := <-taskdone: + // A task got done. Tell dialstate about it so it + // can update its state and remove it from the active + // tasks list. + glog.V(logger.Detail).Infoln("<-taskdone:", t) + dialstate.taskDone(t, now) + delTask(t) + case c := <-srv.posthandshake: + // A connection has passed the encryption handshake so + // the remote identity is known (but hasn't been verified yet). + if trusted[c.id] { + // Ensure that the trusted flag is set before checking against MaxPeers. + c.flags |= trustedConn + } + glog.V(logger.Detail).Infoln("<-posthandshake:", c) + // TODO: track in-progress inbound node IDs (pre-Peer) to avoid dialing them. + c.cont <- srv.encHandshakeChecks(peers, c) + case c := <-srv.addpeer: + // At this point the connection is past the protocol handshake. + // Its capabilities are known and the remote identity is verified. + glog.V(logger.Detail).Infoln("<-addpeer:", c) + err := srv.protoHandshakeChecks(peers, c) + if err != nil { + glog.V(logger.Detail).Infof("Not adding %v as peer: %v", c, err) + } else { + // The handshakes are done and it passed all checks. + p := newPeer(c, srv.Protocols) + peers[c.id] = p + go srv.runPeer(p) + } + // The dialer logic relies on the assumption that + // dial tasks complete after the peer has been added or + // discarded. Unblock the task last. + c.cont <- err + case p := <-srv.delpeer: + // A peer disconnected. + glog.V(logger.Detail).Infoln("<-delpeer:", p) + delete(peers, p.ID()) + } + } + + // Terminate discovery. If there is a running lookup it will terminate soon. srv.ntab.Close() - if srv.listener != nil { - // this unblocks listener Accept - srv.listener.Close() + // Disconnect all peers. + for _, p := range peers { + p.Disconnect(DiscQuitting) + } + // Wait for peers to shut down. Pending connections and tasks are + // not handled here and will terminate soon-ish because srv.quit + // is closed. + glog.V(logger.Detail).Infof("ignoring %d pending tasks at spindown", len(tasks)) + for len(peers) > 0 { + p := <-srv.delpeer + glog.V(logger.Detail).Infoln("<-delpeer (spindown):", p) + delete(peers, p.ID()) } - close(srv.quit) - srv.loopWG.Wait() +} - // No new peers can be added at this point because dialLoop and - // listenLoop are down. It is safe to call peerWG.Wait because - // peerWG.Add is not called outside of those loops. - srv.lock.Lock() - for _, peer := range srv.peers { - peer.Disconnect(DiscQuitting) +func (srv *Server) protoHandshakeChecks(peers map[discover.NodeID]*Peer, c *conn) error { + // Drop connections with no matching protocols. + if len(srv.Protocols) > 0 && countMatchingProtocols(srv.Protocols, c.caps) == 0 { + return DiscUselessPeer } - srv.lock.Unlock() - srv.peerWG.Wait() + // Repeat the encryption handshake checks because the + // peer set might have changed between the handshakes. + return srv.encHandshakeChecks(peers, c) } -// Self returns the local node's endpoint information. -func (srv *Server) Self() *discover.Node { - srv.lock.RLock() - defer srv.lock.RUnlock() - if !srv.running { - return &discover.Node{IP: net.ParseIP("0.0.0.0")} +func (srv *Server) encHandshakeChecks(peers map[discover.NodeID]*Peer, c *conn) error { + switch { + case !c.is(trustedConn|staticDialedConn) && len(peers) >= srv.MaxPeers: + return DiscTooManyPeers + case peers[c.id] != nil: + return DiscAlreadyConnected + case c.id == srv.ntab.Self().ID: + return DiscSelf + default: + return nil } - return srv.ntab.Self() } -// main loop for adding connections via listening +// listenLoop runs in its own goroutine and accepts +// inbound connections. func (srv *Server) listenLoop() { defer srv.loopWG.Done() + glog.V(logger.Info).Infoln("Listening on", srv.listener.Addr()) // This channel acts as a semaphore limiting // active inbound connections that are lingering pre-handshake. @@ -305,204 +513,92 @@ func (srv *Server) listenLoop() { slots <- struct{}{} } - glog.V(logger.Info).Infoln("Listening on", srv.listener.Addr()) for { <-slots - conn, err := srv.listener.Accept() + fd, err := srv.listener.Accept() if err != nil { return } - glog.V(logger.Debug).Infof("Accepted conn %v\n", conn.RemoteAddr()) - srv.peerWG.Add(1) + glog.V(logger.Debug).Infof("Accepted conn %v\n", fd.RemoteAddr()) go func() { - srv.startPeer(conn, nil) + srv.setupConn(fd, inboundConn, nil) slots <- struct{}{} }() } } -// staticNodesLoop is responsible for periodically checking that static -// connections are actually live, and requests dialing if not. -func (srv *Server) staticNodesLoop() { - // Create a default maintenance ticker, but override it requested - cycle := staticPeerCheckInterval - if srv.staticCycle != 0 { - cycle = srv.staticCycle - } - tick := time.NewTicker(cycle) - - for { - select { - case <-srv.quit: - return - - case <-tick.C: - // Collect all the non-connected static nodes - needed := []*discover.Node{} - srv.lock.RLock() - for id, node := range srv.staticNodes { - if _, ok := srv.peers[id]; !ok { - needed = append(needed, node) - } - } - srv.lock.RUnlock() - - // Try to dial each of them (don't hang if server terminates) - for _, node := range needed { - glog.V(logger.Debug).Infof("Dialing static peer %v", node) - select { - case srv.staticDial <- node: - case <-srv.quit: - return - } - } - } - } -} - -func (srv *Server) dialLoop() { - var ( - dialed = make(chan *discover.Node) - dialing = make(map[discover.NodeID]bool) - findresults = make(chan []*discover.Node) - refresh = time.NewTimer(0) - ) - defer srv.loopWG.Done() - defer refresh.Stop() - - // Limit the number of concurrent dials - tokens := maxDialingConns - if srv.MaxPendingPeers > 0 { - tokens = srv.MaxPendingPeers - } - slots := make(chan struct{}, tokens) - for i := 0; i < tokens; i++ { - slots <- struct{}{} +// setupConn runs the handshakes and attempts to add the connection +// as a peer. It returns when the connection has been added as a peer +// or the handshakes have failed. +func (srv *Server) setupConn(fd net.Conn, flags connFlag, dialDest *discover.Node) { + // Prevent leftover pending conns from entering the handshake. + srv.lock.Lock() + running := srv.running + srv.lock.Unlock() + c := &conn{fd: fd, transport: srv.newTransport(fd), flags: flags, cont: make(chan error)} + if !running { + c.close(errServerStopped) + return } - dial := func(dest *discover.Node) { - // Don't dial nodes that would fail the checks in addPeer. - // This is important because the connection handshake is a lot - // of work and we'd rather avoid doing that work for peers - // that can't be added. - srv.lock.RLock() - ok, _ := srv.checkPeer(dest.ID) - srv.lock.RUnlock() - if !ok || dialing[dest.ID] { - return - } - // Request a dial slot to prevent CPU exhaustion - <-slots - - dialing[dest.ID] = true - srv.peerWG.Add(1) - go func() { - srv.dialNode(dest) - slots <- struct{}{} - dialed <- dest - }() + // Run the encryption handshake. + var err error + if c.id, err = c.doEncHandshake(srv.PrivateKey, dialDest); err != nil { + glog.V(logger.Debug).Infof("%v faild enc handshake: %v", c, err) + c.close(err) + return } - - srv.ntab.Bootstrap(srv.BootstrapNodes) - for { - select { - case <-refresh.C: - // Grab some nodes to connect to if we're not at capacity. - srv.lock.RLock() - needpeers := len(srv.peers) < srv.MaxPeers/2 - srv.lock.RUnlock() - if needpeers { - go func() { - var target discover.NodeID - rand.Read(target[:]) - findresults <- srv.ntab.Lookup(target) - }() - } else { - // Make sure we check again if the peer count falls - // below MaxPeers. - refresh.Reset(refreshPeersInterval) - } - case dest := <-srv.staticDial: - dial(dest) - case dests := <-findresults: - for _, dest := range dests { - dial(dest) - } - refresh.Reset(refreshPeersInterval) - case dest := <-dialed: - delete(dialing, dest.ID) - if len(dialing) == 0 { - // Check again immediately after dialing all current candidates. - refresh.Reset(0) - } - case <-srv.quit: - // TODO: maybe wait for active dials - return - } + // For dialed connections, check that the remote public key matches. + if dialDest != nil && c.id != dialDest.ID { + c.close(DiscUnexpectedIdentity) + glog.V(logger.Debug).Infof("%v dialed identity mismatch, want %x", c, dialDest.ID[:8]) + return } -} - -func (srv *Server) dialNode(dest *discover.Node) { - addr := &net.TCPAddr{IP: dest.IP, Port: int(dest.TCP)} - glog.V(logger.Debug).Infof("Dialing %v\n", dest) - conn, err := srv.Dialer.Dial("tcp", addr.String()) - if err != nil { - // dialLoop adds to the wait group counter when launching - // dialNode, so we need to count it down again. startPeer also - // does that when an error occurs. - srv.peerWG.Done() - glog.V(logger.Detail).Infof("dial error: %v", err) + if err := srv.checkpoint(c, srv.posthandshake); err != nil { + glog.V(logger.Debug).Infof("%v failed checkpoint posthandshake: %v", c, err) + c.close(err) return } - srv.startPeer(conn, dest) -} - -func (srv *Server) startPeer(fd net.Conn, dest *discover.Node) { - // TODO: handle/store session token - - // Run setupFunc, which should create an authenticated connection - // and run the capability exchange. Note that any early error - // returns during that exchange need to call peerWG.Done because - // the callers of startPeer added the peer to the wait group already. - fd.SetDeadline(time.Now().Add(handshakeTimeout)) - - conn, err := srv.setupFunc(fd, srv.PrivateKey, srv.ourHandshake, dest, srv.keepconn) + // Run the protocol handshake + phs, err := c.doProtoHandshake(srv.ourHandshake) if err != nil { - fd.Close() - glog.V(logger.Debug).Infof("Handshake with %v failed: %v", fd.RemoteAddr(), err) - srv.peerWG.Done() + glog.V(logger.Debug).Infof("%v failed proto handshake: %v", c, err) + c.close(err) return } - conn.MsgReadWriter = &netWrapper{ - wrapped: conn.MsgReadWriter, - conn: fd, rtimeout: frameReadTimeout, wtimeout: frameWriteTimeout, + if phs.ID != c.id { + glog.V(logger.Debug).Infof("%v wrong proto handshake identity: %x", c, phs.ID[:8]) + c.close(DiscUnexpectedIdentity) + return } - p := newPeer(fd, conn, srv.Protocols) - if ok, reason := srv.addPeer(conn, p); !ok { - glog.V(logger.Detail).Infof("Not adding %v (%v)\n", p, reason) - p.politeDisconnect(reason) - srv.peerWG.Done() + c.caps, c.name = phs.Caps, phs.Name + if err := srv.checkpoint(c, srv.addpeer); err != nil { + glog.V(logger.Debug).Infof("%v failed checkpoint addpeer: %v", c, err) + c.close(err) return } - // The handshakes are done and it passed all checks. - // Spawn the Peer loops. - go srv.runPeer(p) + // If the checks completed successfully, runPeer has now been + // launched by run. } -// preflight checks whether a connection should be kept. it runs -// after the encryption handshake, as soon as the remote identity is -// known. -func (srv *Server) keepconn(id discover.NodeID) bool { - srv.lock.RLock() - defer srv.lock.RUnlock() - if _, ok := srv.staticNodes[id]; ok { - return true // static nodes are always allowed +// checkpoint sends the conn to run, which performs the +// post-handshake checks for the stage (posthandshake, addpeer). +func (srv *Server) checkpoint(c *conn, stage chan<- *conn) error { + select { + case stage <- c: + case <-srv.quit: + return errServerStopped } - if _, ok := srv.trustedNodes[id]; ok { - return true // trusted nodes are always allowed + select { + case err := <-c.cont: + return err + case <-srv.quit: + return errServerStopped } - return len(srv.peers) < srv.MaxPeers } +// runPeer runs in its own goroutine for each peer. +// it waits until the Peer logic returns and removes +// the peer. func (srv *Server) runPeer(p *Peer) { glog.V(logger.Debug).Infof("Added %v\n", p) srvjslog.LogJson(&logger.P2PConnected{ @@ -511,58 +607,18 @@ func (srv *Server) runPeer(p *Peer) { RemoteVersionString: p.Name(), NumConnections: srv.PeerCount(), }) + if srv.newPeerHook != nil { srv.newPeerHook(p) } discreason := p.run() - srv.removePeer(p) + // Note: run waits for existing peers to be sent on srv.delpeer + // before returning, so this send should not select on srv.quit. + srv.delpeer <- p + glog.V(logger.Debug).Infof("Removed %v (%v)\n", p, discreason) srvjslog.LogJson(&logger.P2PDisconnected{ RemoteId: p.ID().String(), NumConnections: srv.PeerCount(), }) } - -func (srv *Server) addPeer(conn *conn, p *Peer) (bool, DiscReason) { - // drop connections with no matching protocols. - if len(srv.Protocols) > 0 && countMatchingProtocols(srv.Protocols, conn.protoHandshake.Caps) == 0 { - return false, DiscUselessPeer - } - // add the peer if it passes the other checks. - srv.lock.Lock() - defer srv.lock.Unlock() - if ok, reason := srv.checkPeer(conn.ID); !ok { - return false, reason - } - srv.peers[conn.ID] = p - return true, 0 -} - -// checkPeer verifies whether a peer looks promising and should be allowed/kept -// in the pool, or if it's of no use. -func (srv *Server) checkPeer(id discover.NodeID) (bool, DiscReason) { - // First up, figure out if the peer is static or trusted - _, static := srv.staticNodes[id] - trusted := srv.trustedNodes[id] - - // Make sure the peer passes all required checks - switch { - case !srv.running: - return false, DiscQuitting - case !static && !trusted && len(srv.peers) >= srv.MaxPeers: - return false, DiscTooManyPeers - case srv.peers[id] != nil: - return false, DiscAlreadyConnected - case id == srv.ntab.Self().ID: - return false, DiscSelf - default: - return true, 0 - } -} - -func (srv *Server) removePeer(p *Peer) { - srv.lock.Lock() - delete(srv.peers, p.ID()) - srv.lock.Unlock() - srv.peerWG.Done() -} diff --git a/p2p/server_test.go b/p2p/server_test.go index 6f7aaf8e1..01448cc7b 100644 --- a/p2p/server_test.go +++ b/p2p/server_test.go @@ -2,8 +2,10 @@ package p2p import ( "crypto/ecdsa" + "errors" "math/rand" "net" + "reflect" "testing" "time" @@ -12,29 +14,50 @@ import ( "github.com/ethereum/go-ethereum/p2p/discover" ) -func startTestServer(t *testing.T, pf newPeerHook) *Server { +func init() { + // glog.SetV(6) + // glog.SetToStderr(true) +} + +type testTransport struct { + id discover.NodeID + *rlpx + + closeErr error +} + +func newTestTransport(id discover.NodeID, fd net.Conn) transport { + wrapped := newRLPX(fd).(*rlpx) + wrapped.rw = newRLPXFrameRW(fd, secrets{ + MAC: zero16, + AES: zero16, + IngressMAC: sha3.NewKeccak256(), + EgressMAC: sha3.NewKeccak256(), + }) + return &testTransport{id: id, rlpx: wrapped} +} + +func (c *testTransport) doEncHandshake(prv *ecdsa.PrivateKey, dialDest *discover.Node) (discover.NodeID, error) { + return c.id, nil +} + +func (c *testTransport) doProtoHandshake(our *protoHandshake) (*protoHandshake, error) { + return &protoHandshake{ID: c.id, Name: "test"}, nil +} + +func (c *testTransport) close(err error) { + c.rlpx.fd.Close() + c.closeErr = err +} + +func startTestServer(t *testing.T, id discover.NodeID, pf func(*Peer)) *Server { server := &Server{ - Name: "test", - MaxPeers: 10, - ListenAddr: "127.0.0.1:0", - PrivateKey: newkey(), - newPeerHook: pf, - setupFunc: func(fd net.Conn, prv *ecdsa.PrivateKey, our *protoHandshake, dial *discover.Node, keepconn func(discover.NodeID) bool) (*conn, error) { - id := randomID() - if !keepconn(id) { - return nil, DiscAlreadyConnected - } - rw := newRlpxFrameRW(fd, secrets{ - MAC: zero16, - AES: zero16, - IngressMAC: sha3.NewKeccak256(), - EgressMAC: sha3.NewKeccak256(), - }) - return &conn{ - MsgReadWriter: rw, - protoHandshake: &protoHandshake{ID: id, Version: baseProtocolVersion}, - }, nil - }, + Name: "test", + MaxPeers: 10, + ListenAddr: "127.0.0.1:0", + PrivateKey: newkey(), + newPeerHook: pf, + newTransport: func(fd net.Conn) transport { return newTestTransport(id, fd) }, } if err := server.Start(); err != nil { t.Fatalf("Could not start server: %v", err) @@ -45,7 +68,11 @@ func startTestServer(t *testing.T, pf newPeerHook) *Server { func TestServerListen(t *testing.T) { // start the test server connected := make(chan *Peer) - srv := startTestServer(t, func(p *Peer) { + remid := randomID() + srv := startTestServer(t, remid, func(p *Peer) { + if p.ID() != remid { + t.Error("peer func called with wrong node id") + } if p == nil { t.Error("peer func called with nil conn") } @@ -67,6 +94,10 @@ func TestServerListen(t *testing.T) { t.Errorf("peer started with wrong conn: got %v, want %v", peer.LocalAddr(), conn.RemoteAddr()) } + peers := srv.Peers() + if !reflect.DeepEqual(peers, []*Peer{peer}) { + t.Errorf("Peers mismatch: got %v, want %v", peers, []*Peer{peer}) + } case <-time.After(1 * time.Second): t.Error("server did not accept within one second") } @@ -92,23 +123,33 @@ func TestServerDial(t *testing.T) { // start the server connected := make(chan *Peer) - srv := startTestServer(t, func(p *Peer) { connected <- p }) + remid := randomID() + srv := startTestServer(t, remid, func(p *Peer) { connected <- p }) defer close(connected) defer srv.Stop() // tell the server to connect tcpAddr := listener.Addr().(*net.TCPAddr) - srv.staticDial <- &discover.Node{IP: tcpAddr.IP, TCP: uint16(tcpAddr.Port)} + srv.AddPeer(&discover.Node{ID: remid, IP: tcpAddr.IP, TCP: uint16(tcpAddr.Port)}) select { case conn := <-accepted: select { case peer := <-connected: + if peer.ID() != remid { + t.Errorf("peer has wrong id") + } + if peer.Name() != "test" { + t.Errorf("peer has wrong name") + } if peer.RemoteAddr().String() != conn.LocalAddr().String() { t.Errorf("peer started with wrong conn: got %v, want %v", peer.RemoteAddr(), conn.LocalAddr()) } - // TODO: validate more fields + peers := srv.Peers() + if !reflect.DeepEqual(peers, []*Peer{peer}) { + t.Errorf("Peers mismatch: got %v, want %v", peers, []*Peer{peer}) + } case <-time.After(1 * time.Second): t.Error("server did not launch peer within one second") } @@ -118,331 +159,250 @@ func TestServerDial(t *testing.T) { } } -// This test checks that connections are disconnected -// just after the encryption handshake when the server is -// at capacity. -// -// It also serves as a light-weight integration test. -func TestServerDisconnectAtCap(t *testing.T) { - started := make(chan *Peer) - srv := &Server{ - ListenAddr: "127.0.0.1:0", - PrivateKey: newkey(), - MaxPeers: 10, - NoDial: true, - // This hook signals that the peer was actually started. We - // need to wait for the peer to be started before dialing the - // next connection to get a deterministic peer count. - newPeerHook: func(p *Peer) { started <- p }, - } - if err := srv.Start(); err != nil { - t.Fatal(err) - } - defer srv.Stop() - - nconns := srv.MaxPeers + 1 - dialer := &net.Dialer{Deadline: time.Now().Add(3 * time.Second)} - for i := 0; i < nconns; i++ { - conn, err := dialer.Dial("tcp", srv.ListenAddr) - if err != nil { - t.Fatalf("conn %d: dial error: %v", i, err) - } - // Close the connection when the test ends, before - // shutting down the server. - defer conn.Close() - // Run the handshakes just like a real peer would. - key := newkey() - hs := &protoHandshake{Version: baseProtocolVersion, ID: discover.PubkeyID(&key.PublicKey)} - _, err = setupConn(conn, key, hs, srv.Self(), keepalways) - if i == nconns-1 { - // When handling the last connection, the server should - // disconnect immediately instead of running the protocol - // handshake. - if err != DiscTooManyPeers { - t.Errorf("conn %d: got error %q, expected %q", i, err, DiscTooManyPeers) - } - } else { - // For all earlier connections, the handshake should go through. - if err != nil { - t.Fatalf("conn %d: unexpected error: %v", i, err) - } - // Wait for runPeer to be started. - <-started +// This test checks that tasks generated by dialstate are +// actually executed and taskdone is called for them. +func TestServerTaskScheduling(t *testing.T) { + var ( + done = make(chan *testTask) + quit, returned = make(chan struct{}), make(chan struct{}) + tc = 0 + tg = taskgen{ + newFunc: func(running int, peers map[discover.NodeID]*Peer) []task { + tc++ + return []task{&testTask{index: tc - 1}} + }, + doneFunc: func(t task) { + select { + case done <- t.(*testTask): + case <-quit: + } + }, } - } -} + ) -// Tests that static peers are (re)connected, and done so even above max peers. -func TestServerStaticPeers(t *testing.T) { - // Create a test server with limited connection slots - started := make(chan *Peer) - server := &Server{ - ListenAddr: "127.0.0.1:0", - PrivateKey: newkey(), - MaxPeers: 3, - newPeerHook: func(p *Peer) { started <- p }, - staticCycle: time.Second, - } - if err := server.Start(); err != nil { - t.Fatal(err) + // The Server in this test isn't actually running + // because we're only interested in what run does. + srv := &Server{ + MaxPeers: 10, + quit: make(chan struct{}), + ntab: fakeTable{}, + running: true, } - defer server.Stop() - - // Fill up all the slots on the server - dialer := &net.Dialer{Deadline: time.Now().Add(3 * time.Second)} - for i := 0; i < server.MaxPeers; i++ { - // Establish a new connection - conn, err := dialer.Dial("tcp", server.ListenAddr) - if err != nil { - t.Fatalf("conn %d: dial error: %v", i, err) - } - defer conn.Close() + srv.loopWG.Add(1) + go func() { + srv.run(tg) + close(returned) + }() - // Run the handshakes just like a real peer would, and wait for completion - key := newkey() - shake := &protoHandshake{Version: baseProtocolVersion, ID: discover.PubkeyID(&key.PublicKey)} - if _, err = setupConn(conn, key, shake, server.Self(), keepalways); err != nil { - t.Fatalf("conn %d: unexpected error: %v", i, err) - } - <-started + var gotdone []*testTask + for i := 0; i < 100; i++ { + gotdone = append(gotdone, <-done) } - // Open a TCP listener to accept static connections - listener, err := net.Listen("tcp", "127.0.0.1:0") - if err != nil { - t.Fatalf("failed to setup listener: %v", err) - } - defer listener.Close() - - connected := make(chan net.Conn) - go func() { - for i := 0; i < 3; i++ { - conn, err := listener.Accept() - if err == nil { - connected <- conn - } + for i, task := range gotdone { + if task.index != i { + t.Errorf("task %d has wrong index, got %d", i, task.index) + break + } + if !task.called { + t.Errorf("task %d was not called", i) + break } - }() - // Inject a static node and wait for a remote dial, then redial, then nothing - addr := listener.Addr().(*net.TCPAddr) - static := &discover.Node{ - ID: discover.PubkeyID(&newkey().PublicKey), - IP: addr.IP, - TCP: uint16(addr.Port), } - server.AddPeer(static) + close(quit) + srv.Stop() select { - case conn := <-connected: - // Close the first connection, expect redial - conn.Close() - - case <-time.After(2 * server.staticCycle): - t.Fatalf("remote dial timeout") + case <-returned: + case <-time.After(500 * time.Millisecond): + t.Error("Server.run did not return within 500ms") } +} - select { - case conn := <-connected: - // Keep the second connection, don't expect redial - defer conn.Close() - - case <-time.After(2 * server.staticCycle): - t.Fatalf("remote re-dial timeout") - } +type taskgen struct { + newFunc func(running int, peers map[discover.NodeID]*Peer) []task + doneFunc func(task) +} - select { - case <-time.After(2 * server.staticCycle): - // Timeout as no dial occurred +func (tg taskgen) newTasks(running int, peers map[discover.NodeID]*Peer, now time.Time) []task { + return tg.newFunc(running, peers) +} +func (tg taskgen) taskDone(t task, now time.Time) { + tg.doneFunc(t) +} +func (tg taskgen) addStatic(*discover.Node) { +} - case <-connected: - t.Fatalf("connected node dialed") - } +type testTask struct { + index int + called bool } -// Tests that trusted peers and can connect above max peer caps. -func TestServerTrustedPeers(t *testing.T) { +func (t *testTask) Do(srv *Server) { + t.called = true +} - // Create a trusted peer to accept connections from - key := newkey() - trusted := &discover.Node{ - ID: discover.PubkeyID(&key.PublicKey), - } - // Create a test server with limited connection slots - started := make(chan *Peer) - server := &Server{ - ListenAddr: "127.0.0.1:0", +// This test checks that connections are disconnected +// just after the encryption handshake when the server is +// at capacity. Trusted connections should still be accepted. +func TestServerAtCap(t *testing.T) { + trustedID := randomID() + srv := &Server{ PrivateKey: newkey(), - MaxPeers: 3, + MaxPeers: 10, NoDial: true, - TrustedNodes: []*discover.Node{trusted}, - newPeerHook: func(p *Peer) { started <- p }, + TrustedNodes: []*discover.Node{{ID: trustedID}}, } - if err := server.Start(); err != nil { - t.Fatal(err) + if err := srv.Start(); err != nil { + t.Fatalf("could not start: %v", err) } - defer server.Stop() + defer srv.Stop() - // Fill up all the slots on the server - dialer := &net.Dialer{Deadline: time.Now().Add(3 * time.Second)} - for i := 0; i < server.MaxPeers; i++ { - // Establish a new connection - conn, err := dialer.Dial("tcp", server.ListenAddr) - if err != nil { - t.Fatalf("conn %d: dial error: %v", i, err) - } - defer conn.Close() + newconn := func(id discover.NodeID) *conn { + fd, _ := net.Pipe() + tx := newTestTransport(id, fd) + return &conn{fd: fd, transport: tx, flags: inboundConn, id: id, cont: make(chan error)} + } - // Run the handshakes just like a real peer would, and wait for completion - key := newkey() - shake := &protoHandshake{Version: baseProtocolVersion, ID: discover.PubkeyID(&key.PublicKey)} - if _, err = setupConn(conn, key, shake, server.Self(), keepalways); err != nil { - t.Fatalf("conn %d: unexpected error: %v", i, err) + // Inject a few connections to fill up the peer set. + for i := 0; i < 10; i++ { + c := newconn(randomID()) + if err := srv.checkpoint(c, srv.addpeer); err != nil { + t.Fatalf("could not add conn %d: %v", i, err) } - <-started } - // Dial from the trusted peer, ensure connection is accepted - conn, err := dialer.Dial("tcp", server.ListenAddr) - if err != nil { - t.Fatalf("trusted node: dial error: %v", err) + // Try inserting a non-trusted connection. + c := newconn(randomID()) + if err := srv.checkpoint(c, srv.posthandshake); err != DiscTooManyPeers { + t.Error("wrong error for insert:", err) } - defer conn.Close() - - shake := &protoHandshake{Version: baseProtocolVersion, ID: trusted.ID} - if _, err = setupConn(conn, key, shake, server.Self(), keepalways); err != nil { - t.Fatalf("trusted node: unexpected error: %v", err) + // Try inserting a trusted connection. + c = newconn(trustedID) + if err := srv.checkpoint(c, srv.posthandshake); err != nil { + t.Error("unexpected error for trusted conn @posthandshake:", err) } - select { - case <-started: - // Ok, trusted peer accepted - - case <-time.After(100 * time.Millisecond): - t.Fatalf("trusted node timeout") + if !c.is(trustedConn) { + t.Error("Server did not set trusted flag") } + } -// Tests that a failed dial will temporarily throttle a peer. -func TestServerMaxPendingDials(t *testing.T) { - // Start a simple test server - server := &Server{ - ListenAddr: "127.0.0.1:0", - PrivateKey: newkey(), - MaxPeers: 10, - MaxPendingPeers: 1, - } - if err := server.Start(); err != nil { - t.Fatal("failed to start test server: %v", err) +func TestServerSetupConn(t *testing.T) { + id := randomID() + srvkey := newkey() + srvid := discover.PubkeyID(&srvkey.PublicKey) + tests := []struct { + dontstart bool + tt *setupTransport + flags connFlag + dialDest *discover.Node + + wantCloseErr error + wantCalls string + }{ + { + dontstart: true, + tt: &setupTransport{id: id}, + wantCalls: "close,", + wantCloseErr: errServerStopped, + }, + { + tt: &setupTransport{id: id, encHandshakeErr: errors.New("read error")}, + flags: inboundConn, + wantCalls: "doEncHandshake,close,", + wantCloseErr: errors.New("read error"), + }, + { + tt: &setupTransport{id: id}, + dialDest: &discover.Node{ID: randomID()}, + flags: dynDialedConn, + wantCalls: "doEncHandshake,close,", + wantCloseErr: DiscUnexpectedIdentity, + }, + { + tt: &setupTransport{id: id, phs: &protoHandshake{ID: randomID()}}, + dialDest: &discover.Node{ID: id}, + flags: dynDialedConn, + wantCalls: "doEncHandshake,doProtoHandshake,close,", + wantCloseErr: DiscUnexpectedIdentity, + }, + { + tt: &setupTransport{id: id, protoHandshakeErr: errors.New("foo")}, + dialDest: &discover.Node{ID: id}, + flags: dynDialedConn, + wantCalls: "doEncHandshake,doProtoHandshake,close,", + wantCloseErr: errors.New("foo"), + }, + { + tt: &setupTransport{id: srvid, phs: &protoHandshake{ID: srvid}}, + flags: inboundConn, + wantCalls: "doEncHandshake,close,", + wantCloseErr: DiscSelf, + }, + { + tt: &setupTransport{id: id, phs: &protoHandshake{ID: id}}, + flags: inboundConn, + wantCalls: "doEncHandshake,doProtoHandshake,close,", + wantCloseErr: DiscUselessPeer, + }, } - defer server.Stop() - // Simulate two separate remote peers - peers := make(chan *discover.Node, 2) - conns := make(chan net.Conn, 2) - for i := 0; i < 2; i++ { - listener, err := net.Listen("tcp", "127.0.0.1:0") - if err != nil { - t.Fatalf("listener %d: failed to setup: %v", i, err) - } - defer listener.Close() - - addr := listener.Addr().(*net.TCPAddr) - peers <- &discover.Node{ - ID: discover.PubkeyID(&newkey().PublicKey), - IP: addr.IP, - TCP: uint16(addr.Port), + for i, test := range tests { + srv := &Server{ + PrivateKey: srvkey, + MaxPeers: 10, + NoDial: true, + Protocols: []Protocol{discard}, + newTransport: func(fd net.Conn) transport { return test.tt }, } - go func() { - conn, err := listener.Accept() - if err == nil { - conns <- conn + if !test.dontstart { + if err := srv.Start(); err != nil { + t.Fatalf("couldn't start server: %v", err) } - }() - } - // Request a dial for both peers - go func() { - for i := 0; i < 2; i++ { - server.staticDial <- <-peers // hack piggybacking the static implementation } - }() - - // Make sure only one outbound connection goes through - var conn net.Conn - - select { - case conn = <-conns: - case <-time.After(100 * time.Millisecond): - t.Fatalf("first dial timeout") - } - select { - case conn = <-conns: - t.Fatalf("second dial completed prematurely") - case <-time.After(100 * time.Millisecond): - } - // Finish the first dial, check the second - conn.Close() - select { - case conn = <-conns: - conn.Close() - - case <-time.After(100 * time.Millisecond): - t.Fatalf("second dial timeout") + p1, _ := net.Pipe() + srv.setupConn(p1, test.flags, test.dialDest) + if !reflect.DeepEqual(test.tt.closeErr, test.wantCloseErr) { + t.Errorf("test %d: close error mismatch: got %q, want %q", i, test.tt.closeErr, test.wantCloseErr) + } + if test.tt.calls != test.wantCalls { + t.Errorf("test %d: calls mismatch: got %q, want %q", i, test.tt.calls, test.wantCalls) + } } } -func TestServerMaxPendingAccepts(t *testing.T) { - // Start a test server and a peer sink for synchronization - started := make(chan *Peer) - server := &Server{ - ListenAddr: "127.0.0.1:0", - PrivateKey: newkey(), - MaxPeers: 10, - MaxPendingPeers: 1, - NoDial: true, - newPeerHook: func(p *Peer) { started <- p }, - } - if err := server.Start(); err != nil { - t.Fatal("failed to start test server: %v", err) - } - defer server.Stop() +type setupTransport struct { + id discover.NodeID + encHandshakeErr error - // Try and connect to the server on multiple threads concurrently - conns := make([]net.Conn, 2) - for i := 0; i < 2; i++ { - dialer := &net.Dialer{Deadline: time.Now().Add(3 * time.Second)} + phs *protoHandshake + protoHandshakeErr error - conn, err := dialer.Dial("tcp", server.ListenAddr) - if err != nil { - t.Fatalf("failed to dial server: %v", err) - } - conns[i] = conn - } - // Check that a handshake on the second doesn't pass - go func() { - key := newkey() - shake := &protoHandshake{Version: baseProtocolVersion, ID: discover.PubkeyID(&key.PublicKey)} - if _, err := setupConn(conns[1], key, shake, server.Self(), keepalways); err != nil { - t.Fatalf("failed to run handshake: %v", err) - } - }() - select { - case <-started: - t.Fatalf("handshake on second connection accepted") + calls string + closeErr error +} - case <-time.After(time.Second): - } - // Shake on first, check that both go through - go func() { - key := newkey() - shake := &protoHandshake{Version: baseProtocolVersion, ID: discover.PubkeyID(&key.PublicKey)} - if _, err := setupConn(conns[0], key, shake, server.Self(), keepalways); err != nil { - t.Fatalf("failed to run handshake: %v", err) - } - }() - for i := 0; i < 2; i++ { - select { - case <-started: - case <-time.After(time.Second): - t.Fatalf("peer %d: handshake timeout", i) - } +func (c *setupTransport) doEncHandshake(prv *ecdsa.PrivateKey, dialDest *discover.Node) (discover.NodeID, error) { + c.calls += "doEncHandshake," + return c.id, c.encHandshakeErr +} +func (c *setupTransport) doProtoHandshake(our *protoHandshake) (*protoHandshake, error) { + c.calls += "doProtoHandshake," + if c.protoHandshakeErr != nil { + return nil, c.protoHandshakeErr } + return c.phs, nil +} +func (c *setupTransport) close(err error) { + c.calls += "close," + c.closeErr = err +} + +// setupConn shouldn't write to/read from the connection. +func (c *setupTransport) WriteMsg(Msg) error { + panic("WriteMsg called on setupTransport") +} +func (c *setupTransport) ReadMsg() (Msg, error) { + panic("ReadMsg called on setupTransport") } func newkey() *ecdsa.PrivateKey { @@ -459,7 +419,3 @@ func randomID() (id discover.NodeID) { } return id } - -func keepalways(id discover.NodeID) bool { - return true -} -- cgit