aboutsummaryrefslogtreecommitdiffstats
path: root/p2p/discover
diff options
context:
space:
mode:
authorFelix Lange <fjl@users.noreply.github.com>2019-01-30 00:39:20 +0800
committerFelix Lange <fjl@twurst.com>2019-01-30 00:50:15 +0800
commit4cd90e02e23ecf2bb11bcb4bba4fea2ae164ef74 (patch)
tree9f4752c2ead76a83998eb9aec12fb1b136cf51b6 /p2p/discover
parent1f3dfed19e0d1e0e2536d547e8fd37e9d0ad3cdf (diff)
downloaddexon-4cd90e02e23ecf2bb11bcb4bba4fea2ae164ef74.tar.gz
dexon-4cd90e02e23ecf2bb11bcb4bba4fea2ae164ef74.tar.zst
dexon-4cd90e02e23ecf2bb11bcb4bba4fea2ae164ef74.zip
p2p/discover, p2p/enode: rework endpoint proof handling, packet logging (#18963)
This change resolves multiple issues around handling of endpoint proofs. The proof is now done separately for each IP and completing the proof requires a matching ping hash. Also remove waitping because it's equivalent to sleep. waitping was slightly more efficient, but that may cause issues with findnode if packets are reordered and the remote end sees findnode before pong. Logging of received packets was hitherto done after handling the packet, which meant that sent replies were logged before the packet that generated them. This change splits up packet handling into 'preverify' and 'handle'. The error from 'preverify' is logged, but 'handle' happens after the message is logged. This fixes the order. Packet logs now contain the node ID.
Diffstat (limited to 'p2p/discover')
-rw-r--r--p2p/discover/node.go3
-rw-r--r--p2p/discover/table.go84
-rw-r--r--p2p/discover/table_test.go34
-rw-r--r--p2p/discover/table_util_test.go21
-rw-r--r--p2p/discover/udp.go222
-rw-r--r--p2p/discover/udp_test.go137
6 files changed, 310 insertions, 191 deletions
diff --git a/p2p/discover/node.go b/p2p/discover/node.go
index 7ddf04fe8..8d4af166b 100644
--- a/p2p/discover/node.go
+++ b/p2p/discover/node.go
@@ -33,7 +33,8 @@ import (
// The fields of Node may not be modified.
type node struct {
enode.Node
- addedAt time.Time // time when the node was added to the table
+ addedAt time.Time // time when the node was added to the table
+ livenessChecks uint // how often liveness was checked
}
type encPubkey [64]byte
diff --git a/p2p/discover/table.go b/p2p/discover/table.go
index 9f7f1d41b..ba4c06327 100644
--- a/p2p/discover/table.go
+++ b/p2p/discover/table.go
@@ -75,8 +75,10 @@ type Table struct {
net transport
refreshReq chan chan struct{}
initDone chan struct{}
- closeReq chan struct{}
- closed chan struct{}
+
+ closeOnce sync.Once
+ closeReq chan struct{}
+ closed chan struct{}
nodeAddedHook func(*node) // for testing
}
@@ -180,16 +182,14 @@ func (tab *Table) ReadRandomNodes(buf []*enode.Node) (n int) {
// Close terminates the network listener and flushes the node database.
func (tab *Table) Close() {
- if tab.net != nil {
- tab.net.close()
- }
-
- select {
- case <-tab.closed:
- // already closed.
- case tab.closeReq <- struct{}{}:
- <-tab.closed // wait for refreshLoop to end.
- }
+ tab.closeOnce.Do(func() {
+ if tab.net != nil {
+ tab.net.close()
+ }
+ // Wait for loop to end.
+ close(tab.closeReq)
+ <-tab.closed
+ })
}
// setFallbackNodes sets the initial points of contact. These nodes
@@ -290,12 +290,16 @@ func (tab *Table) lookup(targetKey encPubkey, refreshIfEmpty bool) []*node {
// we have asked all closest nodes, stop the search
break
}
- // wait for the next reply
- for _, n := range <-reply {
- if n != nil && !seen[n.ID()] {
- seen[n.ID()] = true
- result.push(n, bucketSize)
+ select {
+ case nodes := <-reply:
+ for _, n := range nodes {
+ if n != nil && !seen[n.ID()] {
+ seen[n.ID()] = true
+ result.push(n, bucketSize)
+ }
}
+ case <-tab.closeReq:
+ return nil // shutdown, no need to continue.
}
pendingQueries--
}
@@ -303,18 +307,22 @@ func (tab *Table) lookup(targetKey encPubkey, refreshIfEmpty bool) []*node {
}
func (tab *Table) findnode(n *node, targetKey encPubkey, reply chan<- []*node) {
- fails := tab.db.FindFails(n.ID())
+ fails := tab.db.FindFails(n.ID(), n.IP())
r, err := tab.net.findnode(n.ID(), n.addr(), targetKey)
- if err != nil || len(r) == 0 {
+ if err == errClosed {
+ // Avoid recording failures on shutdown.
+ reply <- nil
+ return
+ } else if err != nil || len(r) == 0 {
fails++
- tab.db.UpdateFindFails(n.ID(), fails)
+ tab.db.UpdateFindFails(n.ID(), n.IP(), fails)
log.Trace("Findnode failed", "id", n.ID(), "failcount", fails, "err", err)
if fails >= maxFindnodeFailures {
log.Trace("Too many findnode failures, dropping", "id", n.ID(), "failcount", fails)
tab.delete(n)
}
} else if fails > 0 {
- tab.db.UpdateFindFails(n.ID(), fails-1)
+ tab.db.UpdateFindFails(n.ID(), n.IP(), fails-1)
}
// Grab as many nodes as possible. Some of them might not be alive anymore, but we'll
@@ -329,7 +337,7 @@ func (tab *Table) refresh() <-chan struct{} {
done := make(chan struct{})
select {
case tab.refreshReq <- done:
- case <-tab.closed:
+ case <-tab.closeReq:
close(done)
}
return done
@@ -433,7 +441,7 @@ func (tab *Table) loadSeedNodes() {
seeds = append(seeds, tab.nursery...)
for i := range seeds {
seed := seeds[i]
- age := log.Lazy{Fn: func() interface{} { return time.Since(tab.db.LastPongReceived(seed.ID())) }}
+ age := log.Lazy{Fn: func() interface{} { return time.Since(tab.db.LastPongReceived(seed.ID(), seed.IP())) }}
log.Trace("Found seed node in database", "id", seed.ID(), "addr", seed.addr(), "age", age)
tab.add(seed)
}
@@ -458,16 +466,17 @@ func (tab *Table) doRevalidate(done chan<- struct{}) {
b := tab.buckets[bi]
if err == nil {
// The node responded, move it to the front.
- log.Debug("Revalidated node", "b", bi, "id", last.ID())
+ last.livenessChecks++
+ log.Debug("Revalidated node", "b", bi, "id", last.ID(), "checks", last.livenessChecks)
b.bump(last)
return
}
// No reply received, pick a replacement or delete the node if there aren't
// any replacements.
if r := tab.replace(b, last); r != nil {
- log.Debug("Replaced dead node", "b", bi, "id", last.ID(), "ip", last.IP(), "r", r.ID(), "rip", r.IP())
+ log.Debug("Replaced dead node", "b", bi, "id", last.ID(), "ip", last.IP(), "checks", last.livenessChecks, "r", r.ID(), "rip", r.IP())
} else {
- log.Debug("Removed dead node", "b", bi, "id", last.ID(), "ip", last.IP())
+ log.Debug("Removed dead node", "b", bi, "id", last.ID(), "ip", last.IP(), "checks", last.livenessChecks)
}
}
@@ -502,7 +511,7 @@ func (tab *Table) copyLiveNodes() {
now := time.Now()
for _, b := range &tab.buckets {
for _, n := range b.entries {
- if now.Sub(n.addedAt) >= seedMinTableTime {
+ if n.livenessChecks > 0 && now.Sub(n.addedAt) >= seedMinTableTime {
tab.db.UpdateNode(unwrapNode(n))
}
}
@@ -518,7 +527,9 @@ func (tab *Table) closest(target enode.ID, nresults int) *nodesByDistance {
close := &nodesByDistance{target: target}
for _, b := range &tab.buckets {
for _, n := range b.entries {
- close.push(n, nresults)
+ if n.livenessChecks > 0 {
+ close.push(n, nresults)
+ }
}
}
return close
@@ -572,23 +583,6 @@ func (tab *Table) addThroughPing(n *node) {
tab.add(n)
}
-// stuff adds nodes the table to the end of their corresponding bucket
-// if the bucket is not full. The caller must not hold tab.mutex.
-func (tab *Table) stuff(nodes []*node) {
- tab.mutex.Lock()
- defer tab.mutex.Unlock()
-
- for _, n := range nodes {
- if n.ID() == tab.self().ID() {
- continue // don't add self
- }
- b := tab.bucket(n.ID())
- if len(b.entries) < bucketSize {
- tab.bumpOrAdd(b, n)
- }
- }
-}
-
// delete removes an entry from the node table. It is used to evacuate dead nodes.
func (tab *Table) delete(node *node) {
tab.mutex.Lock()
diff --git a/p2p/discover/table_test.go b/p2p/discover/table_test.go
index 6b4cd2d18..b00a93211 100644
--- a/p2p/discover/table_test.go
+++ b/p2p/discover/table_test.go
@@ -50,8 +50,8 @@ func TestTable_pingReplace(t *testing.T) {
func testPingReplace(t *testing.T, newNodeIsResponding, lastInBucketIsResponding bool) {
transport := newPingRecorder()
tab, db := newTestTable(transport)
- defer tab.Close()
defer db.Close()
+ defer tab.Close()
<-tab.initDone
@@ -137,8 +137,8 @@ func TestBucket_bumpNoDuplicates(t *testing.T) {
func TestTable_IPLimit(t *testing.T) {
transport := newPingRecorder()
tab, db := newTestTable(transport)
- defer tab.Close()
defer db.Close()
+ defer tab.Close()
for i := 0; i < tableIPLimit+1; i++ {
n := nodeAtDistance(tab.self().ID(), i, net.IP{172, 0, 1, byte(i)})
@@ -153,8 +153,8 @@ func TestTable_IPLimit(t *testing.T) {
func TestTable_BucketIPLimit(t *testing.T) {
transport := newPingRecorder()
tab, db := newTestTable(transport)
- defer tab.Close()
defer db.Close()
+ defer tab.Close()
d := 3
for i := 0; i < bucketIPLimit+1; i++ {
@@ -173,9 +173,9 @@ func TestTable_closest(t *testing.T) {
// for any node table, Target and N
transport := newPingRecorder()
tab, db := newTestTable(transport)
- defer tab.Close()
defer db.Close()
- tab.stuff(test.All)
+ defer tab.Close()
+ fillTable(tab, test.All)
// check that closest(Target, N) returns nodes
result := tab.closest(test.Target, test.N).entries
@@ -234,13 +234,13 @@ func TestTable_ReadRandomNodesGetAll(t *testing.T) {
test := func(buf []*enode.Node) bool {
transport := newPingRecorder()
tab, db := newTestTable(transport)
- defer tab.Close()
defer db.Close()
+ defer tab.Close()
<-tab.initDone
for i := 0; i < len(buf); i++ {
ld := cfg.Rand.Intn(len(tab.buckets))
- tab.stuff([]*node{nodeAtDistance(tab.self().ID(), ld, intIP(ld))})
+ fillTable(tab, []*node{nodeAtDistance(tab.self().ID(), ld, intIP(ld))})
}
gotN := tab.ReadRandomNodes(buf)
if gotN != tab.len() {
@@ -272,16 +272,19 @@ func (*closeTest) Generate(rand *rand.Rand, size int) reflect.Value {
N: rand.Intn(bucketSize),
}
for _, id := range gen([]enode.ID{}, rand).([]enode.ID) {
- n := enode.SignNull(new(enr.Record), id)
- t.All = append(t.All, wrapNode(n))
+ r := new(enr.Record)
+ r.Set(enr.IP(genIP(rand)))
+ n := wrapNode(enode.SignNull(r, id))
+ n.livenessChecks = 1
+ t.All = append(t.All, n)
}
return reflect.ValueOf(t)
}
func TestTable_Lookup(t *testing.T) {
tab, db := newTestTable(lookupTestnet)
- defer tab.Close()
defer db.Close()
+ defer tab.Close()
// lookup on empty table returns no nodes
if results := tab.lookup(lookupTestnet.target, false); len(results) > 0 {
@@ -289,8 +292,9 @@ func TestTable_Lookup(t *testing.T) {
}
// seed table with initial node (otherwise lookup will terminate immediately)
seedKey, _ := decodePubkey(lookupTestnet.dists[256][0])
- seed := wrapNode(enode.NewV4(seedKey, net.IP{}, 0, 256))
- tab.stuff([]*node{seed})
+ seed := wrapNode(enode.NewV4(seedKey, net.IP{127, 0, 0, 1}, 0, 256))
+ seed.livenessChecks = 1
+ fillTable(tab, []*node{seed})
results := tab.lookup(lookupTestnet.target, true)
t.Logf("results:")
@@ -578,6 +582,12 @@ func gen(typ interface{}, rand *rand.Rand) interface{} {
return v.Interface()
}
+func genIP(rand *rand.Rand) net.IP {
+ ip := make(net.IP, 4)
+ rand.Read(ip)
+ return ip
+}
+
func quickcfg() *quick.Config {
return &quick.Config{
MaxCount: 5000,
diff --git a/p2p/discover/table_util_test.go b/p2p/discover/table_util_test.go
index d41519452..3ce582b99 100644
--- a/p2p/discover/table_util_test.go
+++ b/p2p/discover/table_util_test.go
@@ -83,6 +83,23 @@ func fillBucket(tab *Table, n *node) (last *node) {
return b.entries[bucketSize-1]
}
+// fillTable adds nodes the table to the end of their corresponding bucket
+// if the bucket is not full. The caller must not hold tab.mutex.
+func fillTable(tab *Table, nodes []*node) {
+ tab.mutex.Lock()
+ defer tab.mutex.Unlock()
+
+ for _, n := range nodes {
+ if n.ID() == tab.self().ID() {
+ continue // don't add self
+ }
+ b := tab.bucket(n.ID())
+ if len(b.entries) < bucketSize {
+ tab.bumpOrAdd(b, n)
+ }
+ }
+}
+
type pingRecorder struct {
mu sync.Mutex
dead, pinged map[enode.ID]bool
@@ -109,10 +126,6 @@ func (t *pingRecorder) findnode(toid enode.ID, toaddr *net.UDPAddr, target encPu
return nil, nil
}
-func (t *pingRecorder) waitping(from enode.ID) error {
- return nil // remote always pings
-}
-
func (t *pingRecorder) ping(toid enode.ID, toaddr *net.UDPAddr) error {
t.mu.Lock()
defer t.mu.Unlock()
diff --git a/p2p/discover/udp.go b/p2p/discover/udp.go
index 37a044902..5ce4c43dc 100644
--- a/p2p/discover/udp.go
+++ b/p2p/discover/udp.go
@@ -67,6 +67,8 @@ const (
// RPC request structures
type (
ping struct {
+ senderKey *ecdsa.PublicKey // filled in by preverify
+
Version uint
From, To rpcEndpoint
Expiration uint64
@@ -155,8 +157,13 @@ func nodeToRPC(n *node) rpcNode {
return rpcNode{ID: ekey, IP: n.IP(), UDP: uint16(n.UDP()), TCP: uint16(n.TCP())}
}
+// packet is implemented by all protocol messages.
type packet interface {
- handle(t *udp, from *net.UDPAddr, fromKey encPubkey, mac []byte) error
+ // preverify checks whether the packet is valid and should be handled at all.
+ preverify(t *udp, from *net.UDPAddr, fromID enode.ID, fromKey encPubkey) error
+ // handle handles the packet.
+ handle(t *udp, from *net.UDPAddr, fromID enode.ID, mac []byte)
+ // name returns the name of the packet for logging purposes.
name() string
}
@@ -177,43 +184,48 @@ type udp struct {
tab *Table
wg sync.WaitGroup
- addpending chan *pending
- gotreply chan reply
- closing chan struct{}
+ addReplyMatcher chan *replyMatcher
+ gotreply chan reply
+ closing chan struct{}
}
// pending represents a pending reply.
//
-// some implementations of the protocol wish to send more than one
-// reply packet to findnode. in general, any neighbors packet cannot
+// Some implementations of the protocol wish to send more than one
+// reply packet to findnode. In general, any neighbors packet cannot
// be matched up with a specific findnode packet.
//
-// our implementation handles this by storing a callback function for
-// each pending reply. incoming packets from a node are dispatched
-// to all the callback functions for that node.
-type pending struct {
+// Our implementation handles this by storing a callback function for
+// each pending reply. Incoming packets from a node are dispatched
+// to all callback functions for that node.
+type replyMatcher struct {
// these fields must match in the reply.
from enode.ID
+ ip net.IP
ptype byte
// time when the request must complete
deadline time.Time
- // callback is called when a matching reply arrives. if it returns
- // true, the callback is removed from the pending reply queue.
- // if it returns false, the reply is considered incomplete and
- // the callback will be invoked again for the next matching reply.
- callback func(resp interface{}) (done bool)
+ // callback is called when a matching reply arrives. If it returns matched == true, the
+ // reply was acceptable. The second return value indicates whether the callback should
+ // be removed from the pending reply queue. If it returns false, the reply is considered
+ // incomplete and the callback will be invoked again for the next matching reply.
+ callback replyMatchFunc
// errc receives nil when the callback indicates completion or an
// error if no further reply is received within the timeout.
errc chan<- error
}
+type replyMatchFunc func(interface{}) (matched bool, requestDone bool)
+
type reply struct {
from enode.ID
+ ip net.IP
ptype byte
- data interface{}
+ data packet
+
// loop indicates whether there was
// a matching request by sending on this channel.
matched chan<- bool
@@ -247,14 +259,14 @@ func ListenUDP(c conn, ln *enode.LocalNode, cfg Config) (*Table, error) {
func newUDP(c conn, ln *enode.LocalNode, cfg Config) (*Table, *udp, error) {
udp := &udp{
- conn: c,
- priv: cfg.PrivateKey,
- netrestrict: cfg.NetRestrict,
- localNode: ln,
- db: ln.Database(),
- closing: make(chan struct{}),
- gotreply: make(chan reply),
- addpending: make(chan *pending),
+ conn: c,
+ priv: cfg.PrivateKey,
+ netrestrict: cfg.NetRestrict,
+ localNode: ln,
+ db: ln.Database(),
+ closing: make(chan struct{}),
+ gotreply: make(chan reply),
+ addReplyMatcher: make(chan *replyMatcher),
}
tab, err := newTable(udp, ln.Database(), cfg.Bootnodes)
if err != nil {
@@ -304,35 +316,37 @@ func (t *udp) sendPing(toid enode.ID, toaddr *net.UDPAddr, callback func()) <-ch
errc <- err
return errc
}
- errc := t.pending(toid, pongPacket, func(p interface{}) bool {
- ok := bytes.Equal(p.(*pong).ReplyTok, hash)
- if ok && callback != nil {
+ // Add a matcher for the reply to the pending reply queue. Pongs are matched if they
+ // reference the ping we're about to send.
+ errc := t.pending(toid, toaddr.IP, pongPacket, func(p interface{}) (matched bool, requestDone bool) {
+ matched = bytes.Equal(p.(*pong).ReplyTok, hash)
+ if matched && callback != nil {
callback()
}
- return ok
+ return matched, matched
})
+ // Send the packet.
t.localNode.UDPContact(toaddr)
- t.write(toaddr, req.name(), packet)
+ t.write(toaddr, toid, req.name(), packet)
return errc
}
-func (t *udp) waitping(from enode.ID) error {
- return <-t.pending(from, pingPacket, func(interface{}) bool { return true })
-}
-
// findnode sends a findnode request to the given node and waits until
// the node has sent up to k neighbors.
func (t *udp) findnode(toid enode.ID, toaddr *net.UDPAddr, target encPubkey) ([]*node, error) {
// If we haven't seen a ping from the destination node for a while, it won't remember
// our endpoint proof and reject findnode. Solicit a ping first.
- if time.Since(t.db.LastPingReceived(toid)) > bondExpiration {
+ if time.Since(t.db.LastPingReceived(toid, toaddr.IP)) > bondExpiration {
t.ping(toid, toaddr)
- t.waitping(toid)
+ // Wait for them to ping back and process our pong.
+ time.Sleep(respTimeout)
}
+ // Add a matcher for 'neighbours' replies to the pending reply queue. The matcher is
+ // active until enough nodes have been received.
nodes := make([]*node, 0, bucketSize)
nreceived := 0
- errc := t.pending(toid, neighborsPacket, func(r interface{}) bool {
+ errc := t.pending(toid, toaddr.IP, neighborsPacket, func(r interface{}) (matched bool, requestDone bool) {
reply := r.(*neighbors)
for _, rn := range reply.Nodes {
nreceived++
@@ -343,22 +357,22 @@ func (t *udp) findnode(toid enode.ID, toaddr *net.UDPAddr, target encPubkey) ([]
}
nodes = append(nodes, n)
}
- return nreceived >= bucketSize
+ return true, nreceived >= bucketSize
})
- t.send(toaddr, findnodePacket, &findnode{
+ t.send(toaddr, toid, findnodePacket, &findnode{
Target: target,
Expiration: uint64(time.Now().Add(expiration).Unix()),
})
return nodes, <-errc
}
-// pending adds a reply callback to the pending reply queue.
-// see the documentation of type pending for a detailed explanation.
-func (t *udp) pending(id enode.ID, ptype byte, callback func(interface{}) bool) <-chan error {
+// pending adds a reply matcher to the pending reply queue.
+// see the documentation of type replyMatcher for a detailed explanation.
+func (t *udp) pending(id enode.ID, ip net.IP, ptype byte, callback replyMatchFunc) <-chan error {
ch := make(chan error, 1)
- p := &pending{from: id, ptype: ptype, callback: callback, errc: ch}
+ p := &replyMatcher{from: id, ip: ip, ptype: ptype, callback: callback, errc: ch}
select {
- case t.addpending <- p:
+ case t.addReplyMatcher <- p:
// loop will handle it
case <-t.closing:
ch <- errClosed
@@ -366,10 +380,12 @@ func (t *udp) pending(id enode.ID, ptype byte, callback func(interface{}) bool)
return ch
}
-func (t *udp) handleReply(from enode.ID, ptype byte, req packet) bool {
+// handleReply dispatches a reply packet, invoking reply matchers. It returns
+// whether any matcher considered the packet acceptable.
+func (t *udp) handleReply(from enode.ID, fromIP net.IP, ptype byte, req packet) bool {
matched := make(chan bool, 1)
select {
- case t.gotreply <- reply{from, ptype, req, matched}:
+ case t.gotreply <- reply{from, fromIP, ptype, req, matched}:
// loop will handle it
return <-matched
case <-t.closing:
@@ -385,8 +401,8 @@ func (t *udp) loop() {
var (
plist = list.New()
timeout = time.NewTimer(0)
- nextTimeout *pending // head of plist when timeout was last reset
- contTimeouts = 0 // number of continuous timeouts to do NTP checks
+ nextTimeout *replyMatcher // head of plist when timeout was last reset
+ contTimeouts = 0 // number of continuous timeouts to do NTP checks
ntpWarnTime = time.Unix(0, 0)
)
<-timeout.C // ignore first timeout
@@ -399,7 +415,7 @@ func (t *udp) loop() {
// Start the timer so it fires when the next pending reply has expired.
now := time.Now()
for el := plist.Front(); el != nil; el = el.Next() {
- nextTimeout = el.Value.(*pending)
+ nextTimeout = el.Value.(*replyMatcher)
if dist := nextTimeout.deadline.Sub(now); dist < 2*respTimeout {
timeout.Reset(dist)
return
@@ -420,25 +436,23 @@ func (t *udp) loop() {
select {
case <-t.closing:
for el := plist.Front(); el != nil; el = el.Next() {
- el.Value.(*pending).errc <- errClosed
+ el.Value.(*replyMatcher).errc <- errClosed
}
return
- case p := <-t.addpending:
+ case p := <-t.addReplyMatcher:
p.deadline = time.Now().Add(respTimeout)
plist.PushBack(p)
case r := <-t.gotreply:
- var matched bool
+ var matched bool // whether any replyMatcher considered the reply acceptable.
for el := plist.Front(); el != nil; el = el.Next() {
- p := el.Value.(*pending)
- if p.from == r.from && p.ptype == r.ptype {
- matched = true
- // Remove the matcher if its callback indicates
- // that all replies have been received. This is
- // required for packet types that expect multiple
- // reply packets.
- if p.callback(r.data) {
+ p := el.Value.(*replyMatcher)
+ if p.from == r.from && p.ptype == r.ptype && p.ip.Equal(r.ip) {
+ ok, requestDone := p.callback(r.data)
+ matched = matched || ok
+ // Remove the matcher if callback indicates that all replies have been received.
+ if requestDone {
p.errc <- nil
plist.Remove(el)
}
@@ -453,7 +467,7 @@ func (t *udp) loop() {
// Notify and remove callbacks whose deadline is in the past.
for el := plist.Front(); el != nil; el = el.Next() {
- p := el.Value.(*pending)
+ p := el.Value.(*replyMatcher)
if now.After(p.deadline) || now.Equal(p.deadline) {
p.errc <- errTimeout
plist.Remove(el)
@@ -504,17 +518,17 @@ func init() {
}
}
-func (t *udp) send(toaddr *net.UDPAddr, ptype byte, req packet) ([]byte, error) {
+func (t *udp) send(toaddr *net.UDPAddr, toid enode.ID, ptype byte, req packet) ([]byte, error) {
packet, hash, err := encodePacket(t.priv, ptype, req)
if err != nil {
return hash, err
}
- return hash, t.write(toaddr, req.name(), packet)
+ return hash, t.write(toaddr, toid, req.name(), packet)
}
-func (t *udp) write(toaddr *net.UDPAddr, what string, packet []byte) error {
+func (t *udp) write(toaddr *net.UDPAddr, toid enode.ID, what string, packet []byte) error {
_, err := t.conn.WriteToUDP(packet, toaddr)
- log.Trace(">> "+what, "addr", toaddr, "err", err)
+ log.Trace(">> "+what, "id", toid, "addr", toaddr, "err", err)
return err
}
@@ -573,13 +587,19 @@ func (t *udp) readLoop(unhandled chan<- ReadPacket) {
}
func (t *udp) handlePacket(from *net.UDPAddr, buf []byte) error {
- packet, fromID, hash, err := decodePacket(buf)
+ packet, fromKey, hash, err := decodePacket(buf)
if err != nil {
log.Debug("Bad discv4 packet", "addr", from, "err", err)
return err
}
- err = packet.handle(t, from, fromID, hash)
- log.Trace("<< "+packet.name(), "addr", from, "err", err)
+ fromID := fromKey.id()
+ if err == nil {
+ err = packet.preverify(t, from, fromID, fromKey)
+ }
+ log.Trace("<< "+packet.name(), "id", fromID, "addr", from, "err", err)
+ if err == nil {
+ packet.handle(t, from, fromID, hash)
+ }
return err
}
@@ -615,54 +635,67 @@ func decodePacket(buf []byte) (packet, encPubkey, []byte, error) {
return req, fromKey, hash, err
}
-func (req *ping) handle(t *udp, from *net.UDPAddr, fromKey encPubkey, mac []byte) error {
+// Packet Handlers
+
+func (req *ping) preverify(t *udp, from *net.UDPAddr, fromID enode.ID, fromKey encPubkey) error {
if expired(req.Expiration) {
return errExpired
}
key, err := decodePubkey(fromKey)
if err != nil {
- return fmt.Errorf("invalid public key: %v", err)
+ return errors.New("invalid public key")
}
- t.send(from, pongPacket, &pong{
+ req.senderKey = key
+ return nil
+}
+
+func (req *ping) handle(t *udp, from *net.UDPAddr, fromID enode.ID, mac []byte) {
+ // Reply.
+ t.send(from, fromID, pongPacket, &pong{
To: makeEndpoint(from, req.From.TCP),
ReplyTok: mac,
Expiration: uint64(time.Now().Add(expiration).Unix()),
})
- n := wrapNode(enode.NewV4(key, from.IP, int(req.From.TCP), from.Port))
- t.handleReply(n.ID(), pingPacket, req)
- if time.Since(t.db.LastPongReceived(n.ID())) > bondExpiration {
- t.sendPing(n.ID(), from, func() { t.tab.addThroughPing(n) })
+
+ // Ping back if our last pong on file is too far in the past.
+ n := wrapNode(enode.NewV4(req.senderKey, from.IP, int(req.From.TCP), from.Port))
+ if time.Since(t.db.LastPongReceived(n.ID(), from.IP)) > bondExpiration {
+ t.sendPing(fromID, from, func() {
+ t.tab.addThroughPing(n)
+ })
} else {
t.tab.addThroughPing(n)
}
+
+ // Update node database and endpoint predictor.
+ t.db.UpdateLastPingReceived(n.ID(), from.IP, time.Now())
t.localNode.UDPEndpointStatement(from, &net.UDPAddr{IP: req.To.IP, Port: int(req.To.UDP)})
- t.db.UpdateLastPingReceived(n.ID(), time.Now())
- return nil
}
func (req *ping) name() string { return "PING/v4" }
-func (req *pong) handle(t *udp, from *net.UDPAddr, fromKey encPubkey, mac []byte) error {
+func (req *pong) preverify(t *udp, from *net.UDPAddr, fromID enode.ID, fromKey encPubkey) error {
if expired(req.Expiration) {
return errExpired
}
- fromID := fromKey.id()
- if !t.handleReply(fromID, pongPacket, req) {
+ if !t.handleReply(fromID, from.IP, pongPacket, req) {
return errUnsolicitedReply
}
- t.localNode.UDPEndpointStatement(from, &net.UDPAddr{IP: req.To.IP, Port: int(req.To.UDP)})
- t.db.UpdateLastPongReceived(fromID, time.Now())
return nil
}
+func (req *pong) handle(t *udp, from *net.UDPAddr, fromID enode.ID, mac []byte) {
+ t.localNode.UDPEndpointStatement(from, &net.UDPAddr{IP: req.To.IP, Port: int(req.To.UDP)})
+ t.db.UpdateLastPongReceived(fromID, from.IP, time.Now())
+}
+
func (req *pong) name() string { return "PONG/v4" }
-func (req *findnode) handle(t *udp, from *net.UDPAddr, fromKey encPubkey, mac []byte) error {
+func (req *findnode) preverify(t *udp, from *net.UDPAddr, fromID enode.ID, fromKey encPubkey) error {
if expired(req.Expiration) {
return errExpired
}
- fromID := fromKey.id()
- if time.Since(t.db.LastPongReceived(fromID)) > bondExpiration {
+ if time.Since(t.db.LastPongReceived(fromID, from.IP)) > bondExpiration {
// No endpoint proof pong exists, we don't process the packet. This prevents an
// attack vector where the discovery protocol could be used to amplify traffic in a
// DDOS attack. A malicious actor would send a findnode request with the IP address
@@ -671,43 +704,50 @@ func (req *findnode) handle(t *udp, from *net.UDPAddr, fromKey encPubkey, mac []
// findnode) to the victim.
return errUnknownNode
}
+ return nil
+}
+
+func (req *findnode) handle(t *udp, from *net.UDPAddr, fromID enode.ID, mac []byte) {
+ // Determine closest nodes.
target := enode.ID(crypto.Keccak256Hash(req.Target[:]))
t.tab.mutex.Lock()
closest := t.tab.closest(target, bucketSize).entries
t.tab.mutex.Unlock()
- p := neighbors{Expiration: uint64(time.Now().Add(expiration).Unix())}
- var sent bool
// Send neighbors in chunks with at most maxNeighbors per packet
// to stay below the 1280 byte limit.
+ p := neighbors{Expiration: uint64(time.Now().Add(expiration).Unix())}
+ var sent bool
for _, n := range closest {
if netutil.CheckRelayIP(from.IP, n.IP()) == nil {
p.Nodes = append(p.Nodes, nodeToRPC(n))
}
if len(p.Nodes) == maxNeighbors {
- t.send(from, neighborsPacket, &p)
+ t.send(from, fromID, neighborsPacket, &p)
p.Nodes = p.Nodes[:0]
sent = true
}
}
if len(p.Nodes) > 0 || !sent {
- t.send(from, neighborsPacket, &p)
+ t.send(from, fromID, neighborsPacket, &p)
}
- return nil
}
func (req *findnode) name() string { return "FINDNODE/v4" }
-func (req *neighbors) handle(t *udp, from *net.UDPAddr, fromKey encPubkey, mac []byte) error {
+func (req *neighbors) preverify(t *udp, from *net.UDPAddr, fromID enode.ID, fromKey encPubkey) error {
if expired(req.Expiration) {
return errExpired
}
- if !t.handleReply(fromKey.id(), neighborsPacket, req) {
+ if !t.handleReply(fromID, from.IP, neighborsPacket, req) {
return errUnsolicitedReply
}
return nil
}
+func (req *neighbors) handle(t *udp, from *net.UDPAddr, fromID enode.ID, mac []byte) {
+}
+
func (req *neighbors) name() string { return "NEIGHBORS/v4" }
func expired(ts uint64) bool {
diff --git a/p2p/discover/udp_test.go b/p2p/discover/udp_test.go
index a4ddaf750..3d53c9309 100644
--- a/p2p/discover/udp_test.go
+++ b/p2p/discover/udp_test.go
@@ -19,6 +19,7 @@ package discover
import (
"bytes"
"crypto/ecdsa"
+ crand "crypto/rand"
"encoding/binary"
"encoding/hex"
"errors"
@@ -57,6 +58,7 @@ type udpTest struct {
t *testing.T
pipe *dgramPipe
table *Table
+ db *enode.DB
udp *udp
sent [][]byte
localkey, remotekey *ecdsa.PrivateKey
@@ -71,22 +73,32 @@ func newUDPTest(t *testing.T) *udpTest {
remotekey: newkey(),
remoteaddr: &net.UDPAddr{IP: net.IP{10, 0, 1, 99}, Port: 30303},
}
- db, _ := enode.OpenDB("")
- ln := enode.NewLocalNode(db, test.localkey)
+ test.db, _ = enode.OpenDB("")
+ ln := enode.NewLocalNode(test.db, test.localkey)
test.table, test.udp, _ = newUDP(test.pipe, ln, Config{PrivateKey: test.localkey})
// Wait for initial refresh so the table doesn't send unexpected findnode.
<-test.table.initDone
return test
}
+func (test *udpTest) close() {
+ test.table.Close()
+ test.db.Close()
+}
+
// handles a packet as if it had been sent to the transport.
func (test *udpTest) packetIn(wantError error, ptype byte, data packet) error {
- enc, _, err := encodePacket(test.remotekey, ptype, data)
+ return test.packetInFrom(wantError, test.remotekey, test.remoteaddr, ptype, data)
+}
+
+// handles a packet as if it had been sent to the transport by the key/endpoint.
+func (test *udpTest) packetInFrom(wantError error, key *ecdsa.PrivateKey, addr *net.UDPAddr, ptype byte, data packet) error {
+ enc, _, err := encodePacket(key, ptype, data)
if err != nil {
return test.errorf("packet (%d) encode error: %v", ptype, err)
}
test.sent = append(test.sent, enc)
- if err = test.udp.handlePacket(test.remoteaddr, enc); err != wantError {
+ if err = test.udp.handlePacket(addr, enc); err != wantError {
return test.errorf("error mismatch: got %q, want %q", err, wantError)
}
return nil
@@ -94,19 +106,19 @@ func (test *udpTest) packetIn(wantError error, ptype byte, data packet) error {
// waits for a packet to be sent by the transport.
// validate should have type func(*udpTest, X) error, where X is a packet type.
-func (test *udpTest) waitPacketOut(validate interface{}) ([]byte, error) {
+func (test *udpTest) waitPacketOut(validate interface{}) (*net.UDPAddr, []byte, error) {
dgram := test.pipe.waitPacketOut()
- p, _, hash, err := decodePacket(dgram)
+ p, _, hash, err := decodePacket(dgram.data)
if err != nil {
- return hash, test.errorf("sent packet decode error: %v", err)
+ return &dgram.to, hash, test.errorf("sent packet decode error: %v", err)
}
fn := reflect.ValueOf(validate)
exptype := fn.Type().In(0)
if reflect.TypeOf(p) != exptype {
- return hash, test.errorf("sent packet type mismatch, got: %v, want: %v", reflect.TypeOf(p), exptype)
+ return &dgram.to, hash, test.errorf("sent packet type mismatch, got: %v, want: %v", reflect.TypeOf(p), exptype)
}
fn.Call([]reflect.Value{reflect.ValueOf(p)})
- return hash, nil
+ return &dgram.to, hash, nil
}
func (test *udpTest) errorf(format string, args ...interface{}) error {
@@ -125,7 +137,7 @@ func (test *udpTest) errorf(format string, args ...interface{}) error {
func TestUDP_packetErrors(t *testing.T) {
test := newUDPTest(t)
- defer test.table.Close()
+ defer test.close()
test.packetIn(errExpired, pingPacket, &ping{From: testRemote, To: testLocalAnnounced, Version: 4})
test.packetIn(errUnsolicitedReply, pongPacket, &pong{ReplyTok: []byte{}, Expiration: futureExp})
@@ -136,7 +148,7 @@ func TestUDP_packetErrors(t *testing.T) {
func TestUDP_pingTimeout(t *testing.T) {
t.Parallel()
test := newUDPTest(t)
- defer test.table.Close()
+ defer test.close()
toaddr := &net.UDPAddr{IP: net.ParseIP("1.2.3.4"), Port: 2222}
toid := enode.ID{1, 2, 3, 4}
@@ -148,7 +160,7 @@ func TestUDP_pingTimeout(t *testing.T) {
func TestUDP_responseTimeouts(t *testing.T) {
t.Parallel()
test := newUDPTest(t)
- defer test.table.Close()
+ defer test.close()
rand.Seed(time.Now().UnixNano())
randomDuration := func(max time.Duration) time.Duration {
@@ -166,20 +178,20 @@ func TestUDP_responseTimeouts(t *testing.T) {
// with ptype <= 128 will not get a reply and should time out.
// For all other requests, a reply is scheduled to arrive
// within the timeout window.
- p := &pending{
+ p := &replyMatcher{
ptype: byte(rand.Intn(255)),
- callback: func(interface{}) bool { return true },
+ callback: func(interface{}) (bool, bool) { return true, true },
}
binary.BigEndian.PutUint64(p.from[:], uint64(i))
if p.ptype <= 128 {
p.errc = timeoutErr
- test.udp.addpending <- p
+ test.udp.addReplyMatcher <- p
nTimeouts++
} else {
p.errc = nilErr
- test.udp.addpending <- p
+ test.udp.addReplyMatcher <- p
time.AfterFunc(randomDuration(60*time.Millisecond), func() {
- if !test.udp.handleReply(p.from, p.ptype, nil) {
+ if !test.udp.handleReply(p.from, p.ip, p.ptype, nil) {
t.Logf("not matched: %v", p)
}
})
@@ -220,7 +232,7 @@ func TestUDP_responseTimeouts(t *testing.T) {
func TestUDP_findnodeTimeout(t *testing.T) {
t.Parallel()
test := newUDPTest(t)
- defer test.table.Close()
+ defer test.close()
toaddr := &net.UDPAddr{IP: net.ParseIP("1.2.3.4"), Port: 2222}
toid := enode.ID{1, 2, 3, 4}
@@ -236,50 +248,65 @@ func TestUDP_findnodeTimeout(t *testing.T) {
func TestUDP_findnode(t *testing.T) {
test := newUDPTest(t)
- defer test.table.Close()
+ defer test.close()
// put a few nodes into the table. their exact
// distribution shouldn't matter much, although we need to
// take care not to overflow any bucket.
nodes := &nodesByDistance{target: testTarget.id()}
- for i := 0; i < bucketSize; i++ {
+ live := make(map[enode.ID]bool)
+ numCandidates := 2 * bucketSize
+ for i := 0; i < numCandidates; i++ {
key := newkey()
- n := wrapNode(enode.NewV4(&key.PublicKey, net.IP{10, 13, 0, 1}, 0, i))
- nodes.push(n, bucketSize)
+ ip := net.IP{10, 13, 0, byte(i)}
+ n := wrapNode(enode.NewV4(&key.PublicKey, ip, 0, 2000))
+ // Ensure half of table content isn't verified live yet.
+ if i > numCandidates/2 {
+ n.livenessChecks = 1
+ live[n.ID()] = true
+ }
+ nodes.push(n, numCandidates)
}
- test.table.stuff(nodes.entries)
+ fillTable(test.table, nodes.entries)
// ensure there's a bond with the test node,
// findnode won't be accepted otherwise.
remoteID := encodePubkey(&test.remotekey.PublicKey).id()
- test.table.db.UpdateLastPongReceived(remoteID, time.Now())
+ test.table.db.UpdateLastPongReceived(remoteID, test.remoteaddr.IP, time.Now())
// check that closest neighbors are returned.
- test.packetIn(nil, findnodePacket, &findnode{Target: testTarget, Expiration: futureExp})
expected := test.table.closest(testTarget.id(), bucketSize)
-
+ test.packetIn(nil, findnodePacket, &findnode{Target: testTarget, Expiration: futureExp})
waitNeighbors := func(want []*node) {
test.waitPacketOut(func(p *neighbors) {
if len(p.Nodes) != len(want) {
t.Errorf("wrong number of results: got %d, want %d", len(p.Nodes), bucketSize)
}
- for i := range p.Nodes {
- if p.Nodes[i].ID.id() != want[i].ID() {
- t.Errorf("result mismatch at %d:\n got: %v\n want: %v", i, p.Nodes[i], expected.entries[i])
+ for i, n := range p.Nodes {
+ if n.ID.id() != want[i].ID() {
+ t.Errorf("result mismatch at %d:\n got: %v\n want: %v", i, n, expected.entries[i])
+ }
+ if !live[n.ID.id()] {
+ t.Errorf("result includes dead node %v", n.ID.id())
}
}
})
}
- waitNeighbors(expected.entries[:maxNeighbors])
- waitNeighbors(expected.entries[maxNeighbors:])
+ // Receive replies.
+ want := expected.entries
+ if len(want) > maxNeighbors {
+ waitNeighbors(want[:maxNeighbors])
+ want = want[maxNeighbors:]
+ }
+ waitNeighbors(want)
}
func TestUDP_findnodeMultiReply(t *testing.T) {
test := newUDPTest(t)
- defer test.table.Close()
+ defer test.close()
rid := enode.PubkeyToIDV4(&test.remotekey.PublicKey)
- test.table.db.UpdateLastPingReceived(rid, time.Now())
+ test.table.db.UpdateLastPingReceived(rid, test.remoteaddr.IP, time.Now())
// queue a pending findnode request
resultc, errc := make(chan []*node), make(chan error)
@@ -329,11 +356,40 @@ func TestUDP_findnodeMultiReply(t *testing.T) {
}
}
+func TestUDP_pingMatch(t *testing.T) {
+ test := newUDPTest(t)
+ defer test.close()
+
+ randToken := make([]byte, 32)
+ crand.Read(randToken)
+
+ test.packetIn(nil, pingPacket, &ping{From: testRemote, To: testLocalAnnounced, Version: 4, Expiration: futureExp})
+ test.waitPacketOut(func(*pong) error { return nil })
+ test.waitPacketOut(func(*ping) error { return nil })
+ test.packetIn(errUnsolicitedReply, pongPacket, &pong{ReplyTok: randToken, To: testLocalAnnounced, Expiration: futureExp})
+}
+
+func TestUDP_pingMatchIP(t *testing.T) {
+ test := newUDPTest(t)
+ defer test.close()
+
+ test.packetIn(nil, pingPacket, &ping{From: testRemote, To: testLocalAnnounced, Version: 4, Expiration: futureExp})
+ test.waitPacketOut(func(*pong) error { return nil })
+
+ _, hash, _ := test.waitPacketOut(func(*ping) error { return nil })
+ wrongAddr := &net.UDPAddr{IP: net.IP{33, 44, 1, 2}, Port: 30000}
+ test.packetInFrom(errUnsolicitedReply, test.remotekey, wrongAddr, pongPacket, &pong{
+ ReplyTok: hash,
+ To: testLocalAnnounced,
+ Expiration: futureExp,
+ })
+}
+
func TestUDP_successfulPing(t *testing.T) {
test := newUDPTest(t)
added := make(chan *node, 1)
test.table.nodeAddedHook = func(n *node) { added <- n }
- defer test.table.Close()
+ defer test.close()
// The remote side sends a ping packet to initiate the exchange.
go test.packetIn(nil, pingPacket, &ping{From: testRemote, To: testLocalAnnounced, Version: 4, Expiration: futureExp})
@@ -356,7 +412,7 @@ func TestUDP_successfulPing(t *testing.T) {
})
// remote is unknown, the table pings back.
- hash, _ := test.waitPacketOut(func(p *ping) error {
+ _, hash, _ := test.waitPacketOut(func(p *ping) error {
if !reflect.DeepEqual(p.From, test.udp.ourEndpoint()) {
t.Errorf("got ping.From %#v, want %#v", p.From, test.udp.ourEndpoint())
}
@@ -510,7 +566,12 @@ type dgramPipe struct {
cond *sync.Cond
closing chan struct{}
closed bool
- queue [][]byte
+ queue []dgram
+}
+
+type dgram struct {
+ to net.UDPAddr
+ data []byte
}
func newpipe() *dgramPipe {
@@ -531,7 +592,7 @@ func (c *dgramPipe) WriteToUDP(b []byte, to *net.UDPAddr) (n int, err error) {
if c.closed {
return 0, errors.New("closed")
}
- c.queue = append(c.queue, msg)
+ c.queue = append(c.queue, dgram{*to, b})
c.cond.Signal()
return len(b), nil
}
@@ -556,7 +617,7 @@ func (c *dgramPipe) LocalAddr() net.Addr {
return &net.UDPAddr{IP: testLocal.IP, Port: int(testLocal.UDP)}
}
-func (c *dgramPipe) waitPacketOut() []byte {
+func (c *dgramPipe) waitPacketOut() dgram {
c.mu.Lock()
defer c.mu.Unlock()
for len(c.queue) == 0 {