aboutsummaryrefslogtreecommitdiffstats
path: root/p2p/discover
diff options
context:
space:
mode:
Diffstat (limited to 'p2p/discover')
-rw-r--r--p2p/discover/node.go291
-rw-r--r--p2p/discover/node_test.go201
-rw-r--r--p2p/discover/table.go280
-rw-r--r--p2p/discover/table_test.go311
-rw-r--r--p2p/discover/udp.go431
-rw-r--r--p2p/discover/udp_test.go211
6 files changed, 1725 insertions, 0 deletions
diff --git a/p2p/discover/node.go b/p2p/discover/node.go
new file mode 100644
index 000000000..c6d2e9766
--- /dev/null
+++ b/p2p/discover/node.go
@@ -0,0 +1,291 @@
+package discover
+
+import (
+ "crypto/ecdsa"
+ "crypto/elliptic"
+ "encoding/hex"
+ "errors"
+ "fmt"
+ "io"
+ "math/rand"
+ "net"
+ "net/url"
+ "strconv"
+ "strings"
+ "time"
+
+ "github.com/ethereum/go-ethereum/crypto/secp256k1"
+ "github.com/ethereum/go-ethereum/rlp"
+)
+
+const nodeIDBits = 512
+
+// Node represents a host on the network.
+type Node struct {
+ ID NodeID
+ IP net.IP
+
+ DiscPort int // UDP listening port for discovery protocol
+ TCPPort int // TCP listening port for RLPx
+
+ active time.Time
+}
+
+func newNode(id NodeID, addr *net.UDPAddr) *Node {
+ return &Node{
+ ID: id,
+ IP: addr.IP,
+ DiscPort: addr.Port,
+ TCPPort: addr.Port,
+ active: time.Now(),
+ }
+}
+
+func (n *Node) isValid() bool {
+ // TODO: don't accept localhost, LAN addresses from internet hosts
+ return !n.IP.IsMulticast() && !n.IP.IsUnspecified() && n.TCPPort != 0 && n.DiscPort != 0
+}
+
+// The string representation of a Node is a URL.
+// Please see ParseNode for a description of the format.
+func (n *Node) String() string {
+ addr := net.TCPAddr{IP: n.IP, Port: n.TCPPort}
+ u := url.URL{
+ Scheme: "enode",
+ User: url.User(fmt.Sprintf("%x", n.ID[:])),
+ Host: addr.String(),
+ }
+ if n.DiscPort != n.TCPPort {
+ u.RawQuery = "discport=" + strconv.Itoa(n.DiscPort)
+ }
+ return u.String()
+}
+
+// ParseNode parses a node URL.
+//
+// A node URL has scheme "enode".
+//
+// The hexadecimal node ID is encoded in the username portion of the
+// URL, separated from the host by an @ sign. The hostname can only be
+// given as an IP address, DNS domain names are not allowed. The port
+// in the host name section is the TCP listening port. If the TCP and
+// UDP (discovery) ports differ, the UDP port is specified as query
+// parameter "discport".
+//
+// In the following example, the node URL describes
+// a node with IP address 10.3.58.6, TCP listening port 30303
+// and UDP discovery port 30301.
+//
+// enode://<hex node id>@10.3.58.6:30303?discport=30301
+func ParseNode(rawurl string) (*Node, error) {
+ var n Node
+ u, err := url.Parse(rawurl)
+ if u.Scheme != "enode" {
+ return nil, errors.New("invalid URL scheme, want \"enode\"")
+ }
+ if u.User == nil {
+ return nil, errors.New("does not contain node ID")
+ }
+ if n.ID, err = HexID(u.User.String()); err != nil {
+ return nil, fmt.Errorf("invalid node ID (%v)", err)
+ }
+ ip, port, err := net.SplitHostPort(u.Host)
+ if err != nil {
+ return nil, fmt.Errorf("invalid host: %v", err)
+ }
+ if n.IP = net.ParseIP(ip); n.IP == nil {
+ return nil, errors.New("invalid IP address")
+ }
+ if n.TCPPort, err = strconv.Atoi(port); err != nil {
+ return nil, errors.New("invalid port")
+ }
+ qv := u.Query()
+ if qv.Get("discport") == "" {
+ n.DiscPort = n.TCPPort
+ } else {
+ if n.DiscPort, err = strconv.Atoi(qv.Get("discport")); err != nil {
+ return nil, errors.New("invalid discport in query")
+ }
+ }
+ return &n, nil
+}
+
+// MustParseNode parses a node URL. It panics if the URL is not valid.
+func MustParseNode(rawurl string) *Node {
+ n, err := ParseNode(rawurl)
+ if err != nil {
+ panic("invalid node URL: " + err.Error())
+ }
+ return n
+}
+
+func (n Node) EncodeRLP(w io.Writer) error {
+ return rlp.Encode(w, rpcNode{IP: n.IP.String(), Port: uint16(n.TCPPort), ID: n.ID})
+}
+func (n *Node) DecodeRLP(s *rlp.Stream) (err error) {
+ var ext rpcNode
+ if err = s.Decode(&ext); err == nil {
+ n.TCPPort = int(ext.Port)
+ n.DiscPort = int(ext.Port)
+ n.ID = ext.ID
+ if n.IP = net.ParseIP(ext.IP); n.IP == nil {
+ return errors.New("invalid IP string")
+ }
+ }
+ return err
+}
+
+// NodeID is a unique identifier for each node.
+// The node identifier is a marshaled elliptic curve public key.
+type NodeID [nodeIDBits / 8]byte
+
+// NodeID prints as a long hexadecimal number.
+func (n NodeID) String() string {
+ return fmt.Sprintf("%#x", n[:])
+}
+
+// The Go syntax representation of a NodeID is a call to HexID.
+func (n NodeID) GoString() string {
+ return fmt.Sprintf("discover.HexID(\"%#x\")", n[:])
+}
+
+// HexID converts a hex string to a NodeID.
+// The string may be prefixed with 0x.
+func HexID(in string) (NodeID, error) {
+ if strings.HasPrefix(in, "0x") {
+ in = in[2:]
+ }
+ var id NodeID
+ b, err := hex.DecodeString(in)
+ if err != nil {
+ return id, err
+ } else if len(b) != len(id) {
+ return id, fmt.Errorf("wrong length, need %d hex bytes", len(id))
+ }
+ copy(id[:], b)
+ return id, nil
+}
+
+// MustHexID converts a hex string to a NodeID.
+// It panics if the string is not a valid NodeID.
+func MustHexID(in string) NodeID {
+ id, err := HexID(in)
+ if err != nil {
+ panic(err)
+ }
+ return id
+}
+
+// PubkeyID returns a marshaled representation of the given public key.
+func PubkeyID(pub *ecdsa.PublicKey) NodeID {
+ var id NodeID
+ pbytes := elliptic.Marshal(pub.Curve, pub.X, pub.Y)
+ if len(pbytes)-1 != len(id) {
+ panic(fmt.Errorf("need %d bit pubkey, got %d bits", (len(id)+1)*8, len(pbytes)))
+ }
+ copy(id[:], pbytes[1:])
+ return id
+}
+
+// recoverNodeID computes the public key used to sign the
+// given hash from the signature.
+func recoverNodeID(hash, sig []byte) (id NodeID, err error) {
+ pubkey, err := secp256k1.RecoverPubkey(hash, sig)
+ if err != nil {
+ return id, err
+ }
+ if len(pubkey)-1 != len(id) {
+ return id, fmt.Errorf("recovered pubkey has %d bits, want %d bits", len(pubkey)*8, (len(id)+1)*8)
+ }
+ for i := range id {
+ id[i] = pubkey[i+1]
+ }
+ return id, nil
+}
+
+// distcmp compares the distances a->target and b->target.
+// Returns -1 if a is closer to target, 1 if b is closer to target
+// and 0 if they are equal.
+func distcmp(target, a, b NodeID) int {
+ for i := range target {
+ da := a[i] ^ target[i]
+ db := b[i] ^ target[i]
+ if da > db {
+ return 1
+ } else if da < db {
+ return -1
+ }
+ }
+ return 0
+}
+
+// table of leading zero counts for bytes [0..255]
+var lzcount = [256]int{
+ 8, 7, 6, 6, 5, 5, 5, 5,
+ 4, 4, 4, 4, 4, 4, 4, 4,
+ 3, 3, 3, 3, 3, 3, 3, 3,
+ 3, 3, 3, 3, 3, 3, 3, 3,
+ 2, 2, 2, 2, 2, 2, 2, 2,
+ 2, 2, 2, 2, 2, 2, 2, 2,
+ 2, 2, 2, 2, 2, 2, 2, 2,
+ 2, 2, 2, 2, 2, 2, 2, 2,
+ 1, 1, 1, 1, 1, 1, 1, 1,
+ 1, 1, 1, 1, 1, 1, 1, 1,
+ 1, 1, 1, 1, 1, 1, 1, 1,
+ 1, 1, 1, 1, 1, 1, 1, 1,
+ 1, 1, 1, 1, 1, 1, 1, 1,
+ 1, 1, 1, 1, 1, 1, 1, 1,
+ 1, 1, 1, 1, 1, 1, 1, 1,
+ 1, 1, 1, 1, 1, 1, 1, 1,
+ 0, 0, 0, 0, 0, 0, 0, 0,
+ 0, 0, 0, 0, 0, 0, 0, 0,
+ 0, 0, 0, 0, 0, 0, 0, 0,
+ 0, 0, 0, 0, 0, 0, 0, 0,
+ 0, 0, 0, 0, 0, 0, 0, 0,
+ 0, 0, 0, 0, 0, 0, 0, 0,
+ 0, 0, 0, 0, 0, 0, 0, 0,
+ 0, 0, 0, 0, 0, 0, 0, 0,
+ 0, 0, 0, 0, 0, 0, 0, 0,
+ 0, 0, 0, 0, 0, 0, 0, 0,
+ 0, 0, 0, 0, 0, 0, 0, 0,
+ 0, 0, 0, 0, 0, 0, 0, 0,
+ 0, 0, 0, 0, 0, 0, 0, 0,
+ 0, 0, 0, 0, 0, 0, 0, 0,
+ 0, 0, 0, 0, 0, 0, 0, 0,
+ 0, 0, 0, 0, 0, 0, 0, 0,
+}
+
+// logdist returns the logarithmic distance between a and b, log2(a ^ b).
+func logdist(a, b NodeID) int {
+ lz := 0
+ for i := range a {
+ x := a[i] ^ b[i]
+ if x == 0 {
+ lz += 8
+ } else {
+ lz += lzcount[x]
+ break
+ }
+ }
+ return len(a)*8 - lz
+}
+
+// randomID returns a random NodeID such that logdist(a, b) == n
+func randomID(a NodeID, n int) (b NodeID) {
+ if n == 0 {
+ return a
+ }
+ // flip bit at position n, fill the rest with random bits
+ b = a
+ pos := len(a) - n/8 - 1
+ bit := byte(0x01) << (byte(n%8) - 1)
+ if bit == 0 {
+ pos++
+ bit = 0x80
+ }
+ b[pos] = a[pos]&^bit | ^a[pos]&bit // TODO: randomize end bits
+ for i := pos + 1; i < len(a); i++ {
+ b[i] = byte(rand.Intn(255))
+ }
+ return b
+}
diff --git a/p2p/discover/node_test.go b/p2p/discover/node_test.go
new file mode 100644
index 000000000..ae82ae4f1
--- /dev/null
+++ b/p2p/discover/node_test.go
@@ -0,0 +1,201 @@
+package discover
+
+import (
+ "math/big"
+ "math/rand"
+ "net"
+ "reflect"
+ "testing"
+ "testing/quick"
+ "time"
+
+ "github.com/ethereum/go-ethereum/crypto"
+)
+
+var (
+ quickrand = rand.New(rand.NewSource(time.Now().Unix()))
+ quickcfg = &quick.Config{MaxCount: 5000, Rand: quickrand}
+)
+
+var parseNodeTests = []struct {
+ rawurl string
+ wantError string
+ wantResult *Node
+}{
+ {
+ rawurl: "http://foobar",
+ wantError: `invalid URL scheme, want "enode"`,
+ },
+ {
+ rawurl: "enode://foobar",
+ wantError: `does not contain node ID`,
+ },
+ {
+ rawurl: "enode://01010101@123.124.125.126:3",
+ wantError: `invalid node ID (wrong length, need 64 hex bytes)`,
+ },
+ {
+ rawurl: "enode://1dd9d65c4552b5eb43d5ad55a2ee3f56c6cbc1c64a5c8d659f51fcd51bace24351232b8d7821617d2b29b54b81cdefb9b3e9c37d7fd5f63270bcc9e1a6f6a439@hostname:3",
+ wantError: `invalid IP address`,
+ },
+ {
+ rawurl: "enode://1dd9d65c4552b5eb43d5ad55a2ee3f56c6cbc1c64a5c8d659f51fcd51bace24351232b8d7821617d2b29b54b81cdefb9b3e9c37d7fd5f63270bcc9e1a6f6a439@127.0.0.1:foo",
+ wantError: `invalid port`,
+ },
+ {
+ rawurl: "enode://1dd9d65c4552b5eb43d5ad55a2ee3f56c6cbc1c64a5c8d659f51fcd51bace24351232b8d7821617d2b29b54b81cdefb9b3e9c37d7fd5f63270bcc9e1a6f6a439@127.0.0.1:3?discport=foo",
+ wantError: `invalid discport in query`,
+ },
+ {
+ rawurl: "enode://1dd9d65c4552b5eb43d5ad55a2ee3f56c6cbc1c64a5c8d659f51fcd51bace24351232b8d7821617d2b29b54b81cdefb9b3e9c37d7fd5f63270bcc9e1a6f6a439@127.0.0.1:52150",
+ wantResult: &Node{
+ ID: MustHexID("0x1dd9d65c4552b5eb43d5ad55a2ee3f56c6cbc1c64a5c8d659f51fcd51bace24351232b8d7821617d2b29b54b81cdefb9b3e9c37d7fd5f63270bcc9e1a6f6a439"),
+ IP: net.ParseIP("127.0.0.1"),
+ DiscPort: 52150,
+ TCPPort: 52150,
+ },
+ },
+ {
+ rawurl: "enode://1dd9d65c4552b5eb43d5ad55a2ee3f56c6cbc1c64a5c8d659f51fcd51bace24351232b8d7821617d2b29b54b81cdefb9b3e9c37d7fd5f63270bcc9e1a6f6a439@[::]:52150",
+ wantResult: &Node{
+ ID: MustHexID("0x1dd9d65c4552b5eb43d5ad55a2ee3f56c6cbc1c64a5c8d659f51fcd51bace24351232b8d7821617d2b29b54b81cdefb9b3e9c37d7fd5f63270bcc9e1a6f6a439"),
+ IP: net.ParseIP("::"),
+ DiscPort: 52150,
+ TCPPort: 52150,
+ },
+ },
+ {
+ rawurl: "enode://1dd9d65c4552b5eb43d5ad55a2ee3f56c6cbc1c64a5c8d659f51fcd51bace24351232b8d7821617d2b29b54b81cdefb9b3e9c37d7fd5f63270bcc9e1a6f6a439@127.0.0.1:52150?discport=223344",
+ wantResult: &Node{
+ ID: MustHexID("0x1dd9d65c4552b5eb43d5ad55a2ee3f56c6cbc1c64a5c8d659f51fcd51bace24351232b8d7821617d2b29b54b81cdefb9b3e9c37d7fd5f63270bcc9e1a6f6a439"),
+ IP: net.ParseIP("127.0.0.1"),
+ DiscPort: 223344,
+ TCPPort: 52150,
+ },
+ },
+}
+
+func TestParseNode(t *testing.T) {
+ for i, test := range parseNodeTests {
+ n, err := ParseNode(test.rawurl)
+ if err == nil && test.wantError != "" {
+ t.Errorf("test %d: got nil error, expected %#q", i, test.wantError)
+ continue
+ }
+ if err != nil && err.Error() != test.wantError {
+ t.Errorf("test %d: got error %#q, expected %#q", i, err.Error(), test.wantError)
+ continue
+ }
+ if !reflect.DeepEqual(n, test.wantResult) {
+ t.Errorf("test %d: result mismatch:\ngot: %#v, want: %#v", i, n, test.wantResult)
+ }
+ }
+}
+
+func TestNodeString(t *testing.T) {
+ for i, test := range parseNodeTests {
+ if test.wantError != "" {
+ continue
+ }
+ str := test.wantResult.String()
+ if str != test.rawurl {
+ t.Errorf("test %d: Node.String() mismatch:\ngot: %s\nwant: %s", i, str, test.rawurl)
+ }
+ }
+}
+
+func TestHexID(t *testing.T) {
+ ref := NodeID{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 128, 106, 217, 182, 31, 165, 174, 1, 67, 7, 235, 220, 150, 66, 83, 173, 205, 159, 44, 10, 57, 42, 161, 26, 188}
+ id1 := MustHexID("0x000000000000000000000000000000000000000000000000000000000000000000000000000000806ad9b61fa5ae014307ebdc964253adcd9f2c0a392aa11abc")
+ id2 := MustHexID("000000000000000000000000000000000000000000000000000000000000000000000000000000806ad9b61fa5ae014307ebdc964253adcd9f2c0a392aa11abc")
+
+ if id1 != ref {
+ t.Errorf("wrong id1\ngot %v\nwant %v", id1[:], ref[:])
+ }
+ if id2 != ref {
+ t.Errorf("wrong id2\ngot %v\nwant %v", id2[:], ref[:])
+ }
+}
+
+func TestNodeID_recover(t *testing.T) {
+ prv := newkey()
+ hash := make([]byte, 32)
+ sig, err := crypto.Sign(hash, prv)
+ if err != nil {
+ t.Fatalf("signing error: %v", err)
+ }
+
+ pub := PubkeyID(&prv.PublicKey)
+ recpub, err := recoverNodeID(hash, sig)
+ if err != nil {
+ t.Fatalf("recovery error: %v", err)
+ }
+ if pub != recpub {
+ t.Errorf("recovered wrong pubkey:\ngot: %v\nwant: %v", recpub, pub)
+ }
+}
+
+func TestNodeID_distcmp(t *testing.T) {
+ distcmpBig := func(target, a, b NodeID) int {
+ tbig := new(big.Int).SetBytes(target[:])
+ abig := new(big.Int).SetBytes(a[:])
+ bbig := new(big.Int).SetBytes(b[:])
+ return new(big.Int).Xor(tbig, abig).Cmp(new(big.Int).Xor(tbig, bbig))
+ }
+ if err := quick.CheckEqual(distcmp, distcmpBig, quickcfg); err != nil {
+ t.Error(err)
+ }
+}
+
+// the random tests is likely to miss the case where they're equal.
+func TestNodeID_distcmpEqual(t *testing.T) {
+ base := NodeID{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}
+ x := NodeID{15, 14, 13, 12, 11, 10, 9, 8, 7, 6, 5, 4, 3, 2, 1, 0}
+ if distcmp(base, x, x) != 0 {
+ t.Errorf("distcmp(base, x, x) != 0")
+ }
+}
+
+func TestNodeID_logdist(t *testing.T) {
+ logdistBig := func(a, b NodeID) int {
+ abig, bbig := new(big.Int).SetBytes(a[:]), new(big.Int).SetBytes(b[:])
+ return new(big.Int).Xor(abig, bbig).BitLen()
+ }
+ if err := quick.CheckEqual(logdist, logdistBig, quickcfg); err != nil {
+ t.Error(err)
+ }
+}
+
+// the random tests is likely to miss the case where they're equal.
+func TestNodeID_logdistEqual(t *testing.T) {
+ x := NodeID{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}
+ if logdist(x, x) != 0 {
+ t.Errorf("logdist(x, x) != 0")
+ }
+}
+
+func TestNodeID_randomID(t *testing.T) {
+ // we don't use quick.Check here because its output isn't
+ // very helpful when the test fails.
+ for i := 0; i < quickcfg.MaxCount; i++ {
+ a := gen(NodeID{}, quickrand).(NodeID)
+ dist := quickrand.Intn(len(NodeID{}) * 8)
+ result := randomID(a, dist)
+ actualdist := logdist(result, a)
+
+ if dist != actualdist {
+ t.Log("a: ", a)
+ t.Log("result:", result)
+ t.Fatalf("#%d: distance of result is %d, want %d", i, actualdist, dist)
+ }
+ }
+}
+
+func (NodeID) Generate(rand *rand.Rand, size int) reflect.Value {
+ var id NodeID
+ m := rand.Intn(len(id))
+ for i := len(id) - 1; i > m; i-- {
+ id[i] = byte(rand.Uint32())
+ }
+ return reflect.ValueOf(id)
+}
diff --git a/p2p/discover/table.go b/p2p/discover/table.go
new file mode 100644
index 000000000..e3bec9328
--- /dev/null
+++ b/p2p/discover/table.go
@@ -0,0 +1,280 @@
+// Package discover implements the Node Discovery Protocol.
+//
+// The Node Discovery protocol provides a way to find RLPx nodes that
+// can be connected to. It uses a Kademlia-like protocol to maintain a
+// distributed database of the IDs and endpoints of all listening
+// nodes.
+package discover
+
+import (
+ "net"
+ "sort"
+ "sync"
+ "time"
+)
+
+const (
+ alpha = 3 // Kademlia concurrency factor
+ bucketSize = 16 // Kademlia bucket size
+ nBuckets = nodeIDBits + 1 // Number of buckets
+)
+
+type Table struct {
+ mutex sync.Mutex // protects buckets, their content, and nursery
+ buckets [nBuckets]*bucket // index of known nodes by distance
+ nursery []*Node // bootstrap nodes
+
+ net transport
+ self *Node // metadata of the local node
+}
+
+// transport is implemented by the UDP transport.
+// it is an interface so we can test without opening lots of UDP
+// sockets and without generating a private key.
+type transport interface {
+ ping(*Node) error
+ findnode(e *Node, target NodeID) ([]*Node, error)
+ close()
+}
+
+// bucket contains nodes, ordered by their last activity.
+type bucket struct {
+ lastLookup time.Time
+ entries []*Node
+}
+
+func newTable(t transport, ourID NodeID, ourAddr *net.UDPAddr) *Table {
+ tab := &Table{net: t, self: newNode(ourID, ourAddr)}
+ for i := range tab.buckets {
+ tab.buckets[i] = new(bucket)
+ }
+ return tab
+}
+
+// Self returns the local node ID.
+func (tab *Table) Self() NodeID {
+ return tab.self.ID
+}
+
+// Close terminates the network listener.
+func (tab *Table) Close() {
+ tab.net.close()
+}
+
+// Bootstrap sets the bootstrap nodes. These nodes are used to connect
+// to the network if the table is empty. Bootstrap will also attempt to
+// fill the table by performing random lookup operations on the
+// network.
+func (tab *Table) Bootstrap(nodes []*Node) {
+ tab.mutex.Lock()
+ // TODO: maybe filter nodes with bad fields (nil, etc.) to avoid strange crashes
+ tab.nursery = make([]*Node, 0, len(nodes))
+ for _, n := range nodes {
+ cpy := *n
+ tab.nursery = append(tab.nursery, &cpy)
+ }
+ tab.mutex.Unlock()
+ tab.refresh()
+}
+
+// Lookup performs a network search for nodes close
+// to the given target. It approaches the target by querying
+// nodes that are closer to it on each iteration.
+func (tab *Table) Lookup(target NodeID) []*Node {
+ var (
+ asked = make(map[NodeID]bool)
+ seen = make(map[NodeID]bool)
+ reply = make(chan []*Node, alpha)
+ pendingQueries = 0
+ )
+ // don't query further if we hit the target or ourself.
+ // unlikely to happen often in practice.
+ asked[target] = true
+ asked[tab.self.ID] = true
+
+ tab.mutex.Lock()
+ // update last lookup stamp (for refresh logic)
+ tab.buckets[logdist(tab.self.ID, target)].lastLookup = time.Now()
+ // generate initial result set
+ result := tab.closest(target, bucketSize)
+ tab.mutex.Unlock()
+
+ for {
+ // ask the alpha closest nodes that we haven't asked yet
+ for i := 0; i < len(result.entries) && pendingQueries < alpha; i++ {
+ n := result.entries[i]
+ if !asked[n.ID] {
+ asked[n.ID] = true
+ pendingQueries++
+ go func() {
+ result, _ := tab.net.findnode(n, target)
+ reply <- result
+ }()
+ }
+ }
+ if pendingQueries == 0 {
+ // we have asked all closest nodes, stop the search
+ break
+ }
+
+ // wait for the next reply
+ for _, n := range <-reply {
+ cn := n
+ if !seen[n.ID] {
+ seen[n.ID] = true
+ result.push(cn, bucketSize)
+ }
+ }
+ pendingQueries--
+ }
+ return result.entries
+}
+
+// refresh performs a lookup for a random target to keep buckets full.
+func (tab *Table) refresh() {
+ ld := -1 // logdist of chosen bucket
+ tab.mutex.Lock()
+ for i, b := range tab.buckets {
+ if i > 0 && b.lastLookup.Before(time.Now().Add(-1*time.Hour)) {
+ ld = i
+ break
+ }
+ }
+ tab.mutex.Unlock()
+
+ result := tab.Lookup(randomID(tab.self.ID, ld))
+ if len(result) == 0 {
+ // bootstrap the table with a self lookup
+ tab.mutex.Lock()
+ tab.add(tab.nursery)
+ tab.mutex.Unlock()
+ tab.Lookup(tab.self.ID)
+ // TODO: the Kademlia paper says that we're supposed to perform
+ // random lookups in all buckets further away than our closest neighbor.
+ }
+}
+
+// closest returns the n nodes in the table that are closest to the
+// given id. The caller must hold tab.mutex.
+func (tab *Table) closest(target NodeID, nresults int) *nodesByDistance {
+ // This is a very wasteful way to find the closest nodes but
+ // obviously correct. I believe that tree-based buckets would make
+ // this easier to implement efficiently.
+ close := &nodesByDistance{target: target}
+ for _, b := range tab.buckets {
+ for _, n := range b.entries {
+ close.push(n, nresults)
+ }
+ }
+ return close
+}
+
+func (tab *Table) len() (n int) {
+ for _, b := range tab.buckets {
+ n += len(b.entries)
+ }
+ return n
+}
+
+// bumpOrAdd updates the activity timestamp for the given node and
+// attempts to insert the node into a bucket. The returned Node might
+// not be part of the table. The caller must hold tab.mutex.
+func (tab *Table) bumpOrAdd(node NodeID, from *net.UDPAddr) (n *Node) {
+ b := tab.buckets[logdist(tab.self.ID, node)]
+ if n = b.bump(node); n == nil {
+ n = newNode(node, from)
+ if len(b.entries) == bucketSize {
+ tab.pingReplace(n, b)
+ } else {
+ b.entries = append(b.entries, n)
+ }
+ }
+ return n
+}
+
+func (tab *Table) pingReplace(n *Node, b *bucket) {
+ old := b.entries[bucketSize-1]
+ go func() {
+ if err := tab.net.ping(old); err == nil {
+ // it responded, we don't need to replace it.
+ return
+ }
+ // it didn't respond, replace the node if it is still the oldest node.
+ tab.mutex.Lock()
+ if len(b.entries) > 0 && b.entries[len(b.entries)-1] == old {
+ // slide down other entries and put the new one in front.
+ // TODO: insert in correct position to keep the order
+ copy(b.entries[1:], b.entries)
+ b.entries[0] = n
+ }
+ tab.mutex.Unlock()
+ }()
+}
+
+// bump updates the activity timestamp for the given node.
+// The caller must hold tab.mutex.
+func (tab *Table) bump(node NodeID) {
+ tab.buckets[logdist(tab.self.ID, node)].bump(node)
+}
+
+// add puts the entries into the table if their corresponding
+// bucket is not full. The caller must hold tab.mutex.
+func (tab *Table) add(entries []*Node) {
+outer:
+ for _, n := range entries {
+ if n == nil || n.ID == tab.self.ID {
+ // skip bad entries. The RLP decoder returns nil for empty
+ // input lists.
+ continue
+ }
+ bucket := tab.buckets[logdist(tab.self.ID, n.ID)]
+ for i := range bucket.entries {
+ if bucket.entries[i].ID == n.ID {
+ // already in bucket
+ continue outer
+ }
+ }
+ if len(bucket.entries) < bucketSize {
+ bucket.entries = append(bucket.entries, n)
+ }
+ }
+}
+
+func (b *bucket) bump(id NodeID) *Node {
+ for i, n := range b.entries {
+ if n.ID == id {
+ n.active = time.Now()
+ // move it to the front
+ copy(b.entries[1:], b.entries[:i+1])
+ b.entries[0] = n
+ return n
+ }
+ }
+ return nil
+}
+
+// nodesByDistance is a list of nodes, ordered by
+// distance to target.
+type nodesByDistance struct {
+ entries []*Node
+ target NodeID
+}
+
+// push adds the given node to the list, keeping the total size below maxElems.
+func (h *nodesByDistance) push(n *Node, maxElems int) {
+ ix := sort.Search(len(h.entries), func(i int) bool {
+ return distcmp(h.target, h.entries[i].ID, n.ID) > 0
+ })
+ if len(h.entries) < maxElems {
+ h.entries = append(h.entries, n)
+ }
+ if ix == len(h.entries) {
+ // farther away than all nodes we already have.
+ // if there was room for it, the node is now the last element.
+ } else {
+ // slide existing entries down to make room
+ // this will overwrite the entry we just appended.
+ copy(h.entries[ix+1:], h.entries[ix:])
+ h.entries[ix] = n
+ }
+}
diff --git a/p2p/discover/table_test.go b/p2p/discover/table_test.go
new file mode 100644
index 000000000..08faea68e
--- /dev/null
+++ b/p2p/discover/table_test.go
@@ -0,0 +1,311 @@
+package discover
+
+import (
+ "crypto/ecdsa"
+ "errors"
+ "fmt"
+ "math/rand"
+ "net"
+ "reflect"
+ "testing"
+ "testing/quick"
+ "time"
+
+ "github.com/ethereum/go-ethereum/crypto"
+)
+
+func TestTable_bumpOrAddBucketAssign(t *testing.T) {
+ tab := newTable(nil, NodeID{}, &net.UDPAddr{})
+ for i := 1; i < len(tab.buckets); i++ {
+ tab.bumpOrAdd(randomID(tab.self.ID, i), &net.UDPAddr{})
+ }
+ for i, b := range tab.buckets {
+ if i > 0 && len(b.entries) != 1 {
+ t.Errorf("bucket %d has %d entries, want 1", i, len(b.entries))
+ }
+ }
+}
+
+func TestTable_bumpOrAddPingReplace(t *testing.T) {
+ pingC := make(pingC)
+ tab := newTable(pingC, NodeID{}, &net.UDPAddr{})
+ last := fillBucket(tab, 200)
+
+ // this bumpOrAdd should not replace the last node
+ // because the node replies to ping.
+ new := tab.bumpOrAdd(randomID(tab.self.ID, 200), &net.UDPAddr{})
+
+ pinged := <-pingC
+ if pinged != last.ID {
+ t.Fatalf("pinged wrong node: %v\nwant %v", pinged, last.ID)
+ }
+
+ tab.mutex.Lock()
+ defer tab.mutex.Unlock()
+ if l := len(tab.buckets[200].entries); l != bucketSize {
+ t.Errorf("wrong bucket size after bumpOrAdd: got %d, want %d", bucketSize, l)
+ }
+ if !contains(tab.buckets[200].entries, last.ID) {
+ t.Error("last entry was removed")
+ }
+ if contains(tab.buckets[200].entries, new.ID) {
+ t.Error("new entry was added")
+ }
+}
+
+func TestTable_bumpOrAddPingTimeout(t *testing.T) {
+ tab := newTable(pingC(nil), NodeID{}, &net.UDPAddr{})
+ last := fillBucket(tab, 200)
+
+ // this bumpOrAdd should replace the last node
+ // because the node does not reply to ping.
+ new := tab.bumpOrAdd(randomID(tab.self.ID, 200), &net.UDPAddr{})
+
+ // wait for async bucket update. damn. this needs to go away.
+ time.Sleep(2 * time.Millisecond)
+
+ tab.mutex.Lock()
+ defer tab.mutex.Unlock()
+ if l := len(tab.buckets[200].entries); l != bucketSize {
+ t.Errorf("wrong bucket size after bumpOrAdd: got %d, want %d", bucketSize, l)
+ }
+ if contains(tab.buckets[200].entries, last.ID) {
+ t.Error("last entry was not removed")
+ }
+ if !contains(tab.buckets[200].entries, new.ID) {
+ t.Error("new entry was not added")
+ }
+}
+
+func fillBucket(tab *Table, ld int) (last *Node) {
+ b := tab.buckets[ld]
+ for len(b.entries) < bucketSize {
+ b.entries = append(b.entries, &Node{ID: randomID(tab.self.ID, ld)})
+ }
+ return b.entries[bucketSize-1]
+}
+
+type pingC chan NodeID
+
+func (t pingC) findnode(n *Node, target NodeID) ([]*Node, error) {
+ panic("findnode called on pingRecorder")
+}
+func (t pingC) close() {
+ panic("close called on pingRecorder")
+}
+func (t pingC) ping(n *Node) error {
+ if t == nil {
+ return errTimeout
+ }
+ t <- n.ID
+ return nil
+}
+
+func TestTable_bump(t *testing.T) {
+ tab := newTable(nil, NodeID{}, &net.UDPAddr{})
+
+ // add an old entry and two recent ones
+ oldactive := time.Now().Add(-2 * time.Minute)
+ old := &Node{ID: randomID(tab.self.ID, 200), active: oldactive}
+ others := []*Node{
+ &Node{ID: randomID(tab.self.ID, 200), active: time.Now()},
+ &Node{ID: randomID(tab.self.ID, 200), active: time.Now()},
+ }
+ tab.add(append(others, old))
+ if tab.buckets[200].entries[0] == old {
+ t.Fatal("old entry is at front of bucket")
+ }
+
+ // bumping the old entry should move it to the front
+ tab.bump(old.ID)
+ if old.active == oldactive {
+ t.Error("activity timestamp not updated")
+ }
+ if tab.buckets[200].entries[0] != old {
+ t.Errorf("bumped entry did not move to the front of bucket")
+ }
+}
+
+func TestTable_closest(t *testing.T) {
+ t.Parallel()
+
+ test := func(test *closeTest) bool {
+ // for any node table, Target and N
+ tab := newTable(nil, test.Self, &net.UDPAddr{})
+ tab.add(test.All)
+
+ // check that doClosest(Target, N) returns nodes
+ result := tab.closest(test.Target, test.N).entries
+ if hasDuplicates(result) {
+ t.Errorf("result contains duplicates")
+ return false
+ }
+ if !sortedByDistanceTo(test.Target, result) {
+ t.Errorf("result is not sorted by distance to target")
+ return false
+ }
+
+ // check that the number of results is min(N, tablen)
+ wantN := test.N
+ if tlen := tab.len(); tlen < test.N {
+ wantN = tlen
+ }
+ if len(result) != wantN {
+ t.Errorf("wrong number of nodes: got %d, want %d", len(result), wantN)
+ return false
+ } else if len(result) == 0 {
+ return true // no need to check distance
+ }
+
+ // check that the result nodes have minimum distance to target.
+ for _, b := range tab.buckets {
+ for _, n := range b.entries {
+ if contains(result, n.ID) {
+ continue // don't run the check below for nodes in result
+ }
+ farthestResult := result[len(result)-1].ID
+ if distcmp(test.Target, n.ID, farthestResult) < 0 {
+ t.Errorf("table contains node that is closer to target but it's not in result")
+ t.Logf(" Target: %v", test.Target)
+ t.Logf(" Farthest Result: %v", farthestResult)
+ t.Logf(" ID: %v", n.ID)
+ return false
+ }
+ }
+ }
+ return true
+ }
+ if err := quick.Check(test, quickcfg); err != nil {
+ t.Error(err)
+ }
+}
+
+type closeTest struct {
+ Self NodeID
+ Target NodeID
+ All []*Node
+ N int
+}
+
+func (*closeTest) Generate(rand *rand.Rand, size int) reflect.Value {
+ t := &closeTest{
+ Self: gen(NodeID{}, rand).(NodeID),
+ Target: gen(NodeID{}, rand).(NodeID),
+ N: rand.Intn(bucketSize),
+ }
+ for _, id := range gen([]NodeID{}, rand).([]NodeID) {
+ t.All = append(t.All, &Node{ID: id})
+ }
+ return reflect.ValueOf(t)
+}
+
+func TestTable_Lookup(t *testing.T) {
+ self := gen(NodeID{}, quickrand).(NodeID)
+ target := randomID(self, 200)
+ transport := findnodeOracle{t, target}
+ tab := newTable(transport, self, &net.UDPAddr{})
+
+ // lookup on empty table returns no nodes
+ if results := tab.Lookup(target); len(results) > 0 {
+ t.Fatalf("lookup on empty table returned %d results: %#v", len(results), results)
+ }
+ // seed table with initial node (otherwise lookup will terminate immediately)
+ tab.bumpOrAdd(randomID(target, 200), &net.UDPAddr{Port: 200})
+
+ results := tab.Lookup(target)
+ t.Logf("results:")
+ for _, e := range results {
+ t.Logf(" ld=%d, %v", logdist(target, e.ID), e.ID)
+ }
+ if len(results) != bucketSize {
+ t.Errorf("wrong number of results: got %d, want %d", len(results), bucketSize)
+ }
+ if hasDuplicates(results) {
+ t.Errorf("result set contains duplicate entries")
+ }
+ if !sortedByDistanceTo(target, results) {
+ t.Errorf("result set not sorted by distance to target")
+ }
+ if !contains(results, target) {
+ t.Errorf("result set does not contain target")
+ }
+}
+
+// findnode on this transport always returns at least one node
+// that is one bucket closer to the target.
+type findnodeOracle struct {
+ t *testing.T
+ target NodeID
+}
+
+func (t findnodeOracle) findnode(n *Node, target NodeID) ([]*Node, error) {
+ t.t.Logf("findnode query at dist %d", n.DiscPort)
+ // current log distance is encoded in port number
+ var result []*Node
+ switch n.DiscPort {
+ case 0:
+ panic("query to node at distance 0")
+ default:
+ // TODO: add more randomness to distances
+ next := n.DiscPort - 1
+ for i := 0; i < bucketSize; i++ {
+ result = append(result, &Node{ID: randomID(t.target, next), DiscPort: next})
+ }
+ }
+ return result, nil
+}
+
+func (t findnodeOracle) close() {}
+
+func (t findnodeOracle) ping(n *Node) error {
+ return errors.New("ping is not supported by this transport")
+}
+
+func hasDuplicates(slice []*Node) bool {
+ seen := make(map[NodeID]bool)
+ for _, e := range slice {
+ if seen[e.ID] {
+ return true
+ }
+ seen[e.ID] = true
+ }
+ return false
+}
+
+func sortedByDistanceTo(distbase NodeID, slice []*Node) bool {
+ var last NodeID
+ for i, e := range slice {
+ if i > 0 && distcmp(distbase, e.ID, last) < 0 {
+ return false
+ }
+ last = e.ID
+ }
+ return true
+}
+
+func contains(ns []*Node, id NodeID) bool {
+ for _, n := range ns {
+ if n.ID == id {
+ return true
+ }
+ }
+ return false
+}
+
+// gen wraps quick.Value so it's easier to use.
+// it generates a random value of the given value's type.
+func gen(typ interface{}, rand *rand.Rand) interface{} {
+ v, ok := quick.Value(reflect.TypeOf(typ), rand)
+ if !ok {
+ panic(fmt.Sprintf("couldn't generate random value of type %T", typ))
+ }
+ return v.Interface()
+}
+
+func newkey() *ecdsa.PrivateKey {
+ key, err := crypto.GenerateKey()
+ if err != nil {
+ panic("couldn't generate key: " + err.Error())
+ }
+ return key
+}
diff --git a/p2p/discover/udp.go b/p2p/discover/udp.go
new file mode 100644
index 000000000..b2a895442
--- /dev/null
+++ b/p2p/discover/udp.go
@@ -0,0 +1,431 @@
+package discover
+
+import (
+ "bytes"
+ "crypto/ecdsa"
+ "errors"
+ "fmt"
+ "net"
+ "time"
+
+ "github.com/ethereum/go-ethereum/crypto"
+ "github.com/ethereum/go-ethereum/logger"
+ "github.com/ethereum/go-ethereum/p2p/nat"
+ "github.com/ethereum/go-ethereum/rlp"
+)
+
+var log = logger.NewLogger("P2P Discovery")
+
+// Errors
+var (
+ errPacketTooSmall = errors.New("too small")
+ errBadHash = errors.New("bad hash")
+ errExpired = errors.New("expired")
+ errTimeout = errors.New("RPC timeout")
+ errClosed = errors.New("socket closed")
+)
+
+// Timeouts
+const (
+ respTimeout = 300 * time.Millisecond
+ sendTimeout = 300 * time.Millisecond
+ expiration = 20 * time.Second
+
+ refreshInterval = 1 * time.Hour
+)
+
+// RPC packet types
+const (
+ pingPacket = iota + 1 // zero is 'reserved'
+ pongPacket
+ findnodePacket
+ neighborsPacket
+)
+
+// RPC request structures
+type (
+ ping struct {
+ IP string // our IP
+ Port uint16 // our port
+ Expiration uint64
+ }
+
+ // reply to Ping
+ pong struct {
+ ReplyTok []byte
+ Expiration uint64
+ }
+
+ findnode struct {
+ // Id to look up. The responding node will send back nodes
+ // closest to the target.
+ Target NodeID
+ Expiration uint64
+ }
+
+ // reply to findnode
+ neighbors struct {
+ Nodes []*Node
+ Expiration uint64
+ }
+)
+
+type rpcNode struct {
+ IP string
+ Port uint16
+ ID NodeID
+}
+
+// udp implements the RPC protocol.
+type udp struct {
+ conn *net.UDPConn
+ priv *ecdsa.PrivateKey
+ addpending chan *pending
+ replies chan reply
+ closing chan struct{}
+ nat nat.Interface
+
+ *Table
+}
+
+// pending represents a pending reply.
+//
+// some implementations of the protocol wish to send more than one
+// reply packet to findnode. in general, any neighbors packet cannot
+// be matched up with a specific findnode packet.
+//
+// our implementation handles this by storing a callback function for
+// each pending reply. incoming packets from a node are dispatched
+// to all the callback functions for that node.
+type pending struct {
+ // these fields must match in the reply.
+ from NodeID
+ ptype byte
+
+ // time when the request must complete
+ deadline time.Time
+
+ // callback is called when a matching reply arrives. if it returns
+ // true, the callback is removed from the pending reply queue.
+ // if it returns false, the reply is considered incomplete and
+ // the callback will be invoked again for the next matching reply.
+ callback func(resp interface{}) (done bool)
+
+ // errc receives nil when the callback indicates completion or an
+ // error if no further reply is received within the timeout.
+ errc chan<- error
+}
+
+type reply struct {
+ from NodeID
+ ptype byte
+ data interface{}
+}
+
+// ListenUDP returns a new table that listens for UDP packets on laddr.
+func ListenUDP(priv *ecdsa.PrivateKey, laddr string, natm nat.Interface) (*Table, error) {
+ addr, err := net.ResolveUDPAddr("udp", laddr)
+ if err != nil {
+ return nil, err
+ }
+ conn, err := net.ListenUDP("udp", addr)
+ if err != nil {
+ return nil, err
+ }
+ udp := &udp{
+ conn: conn,
+ priv: priv,
+ closing: make(chan struct{}),
+ addpending: make(chan *pending),
+ replies: make(chan reply),
+ }
+
+ realaddr := conn.LocalAddr().(*net.UDPAddr)
+ if natm != nil {
+ if !realaddr.IP.IsLoopback() {
+ go nat.Map(natm, udp.closing, "udp", realaddr.Port, realaddr.Port, "ethereum discovery")
+ }
+ // TODO: react to external IP changes over time.
+ if ext, err := natm.ExternalIP(); err == nil {
+ realaddr = &net.UDPAddr{IP: ext, Port: realaddr.Port}
+ }
+ }
+ udp.Table = newTable(udp, PubkeyID(&priv.PublicKey), realaddr)
+
+ go udp.loop()
+ go udp.readLoop()
+ log.Infoln("Listening, ", udp.self)
+ return udp.Table, nil
+}
+
+func (t *udp) close() {
+ close(t.closing)
+ t.conn.Close()
+ // TODO: wait for the loops to end.
+}
+
+// ping sends a ping message to the given node and waits for a reply.
+func (t *udp) ping(e *Node) error {
+ // TODO: maybe check for ReplyTo field in callback to measure RTT
+ errc := t.pending(e.ID, pongPacket, func(interface{}) bool { return true })
+ t.send(e, pingPacket, ping{
+ IP: t.self.IP.String(),
+ Port: uint16(t.self.TCPPort),
+ Expiration: uint64(time.Now().Add(expiration).Unix()),
+ })
+ return <-errc
+}
+
+// findnode sends a findnode request to the given node and waits until
+// the node has sent up to k neighbors.
+func (t *udp) findnode(to *Node, target NodeID) ([]*Node, error) {
+ nodes := make([]*Node, 0, bucketSize)
+ nreceived := 0
+ errc := t.pending(to.ID, neighborsPacket, func(r interface{}) bool {
+ reply := r.(*neighbors)
+ for _, n := range reply.Nodes {
+ nreceived++
+ if n.isValid() {
+ nodes = append(nodes, n)
+ }
+ }
+ return nreceived >= bucketSize
+ })
+
+ t.send(to, findnodePacket, findnode{
+ Target: target,
+ Expiration: uint64(time.Now().Add(expiration).Unix()),
+ })
+ err := <-errc
+ return nodes, err
+}
+
+// pending adds a reply callback to the pending reply queue.
+// see the documentation of type pending for a detailed explanation.
+func (t *udp) pending(id NodeID, ptype byte, callback func(interface{}) bool) <-chan error {
+ ch := make(chan error, 1)
+ p := &pending{from: id, ptype: ptype, callback: callback, errc: ch}
+ select {
+ case t.addpending <- p:
+ // loop will handle it
+ case <-t.closing:
+ ch <- errClosed
+ }
+ return ch
+}
+
+// loop runs in its own goroutin. it keeps track of
+// the refresh timer and the pending reply queue.
+func (t *udp) loop() {
+ var (
+ pending []*pending
+ nextDeadline time.Time
+ timeout = time.NewTimer(0)
+ refresh = time.NewTicker(refreshInterval)
+ )
+ <-timeout.C // ignore first timeout
+ defer refresh.Stop()
+ defer timeout.Stop()
+
+ rearmTimeout := func() {
+ if len(pending) == 0 || nextDeadline == pending[0].deadline {
+ return
+ }
+ nextDeadline = pending[0].deadline
+ timeout.Reset(nextDeadline.Sub(time.Now()))
+ }
+
+ for {
+ select {
+ case <-refresh.C:
+ go t.refresh()
+
+ case <-t.closing:
+ for _, p := range pending {
+ p.errc <- errClosed
+ }
+ return
+
+ case p := <-t.addpending:
+ p.deadline = time.Now().Add(respTimeout)
+ pending = append(pending, p)
+ rearmTimeout()
+
+ case reply := <-t.replies:
+ // run matching callbacks, remove if they return false.
+ for i, p := range pending {
+ if reply.from == p.from && reply.ptype == p.ptype && p.callback(reply.data) {
+ p.errc <- nil
+ copy(pending[i:], pending[i+1:])
+ pending = pending[:len(pending)-1]
+ i--
+ }
+ }
+ rearmTimeout()
+
+ case now := <-timeout.C:
+ // notify and remove callbacks whose deadline is in the past.
+ i := 0
+ for ; i < len(pending) && now.After(pending[i].deadline); i++ {
+ pending[i].errc <- errTimeout
+ }
+ if i > 0 {
+ copy(pending, pending[i:])
+ pending = pending[:len(pending)-i]
+ }
+ rearmTimeout()
+ }
+ }
+}
+
+const (
+ macSize = 256 / 8
+ sigSize = 520 / 8
+ headSize = macSize + sigSize // space of packet frame data
+)
+
+var headSpace = make([]byte, headSize)
+
+func (t *udp) send(to *Node, ptype byte, req interface{}) error {
+ b := new(bytes.Buffer)
+ b.Write(headSpace)
+ b.WriteByte(ptype)
+ if err := rlp.Encode(b, req); err != nil {
+ log.Errorln("error encoding packet:", err)
+ return err
+ }
+
+ packet := b.Bytes()
+ sig, err := crypto.Sign(crypto.Sha3(packet[headSize:]), t.priv)
+ if err != nil {
+ log.Errorln("could not sign packet:", err)
+ return err
+ }
+ copy(packet[macSize:], sig)
+ // add the hash to the front. Note: this doesn't protect the
+ // packet in any way. Our public key will be part of this hash in
+ // the future.
+ copy(packet, crypto.Sha3(packet[macSize:]))
+
+ toaddr := &net.UDPAddr{IP: to.IP, Port: to.DiscPort}
+ log.DebugDetailf(">>> %v %T %v\n", toaddr, req, req)
+ if _, err = t.conn.WriteToUDP(packet, toaddr); err != nil {
+ log.DebugDetailln("UDP send failed:", err)
+ }
+ return err
+}
+
+// readLoop runs in its own goroutine. it handles incoming UDP packets.
+func (t *udp) readLoop() {
+ defer t.conn.Close()
+ buf := make([]byte, 4096) // TODO: good buffer size
+ for {
+ nbytes, from, err := t.conn.ReadFromUDP(buf)
+ if err != nil {
+ return
+ }
+ if err := t.packetIn(from, buf[:nbytes]); err != nil {
+ log.Debugf("Bad packet from %v: %v\n", from, err)
+ }
+ }
+}
+
+func (t *udp) packetIn(from *net.UDPAddr, buf []byte) error {
+ if len(buf) < headSize+1 {
+ return errPacketTooSmall
+ }
+ hash, sig, sigdata := buf[:macSize], buf[macSize:headSize], buf[headSize:]
+ shouldhash := crypto.Sha3(buf[macSize:])
+ if !bytes.Equal(hash, shouldhash) {
+ return errBadHash
+ }
+ fromID, err := recoverNodeID(crypto.Sha3(buf[headSize:]), sig)
+ if err != nil {
+ return err
+ }
+
+ var req interface {
+ handle(t *udp, from *net.UDPAddr, fromID NodeID, mac []byte) error
+ }
+ switch ptype := sigdata[0]; ptype {
+ case pingPacket:
+ req = new(ping)
+ case pongPacket:
+ req = new(pong)
+ case findnodePacket:
+ req = new(findnode)
+ case neighborsPacket:
+ req = new(neighbors)
+ default:
+ return fmt.Errorf("unknown type: %d", ptype)
+ }
+ if err := rlp.Decode(bytes.NewReader(sigdata[1:]), req); err != nil {
+ return err
+ }
+ log.DebugDetailf("<<< %v %T %v\n", from, req, req)
+ return req.handle(t, from, fromID, hash)
+}
+
+func (req *ping) handle(t *udp, from *net.UDPAddr, fromID NodeID, mac []byte) error {
+ if expired(req.Expiration) {
+ return errExpired
+ }
+ t.mutex.Lock()
+ // Note: we're ignoring the provided IP address right now
+ n := t.bumpOrAdd(fromID, from)
+ if req.Port != 0 {
+ n.TCPPort = int(req.Port)
+ }
+ t.mutex.Unlock()
+
+ t.send(n, pongPacket, pong{
+ ReplyTok: mac,
+ Expiration: uint64(time.Now().Add(expiration).Unix()),
+ })
+ return nil
+}
+
+func (req *pong) handle(t *udp, from *net.UDPAddr, fromID NodeID, mac []byte) error {
+ if expired(req.Expiration) {
+ return errExpired
+ }
+ t.mutex.Lock()
+ t.bump(fromID)
+ t.mutex.Unlock()
+
+ t.replies <- reply{fromID, pongPacket, req}
+ return nil
+}
+
+func (req *findnode) handle(t *udp, from *net.UDPAddr, fromID NodeID, mac []byte) error {
+ if expired(req.Expiration) {
+ return errExpired
+ }
+ t.mutex.Lock()
+ e := t.bumpOrAdd(fromID, from)
+ closest := t.closest(req.Target, bucketSize).entries
+ t.mutex.Unlock()
+
+ t.send(e, neighborsPacket, neighbors{
+ Nodes: closest,
+ Expiration: uint64(time.Now().Add(expiration).Unix()),
+ })
+ return nil
+}
+
+func (req *neighbors) handle(t *udp, from *net.UDPAddr, fromID NodeID, mac []byte) error {
+ if expired(req.Expiration) {
+ return errExpired
+ }
+ t.mutex.Lock()
+ t.bump(fromID)
+ t.add(req.Nodes)
+ t.mutex.Unlock()
+
+ t.replies <- reply{fromID, neighborsPacket, req}
+ return nil
+}
+
+func expired(ts uint64) bool {
+ return time.Unix(int64(ts), 0).Before(time.Now())
+}
diff --git a/p2p/discover/udp_test.go b/p2p/discover/udp_test.go
new file mode 100644
index 000000000..0a8ff6358
--- /dev/null
+++ b/p2p/discover/udp_test.go
@@ -0,0 +1,211 @@
+package discover
+
+import (
+ "fmt"
+ logpkg "log"
+ "net"
+ "os"
+ "testing"
+ "time"
+
+ "github.com/ethereum/go-ethereum/logger"
+)
+
+func init() {
+ logger.AddLogSystem(logger.NewStdLogSystem(os.Stdout, logpkg.LstdFlags, logger.ErrorLevel))
+}
+
+func TestUDP_ping(t *testing.T) {
+ t.Parallel()
+
+ n1, _ := ListenUDP(newkey(), "127.0.0.1:0", nil)
+ n2, _ := ListenUDP(newkey(), "127.0.0.1:0", nil)
+ defer n1.Close()
+ defer n2.Close()
+
+ if err := n1.net.ping(n2.self); err != nil {
+ t.Fatalf("ping error: %v", err)
+ }
+ if find(n2, n1.self.ID) == nil {
+ t.Errorf("node 2 does not contain id of node 1")
+ }
+ if e := find(n1, n2.self.ID); e != nil {
+ t.Errorf("node 1 does contains id of node 2: %v", e)
+ }
+}
+
+func find(tab *Table, id NodeID) *Node {
+ for _, b := range tab.buckets {
+ for _, e := range b.entries {
+ if e.ID == id {
+ return e
+ }
+ }
+ }
+ return nil
+}
+
+func TestUDP_findnode(t *testing.T) {
+ t.Parallel()
+
+ n1, _ := ListenUDP(newkey(), "127.0.0.1:0", nil)
+ n2, _ := ListenUDP(newkey(), "127.0.0.1:0", nil)
+ defer n1.Close()
+ defer n2.Close()
+
+ // put a few nodes into n2. the exact distribution shouldn't
+ // matter much, altough we need to take care not to overflow
+ // any bucket.
+ target := randomID(n1.self.ID, 100)
+ nodes := &nodesByDistance{target: target}
+ for i := 0; i < bucketSize; i++ {
+ n2.add([]*Node{&Node{
+ IP: net.IP{1, 2, 3, byte(i)},
+ DiscPort: i + 2,
+ TCPPort: i + 2,
+ ID: randomID(n2.self.ID, i+2),
+ }})
+ }
+ n2.add(nodes.entries)
+ n2.bumpOrAdd(n1.self.ID, &net.UDPAddr{IP: n1.self.IP, Port: n1.self.DiscPort})
+ expected := n2.closest(target, bucketSize)
+
+ err := runUDP(10, func() error {
+ result, _ := n1.net.findnode(n2.self, target)
+ if len(result) != bucketSize {
+ return fmt.Errorf("wrong number of results: got %d, want %d", len(result), bucketSize)
+ }
+ for i := range result {
+ if result[i].ID != expected.entries[i].ID {
+ return fmt.Errorf("result mismatch at %d:\n got: %v\n want: %v", i, result[i], expected.entries[i])
+ }
+ }
+ return nil
+ })
+ if err != nil {
+ t.Error(err)
+ }
+}
+
+func TestUDP_replytimeout(t *testing.T) {
+ t.Parallel()
+
+ // reserve a port so we don't talk to an existing service by accident
+ addr, _ := net.ResolveUDPAddr("udp", "127.0.0.1:0")
+ fd, err := net.ListenUDP("udp", addr)
+ if err != nil {
+ t.Fatal(err)
+ }
+ defer fd.Close()
+
+ n1, _ := ListenUDP(newkey(), "127.0.0.1:0", nil)
+ defer n1.Close()
+ n2 := n1.bumpOrAdd(randomID(n1.self.ID, 10), fd.LocalAddr().(*net.UDPAddr))
+
+ if err := n1.net.ping(n2); err != errTimeout {
+ t.Error("expected timeout error, got", err)
+ }
+
+ if result, err := n1.net.findnode(n2, n1.self.ID); err != errTimeout {
+ t.Error("expected timeout error, got", err)
+ } else if len(result) > 0 {
+ t.Error("expected empty result, got", result)
+ }
+}
+
+func TestUDP_findnodeMultiReply(t *testing.T) {
+ t.Parallel()
+
+ n1, _ := ListenUDP(newkey(), "127.0.0.1:0", nil)
+ n2, _ := ListenUDP(newkey(), "127.0.0.1:0", nil)
+ udp2 := n2.net.(*udp)
+ defer n1.Close()
+ defer n2.Close()
+
+ err := runUDP(10, func() error {
+ nodes := make([]*Node, bucketSize)
+ for i := range nodes {
+ nodes[i] = &Node{
+ IP: net.IP{1, 2, 3, 4},
+ DiscPort: i + 1,
+ TCPPort: i + 1,
+ ID: randomID(n2.self.ID, i+1),
+ }
+ }
+
+ // ask N2 for neighbors. it will send an empty reply back.
+ // the request will wait for up to bucketSize replies.
+ resultc := make(chan []*Node)
+ errc := make(chan error)
+ go func() {
+ ns, err := n1.net.findnode(n2.self, n1.self.ID)
+ if err != nil {
+ errc <- err
+ } else {
+ resultc <- ns
+ }
+ }()
+
+ // send a few more neighbors packets to N1.
+ // it should collect those.
+ for end := 0; end < len(nodes); {
+ off := end
+ if end = end + 5; end > len(nodes) {
+ end = len(nodes)
+ }
+ udp2.send(n1.self, neighborsPacket, neighbors{
+ Nodes: nodes[off:end],
+ Expiration: uint64(time.Now().Add(10 * time.Second).Unix()),
+ })
+ }
+
+ // check that they are all returned. we cannot just check for
+ // equality because they might not be returned in the order they
+ // were sent.
+ var result []*Node
+ select {
+ case result = <-resultc:
+ case err := <-errc:
+ return err
+ }
+ if hasDuplicates(result) {
+ return fmt.Errorf("result slice contains duplicates")
+ }
+ if len(result) != len(nodes) {
+ return fmt.Errorf("wrong number of nodes returned: got %d, want %d", len(result), len(nodes))
+ }
+ matched := make(map[NodeID]bool)
+ for _, n := range result {
+ for _, expn := range nodes {
+ if n.ID == expn.ID { // && bytes.Equal(n.Addr.IP, expn.Addr.IP) && n.Addr.Port == expn.Addr.Port {
+ matched[n.ID] = true
+ }
+ }
+ }
+ if len(matched) != len(nodes) {
+ return fmt.Errorf("wrong number of matching nodes: got %d, want %d", len(matched), len(nodes))
+ }
+ return nil
+ })
+ if err != nil {
+ t.Error(err)
+ }
+}
+
+// runUDP runs a test n times and returns an error if the test failed
+// in all n runs. This is necessary because UDP is unreliable even for
+// connections on the local machine, causing test failures.
+func runUDP(n int, test func() error) error {
+ errcount := 0
+ errors := ""
+ for i := 0; i < n; i++ {
+ if err := test(); err != nil {
+ errors += fmt.Sprintf("\n#%d: %v", i, err)
+ errcount++
+ }
+ }
+ if errcount == n {
+ return fmt.Errorf("failed on all %d iterations:%s", n, errors)
+ }
+ return nil
+}