diff options
Diffstat (limited to 'p2p/discover/table.go')
-rw-r--r-- | p2p/discover/table.go | 41 |
1 files changed, 22 insertions, 19 deletions
diff --git a/p2p/discover/table.go b/p2p/discover/table.go index 7a3e41de1..afd4c9a27 100644 --- a/p2p/discover/table.go +++ b/p2p/discover/table.go @@ -72,21 +72,20 @@ type Table struct { ips netutil.DistinctNetSet db *enode.DB // database of known nodes + net transport refreshReq chan chan struct{} initDone chan struct{} closeReq chan struct{} closed chan struct{} nodeAddedHook func(*node) // for testing - - net transport - self *node // metadata of the local node } // transport is implemented by the UDP transport. // it is an interface so we can test without opening lots of UDP // sockets and without generating a private key. type transport interface { + self() *enode.Node ping(enode.ID, *net.UDPAddr) error findnode(toid enode.ID, addr *net.UDPAddr, target encPubkey) ([]*node, error) close() @@ -100,11 +99,10 @@ type bucket struct { ips netutil.DistinctNetSet } -func newTable(t transport, self *enode.Node, db *enode.DB, bootnodes []*enode.Node) (*Table, error) { +func newTable(t transport, db *enode.DB, bootnodes []*enode.Node) (*Table, error) { tab := &Table{ net: t, db: db, - self: wrapNode(self), refreshReq: make(chan chan struct{}), initDone: make(chan struct{}), closeReq: make(chan struct{}), @@ -127,6 +125,10 @@ func newTable(t transport, self *enode.Node, db *enode.DB, bootnodes []*enode.No return tab, nil } +func (tab *Table) self() *enode.Node { + return tab.net.self() +} + func (tab *Table) seedRand() { var b [8]byte crand.Read(b[:]) @@ -136,11 +138,6 @@ func (tab *Table) seedRand() { tab.mutex.Unlock() } -// Self returns the local node. -func (tab *Table) Self() *enode.Node { - return unwrapNode(tab.self) -} - // ReadRandomNodes fills the given slice with random nodes from the table. The results // are guaranteed to be unique for a single invocation, no node will appear twice. func (tab *Table) ReadRandomNodes(buf []*enode.Node) (n int) { @@ -183,6 +180,10 @@ func (tab *Table) ReadRandomNodes(buf []*enode.Node) (n int) { // Close terminates the network listener and flushes the node database. func (tab *Table) Close() { + if tab.net != nil { + tab.net.close() + } + select { case <-tab.closed: // already closed. @@ -257,7 +258,7 @@ func (tab *Table) lookup(targetKey encPubkey, refreshIfEmpty bool) []*node { ) // don't query further if we hit ourself. // unlikely to happen often in practice. - asked[tab.self.ID()] = true + asked[tab.self().ID()] = true for { tab.mutex.Lock() @@ -340,8 +341,8 @@ func (tab *Table) loop() { revalidate = time.NewTimer(tab.nextRevalidateTime()) refresh = time.NewTicker(refreshInterval) copyNodes = time.NewTicker(copyNodesInterval) - revalidateDone = make(chan struct{}) refreshDone = make(chan struct{}) // where doRefresh reports completion + revalidateDone chan struct{} // where doRevalidate reports completion waiting = []chan struct{}{tab.initDone} // holds waiting callers while doRefresh runs ) defer refresh.Stop() @@ -372,9 +373,11 @@ loop: } waiting, refreshDone = nil, nil case <-revalidate.C: + revalidateDone = make(chan struct{}) go tab.doRevalidate(revalidateDone) case <-revalidateDone: revalidate.Reset(tab.nextRevalidateTime()) + revalidateDone = nil case <-copyNodes.C: go tab.copyLiveNodes() case <-tab.closeReq: @@ -382,15 +385,15 @@ loop: } } - if tab.net != nil { - tab.net.close() - } if refreshDone != nil { <-refreshDone } for _, ch := range waiting { close(ch) } + if revalidateDone != nil { + <-revalidateDone + } close(tab.closed) } @@ -408,7 +411,7 @@ func (tab *Table) doRefresh(done chan struct{}) { // Run self lookup to discover new neighbor nodes. // We can only do this if we have a secp256k1 identity. var key ecdsa.PublicKey - if err := tab.self.Load((*enode.Secp256k1)(&key)); err == nil { + if err := tab.self().Load((*enode.Secp256k1)(&key)); err == nil { tab.lookup(encodePubkey(&key), false) } @@ -530,7 +533,7 @@ func (tab *Table) len() (n int) { // bucket returns the bucket for the given node ID hash. func (tab *Table) bucket(id enode.ID) *bucket { - d := enode.LogDist(tab.self.ID(), id) + d := enode.LogDist(tab.self().ID(), id) if d <= bucketMinDistance { return tab.buckets[0] } @@ -543,7 +546,7 @@ func (tab *Table) bucket(id enode.ID) *bucket { // // The caller must not hold tab.mutex. func (tab *Table) add(n *node) { - if n.ID() == tab.self.ID() { + if n.ID() == tab.self().ID() { return } @@ -576,7 +579,7 @@ func (tab *Table) stuff(nodes []*node) { defer tab.mutex.Unlock() for _, n := range nodes { - if n.ID() == tab.self.ID() { + if n.ID() == tab.self().ID() { continue // don't add self } b := tab.bucket(n.ID()) |