aboutsummaryrefslogtreecommitdiffstats
path: root/p2p/discover
diff options
context:
space:
mode:
Diffstat (limited to 'p2p/discover')
-rw-r--r--p2p/discover/node.go3
-rw-r--r--p2p/discover/table.go84
-rw-r--r--p2p/discover/table_test.go34
-rw-r--r--p2p/discover/table_util_test.go21
-rw-r--r--p2p/discover/udp.go222
-rw-r--r--p2p/discover/udp_test.go137
6 files changed, 310 insertions, 191 deletions
diff --git a/p2p/discover/node.go b/p2p/discover/node.go
index 7ddf04fe8..8d4af166b 100644
--- a/p2p/discover/node.go
+++ b/p2p/discover/node.go
@@ -33,7 +33,8 @@ import (
// The fields of Node may not be modified.
type node struct {
enode.Node
- addedAt time.Time // time when the node was added to the table
+ addedAt time.Time // time when the node was added to the table
+ livenessChecks uint // how often liveness was checked
}
type encPubkey [64]byte
diff --git a/p2p/discover/table.go b/p2p/discover/table.go
index 9f7f1d41b..ba4c06327 100644
--- a/p2p/discover/table.go
+++ b/p2p/discover/table.go
@@ -75,8 +75,10 @@ type Table struct {
net transport
refreshReq chan chan struct{}
initDone chan struct{}
- closeReq chan struct{}
- closed chan struct{}
+
+ closeOnce sync.Once
+ closeReq chan struct{}
+ closed chan struct{}
nodeAddedHook func(*node) // for testing
}
@@ -180,16 +182,14 @@ func (tab *Table) ReadRandomNodes(buf []*enode.Node) (n int) {
// Close terminates the network listener and flushes the node database.
func (tab *Table) Close() {
- if tab.net != nil {
- tab.net.close()
- }
-
- select {
- case <-tab.closed:
- // already closed.
- case tab.closeReq <- struct{}{}:
- <-tab.closed // wait for refreshLoop to end.
- }
+ tab.closeOnce.Do(func() {
+ if tab.net != nil {
+ tab.net.close()
+ }
+ // Wait for loop to end.
+ close(tab.closeReq)
+ <-tab.closed
+ })
}
// setFallbackNodes sets the initial points of contact. These nodes
@@ -290,12 +290,16 @@ func (tab *Table) lookup(targetKey encPubkey, refreshIfEmpty bool) []*node {
// we have asked all closest nodes, stop the search
break
}
- // wait for the next reply
- for _, n := range <-reply {
- if n != nil && !seen[n.ID()] {
- seen[n.ID()] = true
- result.push(n, bucketSize)
+ select {
+ case nodes := <-reply:
+ for _, n := range nodes {
+ if n != nil && !seen[n.ID()] {
+ seen[n.ID()] = true
+ result.push(n, bucketSize)
+ }
}
+ case <-tab.closeReq:
+ return nil // shutdown, no need to continue.
}
pendingQueries--
}
@@ -303,18 +307,22 @@ func (tab *Table) lookup(targetKey encPubkey, refreshIfEmpty bool) []*node {
}
func (tab *Table) findnode(n *node, targetKey encPubkey, reply chan<- []*node) {
- fails := tab.db.FindFails(n.ID())
+ fails := tab.db.FindFails(n.ID(), n.IP())
r, err := tab.net.findnode(n.ID(), n.addr(), targetKey)
- if err != nil || len(r) == 0 {
+ if err == errClosed {
+ // Avoid recording failures on shutdown.
+ reply <- nil
+ return
+ } else if err != nil || len(r) == 0 {
fails++
- tab.db.UpdateFindFails(n.ID(), fails)
+ tab.db.UpdateFindFails(n.ID(), n.IP(), fails)
log.Trace("Findnode failed", "id", n.ID(), "failcount", fails, "err", err)
if fails >= maxFindnodeFailures {
log.Trace("Too many findnode failures, dropping", "id", n.ID(), "failcount", fails)
tab.delete(n)
}
} else if fails > 0 {
- tab.db.UpdateFindFails(n.ID(), fails-1)
+ tab.db.UpdateFindFails(n.ID(), n.IP(), fails-1)
}
// Grab as many nodes as possible. Some of them might not be alive anymore, but we'll
@@ -329,7 +337,7 @@ func (tab *Table) refresh() <-chan struct{} {
done := make(chan struct{})
select {
case tab.refreshReq <- done:
- case <-tab.closed:
+ case <-tab.closeReq:
close(done)
}
return done
@@ -433,7 +441,7 @@ func (tab *Table) loadSeedNodes() {
seeds = append(seeds, tab.nursery...)
for i := range seeds {
seed := seeds[i]
- age := log.Lazy{Fn: func() interface{} { return time.Since(tab.db.LastPongReceived(seed.ID())) }}
+ age := log.Lazy{Fn: func() interface{} { return time.Since(tab.db.LastPongReceived(seed.ID(), seed.IP())) }}
log.Trace("Found seed node in database", "id", seed.ID(), "addr", seed.addr(), "age", age)
tab.add(seed)
}
@@ -458,16 +466,17 @@ func (tab *Table) doRevalidate(done chan<- struct{}) {
b := tab.buckets[bi]
if err == nil {
// The node responded, move it to the front.
- log.Debug("Revalidated node", "b", bi, "id", last.ID())
+ last.livenessChecks++
+ log.Debug("Revalidated node", "b", bi, "id", last.ID(), "checks", last.livenessChecks)
b.bump(last)
return
}
// No reply received, pick a replacement or delete the node if there aren't
// any replacements.
if r := tab.replace(b, last); r != nil {
- log.Debug("Replaced dead node", "b", bi, "id", last.ID(), "ip", last.IP(), "r", r.ID(), "rip", r.IP())
+ log.Debug("Replaced dead node", "b", bi, "id", last.ID(), "ip", last.IP(), "checks", last.livenessChecks, "r", r.ID(), "rip", r.IP())
} else {
- log.Debug("Removed dead node", "b", bi, "id", last.ID(), "ip", last.IP())
+ log.Debug("Removed dead node", "b", bi, "id", last.ID(), "ip", last.IP(), "checks", last.livenessChecks)
}
}
@@ -502,7 +511,7 @@ func (tab *Table) copyLiveNodes() {
now := time.Now()
for _, b := range &tab.buckets {
for _, n := range b.entries {
- if now.Sub(n.addedAt) >= seedMinTableTime {
+ if n.livenessChecks > 0 && now.Sub(n.addedAt) >= seedMinTableTime {
tab.db.UpdateNode(unwrapNode(n))
}
}
@@ -518,7 +527,9 @@ func (tab *Table) closest(target enode.ID, nresults int) *nodesByDistance {
close := &nodesByDistance{target: target}
for _, b := range &tab.buckets {
for _, n := range b.entries {
- close.push(n, nresults)
+ if n.livenessChecks > 0 {
+ close.push(n, nresults)
+ }
}
}
return close
@@ -572,23 +583,6 @@ func (tab *Table) addThroughPing(n *node) {
tab.add(n)
}
-// stuff adds nodes the table to the end of their corresponding bucket
-// if the bucket is not full. The caller must not hold tab.mutex.
-func (tab *Table) stuff(nodes []*node) {
- tab.mutex.Lock()
- defer tab.mutex.Unlock()
-
- for _, n := range nodes {
- if n.ID() == tab.self().ID() {
- continue // don't add self
- }
- b := tab.bucket(n.ID())
- if len(b.entries) < bucketSize {
- tab.bumpOrAdd(b, n)
- }
- }
-}
-
// delete removes an entry from the node table. It is used to evacuate dead nodes.
func (tab *Table) delete(node *node) {
tab.mutex.Lock()
diff --git a/p2p/discover/table_test.go b/p2p/discover/table_test.go
index 6b4cd2d18..b00a93211 100644
--- a/p2p/discover/table_test.go
+++ b/p2p/discover/table_test.go
@@ -50,8 +50,8 @@ func TestTable_pingReplace(t *testing.T) {
func testPingReplace(t *testing.T, newNodeIsResponding, lastInBucketIsResponding bool) {
transport := newPingRecorder()
tab, db := newTestTable(transport)
- defer tab.Close()
defer db.Close()
+ defer tab.Close()
<-tab.initDone
@@ -137,8 +137,8 @@ func TestBucket_bumpNoDuplicates(t *testing.T) {
func TestTable_IPLimit(t *testing.T) {
transport := newPingRecorder()
tab, db := newTestTable(transport)
- defer tab.Close()
defer db.Close()
+ defer tab.Close()
for i := 0; i < tableIPLimit+1; i++ {
n := nodeAtDistance(tab.self().ID(), i, net.IP{172, 0, 1, byte(i)})
@@ -153,8 +153,8 @@ func TestTable_IPLimit(t *testing.T) {
func TestTable_BucketIPLimit(t *testing.T) {
transport := newPingRecorder()
tab, db := newTestTable(transport)
- defer tab.Close()
defer db.Close()
+ defer tab.Close()
d := 3
for i := 0; i < bucketIPLimit+1; i++ {
@@ -173,9 +173,9 @@ func TestTable_closest(t *testing.T) {
// for any node table, Target and N
transport := newPingRecorder()
tab, db := newTestTable(transport)
- defer tab.Close()
defer db.Close()
- tab.stuff(test.All)
+ defer tab.Close()
+ fillTable(tab, test.All)
// check that closest(Target, N) returns nodes
result := tab.closest(test.Target, test.N).entries
@@ -234,13 +234,13 @@ func TestTable_ReadRandomNodesGetAll(t *testing.T) {
test := func(buf []*enode.Node) bool {
transport := newPingRecorder()
tab, db := newTestTable(transport)
- defer tab.Close()
defer db.Close()
+ defer tab.Close()
<-tab.initDone
for i := 0; i < len(buf); i++ {
ld := cfg.Rand.Intn(len(tab.buckets))
- tab.stuff([]*node{nodeAtDistance(tab.self().ID(), ld, intIP(ld))})
+ fillTable(tab, []*node{nodeAtDistance(tab.self().ID(), ld, intIP(ld))})
}
gotN := tab.ReadRandomNodes(buf)
if gotN != tab.len() {
@@ -272,16 +272,19 @@ func (*closeTest) Generate(rand *rand.Rand, size int) reflect.Value {
N: rand.Intn(bucketSize),
}
for _, id := range gen([]enode.ID{}, rand).([]enode.ID) {
- n := enode.SignNull(new(enr.Record), id)
- t.All = append(t.All, wrapNode(n))
+ r := new(enr.Record)
+ r.Set(enr.IP(genIP(rand)))
+ n := wrapNode(enode.SignNull(r, id))
+ n.livenessChecks = 1
+ t.All = append(t.All, n)
}
return reflect.ValueOf(t)
}
func TestTable_Lookup(t *testing.T) {
tab, db := newTestTable(lookupTestnet)
- defer tab.Close()
defer db.Close()
+ defer tab.Close()
// lookup on empty table returns no nodes
if results := tab.lookup(lookupTestnet.target, false); len(results) > 0 {
@@ -289,8 +292,9 @@ func TestTable_Lookup(t *testing.T) {
}
// seed table with initial node (otherwise lookup will terminate immediately)
seedKey, _ := decodePubkey(lookupTestnet.dists[256][0])
- seed := wrapNode(enode.NewV4(seedKey, net.IP{}, 0, 256))
- tab.stuff([]*node{seed})
+ seed := wrapNode(enode.NewV4(seedKey, net.IP{127, 0, 0, 1}, 0, 256))
+ seed.livenessChecks = 1
+ fillTable(tab, []*node{seed})
results := tab.lookup(lookupTestnet.target, true)
t.Logf("results:")
@@ -578,6 +582,12 @@ func gen(typ interface{}, rand *rand.Rand) interface{} {
return v.Interface()
}
+func genIP(rand *rand.Rand) net.IP {
+ ip := make(net.IP, 4)
+ rand.Read(ip)
+ return ip
+}
+
func quickcfg() *quick.Config {
return &quick.Config{
MaxCount: 5000,
diff --git a/p2p/discover/table_util_test.go b/p2p/discover/table_util_test.go
index d41519452..3ce582b99 100644
--- a/p2p/discover/table_util_test.go
+++ b/p2p/discover/table_util_test.go
@@ -83,6 +83,23 @@ func fillBucket(tab *Table, n *node) (last *node) {
return b.entries[bucketSize-1]
}
+// fillTable adds nodes the table to the end of their corresponding bucket
+// if the bucket is not full. The caller must not hold tab.mutex.
+func fillTable(tab *Table, nodes []*node) {
+ tab.mutex.Lock()
+ defer tab.mutex.Unlock()
+
+ for _, n := range nodes {
+ if n.ID() == tab.self().ID() {
+ continue // don't add self
+ }
+ b := tab.bucket(n.ID())
+ if len(b.entries) < bucketSize {
+ tab.bumpOrAdd(b, n)
+ }
+ }
+}
+
type pingRecorder struct {
mu sync.Mutex
dead, pinged map[enode.ID]bool
@@ -109,10 +126,6 @@ func (t *pingRecorder) findnode(toid enode.ID, toaddr *net.UDPAddr, target encPu
return nil, nil
}
-func (t *pingRecorder) waitping(from enode.ID) error {
- return nil // remote always pings
-}
-
func (t *pingRecorder) ping(toid enode.ID, toaddr *net.UDPAddr) error {
t.mu.Lock()
defer t.mu.Unlock()
diff --git a/p2p/discover/udp.go b/p2p/discover/udp.go
index 37a044902..5ce4c43dc 100644
--- a/p2p/discover/udp.go
+++ b/p2p/discover/udp.go
@@ -67,6 +67,8 @@ const (
// RPC request structures
type (
ping struct {
+ senderKey *ecdsa.PublicKey // filled in by preverify
+
Version uint
From, To rpcEndpoint
Expiration uint64
@@ -155,8 +157,13 @@ func nodeToRPC(n *node) rpcNode {
return rpcNode{ID: ekey, IP: n.IP(), UDP: uint16(n.UDP()), TCP: uint16(n.TCP())}
}
+// packet is implemented by all protocol messages.
type packet interface {
- handle(t *udp, from *net.UDPAddr, fromKey encPubkey, mac []byte) error
+ // preverify checks whether the packet is valid and should be handled at all.
+ preverify(t *udp, from *net.UDPAddr, fromID enode.ID, fromKey encPubkey) error
+ // handle handles the packet.
+ handle(t *udp, from *net.UDPAddr, fromID enode.ID, mac []byte)
+ // name returns the name of the packet for logging purposes.
name() string
}
@@ -177,43 +184,48 @@ type udp struct {
tab *Table
wg sync.WaitGroup
- addpending chan *pending
- gotreply chan reply
- closing chan struct{}
+ addReplyMatcher chan *replyMatcher
+ gotreply chan reply
+ closing chan struct{}
}
// pending represents a pending reply.
//
-// some implementations of the protocol wish to send more than one
-// reply packet to findnode. in general, any neighbors packet cannot
+// Some implementations of the protocol wish to send more than one
+// reply packet to findnode. In general, any neighbors packet cannot
// be matched up with a specific findnode packet.
//
-// our implementation handles this by storing a callback function for
-// each pending reply. incoming packets from a node are dispatched
-// to all the callback functions for that node.
-type pending struct {
+// Our implementation handles this by storing a callback function for
+// each pending reply. Incoming packets from a node are dispatched
+// to all callback functions for that node.
+type replyMatcher struct {
// these fields must match in the reply.
from enode.ID
+ ip net.IP
ptype byte
// time when the request must complete
deadline time.Time
- // callback is called when a matching reply arrives. if it returns
- // true, the callback is removed from the pending reply queue.
- // if it returns false, the reply is considered incomplete and
- // the callback will be invoked again for the next matching reply.
- callback func(resp interface{}) (done bool)
+ // callback is called when a matching reply arrives. If it returns matched == true, the
+ // reply was acceptable. The second return value indicates whether the callback should
+ // be removed from the pending reply queue. If it returns false, the reply is considered
+ // incomplete and the callback will be invoked again for the next matching reply.
+ callback replyMatchFunc
// errc receives nil when the callback indicates completion or an
// error if no further reply is received within the timeout.
errc chan<- error
}
+type replyMatchFunc func(interface{}) (matched bool, requestDone bool)
+
type reply struct {
from enode.ID
+ ip net.IP
ptype byte
- data interface{}
+ data packet
+
// loop indicates whether there was
// a matching request by sending on this channel.
matched chan<- bool
@@ -247,14 +259,14 @@ func ListenUDP(c conn, ln *enode.LocalNode, cfg Config) (*Table, error) {
func newUDP(c conn, ln *enode.LocalNode, cfg Config) (*Table, *udp, error) {
udp := &udp{
- conn: c,
- priv: cfg.PrivateKey,
- netrestrict: cfg.NetRestrict,
- localNode: ln,
- db: ln.Database(),
- closing: make(chan struct{}),
- gotreply: make(chan reply),
- addpending: make(chan *pending),
+ conn: c,
+ priv: cfg.PrivateKey,
+ netrestrict: cfg.NetRestrict,
+ localNode: ln,
+ db: ln.Database(),
+ closing: make(chan struct{}),
+ gotreply: make(chan reply),
+ addReplyMatcher: make(chan *replyMatcher),
}
tab, err := newTable(udp, ln.Database(), cfg.Bootnodes)
if err != nil {
@@ -304,35 +316,37 @@ func (t *udp) sendPing(toid enode.ID, toaddr *net.UDPAddr, callback func()) <-ch
errc <- err
return errc
}
- errc := t.pending(toid, pongPacket, func(p interface{}) bool {
- ok := bytes.Equal(p.(*pong).ReplyTok, hash)
- if ok && callback != nil {
+ // Add a matcher for the reply to the pending reply queue. Pongs are matched if they
+ // reference the ping we're about to send.
+ errc := t.pending(toid, toaddr.IP, pongPacket, func(p interface{}) (matched bool, requestDone bool) {
+ matched = bytes.Equal(p.(*pong).ReplyTok, hash)
+ if matched && callback != nil {
callback()
}
- return ok
+ return matched, matched
})
+ // Send the packet.
t.localNode.UDPContact(toaddr)
- t.write(toaddr, req.name(), packet)
+ t.write(toaddr, toid, req.name(), packet)
return errc
}
-func (t *udp) waitping(from enode.ID) error {
- return <-t.pending(from, pingPacket, func(interface{}) bool { return true })
-}
-
// findnode sends a findnode request to the given node and waits until
// the node has sent up to k neighbors.
func (t *udp) findnode(toid enode.ID, toaddr *net.UDPAddr, target encPubkey) ([]*node, error) {
// If we haven't seen a ping from the destination node for a while, it won't remember
// our endpoint proof and reject findnode. Solicit a ping first.
- if time.Since(t.db.LastPingReceived(toid)) > bondExpiration {
+ if time.Since(t.db.LastPingReceived(toid, toaddr.IP)) > bondExpiration {
t.ping(toid, toaddr)
- t.waitping(toid)
+ // Wait for them to ping back and process our pong.
+ time.Sleep(respTimeout)
}
+ // Add a matcher for 'neighbours' replies to the pending reply queue. The matcher is
+ // active until enough nodes have been received.
nodes := make([]*node, 0, bucketSize)
nreceived := 0
- errc := t.pending(toid, neighborsPacket, func(r interface{}) bool {
+ errc := t.pending(toid, toaddr.IP, neighborsPacket, func(r interface{}) (matched bool, requestDone bool) {
reply := r.(*neighbors)
for _, rn := range reply.Nodes {
nreceived++
@@ -343,22 +357,22 @@ func (t *udp) findnode(toid enode.ID, toaddr *net.UDPAddr, target encPubkey) ([]
}
nodes = append(nodes, n)
}
- return nreceived >= bucketSize
+ return true, nreceived >= bucketSize
})
- t.send(toaddr, findnodePacket, &findnode{
+ t.send(toaddr, toid, findnodePacket, &findnode{
Target: target,
Expiration: uint64(time.Now().Add(expiration).Unix()),
})
return nodes, <-errc
}
-// pending adds a reply callback to the pending reply queue.
-// see the documentation of type pending for a detailed explanation.
-func (t *udp) pending(id enode.ID, ptype byte, callback func(interface{}) bool) <-chan error {
+// pending adds a reply matcher to the pending reply queue.
+// see the documentation of type replyMatcher for a detailed explanation.
+func (t *udp) pending(id enode.ID, ip net.IP, ptype byte, callback replyMatchFunc) <-chan error {
ch := make(chan error, 1)
- p := &pending{from: id, ptype: ptype, callback: callback, errc: ch}
+ p := &replyMatcher{from: id, ip: ip, ptype: ptype, callback: callback, errc: ch}
select {
- case t.addpending <- p:
+ case t.addReplyMatcher <- p:
// loop will handle it
case <-t.closing:
ch <- errClosed
@@ -366,10 +380,12 @@ func (t *udp) pending(id enode.ID, ptype byte, callback func(interface{}) bool)
return ch
}
-func (t *udp) handleReply(from enode.ID, ptype byte, req packet) bool {
+// handleReply dispatches a reply packet, invoking reply matchers. It returns
+// whether any matcher considered the packet acceptable.
+func (t *udp) handleReply(from enode.ID, fromIP net.IP, ptype byte, req packet) bool {
matched := make(chan bool, 1)
select {
- case t.gotreply <- reply{from, ptype, req, matched}:
+ case t.gotreply <- reply{from, fromIP, ptype, req, matched}:
// loop will handle it
return <-matched
case <-t.closing:
@@ -385,8 +401,8 @@ func (t *udp) loop() {
var (
plist = list.New()
timeout = time.NewTimer(0)
- nextTimeout *pending // head of plist when timeout was last reset
- contTimeouts = 0 // number of continuous timeouts to do NTP checks
+ nextTimeout *replyMatcher // head of plist when timeout was last reset
+ contTimeouts = 0 // number of continuous timeouts to do NTP checks
ntpWarnTime = time.Unix(0, 0)
)
<-timeout.C // ignore first timeout
@@ -399,7 +415,7 @@ func (t *udp) loop() {
// Start the timer so it fires when the next pending reply has expired.
now := time.Now()
for el := plist.Front(); el != nil; el = el.Next() {
- nextTimeout = el.Value.(*pending)
+ nextTimeout = el.Value.(*replyMatcher)
if dist := nextTimeout.deadline.Sub(now); dist < 2*respTimeout {
timeout.Reset(dist)
return
@@ -420,25 +436,23 @@ func (t *udp) loop() {
select {
case <-t.closing:
for el := plist.Front(); el != nil; el = el.Next() {
- el.Value.(*pending).errc <- errClosed
+ el.Value.(*replyMatcher).errc <- errClosed
}
return
- case p := <-t.addpending:
+ case p := <-t.addReplyMatcher:
p.deadline = time.Now().Add(respTimeout)
plist.PushBack(p)
case r := <-t.gotreply:
- var matched bool
+ var matched bool // whether any replyMatcher considered the reply acceptable.
for el := plist.Front(); el != nil; el = el.Next() {
- p := el.Value.(*pending)
- if p.from == r.from && p.ptype == r.ptype {
- matched = true
- // Remove the matcher if its callback indicates
- // that all replies have been received. This is
- // required for packet types that expect multiple
- // reply packets.
- if p.callback(r.data) {
+ p := el.Value.(*replyMatcher)
+ if p.from == r.from && p.ptype == r.ptype && p.ip.Equal(r.ip) {
+ ok, requestDone := p.callback(r.data)
+ matched = matched || ok
+ // Remove the matcher if callback indicates that all replies have been received.
+ if requestDone {
p.errc <- nil
plist.Remove(el)
}
@@ -453,7 +467,7 @@ func (t *udp) loop() {
// Notify and remove callbacks whose deadline is in the past.
for el := plist.Front(); el != nil; el = el.Next() {
- p := el.Value.(*pending)
+ p := el.Value.(*replyMatcher)
if now.After(p.deadline) || now.Equal(p.deadline) {
p.errc <- errTimeout
plist.Remove(el)
@@ -504,17 +518,17 @@ func init() {
}
}
-func (t *udp) send(toaddr *net.UDPAddr, ptype byte, req packet) ([]byte, error) {
+func (t *udp) send(toaddr *net.UDPAddr, toid enode.ID, ptype byte, req packet) ([]byte, error) {
packet, hash, err := encodePacket(t.priv, ptype, req)
if err != nil {
return hash, err
}
- return hash, t.write(toaddr, req.name(), packet)
+ return hash, t.write(toaddr, toid, req.name(), packet)
}
-func (t *udp) write(toaddr *net.UDPAddr, what string, packet []byte) error {
+func (t *udp) write(toaddr *net.UDPAddr, toid enode.ID, what string, packet []byte) error {
_, err := t.conn.WriteToUDP(packet, toaddr)
- log.Trace(">> "+what, "addr", toaddr, "err", err)
+ log.Trace(">> "+what, "id", toid, "addr", toaddr, "err", err)
return err
}
@@ -573,13 +587,19 @@ func (t *udp) readLoop(unhandled chan<- ReadPacket) {
}
func (t *udp) handlePacket(from *net.UDPAddr, buf []byte) error {
- packet, fromID, hash, err := decodePacket(buf)
+ packet, fromKey, hash, err := decodePacket(buf)
if err != nil {
log.Debug("Bad discv4 packet", "addr", from, "err", err)
return err
}
- err = packet.handle(t, from, fromID, hash)
- log.Trace("<< "+packet.name(), "addr", from, "err", err)
+ fromID := fromKey.id()
+ if err == nil {
+ err = packet.preverify(t, from, fromID, fromKey)
+ }
+ log.Trace("<< "+packet.name(), "id", fromID, "addr", from, "err", err)
+ if err == nil {
+ packet.handle(t, from, fromID, hash)
+ }
return err
}
@@ -615,54 +635,67 @@ func decodePacket(buf []byte) (packet, encPubkey, []byte, error) {
return req, fromKey, hash, err
}
-func (req *ping) handle(t *udp, from *net.UDPAddr, fromKey encPubkey, mac []byte) error {
+// Packet Handlers
+
+func (req *ping) preverify(t *udp, from *net.UDPAddr, fromID enode.ID, fromKey encPubkey) error {
if expired(req.Expiration) {
return errExpired
}
key, err := decodePubkey(fromKey)
if err != nil {
- return fmt.Errorf("invalid public key: %v", err)
+ return errors.New("invalid public key")
}
- t.send(from, pongPacket, &pong{
+ req.senderKey = key
+ return nil
+}
+
+func (req *ping) handle(t *udp, from *net.UDPAddr, fromID enode.ID, mac []byte) {
+ // Reply.
+ t.send(from, fromID, pongPacket, &pong{
To: makeEndpoint(from, req.From.TCP),
ReplyTok: mac,
Expiration: uint64(time.Now().Add(expiration).Unix()),
})
- n := wrapNode(enode.NewV4(key, from.IP, int(req.From.TCP), from.Port))
- t.handleReply(n.ID(), pingPacket, req)
- if time.Since(t.db.LastPongReceived(n.ID())) > bondExpiration {
- t.sendPing(n.ID(), from, func() { t.tab.addThroughPing(n) })
+
+ // Ping back if our last pong on file is too far in the past.
+ n := wrapNode(enode.NewV4(req.senderKey, from.IP, int(req.From.TCP), from.Port))
+ if time.Since(t.db.LastPongReceived(n.ID(), from.IP)) > bondExpiration {
+ t.sendPing(fromID, from, func() {
+ t.tab.addThroughPing(n)
+ })
} else {
t.tab.addThroughPing(n)
}
+
+ // Update node database and endpoint predictor.
+ t.db.UpdateLastPingReceived(n.ID(), from.IP, time.Now())
t.localNode.UDPEndpointStatement(from, &net.UDPAddr{IP: req.To.IP, Port: int(req.To.UDP)})
- t.db.UpdateLastPingReceived(n.ID(), time.Now())
- return nil
}
func (req *ping) name() string { return "PING/v4" }
-func (req *pong) handle(t *udp, from *net.UDPAddr, fromKey encPubkey, mac []byte) error {
+func (req *pong) preverify(t *udp, from *net.UDPAddr, fromID enode.ID, fromKey encPubkey) error {
if expired(req.Expiration) {
return errExpired
}
- fromID := fromKey.id()
- if !t.handleReply(fromID, pongPacket, req) {
+ if !t.handleReply(fromID, from.IP, pongPacket, req) {
return errUnsolicitedReply
}
- t.localNode.UDPEndpointStatement(from, &net.UDPAddr{IP: req.To.IP, Port: int(req.To.UDP)})
- t.db.UpdateLastPongReceived(fromID, time.Now())
return nil
}
+func (req *pong) handle(t *udp, from *net.UDPAddr, fromID enode.ID, mac []byte) {
+ t.localNode.UDPEndpointStatement(from, &net.UDPAddr{IP: req.To.IP, Port: int(req.To.UDP)})
+ t.db.UpdateLastPongReceived(fromID, from.IP, time.Now())
+}
+
func (req *pong) name() string { return "PONG/v4" }
-func (req *findnode) handle(t *udp, from *net.UDPAddr, fromKey encPubkey, mac []byte) error {
+func (req *findnode) preverify(t *udp, from *net.UDPAddr, fromID enode.ID, fromKey encPubkey) error {
if expired(req.Expiration) {
return errExpired
}
- fromID := fromKey.id()
- if time.Since(t.db.LastPongReceived(fromID)) > bondExpiration {
+ if time.Since(t.db.LastPongReceived(fromID, from.IP)) > bondExpiration {
// No endpoint proof pong exists, we don't process the packet. This prevents an
// attack vector where the discovery protocol could be used to amplify traffic in a
// DDOS attack. A malicious actor would send a findnode request with the IP address
@@ -671,43 +704,50 @@ func (req *findnode) handle(t *udp, from *net.UDPAddr, fromKey encPubkey, mac []
// findnode) to the victim.
return errUnknownNode
}
+ return nil
+}
+
+func (req *findnode) handle(t *udp, from *net.UDPAddr, fromID enode.ID, mac []byte) {
+ // Determine closest nodes.
target := enode.ID(crypto.Keccak256Hash(req.Target[:]))
t.tab.mutex.Lock()
closest := t.tab.closest(target, bucketSize).entries
t.tab.mutex.Unlock()
- p := neighbors{Expiration: uint64(time.Now().Add(expiration).Unix())}
- var sent bool
// Send neighbors in chunks with at most maxNeighbors per packet
// to stay below the 1280 byte limit.
+ p := neighbors{Expiration: uint64(time.Now().Add(expiration).Unix())}
+ var sent bool
for _, n := range closest {
if netutil.CheckRelayIP(from.IP, n.IP()) == nil {
p.Nodes = append(p.Nodes, nodeToRPC(n))
}
if len(p.Nodes) == maxNeighbors {
- t.send(from, neighborsPacket, &p)
+ t.send(from, fromID, neighborsPacket, &p)
p.Nodes = p.Nodes[:0]
sent = true
}
}
if len(p.Nodes) > 0 || !sent {
- t.send(from, neighborsPacket, &p)
+ t.send(from, fromID, neighborsPacket, &p)
}
- return nil
}
func (req *findnode) name() string { return "FINDNODE/v4" }
-func (req *neighbors) handle(t *udp, from *net.UDPAddr, fromKey encPubkey, mac []byte) error {
+func (req *neighbors) preverify(t *udp, from *net.UDPAddr, fromID enode.ID, fromKey encPubkey) error {
if expired(req.Expiration) {
return errExpired
}
- if !t.handleReply(fromKey.id(), neighborsPacket, req) {
+ if !t.handleReply(fromID, from.IP, neighborsPacket, req) {
return errUnsolicitedReply
}
return nil
}
+func (req *neighbors) handle(t *udp, from *net.UDPAddr, fromID enode.ID, mac []byte) {
+}
+
func (req *neighbors) name() string { return "NEIGHBORS/v4" }
func expired(ts uint64) bool {
diff --git a/p2p/discover/udp_test.go b/p2p/discover/udp_test.go
index a4ddaf750..3d53c9309 100644
--- a/p2p/discover/udp_test.go
+++ b/p2p/discover/udp_test.go
@@ -19,6 +19,7 @@ package discover
import (
"bytes"
"crypto/ecdsa"
+ crand "crypto/rand"
"encoding/binary"
"encoding/hex"
"errors"
@@ -57,6 +58,7 @@ type udpTest struct {
t *testing.T
pipe *dgramPipe
table *Table
+ db *enode.DB
udp *udp
sent [][]byte
localkey, remotekey *ecdsa.PrivateKey
@@ -71,22 +73,32 @@ func newUDPTest(t *testing.T) *udpTest {
remotekey: newkey(),
remoteaddr: &net.UDPAddr{IP: net.IP{10, 0, 1, 99}, Port: 30303},
}
- db, _ := enode.OpenDB("")
- ln := enode.NewLocalNode(db, test.localkey)
+ test.db, _ = enode.OpenDB("")
+ ln := enode.NewLocalNode(test.db, test.localkey)
test.table, test.udp, _ = newUDP(test.pipe, ln, Config{PrivateKey: test.localkey})
// Wait for initial refresh so the table doesn't send unexpected findnode.
<-test.table.initDone
return test
}
+func (test *udpTest) close() {
+ test.table.Close()
+ test.db.Close()
+}
+
// handles a packet as if it had been sent to the transport.
func (test *udpTest) packetIn(wantError error, ptype byte, data packet) error {
- enc, _, err := encodePacket(test.remotekey, ptype, data)
+ return test.packetInFrom(wantError, test.remotekey, test.remoteaddr, ptype, data)
+}
+
+// handles a packet as if it had been sent to the transport by the key/endpoint.
+func (test *udpTest) packetInFrom(wantError error, key *ecdsa.PrivateKey, addr *net.UDPAddr, ptype byte, data packet) error {
+ enc, _, err := encodePacket(key, ptype, data)
if err != nil {
return test.errorf("packet (%d) encode error: %v", ptype, err)
}
test.sent = append(test.sent, enc)
- if err = test.udp.handlePacket(test.remoteaddr, enc); err != wantError {
+ if err = test.udp.handlePacket(addr, enc); err != wantError {
return test.errorf("error mismatch: got %q, want %q", err, wantError)
}
return nil
@@ -94,19 +106,19 @@ func (test *udpTest) packetIn(wantError error, ptype byte, data packet) error {
// waits for a packet to be sent by the transport.
// validate should have type func(*udpTest, X) error, where X is a packet type.
-func (test *udpTest) waitPacketOut(validate interface{}) ([]byte, error) {
+func (test *udpTest) waitPacketOut(validate interface{}) (*net.UDPAddr, []byte, error) {
dgram := test.pipe.waitPacketOut()
- p, _, hash, err := decodePacket(dgram)
+ p, _, hash, err := decodePacket(dgram.data)
if err != nil {
- return hash, test.errorf("sent packet decode error: %v", err)
+ return &dgram.to, hash, test.errorf("sent packet decode error: %v", err)
}
fn := reflect.ValueOf(validate)
exptype := fn.Type().In(0)
if reflect.TypeOf(p) != exptype {
- return hash, test.errorf("sent packet type mismatch, got: %v, want: %v", reflect.TypeOf(p), exptype)
+ return &dgram.to, hash, test.errorf("sent packet type mismatch, got: %v, want: %v", reflect.TypeOf(p), exptype)
}
fn.Call([]reflect.Value{reflect.ValueOf(p)})
- return hash, nil
+ return &dgram.to, hash, nil
}
func (test *udpTest) errorf(format string, args ...interface{}) error {
@@ -125,7 +137,7 @@ func (test *udpTest) errorf(format string, args ...interface{}) error {
func TestUDP_packetErrors(t *testing.T) {
test := newUDPTest(t)
- defer test.table.Close()
+ defer test.close()
test.packetIn(errExpired, pingPacket, &ping{From: testRemote, To: testLocalAnnounced, Version: 4})
test.packetIn(errUnsolicitedReply, pongPacket, &pong{ReplyTok: []byte{}, Expiration: futureExp})
@@ -136,7 +148,7 @@ func TestUDP_packetErrors(t *testing.T) {
func TestUDP_pingTimeout(t *testing.T) {
t.Parallel()
test := newUDPTest(t)
- defer test.table.Close()
+ defer test.close()
toaddr := &net.UDPAddr{IP: net.ParseIP("1.2.3.4"), Port: 2222}
toid := enode.ID{1, 2, 3, 4}
@@ -148,7 +160,7 @@ func TestUDP_pingTimeout(t *testing.T) {
func TestUDP_responseTimeouts(t *testing.T) {
t.Parallel()
test := newUDPTest(t)
- defer test.table.Close()
+ defer test.close()
rand.Seed(time.Now().UnixNano())
randomDuration := func(max time.Duration) time.Duration {
@@ -166,20 +178,20 @@ func TestUDP_responseTimeouts(t *testing.T) {
// with ptype <= 128 will not get a reply and should time out.
// For all other requests, a reply is scheduled to arrive
// within the timeout window.
- p := &pending{
+ p := &replyMatcher{
ptype: byte(rand.Intn(255)),
- callback: func(interface{}) bool { return true },
+ callback: func(interface{}) (bool, bool) { return true, true },
}
binary.BigEndian.PutUint64(p.from[:], uint64(i))
if p.ptype <= 128 {
p.errc = timeoutErr
- test.udp.addpending <- p
+ test.udp.addReplyMatcher <- p
nTimeouts++
} else {
p.errc = nilErr
- test.udp.addpending <- p
+ test.udp.addReplyMatcher <- p
time.AfterFunc(randomDuration(60*time.Millisecond), func() {
- if !test.udp.handleReply(p.from, p.ptype, nil) {
+ if !test.udp.handleReply(p.from, p.ip, p.ptype, nil) {
t.Logf("not matched: %v", p)
}
})
@@ -220,7 +232,7 @@ func TestUDP_responseTimeouts(t *testing.T) {
func TestUDP_findnodeTimeout(t *testing.T) {
t.Parallel()
test := newUDPTest(t)
- defer test.table.Close()
+ defer test.close()
toaddr := &net.UDPAddr{IP: net.ParseIP("1.2.3.4"), Port: 2222}
toid := enode.ID{1, 2, 3, 4}
@@ -236,50 +248,65 @@ func TestUDP_findnodeTimeout(t *testing.T) {
func TestUDP_findnode(t *testing.T) {
test := newUDPTest(t)
- defer test.table.Close()
+ defer test.close()
// put a few nodes into the table. their exact
// distribution shouldn't matter much, although we need to
// take care not to overflow any bucket.
nodes := &nodesByDistance{target: testTarget.id()}
- for i := 0; i < bucketSize; i++ {
+ live := make(map[enode.ID]bool)
+ numCandidates := 2 * bucketSize
+ for i := 0; i < numCandidates; i++ {
key := newkey()
- n := wrapNode(enode.NewV4(&key.PublicKey, net.IP{10, 13, 0, 1}, 0, i))
- nodes.push(n, bucketSize)
+ ip := net.IP{10, 13, 0, byte(i)}
+ n := wrapNode(enode.NewV4(&key.PublicKey, ip, 0, 2000))
+ // Ensure half of table content isn't verified live yet.
+ if i > numCandidates/2 {
+ n.livenessChecks = 1
+ live[n.ID()] = true
+ }
+ nodes.push(n, numCandidates)
}
- test.table.stuff(nodes.entries)
+ fillTable(test.table, nodes.entries)
// ensure there's a bond with the test node,
// findnode won't be accepted otherwise.
remoteID := encodePubkey(&test.remotekey.PublicKey).id()
- test.table.db.UpdateLastPongReceived(remoteID, time.Now())
+ test.table.db.UpdateLastPongReceived(remoteID, test.remoteaddr.IP, time.Now())
// check that closest neighbors are returned.
- test.packetIn(nil, findnodePacket, &findnode{Target: testTarget, Expiration: futureExp})
expected := test.table.closest(testTarget.id(), bucketSize)
-
+ test.packetIn(nil, findnodePacket, &findnode{Target: testTarget, Expiration: futureExp})
waitNeighbors := func(want []*node) {
test.waitPacketOut(func(p *neighbors) {
if len(p.Nodes) != len(want) {
t.Errorf("wrong number of results: got %d, want %d", len(p.Nodes), bucketSize)
}
- for i := range p.Nodes {
- if p.Nodes[i].ID.id() != want[i].ID() {
- t.Errorf("result mismatch at %d:\n got: %v\n want: %v", i, p.Nodes[i], expected.entries[i])
+ for i, n := range p.Nodes {
+ if n.ID.id() != want[i].ID() {
+ t.Errorf("result mismatch at %d:\n got: %v\n want: %v", i, n, expected.entries[i])
+ }
+ if !live[n.ID.id()] {
+ t.Errorf("result includes dead node %v", n.ID.id())
}
}
})
}
- waitNeighbors(expected.entries[:maxNeighbors])
- waitNeighbors(expected.entries[maxNeighbors:])
+ // Receive replies.
+ want := expected.entries
+ if len(want) > maxNeighbors {
+ waitNeighbors(want[:maxNeighbors])
+ want = want[maxNeighbors:]
+ }
+ waitNeighbors(want)
}
func TestUDP_findnodeMultiReply(t *testing.T) {
test := newUDPTest(t)
- defer test.table.Close()
+ defer test.close()
rid := enode.PubkeyToIDV4(&test.remotekey.PublicKey)
- test.table.db.UpdateLastPingReceived(rid, time.Now())
+ test.table.db.UpdateLastPingReceived(rid, test.remoteaddr.IP, time.Now())
// queue a pending findnode request
resultc, errc := make(chan []*node), make(chan error)
@@ -329,11 +356,40 @@ func TestUDP_findnodeMultiReply(t *testing.T) {
}
}
+func TestUDP_pingMatch(t *testing.T) {
+ test := newUDPTest(t)
+ defer test.close()
+
+ randToken := make([]byte, 32)
+ crand.Read(randToken)
+
+ test.packetIn(nil, pingPacket, &ping{From: testRemote, To: testLocalAnnounced, Version: 4, Expiration: futureExp})
+ test.waitPacketOut(func(*pong) error { return nil })
+ test.waitPacketOut(func(*ping) error { return nil })
+ test.packetIn(errUnsolicitedReply, pongPacket, &pong{ReplyTok: randToken, To: testLocalAnnounced, Expiration: futureExp})
+}
+
+func TestUDP_pingMatchIP(t *testing.T) {
+ test := newUDPTest(t)
+ defer test.close()
+
+ test.packetIn(nil, pingPacket, &ping{From: testRemote, To: testLocalAnnounced, Version: 4, Expiration: futureExp})
+ test.waitPacketOut(func(*pong) error { return nil })
+
+ _, hash, _ := test.waitPacketOut(func(*ping) error { return nil })
+ wrongAddr := &net.UDPAddr{IP: net.IP{33, 44, 1, 2}, Port: 30000}
+ test.packetInFrom(errUnsolicitedReply, test.remotekey, wrongAddr, pongPacket, &pong{
+ ReplyTok: hash,
+ To: testLocalAnnounced,
+ Expiration: futureExp,
+ })
+}
+
func TestUDP_successfulPing(t *testing.T) {
test := newUDPTest(t)
added := make(chan *node, 1)
test.table.nodeAddedHook = func(n *node) { added <- n }
- defer test.table.Close()
+ defer test.close()
// The remote side sends a ping packet to initiate the exchange.
go test.packetIn(nil, pingPacket, &ping{From: testRemote, To: testLocalAnnounced, Version: 4, Expiration: futureExp})
@@ -356,7 +412,7 @@ func TestUDP_successfulPing(t *testing.T) {
})
// remote is unknown, the table pings back.
- hash, _ := test.waitPacketOut(func(p *ping) error {
+ _, hash, _ := test.waitPacketOut(func(p *ping) error {
if !reflect.DeepEqual(p.From, test.udp.ourEndpoint()) {
t.Errorf("got ping.From %#v, want %#v", p.From, test.udp.ourEndpoint())
}
@@ -510,7 +566,12 @@ type dgramPipe struct {
cond *sync.Cond
closing chan struct{}
closed bool
- queue [][]byte
+ queue []dgram
+}
+
+type dgram struct {
+ to net.UDPAddr
+ data []byte
}
func newpipe() *dgramPipe {
@@ -531,7 +592,7 @@ func (c *dgramPipe) WriteToUDP(b []byte, to *net.UDPAddr) (n int, err error) {
if c.closed {
return 0, errors.New("closed")
}
- c.queue = append(c.queue, msg)
+ c.queue = append(c.queue, dgram{*to, b})
c.cond.Signal()
return len(b), nil
}
@@ -556,7 +617,7 @@ func (c *dgramPipe) LocalAddr() net.Addr {
return &net.UDPAddr{IP: testLocal.IP, Port: int(testLocal.UDP)}
}
-func (c *dgramPipe) waitPacketOut() []byte {
+func (c *dgramPipe) waitPacketOut() dgram {
c.mu.Lock()
defer c.mu.Unlock()
for len(c.queue) == 0 {