aboutsummaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
-rw-r--r--p2p/discover/node.go3
-rw-r--r--p2p/discover/table.go84
-rw-r--r--p2p/discover/table_test.go34
-rw-r--r--p2p/discover/table_util_test.go21
-rw-r--r--p2p/discover/udp.go222
-rw-r--r--p2p/discover/udp_test.go137
-rw-r--r--p2p/enode/nodedb.go210
-rw-r--r--p2p/enode/nodedb_test.go216
8 files changed, 595 insertions, 332 deletions
diff --git a/p2p/discover/node.go b/p2p/discover/node.go
index 7ddf04fe8..8d4af166b 100644
--- a/p2p/discover/node.go
+++ b/p2p/discover/node.go
@@ -33,7 +33,8 @@ import (
// The fields of Node may not be modified.
type node struct {
enode.Node
- addedAt time.Time // time when the node was added to the table
+ addedAt time.Time // time when the node was added to the table
+ livenessChecks uint // how often liveness was checked
}
type encPubkey [64]byte
diff --git a/p2p/discover/table.go b/p2p/discover/table.go
index 9f7f1d41b..ba4c06327 100644
--- a/p2p/discover/table.go
+++ b/p2p/discover/table.go
@@ -75,8 +75,10 @@ type Table struct {
net transport
refreshReq chan chan struct{}
initDone chan struct{}
- closeReq chan struct{}
- closed chan struct{}
+
+ closeOnce sync.Once
+ closeReq chan struct{}
+ closed chan struct{}
nodeAddedHook func(*node) // for testing
}
@@ -180,16 +182,14 @@ func (tab *Table) ReadRandomNodes(buf []*enode.Node) (n int) {
// Close terminates the network listener and flushes the node database.
func (tab *Table) Close() {
- if tab.net != nil {
- tab.net.close()
- }
-
- select {
- case <-tab.closed:
- // already closed.
- case tab.closeReq <- struct{}{}:
- <-tab.closed // wait for refreshLoop to end.
- }
+ tab.closeOnce.Do(func() {
+ if tab.net != nil {
+ tab.net.close()
+ }
+ // Wait for loop to end.
+ close(tab.closeReq)
+ <-tab.closed
+ })
}
// setFallbackNodes sets the initial points of contact. These nodes
@@ -290,12 +290,16 @@ func (tab *Table) lookup(targetKey encPubkey, refreshIfEmpty bool) []*node {
// we have asked all closest nodes, stop the search
break
}
- // wait for the next reply
- for _, n := range <-reply {
- if n != nil && !seen[n.ID()] {
- seen[n.ID()] = true
- result.push(n, bucketSize)
+ select {
+ case nodes := <-reply:
+ for _, n := range nodes {
+ if n != nil && !seen[n.ID()] {
+ seen[n.ID()] = true
+ result.push(n, bucketSize)
+ }
}
+ case <-tab.closeReq:
+ return nil // shutdown, no need to continue.
}
pendingQueries--
}
@@ -303,18 +307,22 @@ func (tab *Table) lookup(targetKey encPubkey, refreshIfEmpty bool) []*node {
}
func (tab *Table) findnode(n *node, targetKey encPubkey, reply chan<- []*node) {
- fails := tab.db.FindFails(n.ID())
+ fails := tab.db.FindFails(n.ID(), n.IP())
r, err := tab.net.findnode(n.ID(), n.addr(), targetKey)
- if err != nil || len(r) == 0 {
+ if err == errClosed {
+ // Avoid recording failures on shutdown.
+ reply <- nil
+ return
+ } else if err != nil || len(r) == 0 {
fails++
- tab.db.UpdateFindFails(n.ID(), fails)
+ tab.db.UpdateFindFails(n.ID(), n.IP(), fails)
log.Trace("Findnode failed", "id", n.ID(), "failcount", fails, "err", err)
if fails >= maxFindnodeFailures {
log.Trace("Too many findnode failures, dropping", "id", n.ID(), "failcount", fails)
tab.delete(n)
}
} else if fails > 0 {
- tab.db.UpdateFindFails(n.ID(), fails-1)
+ tab.db.UpdateFindFails(n.ID(), n.IP(), fails-1)
}
// Grab as many nodes as possible. Some of them might not be alive anymore, but we'll
@@ -329,7 +337,7 @@ func (tab *Table) refresh() <-chan struct{} {
done := make(chan struct{})
select {
case tab.refreshReq <- done:
- case <-tab.closed:
+ case <-tab.closeReq:
close(done)
}
return done
@@ -433,7 +441,7 @@ func (tab *Table) loadSeedNodes() {
seeds = append(seeds, tab.nursery...)
for i := range seeds {
seed := seeds[i]
- age := log.Lazy{Fn: func() interface{} { return time.Since(tab.db.LastPongReceived(seed.ID())) }}
+ age := log.Lazy{Fn: func() interface{} { return time.Since(tab.db.LastPongReceived(seed.ID(), seed.IP())) }}
log.Trace("Found seed node in database", "id", seed.ID(), "addr", seed.addr(), "age", age)
tab.add(seed)
}
@@ -458,16 +466,17 @@ func (tab *Table) doRevalidate(done chan<- struct{}) {
b := tab.buckets[bi]
if err == nil {
// The node responded, move it to the front.
- log.Debug("Revalidated node", "b", bi, "id", last.ID())
+ last.livenessChecks++
+ log.Debug("Revalidated node", "b", bi, "id", last.ID(), "checks", last.livenessChecks)
b.bump(last)
return
}
// No reply received, pick a replacement or delete the node if there aren't
// any replacements.
if r := tab.replace(b, last); r != nil {
- log.Debug("Replaced dead node", "b", bi, "id", last.ID(), "ip", last.IP(), "r", r.ID(), "rip", r.IP())
+ log.Debug("Replaced dead node", "b", bi, "id", last.ID(), "ip", last.IP(), "checks", last.livenessChecks, "r", r.ID(), "rip", r.IP())
} else {
- log.Debug("Removed dead node", "b", bi, "id", last.ID(), "ip", last.IP())
+ log.Debug("Removed dead node", "b", bi, "id", last.ID(), "ip", last.IP(), "checks", last.livenessChecks)
}
}
@@ -502,7 +511,7 @@ func (tab *Table) copyLiveNodes() {
now := time.Now()
for _, b := range &tab.buckets {
for _, n := range b.entries {
- if now.Sub(n.addedAt) >= seedMinTableTime {
+ if n.livenessChecks > 0 && now.Sub(n.addedAt) >= seedMinTableTime {
tab.db.UpdateNode(unwrapNode(n))
}
}
@@ -518,7 +527,9 @@ func (tab *Table) closest(target enode.ID, nresults int) *nodesByDistance {
close := &nodesByDistance{target: target}
for _, b := range &tab.buckets {
for _, n := range b.entries {
- close.push(n, nresults)
+ if n.livenessChecks > 0 {
+ close.push(n, nresults)
+ }
}
}
return close
@@ -572,23 +583,6 @@ func (tab *Table) addThroughPing(n *node) {
tab.add(n)
}
-// stuff adds nodes the table to the end of their corresponding bucket
-// if the bucket is not full. The caller must not hold tab.mutex.
-func (tab *Table) stuff(nodes []*node) {
- tab.mutex.Lock()
- defer tab.mutex.Unlock()
-
- for _, n := range nodes {
- if n.ID() == tab.self().ID() {
- continue // don't add self
- }
- b := tab.bucket(n.ID())
- if len(b.entries) < bucketSize {
- tab.bumpOrAdd(b, n)
- }
- }
-}
-
// delete removes an entry from the node table. It is used to evacuate dead nodes.
func (tab *Table) delete(node *node) {
tab.mutex.Lock()
diff --git a/p2p/discover/table_test.go b/p2p/discover/table_test.go
index 6b4cd2d18..b00a93211 100644
--- a/p2p/discover/table_test.go
+++ b/p2p/discover/table_test.go
@@ -50,8 +50,8 @@ func TestTable_pingReplace(t *testing.T) {
func testPingReplace(t *testing.T, newNodeIsResponding, lastInBucketIsResponding bool) {
transport := newPingRecorder()
tab, db := newTestTable(transport)
- defer tab.Close()
defer db.Close()
+ defer tab.Close()
<-tab.initDone
@@ -137,8 +137,8 @@ func TestBucket_bumpNoDuplicates(t *testing.T) {
func TestTable_IPLimit(t *testing.T) {
transport := newPingRecorder()
tab, db := newTestTable(transport)
- defer tab.Close()
defer db.Close()
+ defer tab.Close()
for i := 0; i < tableIPLimit+1; i++ {
n := nodeAtDistance(tab.self().ID(), i, net.IP{172, 0, 1, byte(i)})
@@ -153,8 +153,8 @@ func TestTable_IPLimit(t *testing.T) {
func TestTable_BucketIPLimit(t *testing.T) {
transport := newPingRecorder()
tab, db := newTestTable(transport)
- defer tab.Close()
defer db.Close()
+ defer tab.Close()
d := 3
for i := 0; i < bucketIPLimit+1; i++ {
@@ -173,9 +173,9 @@ func TestTable_closest(t *testing.T) {
// for any node table, Target and N
transport := newPingRecorder()
tab, db := newTestTable(transport)
- defer tab.Close()
defer db.Close()
- tab.stuff(test.All)
+ defer tab.Close()
+ fillTable(tab, test.All)
// check that closest(Target, N) returns nodes
result := tab.closest(test.Target, test.N).entries
@@ -234,13 +234,13 @@ func TestTable_ReadRandomNodesGetAll(t *testing.T) {
test := func(buf []*enode.Node) bool {
transport := newPingRecorder()
tab, db := newTestTable(transport)
- defer tab.Close()
defer db.Close()
+ defer tab.Close()
<-tab.initDone
for i := 0; i < len(buf); i++ {
ld := cfg.Rand.Intn(len(tab.buckets))
- tab.stuff([]*node{nodeAtDistance(tab.self().ID(), ld, intIP(ld))})
+ fillTable(tab, []*node{nodeAtDistance(tab.self().ID(), ld, intIP(ld))})
}
gotN := tab.ReadRandomNodes(buf)
if gotN != tab.len() {
@@ -272,16 +272,19 @@ func (*closeTest) Generate(rand *rand.Rand, size int) reflect.Value {
N: rand.Intn(bucketSize),
}
for _, id := range gen([]enode.ID{}, rand).([]enode.ID) {
- n := enode.SignNull(new(enr.Record), id)
- t.All = append(t.All, wrapNode(n))
+ r := new(enr.Record)
+ r.Set(enr.IP(genIP(rand)))
+ n := wrapNode(enode.SignNull(r, id))
+ n.livenessChecks = 1
+ t.All = append(t.All, n)
}
return reflect.ValueOf(t)
}
func TestTable_Lookup(t *testing.T) {
tab, db := newTestTable(lookupTestnet)
- defer tab.Close()
defer db.Close()
+ defer tab.Close()
// lookup on empty table returns no nodes
if results := tab.lookup(lookupTestnet.target, false); len(results) > 0 {
@@ -289,8 +292,9 @@ func TestTable_Lookup(t *testing.T) {
}
// seed table with initial node (otherwise lookup will terminate immediately)
seedKey, _ := decodePubkey(lookupTestnet.dists[256][0])
- seed := wrapNode(enode.NewV4(seedKey, net.IP{}, 0, 256))
- tab.stuff([]*node{seed})
+ seed := wrapNode(enode.NewV4(seedKey, net.IP{127, 0, 0, 1}, 0, 256))
+ seed.livenessChecks = 1
+ fillTable(tab, []*node{seed})
results := tab.lookup(lookupTestnet.target, true)
t.Logf("results:")
@@ -578,6 +582,12 @@ func gen(typ interface{}, rand *rand.Rand) interface{} {
return v.Interface()
}
+func genIP(rand *rand.Rand) net.IP {
+ ip := make(net.IP, 4)
+ rand.Read(ip)
+ return ip
+}
+
func quickcfg() *quick.Config {
return &quick.Config{
MaxCount: 5000,
diff --git a/p2p/discover/table_util_test.go b/p2p/discover/table_util_test.go
index d41519452..3ce582b99 100644
--- a/p2p/discover/table_util_test.go
+++ b/p2p/discover/table_util_test.go
@@ -83,6 +83,23 @@ func fillBucket(tab *Table, n *node) (last *node) {
return b.entries[bucketSize-1]
}
+// fillTable adds nodes the table to the end of their corresponding bucket
+// if the bucket is not full. The caller must not hold tab.mutex.
+func fillTable(tab *Table, nodes []*node) {
+ tab.mutex.Lock()
+ defer tab.mutex.Unlock()
+
+ for _, n := range nodes {
+ if n.ID() == tab.self().ID() {
+ continue // don't add self
+ }
+ b := tab.bucket(n.ID())
+ if len(b.entries) < bucketSize {
+ tab.bumpOrAdd(b, n)
+ }
+ }
+}
+
type pingRecorder struct {
mu sync.Mutex
dead, pinged map[enode.ID]bool
@@ -109,10 +126,6 @@ func (t *pingRecorder) findnode(toid enode.ID, toaddr *net.UDPAddr, target encPu
return nil, nil
}
-func (t *pingRecorder) waitping(from enode.ID) error {
- return nil // remote always pings
-}
-
func (t *pingRecorder) ping(toid enode.ID, toaddr *net.UDPAddr) error {
t.mu.Lock()
defer t.mu.Unlock()
diff --git a/p2p/discover/udp.go b/p2p/discover/udp.go
index 37a044902..5ce4c43dc 100644
--- a/p2p/discover/udp.go
+++ b/p2p/discover/udp.go
@@ -67,6 +67,8 @@ const (
// RPC request structures
type (
ping struct {
+ senderKey *ecdsa.PublicKey // filled in by preverify
+
Version uint
From, To rpcEndpoint
Expiration uint64
@@ -155,8 +157,13 @@ func nodeToRPC(n *node) rpcNode {
return rpcNode{ID: ekey, IP: n.IP(), UDP: uint16(n.UDP()), TCP: uint16(n.TCP())}
}
+// packet is implemented by all protocol messages.
type packet interface {
- handle(t *udp, from *net.UDPAddr, fromKey encPubkey, mac []byte) error
+ // preverify checks whether the packet is valid and should be handled at all.
+ preverify(t *udp, from *net.UDPAddr, fromID enode.ID, fromKey encPubkey) error
+ // handle handles the packet.
+ handle(t *udp, from *net.UDPAddr, fromID enode.ID, mac []byte)
+ // name returns the name of the packet for logging purposes.
name() string
}
@@ -177,43 +184,48 @@ type udp struct {
tab *Table
wg sync.WaitGroup
- addpending chan *pending
- gotreply chan reply
- closing chan struct{}
+ addReplyMatcher chan *replyMatcher
+ gotreply chan reply
+ closing chan struct{}
}
// pending represents a pending reply.
//
-// some implementations of the protocol wish to send more than one
-// reply packet to findnode. in general, any neighbors packet cannot
+// Some implementations of the protocol wish to send more than one
+// reply packet to findnode. In general, any neighbors packet cannot
// be matched up with a specific findnode packet.
//
-// our implementation handles this by storing a callback function for
-// each pending reply. incoming packets from a node are dispatched
-// to all the callback functions for that node.
-type pending struct {
+// Our implementation handles this by storing a callback function for
+// each pending reply. Incoming packets from a node are dispatched
+// to all callback functions for that node.
+type replyMatcher struct {
// these fields must match in the reply.
from enode.ID
+ ip net.IP
ptype byte
// time when the request must complete
deadline time.Time
- // callback is called when a matching reply arrives. if it returns
- // true, the callback is removed from the pending reply queue.
- // if it returns false, the reply is considered incomplete and
- // the callback will be invoked again for the next matching reply.
- callback func(resp interface{}) (done bool)
+ // callback is called when a matching reply arrives. If it returns matched == true, the
+ // reply was acceptable. The second return value indicates whether the callback should
+ // be removed from the pending reply queue. If it returns false, the reply is considered
+ // incomplete and the callback will be invoked again for the next matching reply.
+ callback replyMatchFunc
// errc receives nil when the callback indicates completion or an
// error if no further reply is received within the timeout.
errc chan<- error
}
+type replyMatchFunc func(interface{}) (matched bool, requestDone bool)
+
type reply struct {
from enode.ID
+ ip net.IP
ptype byte
- data interface{}
+ data packet
+
// loop indicates whether there was
// a matching request by sending on this channel.
matched chan<- bool
@@ -247,14 +259,14 @@ func ListenUDP(c conn, ln *enode.LocalNode, cfg Config) (*Table, error) {
func newUDP(c conn, ln *enode.LocalNode, cfg Config) (*Table, *udp, error) {
udp := &udp{
- conn: c,
- priv: cfg.PrivateKey,
- netrestrict: cfg.NetRestrict,
- localNode: ln,
- db: ln.Database(),
- closing: make(chan struct{}),
- gotreply: make(chan reply),
- addpending: make(chan *pending),
+ conn: c,
+ priv: cfg.PrivateKey,
+ netrestrict: cfg.NetRestrict,
+ localNode: ln,
+ db: ln.Database(),
+ closing: make(chan struct{}),
+ gotreply: make(chan reply),
+ addReplyMatcher: make(chan *replyMatcher),
}
tab, err := newTable(udp, ln.Database(), cfg.Bootnodes)
if err != nil {
@@ -304,35 +316,37 @@ func (t *udp) sendPing(toid enode.ID, toaddr *net.UDPAddr, callback func()) <-ch
errc <- err
return errc
}
- errc := t.pending(toid, pongPacket, func(p interface{}) bool {
- ok := bytes.Equal(p.(*pong).ReplyTok, hash)
- if ok && callback != nil {
+ // Add a matcher for the reply to the pending reply queue. Pongs are matched if they
+ // reference the ping we're about to send.
+ errc := t.pending(toid, toaddr.IP, pongPacket, func(p interface{}) (matched bool, requestDone bool) {
+ matched = bytes.Equal(p.(*pong).ReplyTok, hash)
+ if matched && callback != nil {
callback()
}
- return ok
+ return matched, matched
})
+ // Send the packet.
t.localNode.UDPContact(toaddr)
- t.write(toaddr, req.name(), packet)
+ t.write(toaddr, toid, req.name(), packet)
return errc
}
-func (t *udp) waitping(from enode.ID) error {
- return <-t.pending(from, pingPacket, func(interface{}) bool { return true })
-}
-
// findnode sends a findnode request to the given node and waits until
// the node has sent up to k neighbors.
func (t *udp) findnode(toid enode.ID, toaddr *net.UDPAddr, target encPubkey) ([]*node, error) {
// If we haven't seen a ping from the destination node for a while, it won't remember
// our endpoint proof and reject findnode. Solicit a ping first.
- if time.Since(t.db.LastPingReceived(toid)) > bondExpiration {
+ if time.Since(t.db.LastPingReceived(toid, toaddr.IP)) > bondExpiration {
t.ping(toid, toaddr)
- t.waitping(toid)
+ // Wait for them to ping back and process our pong.
+ time.Sleep(respTimeout)
}
+ // Add a matcher for 'neighbours' replies to the pending reply queue. The matcher is
+ // active until enough nodes have been received.
nodes := make([]*node, 0, bucketSize)
nreceived := 0
- errc := t.pending(toid, neighborsPacket, func(r interface{}) bool {
+ errc := t.pending(toid, toaddr.IP, neighborsPacket, func(r interface{}) (matched bool, requestDone bool) {
reply := r.(*neighbors)
for _, rn := range reply.Nodes {
nreceived++
@@ -343,22 +357,22 @@ func (t *udp) findnode(toid enode.ID, toaddr *net.UDPAddr, target encPubkey) ([]
}
nodes = append(nodes, n)
}
- return nreceived >= bucketSize
+ return true, nreceived >= bucketSize
})
- t.send(toaddr, findnodePacket, &findnode{
+ t.send(toaddr, toid, findnodePacket, &findnode{
Target: target,
Expiration: uint64(time.Now().Add(expiration).Unix()),
})
return nodes, <-errc
}
-// pending adds a reply callback to the pending reply queue.
-// see the documentation of type pending for a detailed explanation.
-func (t *udp) pending(id enode.ID, ptype byte, callback func(interface{}) bool) <-chan error {
+// pending adds a reply matcher to the pending reply queue.
+// see the documentation of type replyMatcher for a detailed explanation.
+func (t *udp) pending(id enode.ID, ip net.IP, ptype byte, callback replyMatchFunc) <-chan error {
ch := make(chan error, 1)
- p := &pending{from: id, ptype: ptype, callback: callback, errc: ch}
+ p := &replyMatcher{from: id, ip: ip, ptype: ptype, callback: callback, errc: ch}
select {
- case t.addpending <- p:
+ case t.addReplyMatcher <- p:
// loop will handle it
case <-t.closing:
ch <- errClosed
@@ -366,10 +380,12 @@ func (t *udp) pending(id enode.ID, ptype byte, callback func(interface{}) bool)
return ch
}
-func (t *udp) handleReply(from enode.ID, ptype byte, req packet) bool {
+// handleReply dispatches a reply packet, invoking reply matchers. It returns
+// whether any matcher considered the packet acceptable.
+func (t *udp) handleReply(from enode.ID, fromIP net.IP, ptype byte, req packet) bool {
matched := make(chan bool, 1)
select {
- case t.gotreply <- reply{from, ptype, req, matched}:
+ case t.gotreply <- reply{from, fromIP, ptype, req, matched}:
// loop will handle it
return <-matched
case <-t.closing:
@@ -385,8 +401,8 @@ func (t *udp) loop() {
var (
plist = list.New()
timeout = time.NewTimer(0)
- nextTimeout *pending // head of plist when timeout was last reset
- contTimeouts = 0 // number of continuous timeouts to do NTP checks
+ nextTimeout *replyMatcher // head of plist when timeout was last reset
+ contTimeouts = 0 // number of continuous timeouts to do NTP checks
ntpWarnTime = time.Unix(0, 0)
)
<-timeout.C // ignore first timeout
@@ -399,7 +415,7 @@ func (t *udp) loop() {
// Start the timer so it fires when the next pending reply has expired.
now := time.Now()
for el := plist.Front(); el != nil; el = el.Next() {
- nextTimeout = el.Value.(*pending)
+ nextTimeout = el.Value.(*replyMatcher)
if dist := nextTimeout.deadline.Sub(now); dist < 2*respTimeout {
timeout.Reset(dist)
return
@@ -420,25 +436,23 @@ func (t *udp) loop() {
select {
case <-t.closing:
for el := plist.Front(); el != nil; el = el.Next() {
- el.Value.(*pending).errc <- errClosed
+ el.Value.(*replyMatcher).errc <- errClosed
}
return
- case p := <-t.addpending:
+ case p := <-t.addReplyMatcher:
p.deadline = time.Now().Add(respTimeout)
plist.PushBack(p)
case r := <-t.gotreply:
- var matched bool
+ var matched bool // whether any replyMatcher considered the reply acceptable.
for el := plist.Front(); el != nil; el = el.Next() {
- p := el.Value.(*pending)
- if p.from == r.from && p.ptype == r.ptype {
- matched = true
- // Remove the matcher if its callback indicates
- // that all replies have been received. This is
- // required for packet types that expect multiple
- // reply packets.
- if p.callback(r.data) {
+ p := el.Value.(*replyMatcher)
+ if p.from == r.from && p.ptype == r.ptype && p.ip.Equal(r.ip) {
+ ok, requestDone := p.callback(r.data)
+ matched = matched || ok
+ // Remove the matcher if callback indicates that all replies have been received.
+ if requestDone {
p.errc <- nil
plist.Remove(el)
}
@@ -453,7 +467,7 @@ func (t *udp) loop() {
// Notify and remove callbacks whose deadline is in the past.
for el := plist.Front(); el != nil; el = el.Next() {
- p := el.Value.(*pending)
+ p := el.Value.(*replyMatcher)
if now.After(p.deadline) || now.Equal(p.deadline) {
p.errc <- errTimeout
plist.Remove(el)
@@ -504,17 +518,17 @@ func init() {
}
}
-func (t *udp) send(toaddr *net.UDPAddr, ptype byte, req packet) ([]byte, error) {
+func (t *udp) send(toaddr *net.UDPAddr, toid enode.ID, ptype byte, req packet) ([]byte, error) {
packet, hash, err := encodePacket(t.priv, ptype, req)
if err != nil {
return hash, err
}
- return hash, t.write(toaddr, req.name(), packet)
+ return hash, t.write(toaddr, toid, req.name(), packet)
}
-func (t *udp) write(toaddr *net.UDPAddr, what string, packet []byte) error {
+func (t *udp) write(toaddr *net.UDPAddr, toid enode.ID, what string, packet []byte) error {
_, err := t.conn.WriteToUDP(packet, toaddr)
- log.Trace(">> "+what, "addr", toaddr, "err", err)
+ log.Trace(">> "+what, "id", toid, "addr", toaddr, "err", err)
return err
}
@@ -573,13 +587,19 @@ func (t *udp) readLoop(unhandled chan<- ReadPacket) {
}
func (t *udp) handlePacket(from *net.UDPAddr, buf []byte) error {
- packet, fromID, hash, err := decodePacket(buf)
+ packet, fromKey, hash, err := decodePacket(buf)
if err != nil {
log.Debug("Bad discv4 packet", "addr", from, "err", err)
return err
}
- err = packet.handle(t, from, fromID, hash)
- log.Trace("<< "+packet.name(), "addr", from, "err", err)
+ fromID := fromKey.id()
+ if err == nil {
+ err = packet.preverify(t, from, fromID, fromKey)
+ }
+ log.Trace("<< "+packet.name(), "id", fromID, "addr", from, "err", err)
+ if err == nil {
+ packet.handle(t, from, fromID, hash)
+ }
return err
}
@@ -615,54 +635,67 @@ func decodePacket(buf []byte) (packet, encPubkey, []byte, error) {
return req, fromKey, hash, err
}
-func (req *ping) handle(t *udp, from *net.UDPAddr, fromKey encPubkey, mac []byte) error {
+// Packet Handlers
+
+func (req *ping) preverify(t *udp, from *net.UDPAddr, fromID enode.ID, fromKey encPubkey) error {
if expired(req.Expiration) {
return errExpired
}
key, err := decodePubkey(fromKey)
if err != nil {
- return fmt.Errorf("invalid public key: %v", err)
+ return errors.New("invalid public key")
}
- t.send(from, pongPacket, &pong{
+ req.senderKey = key
+ return nil
+}
+
+func (req *ping) handle(t *udp, from *net.UDPAddr, fromID enode.ID, mac []byte) {
+ // Reply.
+ t.send(from, fromID, pongPacket, &pong{
To: makeEndpoint(from, req.From.TCP),
ReplyTok: mac,
Expiration: uint64(time.Now().Add(expiration).Unix()),
})
- n := wrapNode(enode.NewV4(key, from.IP, int(req.From.TCP), from.Port))
- t.handleReply(n.ID(), pingPacket, req)
- if time.Since(t.db.LastPongReceived(n.ID())) > bondExpiration {
- t.sendPing(n.ID(), from, func() { t.tab.addThroughPing(n) })
+
+ // Ping back if our last pong on file is too far in the past.
+ n := wrapNode(enode.NewV4(req.senderKey, from.IP, int(req.From.TCP), from.Port))
+ if time.Since(t.db.LastPongReceived(n.ID(), from.IP)) > bondExpiration {
+ t.sendPing(fromID, from, func() {
+ t.tab.addThroughPing(n)
+ })
} else {
t.tab.addThroughPing(n)
}
+
+ // Update node database and endpoint predictor.
+ t.db.UpdateLastPingReceived(n.ID(), from.IP, time.Now())
t.localNode.UDPEndpointStatement(from, &net.UDPAddr{IP: req.To.IP, Port: int(req.To.UDP)})
- t.db.UpdateLastPingReceived(n.ID(), time.Now())
- return nil
}
func (req *ping) name() string { return "PING/v4" }
-func (req *pong) handle(t *udp, from *net.UDPAddr, fromKey encPubkey, mac []byte) error {
+func (req *pong) preverify(t *udp, from *net.UDPAddr, fromID enode.ID, fromKey encPubkey) error {
if expired(req.Expiration) {
return errExpired
}
- fromID := fromKey.id()
- if !t.handleReply(fromID, pongPacket, req) {
+ if !t.handleReply(fromID, from.IP, pongPacket, req) {
return errUnsolicitedReply
}
- t.localNode.UDPEndpointStatement(from, &net.UDPAddr{IP: req.To.IP, Port: int(req.To.UDP)})
- t.db.UpdateLastPongReceived(fromID, time.Now())
return nil
}
+func (req *pong) handle(t *udp, from *net.UDPAddr, fromID enode.ID, mac []byte) {
+ t.localNode.UDPEndpointStatement(from, &net.UDPAddr{IP: req.To.IP, Port: int(req.To.UDP)})
+ t.db.UpdateLastPongReceived(fromID, from.IP, time.Now())
+}
+
func (req *pong) name() string { return "PONG/v4" }
-func (req *findnode) handle(t *udp, from *net.UDPAddr, fromKey encPubkey, mac []byte) error {
+func (req *findnode) preverify(t *udp, from *net.UDPAddr, fromID enode.ID, fromKey encPubkey) error {
if expired(req.Expiration) {
return errExpired
}
- fromID := fromKey.id()
- if time.Since(t.db.LastPongReceived(fromID)) > bondExpiration {
+ if time.Since(t.db.LastPongReceived(fromID, from.IP)) > bondExpiration {
// No endpoint proof pong exists, we don't process the packet. This prevents an
// attack vector where the discovery protocol could be used to amplify traffic in a
// DDOS attack. A malicious actor would send a findnode request with the IP address
@@ -671,43 +704,50 @@ func (req *findnode) handle(t *udp, from *net.UDPAddr, fromKey encPubkey, mac []
// findnode) to the victim.
return errUnknownNode
}
+ return nil
+}
+
+func (req *findnode) handle(t *udp, from *net.UDPAddr, fromID enode.ID, mac []byte) {
+ // Determine closest nodes.
target := enode.ID(crypto.Keccak256Hash(req.Target[:]))
t.tab.mutex.Lock()
closest := t.tab.closest(target, bucketSize).entries
t.tab.mutex.Unlock()
- p := neighbors{Expiration: uint64(time.Now().Add(expiration).Unix())}
- var sent bool
// Send neighbors in chunks with at most maxNeighbors per packet
// to stay below the 1280 byte limit.
+ p := neighbors{Expiration: uint64(time.Now().Add(expiration).Unix())}
+ var sent bool
for _, n := range closest {
if netutil.CheckRelayIP(from.IP, n.IP()) == nil {
p.Nodes = append(p.Nodes, nodeToRPC(n))
}
if len(p.Nodes) == maxNeighbors {
- t.send(from, neighborsPacket, &p)
+ t.send(from, fromID, neighborsPacket, &p)
p.Nodes = p.Nodes[:0]
sent = true
}
}
if len(p.Nodes) > 0 || !sent {
- t.send(from, neighborsPacket, &p)
+ t.send(from, fromID, neighborsPacket, &p)
}
- return nil
}
func (req *findnode) name() string { return "FINDNODE/v4" }
-func (req *neighbors) handle(t *udp, from *net.UDPAddr, fromKey encPubkey, mac []byte) error {
+func (req *neighbors) preverify(t *udp, from *net.UDPAddr, fromID enode.ID, fromKey encPubkey) error {
if expired(req.Expiration) {
return errExpired
}
- if !t.handleReply(fromKey.id(), neighborsPacket, req) {
+ if !t.handleReply(fromID, from.IP, neighborsPacket, req) {
return errUnsolicitedReply
}
return nil
}
+func (req *neighbors) handle(t *udp, from *net.UDPAddr, fromID enode.ID, mac []byte) {
+}
+
func (req *neighbors) name() string { return "NEIGHBORS/v4" }
func expired(ts uint64) bool {
diff --git a/p2p/discover/udp_test.go b/p2p/discover/udp_test.go
index a4ddaf750..3d53c9309 100644
--- a/p2p/discover/udp_test.go
+++ b/p2p/discover/udp_test.go
@@ -19,6 +19,7 @@ package discover
import (
"bytes"
"crypto/ecdsa"
+ crand "crypto/rand"
"encoding/binary"
"encoding/hex"
"errors"
@@ -57,6 +58,7 @@ type udpTest struct {
t *testing.T
pipe *dgramPipe
table *Table
+ db *enode.DB
udp *udp
sent [][]byte
localkey, remotekey *ecdsa.PrivateKey
@@ -71,22 +73,32 @@ func newUDPTest(t *testing.T) *udpTest {
remotekey: newkey(),
remoteaddr: &net.UDPAddr{IP: net.IP{10, 0, 1, 99}, Port: 30303},
}
- db, _ := enode.OpenDB("")
- ln := enode.NewLocalNode(db, test.localkey)
+ test.db, _ = enode.OpenDB("")
+ ln := enode.NewLocalNode(test.db, test.localkey)
test.table, test.udp, _ = newUDP(test.pipe, ln, Config{PrivateKey: test.localkey})
// Wait for initial refresh so the table doesn't send unexpected findnode.
<-test.table.initDone
return test
}
+func (test *udpTest) close() {
+ test.table.Close()
+ test.db.Close()
+}
+
// handles a packet as if it had been sent to the transport.
func (test *udpTest) packetIn(wantError error, ptype byte, data packet) error {
- enc, _, err := encodePacket(test.remotekey, ptype, data)
+ return test.packetInFrom(wantError, test.remotekey, test.remoteaddr, ptype, data)
+}
+
+// handles a packet as if it had been sent to the transport by the key/endpoint.
+func (test *udpTest) packetInFrom(wantError error, key *ecdsa.PrivateKey, addr *net.UDPAddr, ptype byte, data packet) error {
+ enc, _, err := encodePacket(key, ptype, data)
if err != nil {
return test.errorf("packet (%d) encode error: %v", ptype, err)
}
test.sent = append(test.sent, enc)
- if err = test.udp.handlePacket(test.remoteaddr, enc); err != wantError {
+ if err = test.udp.handlePacket(addr, enc); err != wantError {
return test.errorf("error mismatch: got %q, want %q", err, wantError)
}
return nil
@@ -94,19 +106,19 @@ func (test *udpTest) packetIn(wantError error, ptype byte, data packet) error {
// waits for a packet to be sent by the transport.
// validate should have type func(*udpTest, X) error, where X is a packet type.
-func (test *udpTest) waitPacketOut(validate interface{}) ([]byte, error) {
+func (test *udpTest) waitPacketOut(validate interface{}) (*net.UDPAddr, []byte, error) {
dgram := test.pipe.waitPacketOut()
- p, _, hash, err := decodePacket(dgram)
+ p, _, hash, err := decodePacket(dgram.data)
if err != nil {
- return hash, test.errorf("sent packet decode error: %v", err)
+ return &dgram.to, hash, test.errorf("sent packet decode error: %v", err)
}
fn := reflect.ValueOf(validate)
exptype := fn.Type().In(0)
if reflect.TypeOf(p) != exptype {
- return hash, test.errorf("sent packet type mismatch, got: %v, want: %v", reflect.TypeOf(p), exptype)
+ return &dgram.to, hash, test.errorf("sent packet type mismatch, got: %v, want: %v", reflect.TypeOf(p), exptype)
}
fn.Call([]reflect.Value{reflect.ValueOf(p)})
- return hash, nil
+ return &dgram.to, hash, nil
}
func (test *udpTest) errorf(format string, args ...interface{}) error {
@@ -125,7 +137,7 @@ func (test *udpTest) errorf(format string, args ...interface{}) error {
func TestUDP_packetErrors(t *testing.T) {
test := newUDPTest(t)
- defer test.table.Close()
+ defer test.close()
test.packetIn(errExpired, pingPacket, &ping{From: testRemote, To: testLocalAnnounced, Version: 4})
test.packetIn(errUnsolicitedReply, pongPacket, &pong{ReplyTok: []byte{}, Expiration: futureExp})
@@ -136,7 +148,7 @@ func TestUDP_packetErrors(t *testing.T) {
func TestUDP_pingTimeout(t *testing.T) {
t.Parallel()
test := newUDPTest(t)
- defer test.table.Close()
+ defer test.close()
toaddr := &net.UDPAddr{IP: net.ParseIP("1.2.3.4"), Port: 2222}
toid := enode.ID{1, 2, 3, 4}
@@ -148,7 +160,7 @@ func TestUDP_pingTimeout(t *testing.T) {
func TestUDP_responseTimeouts(t *testing.T) {
t.Parallel()
test := newUDPTest(t)
- defer test.table.Close()
+ defer test.close()
rand.Seed(time.Now().UnixNano())
randomDuration := func(max time.Duration) time.Duration {
@@ -166,20 +178,20 @@ func TestUDP_responseTimeouts(t *testing.T) {
// with ptype <= 128 will not get a reply and should time out.
// For all other requests, a reply is scheduled to arrive
// within the timeout window.
- p := &pending{
+ p := &replyMatcher{
ptype: byte(rand.Intn(255)),
- callback: func(interface{}) bool { return true },
+ callback: func(interface{}) (bool, bool) { return true, true },
}
binary.BigEndian.PutUint64(p.from[:], uint64(i))
if p.ptype <= 128 {
p.errc = timeoutErr
- test.udp.addpending <- p
+ test.udp.addReplyMatcher <- p
nTimeouts++
} else {
p.errc = nilErr
- test.udp.addpending <- p
+ test.udp.addReplyMatcher <- p
time.AfterFunc(randomDuration(60*time.Millisecond), func() {
- if !test.udp.handleReply(p.from, p.ptype, nil) {
+ if !test.udp.handleReply(p.from, p.ip, p.ptype, nil) {
t.Logf("not matched: %v", p)
}
})
@@ -220,7 +232,7 @@ func TestUDP_responseTimeouts(t *testing.T) {
func TestUDP_findnodeTimeout(t *testing.T) {
t.Parallel()
test := newUDPTest(t)
- defer test.table.Close()
+ defer test.close()
toaddr := &net.UDPAddr{IP: net.ParseIP("1.2.3.4"), Port: 2222}
toid := enode.ID{1, 2, 3, 4}
@@ -236,50 +248,65 @@ func TestUDP_findnodeTimeout(t *testing.T) {
func TestUDP_findnode(t *testing.T) {
test := newUDPTest(t)
- defer test.table.Close()
+ defer test.close()
// put a few nodes into the table. their exact
// distribution shouldn't matter much, although we need to
// take care not to overflow any bucket.
nodes := &nodesByDistance{target: testTarget.id()}
- for i := 0; i < bucketSize; i++ {
+ live := make(map[enode.ID]bool)
+ numCandidates := 2 * bucketSize
+ for i := 0; i < numCandidates; i++ {
key := newkey()
- n := wrapNode(enode.NewV4(&key.PublicKey, net.IP{10, 13, 0, 1}, 0, i))
- nodes.push(n, bucketSize)
+ ip := net.IP{10, 13, 0, byte(i)}
+ n := wrapNode(enode.NewV4(&key.PublicKey, ip, 0, 2000))
+ // Ensure half of table content isn't verified live yet.
+ if i > numCandidates/2 {
+ n.livenessChecks = 1
+ live[n.ID()] = true
+ }
+ nodes.push(n, numCandidates)
}
- test.table.stuff(nodes.entries)
+ fillTable(test.table, nodes.entries)
// ensure there's a bond with the test node,
// findnode won't be accepted otherwise.
remoteID := encodePubkey(&test.remotekey.PublicKey).id()
- test.table.db.UpdateLastPongReceived(remoteID, time.Now())
+ test.table.db.UpdateLastPongReceived(remoteID, test.remoteaddr.IP, time.Now())
// check that closest neighbors are returned.
- test.packetIn(nil, findnodePacket, &findnode{Target: testTarget, Expiration: futureExp})
expected := test.table.closest(testTarget.id(), bucketSize)
-
+ test.packetIn(nil, findnodePacket, &findnode{Target: testTarget, Expiration: futureExp})
waitNeighbors := func(want []*node) {
test.waitPacketOut(func(p *neighbors) {
if len(p.Nodes) != len(want) {
t.Errorf("wrong number of results: got %d, want %d", len(p.Nodes), bucketSize)
}
- for i := range p.Nodes {
- if p.Nodes[i].ID.id() != want[i].ID() {
- t.Errorf("result mismatch at %d:\n got: %v\n want: %v", i, p.Nodes[i], expected.entries[i])
+ for i, n := range p.Nodes {
+ if n.ID.id() != want[i].ID() {
+ t.Errorf("result mismatch at %d:\n got: %v\n want: %v", i, n, expected.entries[i])
+ }
+ if !live[n.ID.id()] {
+ t.Errorf("result includes dead node %v", n.ID.id())
}
}
})
}
- waitNeighbors(expected.entries[:maxNeighbors])
- waitNeighbors(expected.entries[maxNeighbors:])
+ // Receive replies.
+ want := expected.entries
+ if len(want) > maxNeighbors {
+ waitNeighbors(want[:maxNeighbors])
+ want = want[maxNeighbors:]
+ }
+ waitNeighbors(want)
}
func TestUDP_findnodeMultiReply(t *testing.T) {
test := newUDPTest(t)
- defer test.table.Close()
+ defer test.close()
rid := enode.PubkeyToIDV4(&test.remotekey.PublicKey)
- test.table.db.UpdateLastPingReceived(rid, time.Now())
+ test.table.db.UpdateLastPingReceived(rid, test.remoteaddr.IP, time.Now())
// queue a pending findnode request
resultc, errc := make(chan []*node), make(chan error)
@@ -329,11 +356,40 @@ func TestUDP_findnodeMultiReply(t *testing.T) {
}
}
+func TestUDP_pingMatch(t *testing.T) {
+ test := newUDPTest(t)
+ defer test.close()
+
+ randToken := make([]byte, 32)
+ crand.Read(randToken)
+
+ test.packetIn(nil, pingPacket, &ping{From: testRemote, To: testLocalAnnounced, Version: 4, Expiration: futureExp})
+ test.waitPacketOut(func(*pong) error { return nil })
+ test.waitPacketOut(func(*ping) error { return nil })
+ test.packetIn(errUnsolicitedReply, pongPacket, &pong{ReplyTok: randToken, To: testLocalAnnounced, Expiration: futureExp})
+}
+
+func TestUDP_pingMatchIP(t *testing.T) {
+ test := newUDPTest(t)
+ defer test.close()
+
+ test.packetIn(nil, pingPacket, &ping{From: testRemote, To: testLocalAnnounced, Version: 4, Expiration: futureExp})
+ test.waitPacketOut(func(*pong) error { return nil })
+
+ _, hash, _ := test.waitPacketOut(func(*ping) error { return nil })
+ wrongAddr := &net.UDPAddr{IP: net.IP{33, 44, 1, 2}, Port: 30000}
+ test.packetInFrom(errUnsolicitedReply, test.remotekey, wrongAddr, pongPacket, &pong{
+ ReplyTok: hash,
+ To: testLocalAnnounced,
+ Expiration: futureExp,
+ })
+}
+
func TestUDP_successfulPing(t *testing.T) {
test := newUDPTest(t)
added := make(chan *node, 1)
test.table.nodeAddedHook = func(n *node) { added <- n }
- defer test.table.Close()
+ defer test.close()
// The remote side sends a ping packet to initiate the exchange.
go test.packetIn(nil, pingPacket, &ping{From: testRemote, To: testLocalAnnounced, Version: 4, Expiration: futureExp})
@@ -356,7 +412,7 @@ func TestUDP_successfulPing(t *testing.T) {
})
// remote is unknown, the table pings back.
- hash, _ := test.waitPacketOut(func(p *ping) error {
+ _, hash, _ := test.waitPacketOut(func(p *ping) error {
if !reflect.DeepEqual(p.From, test.udp.ourEndpoint()) {
t.Errorf("got ping.From %#v, want %#v", p.From, test.udp.ourEndpoint())
}
@@ -510,7 +566,12 @@ type dgramPipe struct {
cond *sync.Cond
closing chan struct{}
closed bool
- queue [][]byte
+ queue []dgram
+}
+
+type dgram struct {
+ to net.UDPAddr
+ data []byte
}
func newpipe() *dgramPipe {
@@ -531,7 +592,7 @@ func (c *dgramPipe) WriteToUDP(b []byte, to *net.UDPAddr) (n int, err error) {
if c.closed {
return 0, errors.New("closed")
}
- c.queue = append(c.queue, msg)
+ c.queue = append(c.queue, dgram{*to, b})
c.cond.Signal()
return len(b), nil
}
@@ -556,7 +617,7 @@ func (c *dgramPipe) LocalAddr() net.Addr {
return &net.UDPAddr{IP: testLocal.IP, Port: int(testLocal.UDP)}
}
-func (c *dgramPipe) waitPacketOut() []byte {
+func (c *dgramPipe) waitPacketOut() dgram {
c.mu.Lock()
defer c.mu.Unlock()
for len(c.queue) == 0 {
diff --git a/p2p/enode/nodedb.go b/p2p/enode/nodedb.go
index 7ee0c09a9..9353b155c 100644
--- a/p2p/enode/nodedb.go
+++ b/p2p/enode/nodedb.go
@@ -21,11 +21,11 @@ import (
"crypto/rand"
"encoding/binary"
"fmt"
+ "net"
"os"
"sync"
"time"
- "github.com/ethereum/go-ethereum/log"
"github.com/ethereum/go-ethereum/rlp"
"github.com/syndtr/goleveldb/leveldb"
"github.com/syndtr/goleveldb/leveldb/errors"
@@ -37,24 +37,31 @@ import (
// Keys in the node database.
const (
- dbVersionKey = "version" // Version of the database to flush if changes
- dbItemPrefix = "n:" // Identifier to prefix node entries with
-
- dbDiscoverRoot = ":discover"
- dbDiscoverSeq = dbDiscoverRoot + ":seq"
- dbDiscoverPing = dbDiscoverRoot + ":lastping"
- dbDiscoverPong = dbDiscoverRoot + ":lastpong"
- dbDiscoverFindFails = dbDiscoverRoot + ":findfail"
- dbLocalRoot = ":local"
- dbLocalSeq = dbLocalRoot + ":seq"
+ dbVersionKey = "version" // Version of the database to flush if changes
+ dbNodePrefix = "n:" // Identifier to prefix node entries with
+ dbLocalPrefix = "local:"
+ dbDiscoverRoot = "v4"
+
+ // These fields are stored per ID and IP, the full key is "n:<ID>:v4:<IP>:findfail".
+ // Use nodeItemKey to create those keys.
+ dbNodeFindFails = "findfail"
+ dbNodePing = "lastping"
+ dbNodePong = "lastpong"
+ dbNodeSeq = "seq"
+
+ // Local information is keyed by ID only, the full key is "local:<ID>:seq".
+ // Use localItemKey to create those keys.
+ dbLocalSeq = "seq"
)
-var (
+const (
dbNodeExpiration = 24 * time.Hour // Time after which an unseen node should be dropped.
dbCleanupCycle = time.Hour // Time period for running the expiration task.
- dbVersion = 7
+ dbVersion = 8
)
+var zeroIP = make(net.IP, 16)
+
// DB is the node database, storing previously seen nodes and any collected metadata about
// them for QoS purposes.
type DB struct {
@@ -119,27 +126,58 @@ func newPersistentDB(path string) (*DB, error) {
return &DB{lvl: db, quit: make(chan struct{})}, nil
}
-// makeKey generates the leveldb key-blob from a node id and its particular
-// field of interest.
-func makeKey(id ID, field string) []byte {
- if (id == ID{}) {
- return []byte(field)
- }
- return append([]byte(dbItemPrefix), append(id[:], field...)...)
+// nodeKey returns the database key for a node record.
+func nodeKey(id ID) []byte {
+ key := append([]byte(dbNodePrefix), id[:]...)
+ key = append(key, ':')
+ key = append(key, dbDiscoverRoot...)
+ return key
}
-// splitKey tries to split a database key into a node id and a field part.
-func splitKey(key []byte) (id ID, field string) {
- // If the key is not of a node, return it plainly
- if !bytes.HasPrefix(key, []byte(dbItemPrefix)) {
- return ID{}, string(key)
+// splitNodeKey returns the node ID of a key created by nodeKey.
+func splitNodeKey(key []byte) (id ID, rest []byte) {
+ if !bytes.HasPrefix(key, []byte(dbNodePrefix)) {
+ return ID{}, nil
}
- // Otherwise split the id and field
- item := key[len(dbItemPrefix):]
+ item := key[len(dbNodePrefix):]
copy(id[:], item[:len(id)])
- field = string(item[len(id):])
+ return id, item[len(id)+1:]
+}
- return id, field
+// nodeItemKey returns the database key for a node metadata field.
+func nodeItemKey(id ID, ip net.IP, field string) []byte {
+ ip16 := ip.To16()
+ if ip16 == nil {
+ panic(fmt.Errorf("invalid IP (length %d)", len(ip)))
+ }
+ return bytes.Join([][]byte{nodeKey(id), ip16, []byte(field)}, []byte{':'})
+}
+
+// splitNodeItemKey returns the components of a key created by nodeItemKey.
+func splitNodeItemKey(key []byte) (id ID, ip net.IP, field string) {
+ id, key = splitNodeKey(key)
+ // Skip discover root.
+ if string(key) == dbDiscoverRoot {
+ return id, nil, ""
+ }
+ key = key[len(dbDiscoverRoot)+1:]
+ // Split out the IP.
+ ip = net.IP(key[:16])
+ if ip4 := ip.To4(); ip4 != nil {
+ ip = ip4
+ }
+ key = key[16+1:]
+ // Field is the remainder of key.
+ field = string(key)
+ return id, ip, field
+}
+
+// localItemKey returns the key of a local node item.
+func localItemKey(id ID, field string) []byte {
+ key := append([]byte(dbLocalPrefix), id[:]...)
+ key = append(key, ':')
+ key = append(key, field...)
+ return key
}
// fetchInt64 retrieves an integer associated with a particular key.
@@ -181,7 +219,7 @@ func (db *DB) storeUint64(key []byte, n uint64) error {
// Node retrieves a node with a given id from the database.
func (db *DB) Node(id ID) *Node {
- blob, err := db.lvl.Get(makeKey(id, dbDiscoverRoot), nil)
+ blob, err := db.lvl.Get(nodeKey(id), nil)
if err != nil {
return nil
}
@@ -207,15 +245,15 @@ func (db *DB) UpdateNode(node *Node) error {
if err != nil {
return err
}
- if err := db.lvl.Put(makeKey(node.ID(), dbDiscoverRoot), blob, nil); err != nil {
+ if err := db.lvl.Put(nodeKey(node.ID()), blob, nil); err != nil {
return err
}
- return db.storeUint64(makeKey(node.ID(), dbDiscoverSeq), node.Seq())
+ return db.storeUint64(nodeItemKey(node.ID(), zeroIP, dbNodeSeq), node.Seq())
}
// NodeSeq returns the stored record sequence number of the given node.
func (db *DB) NodeSeq(id ID) uint64 {
- return db.fetchUint64(makeKey(id, dbDiscoverSeq))
+ return db.fetchUint64(nodeItemKey(id, zeroIP, dbNodeSeq))
}
// Resolve returns the stored record of the node if it has a larger sequence
@@ -227,15 +265,17 @@ func (db *DB) Resolve(n *Node) *Node {
return db.Node(n.ID())
}
-// DeleteNode deletes all information/keys associated with a node.
-func (db *DB) DeleteNode(id ID) error {
- deleter := db.lvl.NewIterator(util.BytesPrefix(makeKey(id, "")), nil)
- for deleter.Next() {
- if err := db.lvl.Delete(deleter.Key(), nil); err != nil {
- return err
- }
+// DeleteNode deletes all information associated with a node.
+func (db *DB) DeleteNode(id ID) {
+ deleteRange(db.lvl, nodeKey(id))
+}
+
+func deleteRange(db *leveldb.DB, prefix []byte) {
+ it := db.NewIterator(util.BytesPrefix(prefix), nil)
+ defer it.Release()
+ for it.Next() {
+ db.Delete(it.Key(), nil)
}
- return nil
}
// ensureExpirer is a small helper method ensuring that the data expiration
@@ -259,9 +299,7 @@ func (db *DB) expirer() {
for {
select {
case <-tick.C:
- if err := db.expireNodes(); err != nil {
- log.Error("Failed to expire nodedb items", "err", err)
- }
+ db.expireNodes()
case <-db.quit:
return
}
@@ -269,71 +307,85 @@ func (db *DB) expirer() {
}
// expireNodes iterates over the database and deletes all nodes that have not
-// been seen (i.e. received a pong from) for some allotted time.
-func (db *DB) expireNodes() error {
- threshold := time.Now().Add(-dbNodeExpiration)
-
- // Find discovered nodes that are older than the allowance
- it := db.lvl.NewIterator(nil, nil)
+// been seen (i.e. received a pong from) for some time.
+func (db *DB) expireNodes() {
+ it := db.lvl.NewIterator(util.BytesPrefix([]byte(dbNodePrefix)), nil)
defer it.Release()
+ if !it.Next() {
+ return
+ }
- for it.Next() {
- // Skip the item if not a discovery node
- id, field := splitKey(it.Key())
- if field != dbDiscoverRoot {
- continue
+ var (
+ threshold = time.Now().Add(-dbNodeExpiration).Unix()
+ youngestPong int64
+ atEnd = false
+ )
+ for !atEnd {
+ id, ip, field := splitNodeItemKey(it.Key())
+ if field == dbNodePong {
+ time, _ := binary.Varint(it.Value())
+ if time > youngestPong {
+ youngestPong = time
+ }
+ if time < threshold {
+ // Last pong from this IP older than threshold, remove fields belonging to it.
+ deleteRange(db.lvl, nodeItemKey(id, ip, ""))
+ }
}
- // Skip the node if not expired yet (and not self)
- if seen := db.LastPongReceived(id); seen.After(threshold) {
- continue
+ atEnd = !it.Next()
+ nextID, _ := splitNodeKey(it.Key())
+ if atEnd || nextID != id {
+ // We've moved beyond the last entry of the current ID.
+ // Remove everything if there was no recent enough pong.
+ if youngestPong > 0 && youngestPong < threshold {
+ deleteRange(db.lvl, nodeKey(id))
+ }
+ youngestPong = 0
}
- // Otherwise delete all associated information
- db.DeleteNode(id)
}
- return nil
}
// LastPingReceived retrieves the time of the last ping packet received from
// a remote node.
-func (db *DB) LastPingReceived(id ID) time.Time {
- return time.Unix(db.fetchInt64(makeKey(id, dbDiscoverPing)), 0)
+func (db *DB) LastPingReceived(id ID, ip net.IP) time.Time {
+ return time.Unix(db.fetchInt64(nodeItemKey(id, ip, dbNodePing)), 0)
}
// UpdateLastPingReceived updates the last time we tried contacting a remote node.
-func (db *DB) UpdateLastPingReceived(id ID, instance time.Time) error {
- return db.storeInt64(makeKey(id, dbDiscoverPing), instance.Unix())
+func (db *DB) UpdateLastPingReceived(id ID, ip net.IP, instance time.Time) error {
+ return db.storeInt64(nodeItemKey(id, ip, dbNodePing), instance.Unix())
}
// LastPongReceived retrieves the time of the last successful pong from remote node.
-func (db *DB) LastPongReceived(id ID) time.Time {
+func (db *DB) LastPongReceived(id ID, ip net.IP) time.Time {
// Launch expirer
db.ensureExpirer()
- return time.Unix(db.fetchInt64(makeKey(id, dbDiscoverPong)), 0)
+ return time.Unix(db.fetchInt64(nodeItemKey(id, ip, dbNodePong)), 0)
}
// UpdateLastPongReceived updates the last pong time of a node.
-func (db *DB) UpdateLastPongReceived(id ID, instance time.Time) error {
- return db.storeInt64(makeKey(id, dbDiscoverPong), instance.Unix())
+func (db *DB) UpdateLastPongReceived(id ID, ip net.IP, instance time.Time) error {
+ return db.storeInt64(nodeItemKey(id, ip, dbNodePong), instance.Unix())
}
// FindFails retrieves the number of findnode failures since bonding.
-func (db *DB) FindFails(id ID) int {
- return int(db.fetchInt64(makeKey(id, dbDiscoverFindFails)))
+func (db *DB) FindFails(id ID, ip net.IP) int {
+ return int(db.fetchInt64(nodeItemKey(id, ip, dbNodeFindFails)))
}
// UpdateFindFails updates the number of findnode failures since bonding.
-func (db *DB) UpdateFindFails(id ID, fails int) error {
- return db.storeInt64(makeKey(id, dbDiscoverFindFails), int64(fails))
+func (db *DB) UpdateFindFails(id ID, ip net.IP, fails int) error {
+ return db.storeInt64(nodeItemKey(id, ip, dbNodeFindFails), int64(fails))
}
// LocalSeq retrieves the local record sequence counter.
func (db *DB) localSeq(id ID) uint64 {
- return db.fetchUint64(makeKey(id, dbLocalSeq))
+ return db.fetchUint64(nodeItemKey(id, zeroIP, dbLocalSeq))
}
// storeLocalSeq stores the local record sequence counter.
func (db *DB) storeLocalSeq(id ID, n uint64) {
- db.storeUint64(makeKey(id, dbLocalSeq), n)
+ db.storeUint64(nodeItemKey(id, zeroIP, dbLocalSeq), n)
}
// QuerySeeds retrieves random nodes to be used as potential seed nodes
@@ -355,14 +407,14 @@ seek:
ctr := id[0]
rand.Read(id[:])
id[0] = ctr + id[0]%16
- it.Seek(makeKey(id, dbDiscoverRoot))
+ it.Seek(nodeKey(id))
n := nextNode(it)
if n == nil {
id[0] = 0
continue seek // iterator exhausted
}
- if now.Sub(db.LastPongReceived(n.ID())) > maxAge {
+ if now.Sub(db.LastPongReceived(n.ID(), n.IP())) > maxAge {
continue seek
}
for i := range nodes {
@@ -379,8 +431,8 @@ seek:
// database entries.
func nextNode(it iterator.Iterator) *Node {
for end := false; !end; end = !it.Next() {
- id, field := splitKey(it.Key())
- if field != dbDiscoverRoot {
+ id, rest := splitNodeKey(it.Key())
+ if string(rest) != dbDiscoverRoot {
continue
}
return mustDecodeNode(id[:], it.Value())
diff --git a/p2p/enode/nodedb_test.go b/p2p/enode/nodedb_test.go
index 96794827c..341b61a28 100644
--- a/p2p/enode/nodedb_test.go
+++ b/p2p/enode/nodedb_test.go
@@ -28,42 +28,54 @@ import (
"time"
)
-var nodeDBKeyTests = []struct {
- id ID
- field string
- key []byte
-}{
- {
- id: ID{},
- field: "version",
- key: []byte{0x76, 0x65, 0x72, 0x73, 0x69, 0x6f, 0x6e}, // field
- },
- {
- id: HexID("51232b8d7821617d2b29b54b81cdefb9b3e9c37d7fd5f63270bcc9e1a6f6a439"),
- field: ":discover",
- key: []byte{
- 0x6e, 0x3a, // prefix
- 0x51, 0x23, 0x2b, 0x8d, 0x78, 0x21, 0x61, 0x7d, // node id
- 0x2b, 0x29, 0xb5, 0x4b, 0x81, 0xcd, 0xef, 0xb9, //
- 0xb3, 0xe9, 0xc3, 0x7d, 0x7f, 0xd5, 0xf6, 0x32, //
- 0x70, 0xbc, 0xc9, 0xe1, 0xa6, 0xf6, 0xa4, 0x39, //
- 0x3a, 0x64, 0x69, 0x73, 0x63, 0x6f, 0x76, 0x65, 0x72, // field
- },
- },
+var keytestID = HexID("51232b8d7821617d2b29b54b81cdefb9b3e9c37d7fd5f63270bcc9e1a6f6a439")
+
+func TestDBNodeKey(t *testing.T) {
+ enc := nodeKey(keytestID)
+ want := []byte{
+ 'n', ':',
+ 0x51, 0x23, 0x2b, 0x8d, 0x78, 0x21, 0x61, 0x7d, // node id
+ 0x2b, 0x29, 0xb5, 0x4b, 0x81, 0xcd, 0xef, 0xb9, //
+ 0xb3, 0xe9, 0xc3, 0x7d, 0x7f, 0xd5, 0xf6, 0x32, //
+ 0x70, 0xbc, 0xc9, 0xe1, 0xa6, 0xf6, 0xa4, 0x39, //
+ ':', 'v', '4',
+ }
+ if !bytes.Equal(enc, want) {
+ t.Errorf("wrong encoded key:\ngot %q\nwant %q", enc, want)
+ }
+ id, _ := splitNodeKey(enc)
+ if id != keytestID {
+ t.Errorf("wrong ID from splitNodeKey")
+ }
}
-func TestDBKeys(t *testing.T) {
- for i, tt := range nodeDBKeyTests {
- if key := makeKey(tt.id, tt.field); !bytes.Equal(key, tt.key) {
- t.Errorf("make test %d: key mismatch: have 0x%x, want 0x%x", i, key, tt.key)
- }
- id, field := splitKey(tt.key)
- if !bytes.Equal(id[:], tt.id[:]) {
- t.Errorf("split test %d: id mismatch: have 0x%x, want 0x%x", i, id, tt.id)
- }
- if field != tt.field {
- t.Errorf("split test %d: field mismatch: have 0x%x, want 0x%x", i, field, tt.field)
- }
+func TestDBNodeItemKey(t *testing.T) {
+ wantIP := net.IP{127, 0, 0, 3}
+ wantField := "foobar"
+ enc := nodeItemKey(keytestID, wantIP, wantField)
+ want := []byte{
+ 'n', ':',
+ 0x51, 0x23, 0x2b, 0x8d, 0x78, 0x21, 0x61, 0x7d, // node id
+ 0x2b, 0x29, 0xb5, 0x4b, 0x81, 0xcd, 0xef, 0xb9, //
+ 0xb3, 0xe9, 0xc3, 0x7d, 0x7f, 0xd5, 0xf6, 0x32, //
+ 0x70, 0xbc, 0xc9, 0xe1, 0xa6, 0xf6, 0xa4, 0x39, //
+ ':', 'v', '4', ':',
+ 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, // IP
+ 0x00, 0x00, 0xff, 0xff, 0x7f, 0x00, 0x00, 0x03, //
+ ':', 'f', 'o', 'o', 'b', 'a', 'r',
+ }
+ if !bytes.Equal(enc, want) {
+ t.Errorf("wrong encoded key:\ngot %q\nwant %q", enc, want)
+ }
+ id, ip, field := splitNodeItemKey(enc)
+ if id != keytestID {
+ t.Errorf("splitNodeItemKey returned wrong ID: %v", id)
+ }
+ if !bytes.Equal(ip, wantIP) {
+ t.Errorf("splitNodeItemKey returned wrong IP: %v", ip)
+ }
+ if field != wantField {
+ t.Errorf("splitNodeItemKey returned wrong field: %q", field)
}
}
@@ -113,33 +125,33 @@ func TestDBFetchStore(t *testing.T) {
defer db.Close()
// Check fetch/store operations on a node ping object
- if stored := db.LastPingReceived(node.ID()); stored.Unix() != 0 {
+ if stored := db.LastPingReceived(node.ID(), node.IP()); stored.Unix() != 0 {
t.Errorf("ping: non-existing object: %v", stored)
}
- if err := db.UpdateLastPingReceived(node.ID(), inst); err != nil {
+ if err := db.UpdateLastPingReceived(node.ID(), node.IP(), inst); err != nil {
t.Errorf("ping: failed to update: %v", err)
}
- if stored := db.LastPingReceived(node.ID()); stored.Unix() != inst.Unix() {
+ if stored := db.LastPingReceived(node.ID(), node.IP()); stored.Unix() != inst.Unix() {
t.Errorf("ping: value mismatch: have %v, want %v", stored, inst)
}
// Check fetch/store operations on a node pong object
- if stored := db.LastPongReceived(node.ID()); stored.Unix() != 0 {
+ if stored := db.LastPongReceived(node.ID(), node.IP()); stored.Unix() != 0 {
t.Errorf("pong: non-existing object: %v", stored)
}
- if err := db.UpdateLastPongReceived(node.ID(), inst); err != nil {
+ if err := db.UpdateLastPongReceived(node.ID(), node.IP(), inst); err != nil {
t.Errorf("pong: failed to update: %v", err)
}
- if stored := db.LastPongReceived(node.ID()); stored.Unix() != inst.Unix() {
+ if stored := db.LastPongReceived(node.ID(), node.IP()); stored.Unix() != inst.Unix() {
t.Errorf("pong: value mismatch: have %v, want %v", stored, inst)
}
// Check fetch/store operations on a node findnode-failure object
- if stored := db.FindFails(node.ID()); stored != 0 {
+ if stored := db.FindFails(node.ID(), node.IP()); stored != 0 {
t.Errorf("find-node fails: non-existing object: %v", stored)
}
- if err := db.UpdateFindFails(node.ID(), num); err != nil {
+ if err := db.UpdateFindFails(node.ID(), node.IP(), num); err != nil {
t.Errorf("find-node fails: failed to update: %v", err)
}
- if stored := db.FindFails(node.ID()); stored != num {
+ if stored := db.FindFails(node.ID(), node.IP()); stored != num {
t.Errorf("find-node fails: value mismatch: have %v, want %v", stored, num)
}
// Check fetch/store operations on an actual node object
@@ -256,7 +268,7 @@ func testSeedQuery() error {
if err := db.UpdateNode(seed.node); err != nil {
return fmt.Errorf("node %d: failed to insert: %v", i, err)
}
- if err := db.UpdateLastPongReceived(seed.node.ID(), seed.pong); err != nil {
+ if err := db.UpdateLastPongReceived(seed.node.ID(), seed.node.IP(), seed.pong); err != nil {
return fmt.Errorf("node %d: failed to insert bondTime: %v", i, err)
}
}
@@ -321,10 +333,12 @@ func TestDBPersistency(t *testing.T) {
}
var nodeDBExpirationNodes = []struct {
- node *Node
- pong time.Time
- exp bool
+ node *Node
+ pong time.Time
+ storeNode bool
+ exp bool
}{
+ // Node has new enough pong time and isn't expired:
{
node: NewV4(
hexPubkey("8d110e2ed4b446d9b5fb50f117e5f37fb7597af455e1dab0e6f045a6eeaa786a6781141659020d38bdc5e698ed3d4d2bafa8b5061810dfa63e8ac038db2e9b67"),
@@ -332,17 +346,79 @@ var nodeDBExpirationNodes = []struct {
30303,
30303,
),
- pong: time.Now().Add(-dbNodeExpiration + time.Minute),
- exp: false,
- }, {
+ storeNode: true,
+ pong: time.Now().Add(-dbNodeExpiration + time.Minute),
+ exp: false,
+ },
+ // Node with pong time before expiration is removed:
+ {
node: NewV4(
hexPubkey("913a205579c32425b220dfba999d215066e5bdbf900226b11da1907eae5e93eb40616d47412cf819664e9eacbdfcca6b0c6e07e09847a38472d4be46ab0c3672"),
net.IP{127, 0, 0, 2},
30303,
30303,
),
- pong: time.Now().Add(-dbNodeExpiration - time.Minute),
- exp: true,
+ storeNode: true,
+ pong: time.Now().Add(-dbNodeExpiration - time.Minute),
+ exp: true,
+ },
+ // Just pong time, no node stored:
+ {
+ node: NewV4(
+ hexPubkey("b56670e0b6bad2c5dab9f9fe6f061a16cf78d68b6ae2cfda3144262d08d97ce5f46fd8799b6d1f709b1abe718f2863e224488bd7518e5e3b43809ac9bd1138ca"),
+ net.IP{127, 0, 0, 3},
+ 30303,
+ 30303,
+ ),
+ storeNode: false,
+ pong: time.Now().Add(-dbNodeExpiration - time.Minute),
+ exp: true,
+ },
+ // Node with multiple pong times, all older than expiration.
+ {
+ node: NewV4(
+ hexPubkey("29f619cebfd32c9eab34aec797ed5e3fe15b9b45be95b4df3f5fe6a9ae892f433eb08d7698b2ef3621568b0fb70d57b515ab30d4e72583b798298e0f0a66b9d1"),
+ net.IP{127, 0, 0, 4},
+ 30303,
+ 30303,
+ ),
+ storeNode: true,
+ pong: time.Now().Add(-dbNodeExpiration - time.Minute),
+ exp: true,
+ },
+ {
+ node: NewV4(
+ hexPubkey("29f619cebfd32c9eab34aec797ed5e3fe15b9b45be95b4df3f5fe6a9ae892f433eb08d7698b2ef3621568b0fb70d57b515ab30d4e72583b798298e0f0a66b9d1"),
+ net.IP{127, 0, 0, 5},
+ 30303,
+ 30303,
+ ),
+ storeNode: false,
+ pong: time.Now().Add(-dbNodeExpiration - 2*time.Minute),
+ exp: true,
+ },
+ // Node with multiple pong times, one newer, one older than expiration.
+ {
+ node: NewV4(
+ hexPubkey("3b73a9e5f4af6c4701c57c73cc8cfa0f4802840b24c11eba92aac3aef65644a3728b4b2aec8199f6d72bd66be2c65861c773129039bd47daa091ca90a6d4c857"),
+ net.IP{127, 0, 0, 6},
+ 30303,
+ 30303,
+ ),
+ storeNode: true,
+ pong: time.Now().Add(-dbNodeExpiration + time.Minute),
+ exp: false,
+ },
+ {
+ node: NewV4(
+ hexPubkey("3b73a9e5f4af6c4701c57c73cc8cfa0f4802840b24c11eba92aac3aef65644a3728b4b2aec8199f6d72bd66be2c65861c773129039bd47daa091ca90a6d4c857"),
+ net.IP{127, 0, 0, 7},
+ 30303,
+ 30303,
+ ),
+ storeNode: false,
+ pong: time.Now().Add(-dbNodeExpiration - time.Minute),
+ exp: true,
},
}
@@ -350,23 +426,39 @@ func TestDBExpiration(t *testing.T) {
db, _ := OpenDB("")
defer db.Close()
- // Add all the test nodes and set their last pong time
+ // Add all the test nodes and set their last pong time.
for i, seed := range nodeDBExpirationNodes {
- if err := db.UpdateNode(seed.node); err != nil {
- t.Fatalf("node %d: failed to insert: %v", i, err)
+ if seed.storeNode {
+ if err := db.UpdateNode(seed.node); err != nil {
+ t.Fatalf("node %d: failed to insert: %v", i, err)
+ }
}
- if err := db.UpdateLastPongReceived(seed.node.ID(), seed.pong); err != nil {
+ if err := db.UpdateLastPongReceived(seed.node.ID(), seed.node.IP(), seed.pong); err != nil {
t.Fatalf("node %d: failed to update bondTime: %v", i, err)
}
}
- // Expire some of them, and check the rest
- if err := db.expireNodes(); err != nil {
- t.Fatalf("failed to expire nodes: %v", err)
- }
+
+ db.expireNodes()
+
+ // Check that expired entries have been removed.
+ unixZeroTime := time.Unix(0, 0)
for i, seed := range nodeDBExpirationNodes {
node := db.Node(seed.node.ID())
- if (node == nil && !seed.exp) || (node != nil && seed.exp) {
- t.Errorf("node %d: expiration mismatch: have %v, want %v", i, node, seed.exp)
+ pong := db.LastPongReceived(seed.node.ID(), seed.node.IP())
+ if seed.exp {
+ if seed.storeNode && node != nil {
+ t.Errorf("node %d (%s) shouldn't be present after expiration", i, seed.node.ID().TerminalString())
+ }
+ if !pong.Equal(unixZeroTime) {
+ t.Errorf("pong time %d (%s %v) shouldn't be present after expiration", i, seed.node.ID().TerminalString(), seed.node.IP())
+ }
+ } else {
+ if seed.storeNode && node == nil {
+ t.Errorf("node %d (%s) should be present after expiration", i, seed.node.ID().TerminalString())
+ }
+ if !pong.Equal(seed.pong.Truncate(1 * time.Second)) {
+ t.Errorf("pong time %d (%s) should be %v after expiration, but is %v", i, seed.node.ID().TerminalString(), seed.pong, pong)
+ }
}
}
}