diff options
Diffstat (limited to 'p2p/discover')
-rw-r--r-- | p2p/discover/node.go | 3 | ||||
-rw-r--r-- | p2p/discover/table.go | 84 | ||||
-rw-r--r-- | p2p/discover/table_test.go | 34 | ||||
-rw-r--r-- | p2p/discover/table_util_test.go | 21 | ||||
-rw-r--r-- | p2p/discover/udp.go | 222 | ||||
-rw-r--r-- | p2p/discover/udp_test.go | 137 |
6 files changed, 310 insertions, 191 deletions
diff --git a/p2p/discover/node.go b/p2p/discover/node.go index 7ddf04fe8..8d4af166b 100644 --- a/p2p/discover/node.go +++ b/p2p/discover/node.go @@ -33,7 +33,8 @@ import ( // The fields of Node may not be modified. type node struct { enode.Node - addedAt time.Time // time when the node was added to the table + addedAt time.Time // time when the node was added to the table + livenessChecks uint // how often liveness was checked } type encPubkey [64]byte diff --git a/p2p/discover/table.go b/p2p/discover/table.go index 9f7f1d41b..ba4c06327 100644 --- a/p2p/discover/table.go +++ b/p2p/discover/table.go @@ -75,8 +75,10 @@ type Table struct { net transport refreshReq chan chan struct{} initDone chan struct{} - closeReq chan struct{} - closed chan struct{} + + closeOnce sync.Once + closeReq chan struct{} + closed chan struct{} nodeAddedHook func(*node) // for testing } @@ -180,16 +182,14 @@ func (tab *Table) ReadRandomNodes(buf []*enode.Node) (n int) { // Close terminates the network listener and flushes the node database. func (tab *Table) Close() { - if tab.net != nil { - tab.net.close() - } - - select { - case <-tab.closed: - // already closed. - case tab.closeReq <- struct{}{}: - <-tab.closed // wait for refreshLoop to end. - } + tab.closeOnce.Do(func() { + if tab.net != nil { + tab.net.close() + } + // Wait for loop to end. + close(tab.closeReq) + <-tab.closed + }) } // setFallbackNodes sets the initial points of contact. These nodes @@ -290,12 +290,16 @@ func (tab *Table) lookup(targetKey encPubkey, refreshIfEmpty bool) []*node { // we have asked all closest nodes, stop the search break } - // wait for the next reply - for _, n := range <-reply { - if n != nil && !seen[n.ID()] { - seen[n.ID()] = true - result.push(n, bucketSize) + select { + case nodes := <-reply: + for _, n := range nodes { + if n != nil && !seen[n.ID()] { + seen[n.ID()] = true + result.push(n, bucketSize) + } } + case <-tab.closeReq: + return nil // shutdown, no need to continue. } pendingQueries-- } @@ -303,18 +307,22 @@ func (tab *Table) lookup(targetKey encPubkey, refreshIfEmpty bool) []*node { } func (tab *Table) findnode(n *node, targetKey encPubkey, reply chan<- []*node) { - fails := tab.db.FindFails(n.ID()) + fails := tab.db.FindFails(n.ID(), n.IP()) r, err := tab.net.findnode(n.ID(), n.addr(), targetKey) - if err != nil || len(r) == 0 { + if err == errClosed { + // Avoid recording failures on shutdown. + reply <- nil + return + } else if err != nil || len(r) == 0 { fails++ - tab.db.UpdateFindFails(n.ID(), fails) + tab.db.UpdateFindFails(n.ID(), n.IP(), fails) log.Trace("Findnode failed", "id", n.ID(), "failcount", fails, "err", err) if fails >= maxFindnodeFailures { log.Trace("Too many findnode failures, dropping", "id", n.ID(), "failcount", fails) tab.delete(n) } } else if fails > 0 { - tab.db.UpdateFindFails(n.ID(), fails-1) + tab.db.UpdateFindFails(n.ID(), n.IP(), fails-1) } // Grab as many nodes as possible. Some of them might not be alive anymore, but we'll @@ -329,7 +337,7 @@ func (tab *Table) refresh() <-chan struct{} { done := make(chan struct{}) select { case tab.refreshReq <- done: - case <-tab.closed: + case <-tab.closeReq: close(done) } return done @@ -433,7 +441,7 @@ func (tab *Table) loadSeedNodes() { seeds = append(seeds, tab.nursery...) for i := range seeds { seed := seeds[i] - age := log.Lazy{Fn: func() interface{} { return time.Since(tab.db.LastPongReceived(seed.ID())) }} + age := log.Lazy{Fn: func() interface{} { return time.Since(tab.db.LastPongReceived(seed.ID(), seed.IP())) }} log.Trace("Found seed node in database", "id", seed.ID(), "addr", seed.addr(), "age", age) tab.add(seed) } @@ -458,16 +466,17 @@ func (tab *Table) doRevalidate(done chan<- struct{}) { b := tab.buckets[bi] if err == nil { // The node responded, move it to the front. - log.Debug("Revalidated node", "b", bi, "id", last.ID()) + last.livenessChecks++ + log.Debug("Revalidated node", "b", bi, "id", last.ID(), "checks", last.livenessChecks) b.bump(last) return } // No reply received, pick a replacement or delete the node if there aren't // any replacements. if r := tab.replace(b, last); r != nil { - log.Debug("Replaced dead node", "b", bi, "id", last.ID(), "ip", last.IP(), "r", r.ID(), "rip", r.IP()) + log.Debug("Replaced dead node", "b", bi, "id", last.ID(), "ip", last.IP(), "checks", last.livenessChecks, "r", r.ID(), "rip", r.IP()) } else { - log.Debug("Removed dead node", "b", bi, "id", last.ID(), "ip", last.IP()) + log.Debug("Removed dead node", "b", bi, "id", last.ID(), "ip", last.IP(), "checks", last.livenessChecks) } } @@ -502,7 +511,7 @@ func (tab *Table) copyLiveNodes() { now := time.Now() for _, b := range &tab.buckets { for _, n := range b.entries { - if now.Sub(n.addedAt) >= seedMinTableTime { + if n.livenessChecks > 0 && now.Sub(n.addedAt) >= seedMinTableTime { tab.db.UpdateNode(unwrapNode(n)) } } @@ -518,7 +527,9 @@ func (tab *Table) closest(target enode.ID, nresults int) *nodesByDistance { close := &nodesByDistance{target: target} for _, b := range &tab.buckets { for _, n := range b.entries { - close.push(n, nresults) + if n.livenessChecks > 0 { + close.push(n, nresults) + } } } return close @@ -572,23 +583,6 @@ func (tab *Table) addThroughPing(n *node) { tab.add(n) } -// stuff adds nodes the table to the end of their corresponding bucket -// if the bucket is not full. The caller must not hold tab.mutex. -func (tab *Table) stuff(nodes []*node) { - tab.mutex.Lock() - defer tab.mutex.Unlock() - - for _, n := range nodes { - if n.ID() == tab.self().ID() { - continue // don't add self - } - b := tab.bucket(n.ID()) - if len(b.entries) < bucketSize { - tab.bumpOrAdd(b, n) - } - } -} - // delete removes an entry from the node table. It is used to evacuate dead nodes. func (tab *Table) delete(node *node) { tab.mutex.Lock() diff --git a/p2p/discover/table_test.go b/p2p/discover/table_test.go index 6b4cd2d18..b00a93211 100644 --- a/p2p/discover/table_test.go +++ b/p2p/discover/table_test.go @@ -50,8 +50,8 @@ func TestTable_pingReplace(t *testing.T) { func testPingReplace(t *testing.T, newNodeIsResponding, lastInBucketIsResponding bool) { transport := newPingRecorder() tab, db := newTestTable(transport) - defer tab.Close() defer db.Close() + defer tab.Close() <-tab.initDone @@ -137,8 +137,8 @@ func TestBucket_bumpNoDuplicates(t *testing.T) { func TestTable_IPLimit(t *testing.T) { transport := newPingRecorder() tab, db := newTestTable(transport) - defer tab.Close() defer db.Close() + defer tab.Close() for i := 0; i < tableIPLimit+1; i++ { n := nodeAtDistance(tab.self().ID(), i, net.IP{172, 0, 1, byte(i)}) @@ -153,8 +153,8 @@ func TestTable_IPLimit(t *testing.T) { func TestTable_BucketIPLimit(t *testing.T) { transport := newPingRecorder() tab, db := newTestTable(transport) - defer tab.Close() defer db.Close() + defer tab.Close() d := 3 for i := 0; i < bucketIPLimit+1; i++ { @@ -173,9 +173,9 @@ func TestTable_closest(t *testing.T) { // for any node table, Target and N transport := newPingRecorder() tab, db := newTestTable(transport) - defer tab.Close() defer db.Close() - tab.stuff(test.All) + defer tab.Close() + fillTable(tab, test.All) // check that closest(Target, N) returns nodes result := tab.closest(test.Target, test.N).entries @@ -234,13 +234,13 @@ func TestTable_ReadRandomNodesGetAll(t *testing.T) { test := func(buf []*enode.Node) bool { transport := newPingRecorder() tab, db := newTestTable(transport) - defer tab.Close() defer db.Close() + defer tab.Close() <-tab.initDone for i := 0; i < len(buf); i++ { ld := cfg.Rand.Intn(len(tab.buckets)) - tab.stuff([]*node{nodeAtDistance(tab.self().ID(), ld, intIP(ld))}) + fillTable(tab, []*node{nodeAtDistance(tab.self().ID(), ld, intIP(ld))}) } gotN := tab.ReadRandomNodes(buf) if gotN != tab.len() { @@ -272,16 +272,19 @@ func (*closeTest) Generate(rand *rand.Rand, size int) reflect.Value { N: rand.Intn(bucketSize), } for _, id := range gen([]enode.ID{}, rand).([]enode.ID) { - n := enode.SignNull(new(enr.Record), id) - t.All = append(t.All, wrapNode(n)) + r := new(enr.Record) + r.Set(enr.IP(genIP(rand))) + n := wrapNode(enode.SignNull(r, id)) + n.livenessChecks = 1 + t.All = append(t.All, n) } return reflect.ValueOf(t) } func TestTable_Lookup(t *testing.T) { tab, db := newTestTable(lookupTestnet) - defer tab.Close() defer db.Close() + defer tab.Close() // lookup on empty table returns no nodes if results := tab.lookup(lookupTestnet.target, false); len(results) > 0 { @@ -289,8 +292,9 @@ func TestTable_Lookup(t *testing.T) { } // seed table with initial node (otherwise lookup will terminate immediately) seedKey, _ := decodePubkey(lookupTestnet.dists[256][0]) - seed := wrapNode(enode.NewV4(seedKey, net.IP{}, 0, 256)) - tab.stuff([]*node{seed}) + seed := wrapNode(enode.NewV4(seedKey, net.IP{127, 0, 0, 1}, 0, 256)) + seed.livenessChecks = 1 + fillTable(tab, []*node{seed}) results := tab.lookup(lookupTestnet.target, true) t.Logf("results:") @@ -578,6 +582,12 @@ func gen(typ interface{}, rand *rand.Rand) interface{} { return v.Interface() } +func genIP(rand *rand.Rand) net.IP { + ip := make(net.IP, 4) + rand.Read(ip) + return ip +} + func quickcfg() *quick.Config { return &quick.Config{ MaxCount: 5000, diff --git a/p2p/discover/table_util_test.go b/p2p/discover/table_util_test.go index d41519452..3ce582b99 100644 --- a/p2p/discover/table_util_test.go +++ b/p2p/discover/table_util_test.go @@ -83,6 +83,23 @@ func fillBucket(tab *Table, n *node) (last *node) { return b.entries[bucketSize-1] } +// fillTable adds nodes the table to the end of their corresponding bucket +// if the bucket is not full. The caller must not hold tab.mutex. +func fillTable(tab *Table, nodes []*node) { + tab.mutex.Lock() + defer tab.mutex.Unlock() + + for _, n := range nodes { + if n.ID() == tab.self().ID() { + continue // don't add self + } + b := tab.bucket(n.ID()) + if len(b.entries) < bucketSize { + tab.bumpOrAdd(b, n) + } + } +} + type pingRecorder struct { mu sync.Mutex dead, pinged map[enode.ID]bool @@ -109,10 +126,6 @@ func (t *pingRecorder) findnode(toid enode.ID, toaddr *net.UDPAddr, target encPu return nil, nil } -func (t *pingRecorder) waitping(from enode.ID) error { - return nil // remote always pings -} - func (t *pingRecorder) ping(toid enode.ID, toaddr *net.UDPAddr) error { t.mu.Lock() defer t.mu.Unlock() diff --git a/p2p/discover/udp.go b/p2p/discover/udp.go index 37a044902..5ce4c43dc 100644 --- a/p2p/discover/udp.go +++ b/p2p/discover/udp.go @@ -67,6 +67,8 @@ const ( // RPC request structures type ( ping struct { + senderKey *ecdsa.PublicKey // filled in by preverify + Version uint From, To rpcEndpoint Expiration uint64 @@ -155,8 +157,13 @@ func nodeToRPC(n *node) rpcNode { return rpcNode{ID: ekey, IP: n.IP(), UDP: uint16(n.UDP()), TCP: uint16(n.TCP())} } +// packet is implemented by all protocol messages. type packet interface { - handle(t *udp, from *net.UDPAddr, fromKey encPubkey, mac []byte) error + // preverify checks whether the packet is valid and should be handled at all. + preverify(t *udp, from *net.UDPAddr, fromID enode.ID, fromKey encPubkey) error + // handle handles the packet. + handle(t *udp, from *net.UDPAddr, fromID enode.ID, mac []byte) + // name returns the name of the packet for logging purposes. name() string } @@ -177,43 +184,48 @@ type udp struct { tab *Table wg sync.WaitGroup - addpending chan *pending - gotreply chan reply - closing chan struct{} + addReplyMatcher chan *replyMatcher + gotreply chan reply + closing chan struct{} } // pending represents a pending reply. // -// some implementations of the protocol wish to send more than one -// reply packet to findnode. in general, any neighbors packet cannot +// Some implementations of the protocol wish to send more than one +// reply packet to findnode. In general, any neighbors packet cannot // be matched up with a specific findnode packet. // -// our implementation handles this by storing a callback function for -// each pending reply. incoming packets from a node are dispatched -// to all the callback functions for that node. -type pending struct { +// Our implementation handles this by storing a callback function for +// each pending reply. Incoming packets from a node are dispatched +// to all callback functions for that node. +type replyMatcher struct { // these fields must match in the reply. from enode.ID + ip net.IP ptype byte // time when the request must complete deadline time.Time - // callback is called when a matching reply arrives. if it returns - // true, the callback is removed from the pending reply queue. - // if it returns false, the reply is considered incomplete and - // the callback will be invoked again for the next matching reply. - callback func(resp interface{}) (done bool) + // callback is called when a matching reply arrives. If it returns matched == true, the + // reply was acceptable. The second return value indicates whether the callback should + // be removed from the pending reply queue. If it returns false, the reply is considered + // incomplete and the callback will be invoked again for the next matching reply. + callback replyMatchFunc // errc receives nil when the callback indicates completion or an // error if no further reply is received within the timeout. errc chan<- error } +type replyMatchFunc func(interface{}) (matched bool, requestDone bool) + type reply struct { from enode.ID + ip net.IP ptype byte - data interface{} + data packet + // loop indicates whether there was // a matching request by sending on this channel. matched chan<- bool @@ -247,14 +259,14 @@ func ListenUDP(c conn, ln *enode.LocalNode, cfg Config) (*Table, error) { func newUDP(c conn, ln *enode.LocalNode, cfg Config) (*Table, *udp, error) { udp := &udp{ - conn: c, - priv: cfg.PrivateKey, - netrestrict: cfg.NetRestrict, - localNode: ln, - db: ln.Database(), - closing: make(chan struct{}), - gotreply: make(chan reply), - addpending: make(chan *pending), + conn: c, + priv: cfg.PrivateKey, + netrestrict: cfg.NetRestrict, + localNode: ln, + db: ln.Database(), + closing: make(chan struct{}), + gotreply: make(chan reply), + addReplyMatcher: make(chan *replyMatcher), } tab, err := newTable(udp, ln.Database(), cfg.Bootnodes) if err != nil { @@ -304,35 +316,37 @@ func (t *udp) sendPing(toid enode.ID, toaddr *net.UDPAddr, callback func()) <-ch errc <- err return errc } - errc := t.pending(toid, pongPacket, func(p interface{}) bool { - ok := bytes.Equal(p.(*pong).ReplyTok, hash) - if ok && callback != nil { + // Add a matcher for the reply to the pending reply queue. Pongs are matched if they + // reference the ping we're about to send. + errc := t.pending(toid, toaddr.IP, pongPacket, func(p interface{}) (matched bool, requestDone bool) { + matched = bytes.Equal(p.(*pong).ReplyTok, hash) + if matched && callback != nil { callback() } - return ok + return matched, matched }) + // Send the packet. t.localNode.UDPContact(toaddr) - t.write(toaddr, req.name(), packet) + t.write(toaddr, toid, req.name(), packet) return errc } -func (t *udp) waitping(from enode.ID) error { - return <-t.pending(from, pingPacket, func(interface{}) bool { return true }) -} - // findnode sends a findnode request to the given node and waits until // the node has sent up to k neighbors. func (t *udp) findnode(toid enode.ID, toaddr *net.UDPAddr, target encPubkey) ([]*node, error) { // If we haven't seen a ping from the destination node for a while, it won't remember // our endpoint proof and reject findnode. Solicit a ping first. - if time.Since(t.db.LastPingReceived(toid)) > bondExpiration { + if time.Since(t.db.LastPingReceived(toid, toaddr.IP)) > bondExpiration { t.ping(toid, toaddr) - t.waitping(toid) + // Wait for them to ping back and process our pong. + time.Sleep(respTimeout) } + // Add a matcher for 'neighbours' replies to the pending reply queue. The matcher is + // active until enough nodes have been received. nodes := make([]*node, 0, bucketSize) nreceived := 0 - errc := t.pending(toid, neighborsPacket, func(r interface{}) bool { + errc := t.pending(toid, toaddr.IP, neighborsPacket, func(r interface{}) (matched bool, requestDone bool) { reply := r.(*neighbors) for _, rn := range reply.Nodes { nreceived++ @@ -343,22 +357,22 @@ func (t *udp) findnode(toid enode.ID, toaddr *net.UDPAddr, target encPubkey) ([] } nodes = append(nodes, n) } - return nreceived >= bucketSize + return true, nreceived >= bucketSize }) - t.send(toaddr, findnodePacket, &findnode{ + t.send(toaddr, toid, findnodePacket, &findnode{ Target: target, Expiration: uint64(time.Now().Add(expiration).Unix()), }) return nodes, <-errc } -// pending adds a reply callback to the pending reply queue. -// see the documentation of type pending for a detailed explanation. -func (t *udp) pending(id enode.ID, ptype byte, callback func(interface{}) bool) <-chan error { +// pending adds a reply matcher to the pending reply queue. +// see the documentation of type replyMatcher for a detailed explanation. +func (t *udp) pending(id enode.ID, ip net.IP, ptype byte, callback replyMatchFunc) <-chan error { ch := make(chan error, 1) - p := &pending{from: id, ptype: ptype, callback: callback, errc: ch} + p := &replyMatcher{from: id, ip: ip, ptype: ptype, callback: callback, errc: ch} select { - case t.addpending <- p: + case t.addReplyMatcher <- p: // loop will handle it case <-t.closing: ch <- errClosed @@ -366,10 +380,12 @@ func (t *udp) pending(id enode.ID, ptype byte, callback func(interface{}) bool) return ch } -func (t *udp) handleReply(from enode.ID, ptype byte, req packet) bool { +// handleReply dispatches a reply packet, invoking reply matchers. It returns +// whether any matcher considered the packet acceptable. +func (t *udp) handleReply(from enode.ID, fromIP net.IP, ptype byte, req packet) bool { matched := make(chan bool, 1) select { - case t.gotreply <- reply{from, ptype, req, matched}: + case t.gotreply <- reply{from, fromIP, ptype, req, matched}: // loop will handle it return <-matched case <-t.closing: @@ -385,8 +401,8 @@ func (t *udp) loop() { var ( plist = list.New() timeout = time.NewTimer(0) - nextTimeout *pending // head of plist when timeout was last reset - contTimeouts = 0 // number of continuous timeouts to do NTP checks + nextTimeout *replyMatcher // head of plist when timeout was last reset + contTimeouts = 0 // number of continuous timeouts to do NTP checks ntpWarnTime = time.Unix(0, 0) ) <-timeout.C // ignore first timeout @@ -399,7 +415,7 @@ func (t *udp) loop() { // Start the timer so it fires when the next pending reply has expired. now := time.Now() for el := plist.Front(); el != nil; el = el.Next() { - nextTimeout = el.Value.(*pending) + nextTimeout = el.Value.(*replyMatcher) if dist := nextTimeout.deadline.Sub(now); dist < 2*respTimeout { timeout.Reset(dist) return @@ -420,25 +436,23 @@ func (t *udp) loop() { select { case <-t.closing: for el := plist.Front(); el != nil; el = el.Next() { - el.Value.(*pending).errc <- errClosed + el.Value.(*replyMatcher).errc <- errClosed } return - case p := <-t.addpending: + case p := <-t.addReplyMatcher: p.deadline = time.Now().Add(respTimeout) plist.PushBack(p) case r := <-t.gotreply: - var matched bool + var matched bool // whether any replyMatcher considered the reply acceptable. for el := plist.Front(); el != nil; el = el.Next() { - p := el.Value.(*pending) - if p.from == r.from && p.ptype == r.ptype { - matched = true - // Remove the matcher if its callback indicates - // that all replies have been received. This is - // required for packet types that expect multiple - // reply packets. - if p.callback(r.data) { + p := el.Value.(*replyMatcher) + if p.from == r.from && p.ptype == r.ptype && p.ip.Equal(r.ip) { + ok, requestDone := p.callback(r.data) + matched = matched || ok + // Remove the matcher if callback indicates that all replies have been received. + if requestDone { p.errc <- nil plist.Remove(el) } @@ -453,7 +467,7 @@ func (t *udp) loop() { // Notify and remove callbacks whose deadline is in the past. for el := plist.Front(); el != nil; el = el.Next() { - p := el.Value.(*pending) + p := el.Value.(*replyMatcher) if now.After(p.deadline) || now.Equal(p.deadline) { p.errc <- errTimeout plist.Remove(el) @@ -504,17 +518,17 @@ func init() { } } -func (t *udp) send(toaddr *net.UDPAddr, ptype byte, req packet) ([]byte, error) { +func (t *udp) send(toaddr *net.UDPAddr, toid enode.ID, ptype byte, req packet) ([]byte, error) { packet, hash, err := encodePacket(t.priv, ptype, req) if err != nil { return hash, err } - return hash, t.write(toaddr, req.name(), packet) + return hash, t.write(toaddr, toid, req.name(), packet) } -func (t *udp) write(toaddr *net.UDPAddr, what string, packet []byte) error { +func (t *udp) write(toaddr *net.UDPAddr, toid enode.ID, what string, packet []byte) error { _, err := t.conn.WriteToUDP(packet, toaddr) - log.Trace(">> "+what, "addr", toaddr, "err", err) + log.Trace(">> "+what, "id", toid, "addr", toaddr, "err", err) return err } @@ -573,13 +587,19 @@ func (t *udp) readLoop(unhandled chan<- ReadPacket) { } func (t *udp) handlePacket(from *net.UDPAddr, buf []byte) error { - packet, fromID, hash, err := decodePacket(buf) + packet, fromKey, hash, err := decodePacket(buf) if err != nil { log.Debug("Bad discv4 packet", "addr", from, "err", err) return err } - err = packet.handle(t, from, fromID, hash) - log.Trace("<< "+packet.name(), "addr", from, "err", err) + fromID := fromKey.id() + if err == nil { + err = packet.preverify(t, from, fromID, fromKey) + } + log.Trace("<< "+packet.name(), "id", fromID, "addr", from, "err", err) + if err == nil { + packet.handle(t, from, fromID, hash) + } return err } @@ -615,54 +635,67 @@ func decodePacket(buf []byte) (packet, encPubkey, []byte, error) { return req, fromKey, hash, err } -func (req *ping) handle(t *udp, from *net.UDPAddr, fromKey encPubkey, mac []byte) error { +// Packet Handlers + +func (req *ping) preverify(t *udp, from *net.UDPAddr, fromID enode.ID, fromKey encPubkey) error { if expired(req.Expiration) { return errExpired } key, err := decodePubkey(fromKey) if err != nil { - return fmt.Errorf("invalid public key: %v", err) + return errors.New("invalid public key") } - t.send(from, pongPacket, &pong{ + req.senderKey = key + return nil +} + +func (req *ping) handle(t *udp, from *net.UDPAddr, fromID enode.ID, mac []byte) { + // Reply. + t.send(from, fromID, pongPacket, &pong{ To: makeEndpoint(from, req.From.TCP), ReplyTok: mac, Expiration: uint64(time.Now().Add(expiration).Unix()), }) - n := wrapNode(enode.NewV4(key, from.IP, int(req.From.TCP), from.Port)) - t.handleReply(n.ID(), pingPacket, req) - if time.Since(t.db.LastPongReceived(n.ID())) > bondExpiration { - t.sendPing(n.ID(), from, func() { t.tab.addThroughPing(n) }) + + // Ping back if our last pong on file is too far in the past. + n := wrapNode(enode.NewV4(req.senderKey, from.IP, int(req.From.TCP), from.Port)) + if time.Since(t.db.LastPongReceived(n.ID(), from.IP)) > bondExpiration { + t.sendPing(fromID, from, func() { + t.tab.addThroughPing(n) + }) } else { t.tab.addThroughPing(n) } + + // Update node database and endpoint predictor. + t.db.UpdateLastPingReceived(n.ID(), from.IP, time.Now()) t.localNode.UDPEndpointStatement(from, &net.UDPAddr{IP: req.To.IP, Port: int(req.To.UDP)}) - t.db.UpdateLastPingReceived(n.ID(), time.Now()) - return nil } func (req *ping) name() string { return "PING/v4" } -func (req *pong) handle(t *udp, from *net.UDPAddr, fromKey encPubkey, mac []byte) error { +func (req *pong) preverify(t *udp, from *net.UDPAddr, fromID enode.ID, fromKey encPubkey) error { if expired(req.Expiration) { return errExpired } - fromID := fromKey.id() - if !t.handleReply(fromID, pongPacket, req) { + if !t.handleReply(fromID, from.IP, pongPacket, req) { return errUnsolicitedReply } - t.localNode.UDPEndpointStatement(from, &net.UDPAddr{IP: req.To.IP, Port: int(req.To.UDP)}) - t.db.UpdateLastPongReceived(fromID, time.Now()) return nil } +func (req *pong) handle(t *udp, from *net.UDPAddr, fromID enode.ID, mac []byte) { + t.localNode.UDPEndpointStatement(from, &net.UDPAddr{IP: req.To.IP, Port: int(req.To.UDP)}) + t.db.UpdateLastPongReceived(fromID, from.IP, time.Now()) +} + func (req *pong) name() string { return "PONG/v4" } -func (req *findnode) handle(t *udp, from *net.UDPAddr, fromKey encPubkey, mac []byte) error { +func (req *findnode) preverify(t *udp, from *net.UDPAddr, fromID enode.ID, fromKey encPubkey) error { if expired(req.Expiration) { return errExpired } - fromID := fromKey.id() - if time.Since(t.db.LastPongReceived(fromID)) > bondExpiration { + if time.Since(t.db.LastPongReceived(fromID, from.IP)) > bondExpiration { // No endpoint proof pong exists, we don't process the packet. This prevents an // attack vector where the discovery protocol could be used to amplify traffic in a // DDOS attack. A malicious actor would send a findnode request with the IP address @@ -671,43 +704,50 @@ func (req *findnode) handle(t *udp, from *net.UDPAddr, fromKey encPubkey, mac [] // findnode) to the victim. return errUnknownNode } + return nil +} + +func (req *findnode) handle(t *udp, from *net.UDPAddr, fromID enode.ID, mac []byte) { + // Determine closest nodes. target := enode.ID(crypto.Keccak256Hash(req.Target[:])) t.tab.mutex.Lock() closest := t.tab.closest(target, bucketSize).entries t.tab.mutex.Unlock() - p := neighbors{Expiration: uint64(time.Now().Add(expiration).Unix())} - var sent bool // Send neighbors in chunks with at most maxNeighbors per packet // to stay below the 1280 byte limit. + p := neighbors{Expiration: uint64(time.Now().Add(expiration).Unix())} + var sent bool for _, n := range closest { if netutil.CheckRelayIP(from.IP, n.IP()) == nil { p.Nodes = append(p.Nodes, nodeToRPC(n)) } if len(p.Nodes) == maxNeighbors { - t.send(from, neighborsPacket, &p) + t.send(from, fromID, neighborsPacket, &p) p.Nodes = p.Nodes[:0] sent = true } } if len(p.Nodes) > 0 || !sent { - t.send(from, neighborsPacket, &p) + t.send(from, fromID, neighborsPacket, &p) } - return nil } func (req *findnode) name() string { return "FINDNODE/v4" } -func (req *neighbors) handle(t *udp, from *net.UDPAddr, fromKey encPubkey, mac []byte) error { +func (req *neighbors) preverify(t *udp, from *net.UDPAddr, fromID enode.ID, fromKey encPubkey) error { if expired(req.Expiration) { return errExpired } - if !t.handleReply(fromKey.id(), neighborsPacket, req) { + if !t.handleReply(fromID, from.IP, neighborsPacket, req) { return errUnsolicitedReply } return nil } +func (req *neighbors) handle(t *udp, from *net.UDPAddr, fromID enode.ID, mac []byte) { +} + func (req *neighbors) name() string { return "NEIGHBORS/v4" } func expired(ts uint64) bool { diff --git a/p2p/discover/udp_test.go b/p2p/discover/udp_test.go index a4ddaf750..3d53c9309 100644 --- a/p2p/discover/udp_test.go +++ b/p2p/discover/udp_test.go @@ -19,6 +19,7 @@ package discover import ( "bytes" "crypto/ecdsa" + crand "crypto/rand" "encoding/binary" "encoding/hex" "errors" @@ -57,6 +58,7 @@ type udpTest struct { t *testing.T pipe *dgramPipe table *Table + db *enode.DB udp *udp sent [][]byte localkey, remotekey *ecdsa.PrivateKey @@ -71,22 +73,32 @@ func newUDPTest(t *testing.T) *udpTest { remotekey: newkey(), remoteaddr: &net.UDPAddr{IP: net.IP{10, 0, 1, 99}, Port: 30303}, } - db, _ := enode.OpenDB("") - ln := enode.NewLocalNode(db, test.localkey) + test.db, _ = enode.OpenDB("") + ln := enode.NewLocalNode(test.db, test.localkey) test.table, test.udp, _ = newUDP(test.pipe, ln, Config{PrivateKey: test.localkey}) // Wait for initial refresh so the table doesn't send unexpected findnode. <-test.table.initDone return test } +func (test *udpTest) close() { + test.table.Close() + test.db.Close() +} + // handles a packet as if it had been sent to the transport. func (test *udpTest) packetIn(wantError error, ptype byte, data packet) error { - enc, _, err := encodePacket(test.remotekey, ptype, data) + return test.packetInFrom(wantError, test.remotekey, test.remoteaddr, ptype, data) +} + +// handles a packet as if it had been sent to the transport by the key/endpoint. +func (test *udpTest) packetInFrom(wantError error, key *ecdsa.PrivateKey, addr *net.UDPAddr, ptype byte, data packet) error { + enc, _, err := encodePacket(key, ptype, data) if err != nil { return test.errorf("packet (%d) encode error: %v", ptype, err) } test.sent = append(test.sent, enc) - if err = test.udp.handlePacket(test.remoteaddr, enc); err != wantError { + if err = test.udp.handlePacket(addr, enc); err != wantError { return test.errorf("error mismatch: got %q, want %q", err, wantError) } return nil @@ -94,19 +106,19 @@ func (test *udpTest) packetIn(wantError error, ptype byte, data packet) error { // waits for a packet to be sent by the transport. // validate should have type func(*udpTest, X) error, where X is a packet type. -func (test *udpTest) waitPacketOut(validate interface{}) ([]byte, error) { +func (test *udpTest) waitPacketOut(validate interface{}) (*net.UDPAddr, []byte, error) { dgram := test.pipe.waitPacketOut() - p, _, hash, err := decodePacket(dgram) + p, _, hash, err := decodePacket(dgram.data) if err != nil { - return hash, test.errorf("sent packet decode error: %v", err) + return &dgram.to, hash, test.errorf("sent packet decode error: %v", err) } fn := reflect.ValueOf(validate) exptype := fn.Type().In(0) if reflect.TypeOf(p) != exptype { - return hash, test.errorf("sent packet type mismatch, got: %v, want: %v", reflect.TypeOf(p), exptype) + return &dgram.to, hash, test.errorf("sent packet type mismatch, got: %v, want: %v", reflect.TypeOf(p), exptype) } fn.Call([]reflect.Value{reflect.ValueOf(p)}) - return hash, nil + return &dgram.to, hash, nil } func (test *udpTest) errorf(format string, args ...interface{}) error { @@ -125,7 +137,7 @@ func (test *udpTest) errorf(format string, args ...interface{}) error { func TestUDP_packetErrors(t *testing.T) { test := newUDPTest(t) - defer test.table.Close() + defer test.close() test.packetIn(errExpired, pingPacket, &ping{From: testRemote, To: testLocalAnnounced, Version: 4}) test.packetIn(errUnsolicitedReply, pongPacket, &pong{ReplyTok: []byte{}, Expiration: futureExp}) @@ -136,7 +148,7 @@ func TestUDP_packetErrors(t *testing.T) { func TestUDP_pingTimeout(t *testing.T) { t.Parallel() test := newUDPTest(t) - defer test.table.Close() + defer test.close() toaddr := &net.UDPAddr{IP: net.ParseIP("1.2.3.4"), Port: 2222} toid := enode.ID{1, 2, 3, 4} @@ -148,7 +160,7 @@ func TestUDP_pingTimeout(t *testing.T) { func TestUDP_responseTimeouts(t *testing.T) { t.Parallel() test := newUDPTest(t) - defer test.table.Close() + defer test.close() rand.Seed(time.Now().UnixNano()) randomDuration := func(max time.Duration) time.Duration { @@ -166,20 +178,20 @@ func TestUDP_responseTimeouts(t *testing.T) { // with ptype <= 128 will not get a reply and should time out. // For all other requests, a reply is scheduled to arrive // within the timeout window. - p := &pending{ + p := &replyMatcher{ ptype: byte(rand.Intn(255)), - callback: func(interface{}) bool { return true }, + callback: func(interface{}) (bool, bool) { return true, true }, } binary.BigEndian.PutUint64(p.from[:], uint64(i)) if p.ptype <= 128 { p.errc = timeoutErr - test.udp.addpending <- p + test.udp.addReplyMatcher <- p nTimeouts++ } else { p.errc = nilErr - test.udp.addpending <- p + test.udp.addReplyMatcher <- p time.AfterFunc(randomDuration(60*time.Millisecond), func() { - if !test.udp.handleReply(p.from, p.ptype, nil) { + if !test.udp.handleReply(p.from, p.ip, p.ptype, nil) { t.Logf("not matched: %v", p) } }) @@ -220,7 +232,7 @@ func TestUDP_responseTimeouts(t *testing.T) { func TestUDP_findnodeTimeout(t *testing.T) { t.Parallel() test := newUDPTest(t) - defer test.table.Close() + defer test.close() toaddr := &net.UDPAddr{IP: net.ParseIP("1.2.3.4"), Port: 2222} toid := enode.ID{1, 2, 3, 4} @@ -236,50 +248,65 @@ func TestUDP_findnodeTimeout(t *testing.T) { func TestUDP_findnode(t *testing.T) { test := newUDPTest(t) - defer test.table.Close() + defer test.close() // put a few nodes into the table. their exact // distribution shouldn't matter much, although we need to // take care not to overflow any bucket. nodes := &nodesByDistance{target: testTarget.id()} - for i := 0; i < bucketSize; i++ { + live := make(map[enode.ID]bool) + numCandidates := 2 * bucketSize + for i := 0; i < numCandidates; i++ { key := newkey() - n := wrapNode(enode.NewV4(&key.PublicKey, net.IP{10, 13, 0, 1}, 0, i)) - nodes.push(n, bucketSize) + ip := net.IP{10, 13, 0, byte(i)} + n := wrapNode(enode.NewV4(&key.PublicKey, ip, 0, 2000)) + // Ensure half of table content isn't verified live yet. + if i > numCandidates/2 { + n.livenessChecks = 1 + live[n.ID()] = true + } + nodes.push(n, numCandidates) } - test.table.stuff(nodes.entries) + fillTable(test.table, nodes.entries) // ensure there's a bond with the test node, // findnode won't be accepted otherwise. remoteID := encodePubkey(&test.remotekey.PublicKey).id() - test.table.db.UpdateLastPongReceived(remoteID, time.Now()) + test.table.db.UpdateLastPongReceived(remoteID, test.remoteaddr.IP, time.Now()) // check that closest neighbors are returned. - test.packetIn(nil, findnodePacket, &findnode{Target: testTarget, Expiration: futureExp}) expected := test.table.closest(testTarget.id(), bucketSize) - + test.packetIn(nil, findnodePacket, &findnode{Target: testTarget, Expiration: futureExp}) waitNeighbors := func(want []*node) { test.waitPacketOut(func(p *neighbors) { if len(p.Nodes) != len(want) { t.Errorf("wrong number of results: got %d, want %d", len(p.Nodes), bucketSize) } - for i := range p.Nodes { - if p.Nodes[i].ID.id() != want[i].ID() { - t.Errorf("result mismatch at %d:\n got: %v\n want: %v", i, p.Nodes[i], expected.entries[i]) + for i, n := range p.Nodes { + if n.ID.id() != want[i].ID() { + t.Errorf("result mismatch at %d:\n got: %v\n want: %v", i, n, expected.entries[i]) + } + if !live[n.ID.id()] { + t.Errorf("result includes dead node %v", n.ID.id()) } } }) } - waitNeighbors(expected.entries[:maxNeighbors]) - waitNeighbors(expected.entries[maxNeighbors:]) + // Receive replies. + want := expected.entries + if len(want) > maxNeighbors { + waitNeighbors(want[:maxNeighbors]) + want = want[maxNeighbors:] + } + waitNeighbors(want) } func TestUDP_findnodeMultiReply(t *testing.T) { test := newUDPTest(t) - defer test.table.Close() + defer test.close() rid := enode.PubkeyToIDV4(&test.remotekey.PublicKey) - test.table.db.UpdateLastPingReceived(rid, time.Now()) + test.table.db.UpdateLastPingReceived(rid, test.remoteaddr.IP, time.Now()) // queue a pending findnode request resultc, errc := make(chan []*node), make(chan error) @@ -329,11 +356,40 @@ func TestUDP_findnodeMultiReply(t *testing.T) { } } +func TestUDP_pingMatch(t *testing.T) { + test := newUDPTest(t) + defer test.close() + + randToken := make([]byte, 32) + crand.Read(randToken) + + test.packetIn(nil, pingPacket, &ping{From: testRemote, To: testLocalAnnounced, Version: 4, Expiration: futureExp}) + test.waitPacketOut(func(*pong) error { return nil }) + test.waitPacketOut(func(*ping) error { return nil }) + test.packetIn(errUnsolicitedReply, pongPacket, &pong{ReplyTok: randToken, To: testLocalAnnounced, Expiration: futureExp}) +} + +func TestUDP_pingMatchIP(t *testing.T) { + test := newUDPTest(t) + defer test.close() + + test.packetIn(nil, pingPacket, &ping{From: testRemote, To: testLocalAnnounced, Version: 4, Expiration: futureExp}) + test.waitPacketOut(func(*pong) error { return nil }) + + _, hash, _ := test.waitPacketOut(func(*ping) error { return nil }) + wrongAddr := &net.UDPAddr{IP: net.IP{33, 44, 1, 2}, Port: 30000} + test.packetInFrom(errUnsolicitedReply, test.remotekey, wrongAddr, pongPacket, &pong{ + ReplyTok: hash, + To: testLocalAnnounced, + Expiration: futureExp, + }) +} + func TestUDP_successfulPing(t *testing.T) { test := newUDPTest(t) added := make(chan *node, 1) test.table.nodeAddedHook = func(n *node) { added <- n } - defer test.table.Close() + defer test.close() // The remote side sends a ping packet to initiate the exchange. go test.packetIn(nil, pingPacket, &ping{From: testRemote, To: testLocalAnnounced, Version: 4, Expiration: futureExp}) @@ -356,7 +412,7 @@ func TestUDP_successfulPing(t *testing.T) { }) // remote is unknown, the table pings back. - hash, _ := test.waitPacketOut(func(p *ping) error { + _, hash, _ := test.waitPacketOut(func(p *ping) error { if !reflect.DeepEqual(p.From, test.udp.ourEndpoint()) { t.Errorf("got ping.From %#v, want %#v", p.From, test.udp.ourEndpoint()) } @@ -510,7 +566,12 @@ type dgramPipe struct { cond *sync.Cond closing chan struct{} closed bool - queue [][]byte + queue []dgram +} + +type dgram struct { + to net.UDPAddr + data []byte } func newpipe() *dgramPipe { @@ -531,7 +592,7 @@ func (c *dgramPipe) WriteToUDP(b []byte, to *net.UDPAddr) (n int, err error) { if c.closed { return 0, errors.New("closed") } - c.queue = append(c.queue, msg) + c.queue = append(c.queue, dgram{*to, b}) c.cond.Signal() return len(b), nil } @@ -556,7 +617,7 @@ func (c *dgramPipe) LocalAddr() net.Addr { return &net.UDPAddr{IP: testLocal.IP, Port: int(testLocal.UDP)} } -func (c *dgramPipe) waitPacketOut() []byte { +func (c *dgramPipe) waitPacketOut() dgram { c.mu.Lock() defer c.mu.Unlock() for len(c.queue) == 0 { |