aboutsummaryrefslogtreecommitdiffstats
path: root/p2p/discover
diff options
context:
space:
mode:
Diffstat (limited to 'p2p/discover')
-rw-r--r--p2p/discover/node.go289
-rw-r--r--p2p/discover/node_test.go201
-rw-r--r--p2p/discover/table.go197
-rw-r--r--p2p/discover/table_test.go116
-rw-r--r--p2p/discover/udp.go32
-rw-r--r--p2p/discover/udp_test.go13
6 files changed, 527 insertions, 321 deletions
diff --git a/p2p/discover/node.go b/p2p/discover/node.go
new file mode 100644
index 000000000..a49a2f547
--- /dev/null
+++ b/p2p/discover/node.go
@@ -0,0 +1,289 @@
+package discover
+
+import (
+ "crypto/ecdsa"
+ "crypto/elliptic"
+ "encoding/hex"
+ "errors"
+ "fmt"
+ "io"
+ "math/rand"
+ "net"
+ "net/url"
+ "strconv"
+ "strings"
+ "time"
+
+ "github.com/ethereum/go-ethereum/crypto/secp256k1"
+ "github.com/ethereum/go-ethereum/rlp"
+)
+
+// Node represents a host on the network.
+type Node struct {
+ ID NodeID
+ IP net.IP
+
+ DiscPort int // UDP listening port for discovery protocol
+ TCPPort int // TCP listening port for RLPx
+
+ active time.Time
+}
+
+func newNode(id NodeID, addr *net.UDPAddr) *Node {
+ return &Node{
+ ID: id,
+ IP: addr.IP,
+ DiscPort: addr.Port,
+ TCPPort: addr.Port,
+ active: time.Now(),
+ }
+}
+
+func (n *Node) isValid() bool {
+ // TODO: don't accept localhost, LAN addresses from internet hosts
+ return !n.IP.IsMulticast() && !n.IP.IsUnspecified() && n.TCPPort != 0 && n.DiscPort != 0
+}
+
+// The string representation of a Node is a URL.
+// Please see ParseNode for a description of the format.
+func (n *Node) String() string {
+ addr := net.TCPAddr{IP: n.IP, Port: n.TCPPort}
+ u := url.URL{
+ Scheme: "enode",
+ User: url.User(fmt.Sprintf("%x", n.ID[:])),
+ Host: addr.String(),
+ }
+ if n.DiscPort != n.TCPPort {
+ u.RawQuery = "discport=" + strconv.Itoa(n.DiscPort)
+ }
+ return u.String()
+}
+
+// ParseNode parses a node URL.
+//
+// A node URL has scheme "enode".
+//
+// The hexadecimal node ID is encoded in the username portion of the
+// URL, separated from the host by an @ sign. The hostname can only be
+// given as an IP address, DNS domain names are not allowed. The port
+// in the host name section is the TCP listening port. If the TCP and
+// UDP (discovery) ports differ, the UDP port is specified as query
+// parameter "discport".
+//
+// In the following example, the node URL describes
+// a node with IP address 10.3.58.6, TCP listening port 30303
+// and UDP discovery port 30301.
+//
+// enode://<hex node id>@10.3.58.6:30303?discport=30301
+func ParseNode(rawurl string) (*Node, error) {
+ var n Node
+ u, err := url.Parse(rawurl)
+ if u.Scheme != "enode" {
+ return nil, errors.New("invalid URL scheme, want \"enode\"")
+ }
+ if u.User == nil {
+ return nil, errors.New("does not contain node ID")
+ }
+ if n.ID, err = HexID(u.User.String()); err != nil {
+ return nil, fmt.Errorf("invalid node ID (%v)", err)
+ }
+ ip, port, err := net.SplitHostPort(u.Host)
+ if err != nil {
+ return nil, fmt.Errorf("invalid host: %v", err)
+ }
+ if n.IP = net.ParseIP(ip); n.IP == nil {
+ return nil, errors.New("invalid IP address")
+ }
+ if n.TCPPort, err = strconv.Atoi(port); err != nil {
+ return nil, errors.New("invalid port")
+ }
+ qv := u.Query()
+ if qv.Get("discport") == "" {
+ n.DiscPort = n.TCPPort
+ } else {
+ if n.DiscPort, err = strconv.Atoi(qv.Get("discport")); err != nil {
+ return nil, errors.New("invalid discport in query")
+ }
+ }
+ return &n, nil
+}
+
+// MustParseNode parses a node URL. It panics if the URL is not valid.
+func MustParseNode(rawurl string) *Node {
+ n, err := ParseNode(rawurl)
+ if err != nil {
+ panic("invalid node URL: " + err.Error())
+ }
+ return n
+}
+
+func (n Node) EncodeRLP(w io.Writer) error {
+ return rlp.Encode(w, rpcNode{IP: n.IP.String(), Port: uint16(n.TCPPort), ID: n.ID})
+}
+func (n *Node) DecodeRLP(s *rlp.Stream) (err error) {
+ var ext rpcNode
+ if err = s.Decode(&ext); err == nil {
+ n.TCPPort = int(ext.Port)
+ n.DiscPort = int(ext.Port)
+ n.ID = ext.ID
+ if n.IP = net.ParseIP(ext.IP); n.IP == nil {
+ return errors.New("invalid IP string")
+ }
+ }
+ return err
+}
+
+// NodeID is a unique identifier for each node.
+// The node identifier is a marshaled elliptic curve public key.
+type NodeID [512 / 8]byte
+
+// NodeID prints as a long hexadecimal number.
+func (n NodeID) String() string {
+ return fmt.Sprintf("%#x", n[:])
+}
+
+// The Go syntax representation of a NodeID is a call to HexID.
+func (n NodeID) GoString() string {
+ return fmt.Sprintf("discover.HexID(\"%#x\")", n[:])
+}
+
+// HexID converts a hex string to a NodeID.
+// The string may be prefixed with 0x.
+func HexID(in string) (NodeID, error) {
+ if strings.HasPrefix(in, "0x") {
+ in = in[2:]
+ }
+ var id NodeID
+ b, err := hex.DecodeString(in)
+ if err != nil {
+ return id, err
+ } else if len(b) != len(id) {
+ return id, fmt.Errorf("wrong length, need %d hex bytes", len(id))
+ }
+ copy(id[:], b)
+ return id, nil
+}
+
+// MustHexID converts a hex string to a NodeID.
+// It panics if the string is not a valid NodeID.
+func MustHexID(in string) NodeID {
+ id, err := HexID(in)
+ if err != nil {
+ panic(err)
+ }
+ return id
+}
+
+// PubkeyID returns a marshaled representation of the given public key.
+func PubkeyID(pub *ecdsa.PublicKey) NodeID {
+ var id NodeID
+ pbytes := elliptic.Marshal(pub.Curve, pub.X, pub.Y)
+ if len(pbytes)-1 != len(id) {
+ panic(fmt.Errorf("need %d bit pubkey, got %d bits", (len(id)+1)*8, len(pbytes)))
+ }
+ copy(id[:], pbytes[1:])
+ return id
+}
+
+// recoverNodeID computes the public key used to sign the
+// given hash from the signature.
+func recoverNodeID(hash, sig []byte) (id NodeID, err error) {
+ pubkey, err := secp256k1.RecoverPubkey(hash, sig)
+ if err != nil {
+ return id, err
+ }
+ if len(pubkey)-1 != len(id) {
+ return id, fmt.Errorf("recovered pubkey has %d bits, want %d bits", len(pubkey)*8, (len(id)+1)*8)
+ }
+ for i := range id {
+ id[i] = pubkey[i+1]
+ }
+ return id, nil
+}
+
+// distcmp compares the distances a->target and b->target.
+// Returns -1 if a is closer to target, 1 if b is closer to target
+// and 0 if they are equal.
+func distcmp(target, a, b NodeID) int {
+ for i := range target {
+ da := a[i] ^ target[i]
+ db := b[i] ^ target[i]
+ if da > db {
+ return 1
+ } else if da < db {
+ return -1
+ }
+ }
+ return 0
+}
+
+// table of leading zero counts for bytes [0..255]
+var lzcount = [256]int{
+ 8, 7, 6, 6, 5, 5, 5, 5,
+ 4, 4, 4, 4, 4, 4, 4, 4,
+ 3, 3, 3, 3, 3, 3, 3, 3,
+ 3, 3, 3, 3, 3, 3, 3, 3,
+ 2, 2, 2, 2, 2, 2, 2, 2,
+ 2, 2, 2, 2, 2, 2, 2, 2,
+ 2, 2, 2, 2, 2, 2, 2, 2,
+ 2, 2, 2, 2, 2, 2, 2, 2,
+ 1, 1, 1, 1, 1, 1, 1, 1,
+ 1, 1, 1, 1, 1, 1, 1, 1,
+ 1, 1, 1, 1, 1, 1, 1, 1,
+ 1, 1, 1, 1, 1, 1, 1, 1,
+ 1, 1, 1, 1, 1, 1, 1, 1,
+ 1, 1, 1, 1, 1, 1, 1, 1,
+ 1, 1, 1, 1, 1, 1, 1, 1,
+ 1, 1, 1, 1, 1, 1, 1, 1,
+ 0, 0, 0, 0, 0, 0, 0, 0,
+ 0, 0, 0, 0, 0, 0, 0, 0,
+ 0, 0, 0, 0, 0, 0, 0, 0,
+ 0, 0, 0, 0, 0, 0, 0, 0,
+ 0, 0, 0, 0, 0, 0, 0, 0,
+ 0, 0, 0, 0, 0, 0, 0, 0,
+ 0, 0, 0, 0, 0, 0, 0, 0,
+ 0, 0, 0, 0, 0, 0, 0, 0,
+ 0, 0, 0, 0, 0, 0, 0, 0,
+ 0, 0, 0, 0, 0, 0, 0, 0,
+ 0, 0, 0, 0, 0, 0, 0, 0,
+ 0, 0, 0, 0, 0, 0, 0, 0,
+ 0, 0, 0, 0, 0, 0, 0, 0,
+ 0, 0, 0, 0, 0, 0, 0, 0,
+ 0, 0, 0, 0, 0, 0, 0, 0,
+ 0, 0, 0, 0, 0, 0, 0, 0,
+}
+
+// logdist returns the logarithmic distance between a and b, log2(a ^ b).
+func logdist(a, b NodeID) int {
+ lz := 0
+ for i := range a {
+ x := a[i] ^ b[i]
+ if x == 0 {
+ lz += 8
+ } else {
+ lz += lzcount[x]
+ break
+ }
+ }
+ return len(a)*8 - lz
+}
+
+// randomID returns a random NodeID such that logdist(a, b) == n
+func randomID(a NodeID, n int) (b NodeID) {
+ if n == 0 {
+ return a
+ }
+ // flip bit at position n, fill the rest with random bits
+ b = a
+ pos := len(a) - n/8 - 1
+ bit := byte(0x01) << (byte(n%8) - 1)
+ if bit == 0 {
+ pos++
+ bit = 0x80
+ }
+ b[pos] = a[pos]&^bit | ^a[pos]&bit // TODO: randomize end bits
+ for i := pos + 1; i < len(a); i++ {
+ b[i] = byte(rand.Intn(255))
+ }
+ return b
+}
diff --git a/p2p/discover/node_test.go b/p2p/discover/node_test.go
new file mode 100644
index 000000000..ae82ae4f1
--- /dev/null
+++ b/p2p/discover/node_test.go
@@ -0,0 +1,201 @@
+package discover
+
+import (
+ "math/big"
+ "math/rand"
+ "net"
+ "reflect"
+ "testing"
+ "testing/quick"
+ "time"
+
+ "github.com/ethereum/go-ethereum/crypto"
+)
+
+var (
+ quickrand = rand.New(rand.NewSource(time.Now().Unix()))
+ quickcfg = &quick.Config{MaxCount: 5000, Rand: quickrand}
+)
+
+var parseNodeTests = []struct {
+ rawurl string
+ wantError string
+ wantResult *Node
+}{
+ {
+ rawurl: "http://foobar",
+ wantError: `invalid URL scheme, want "enode"`,
+ },
+ {
+ rawurl: "enode://foobar",
+ wantError: `does not contain node ID`,
+ },
+ {
+ rawurl: "enode://01010101@123.124.125.126:3",
+ wantError: `invalid node ID (wrong length, need 64 hex bytes)`,
+ },
+ {
+ rawurl: "enode://1dd9d65c4552b5eb43d5ad55a2ee3f56c6cbc1c64a5c8d659f51fcd51bace24351232b8d7821617d2b29b54b81cdefb9b3e9c37d7fd5f63270bcc9e1a6f6a439@hostname:3",
+ wantError: `invalid IP address`,
+ },
+ {
+ rawurl: "enode://1dd9d65c4552b5eb43d5ad55a2ee3f56c6cbc1c64a5c8d659f51fcd51bace24351232b8d7821617d2b29b54b81cdefb9b3e9c37d7fd5f63270bcc9e1a6f6a439@127.0.0.1:foo",
+ wantError: `invalid port`,
+ },
+ {
+ rawurl: "enode://1dd9d65c4552b5eb43d5ad55a2ee3f56c6cbc1c64a5c8d659f51fcd51bace24351232b8d7821617d2b29b54b81cdefb9b3e9c37d7fd5f63270bcc9e1a6f6a439@127.0.0.1:3?discport=foo",
+ wantError: `invalid discport in query`,
+ },
+ {
+ rawurl: "enode://1dd9d65c4552b5eb43d5ad55a2ee3f56c6cbc1c64a5c8d659f51fcd51bace24351232b8d7821617d2b29b54b81cdefb9b3e9c37d7fd5f63270bcc9e1a6f6a439@127.0.0.1:52150",
+ wantResult: &Node{
+ ID: MustHexID("0x1dd9d65c4552b5eb43d5ad55a2ee3f56c6cbc1c64a5c8d659f51fcd51bace24351232b8d7821617d2b29b54b81cdefb9b3e9c37d7fd5f63270bcc9e1a6f6a439"),
+ IP: net.ParseIP("127.0.0.1"),
+ DiscPort: 52150,
+ TCPPort: 52150,
+ },
+ },
+ {
+ rawurl: "enode://1dd9d65c4552b5eb43d5ad55a2ee3f56c6cbc1c64a5c8d659f51fcd51bace24351232b8d7821617d2b29b54b81cdefb9b3e9c37d7fd5f63270bcc9e1a6f6a439@[::]:52150",
+ wantResult: &Node{
+ ID: MustHexID("0x1dd9d65c4552b5eb43d5ad55a2ee3f56c6cbc1c64a5c8d659f51fcd51bace24351232b8d7821617d2b29b54b81cdefb9b3e9c37d7fd5f63270bcc9e1a6f6a439"),
+ IP: net.ParseIP("::"),
+ DiscPort: 52150,
+ TCPPort: 52150,
+ },
+ },
+ {
+ rawurl: "enode://1dd9d65c4552b5eb43d5ad55a2ee3f56c6cbc1c64a5c8d659f51fcd51bace24351232b8d7821617d2b29b54b81cdefb9b3e9c37d7fd5f63270bcc9e1a6f6a439@127.0.0.1:52150?discport=223344",
+ wantResult: &Node{
+ ID: MustHexID("0x1dd9d65c4552b5eb43d5ad55a2ee3f56c6cbc1c64a5c8d659f51fcd51bace24351232b8d7821617d2b29b54b81cdefb9b3e9c37d7fd5f63270bcc9e1a6f6a439"),
+ IP: net.ParseIP("127.0.0.1"),
+ DiscPort: 223344,
+ TCPPort: 52150,
+ },
+ },
+}
+
+func TestParseNode(t *testing.T) {
+ for i, test := range parseNodeTests {
+ n, err := ParseNode(test.rawurl)
+ if err == nil && test.wantError != "" {
+ t.Errorf("test %d: got nil error, expected %#q", i, test.wantError)
+ continue
+ }
+ if err != nil && err.Error() != test.wantError {
+ t.Errorf("test %d: got error %#q, expected %#q", i, err.Error(), test.wantError)
+ continue
+ }
+ if !reflect.DeepEqual(n, test.wantResult) {
+ t.Errorf("test %d: result mismatch:\ngot: %#v, want: %#v", i, n, test.wantResult)
+ }
+ }
+}
+
+func TestNodeString(t *testing.T) {
+ for i, test := range parseNodeTests {
+ if test.wantError != "" {
+ continue
+ }
+ str := test.wantResult.String()
+ if str != test.rawurl {
+ t.Errorf("test %d: Node.String() mismatch:\ngot: %s\nwant: %s", i, str, test.rawurl)
+ }
+ }
+}
+
+func TestHexID(t *testing.T) {
+ ref := NodeID{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 128, 106, 217, 182, 31, 165, 174, 1, 67, 7, 235, 220, 150, 66, 83, 173, 205, 159, 44, 10, 57, 42, 161, 26, 188}
+ id1 := MustHexID("0x000000000000000000000000000000000000000000000000000000000000000000000000000000806ad9b61fa5ae014307ebdc964253adcd9f2c0a392aa11abc")
+ id2 := MustHexID("000000000000000000000000000000000000000000000000000000000000000000000000000000806ad9b61fa5ae014307ebdc964253adcd9f2c0a392aa11abc")
+
+ if id1 != ref {
+ t.Errorf("wrong id1\ngot %v\nwant %v", id1[:], ref[:])
+ }
+ if id2 != ref {
+ t.Errorf("wrong id2\ngot %v\nwant %v", id2[:], ref[:])
+ }
+}
+
+func TestNodeID_recover(t *testing.T) {
+ prv := newkey()
+ hash := make([]byte, 32)
+ sig, err := crypto.Sign(hash, prv)
+ if err != nil {
+ t.Fatalf("signing error: %v", err)
+ }
+
+ pub := PubkeyID(&prv.PublicKey)
+ recpub, err := recoverNodeID(hash, sig)
+ if err != nil {
+ t.Fatalf("recovery error: %v", err)
+ }
+ if pub != recpub {
+ t.Errorf("recovered wrong pubkey:\ngot: %v\nwant: %v", recpub, pub)
+ }
+}
+
+func TestNodeID_distcmp(t *testing.T) {
+ distcmpBig := func(target, a, b NodeID) int {
+ tbig := new(big.Int).SetBytes(target[:])
+ abig := new(big.Int).SetBytes(a[:])
+ bbig := new(big.Int).SetBytes(b[:])
+ return new(big.Int).Xor(tbig, abig).Cmp(new(big.Int).Xor(tbig, bbig))
+ }
+ if err := quick.CheckEqual(distcmp, distcmpBig, quickcfg); err != nil {
+ t.Error(err)
+ }
+}
+
+// the random tests is likely to miss the case where they're equal.
+func TestNodeID_distcmpEqual(t *testing.T) {
+ base := NodeID{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}
+ x := NodeID{15, 14, 13, 12, 11, 10, 9, 8, 7, 6, 5, 4, 3, 2, 1, 0}
+ if distcmp(base, x, x) != 0 {
+ t.Errorf("distcmp(base, x, x) != 0")
+ }
+}
+
+func TestNodeID_logdist(t *testing.T) {
+ logdistBig := func(a, b NodeID) int {
+ abig, bbig := new(big.Int).SetBytes(a[:]), new(big.Int).SetBytes(b[:])
+ return new(big.Int).Xor(abig, bbig).BitLen()
+ }
+ if err := quick.CheckEqual(logdist, logdistBig, quickcfg); err != nil {
+ t.Error(err)
+ }
+}
+
+// the random tests is likely to miss the case where they're equal.
+func TestNodeID_logdistEqual(t *testing.T) {
+ x := NodeID{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}
+ if logdist(x, x) != 0 {
+ t.Errorf("logdist(x, x) != 0")
+ }
+}
+
+func TestNodeID_randomID(t *testing.T) {
+ // we don't use quick.Check here because its output isn't
+ // very helpful when the test fails.
+ for i := 0; i < quickcfg.MaxCount; i++ {
+ a := gen(NodeID{}, quickrand).(NodeID)
+ dist := quickrand.Intn(len(NodeID{}) * 8)
+ result := randomID(a, dist)
+ actualdist := logdist(result, a)
+
+ if dist != actualdist {
+ t.Log("a: ", a)
+ t.Log("result:", result)
+ t.Fatalf("#%d: distance of result is %d, want %d", i, actualdist, dist)
+ }
+ }
+}
+
+func (NodeID) Generate(rand *rand.Rand, size int) reflect.Value {
+ var id NodeID
+ m := rand.Intn(len(id))
+ for i := len(id) - 1; i > m; i-- {
+ id[i] = byte(rand.Uint32())
+ }
+ return reflect.ValueOf(id)
+}
diff --git a/p2p/discover/table.go b/p2p/discover/table.go
index ea9680404..6025507eb 100644
--- a/p2p/discover/table.go
+++ b/p2p/discover/table.go
@@ -7,20 +7,10 @@
package discover
import (
- "crypto/ecdsa"
- "crypto/elliptic"
- "encoding/hex"
- "fmt"
- "io"
- "math/rand"
"net"
"sort"
- "strings"
"sync"
"time"
-
- "github.com/ethereum/go-ethereum/crypto/secp256k1"
- "github.com/ethereum/go-ethereum/rlp"
)
const (
@@ -53,36 +43,10 @@ type bucket struct {
entries []*Node
}
-// Node represents node metadata that is stored in the table.
-type Node struct {
- Addr *net.UDPAddr
- ID NodeID
-
- active time.Time
-}
-
-type rpcNode struct {
- IP string
- Port uint16
- ID NodeID
-}
-
-func (n Node) EncodeRLP(w io.Writer) error {
- return rlp.Encode(w, rpcNode{IP: n.Addr.IP.String(), Port: uint16(n.Addr.Port), ID: n.ID})
-}
-func (n *Node) DecodeRLP(s *rlp.Stream) (err error) {
- var ext rpcNode
- if err = s.Decode(&ext); err == nil {
- n.Addr = &net.UDPAddr{IP: net.ParseIP(ext.IP), Port: int(ext.Port)}
- n.ID = ext.ID
- }
- return err
-}
-
func newTable(t transport, ourID NodeID, ourAddr *net.UDPAddr) *Table {
- tab := &Table{net: t, self: &Node{ID: ourID, Addr: ourAddr}}
+ tab := &Table{net: t, self: newNode(ourID, ourAddr)}
for i := range tab.buckets {
- tab.buckets[i] = &bucket{}
+ tab.buckets[i] = new(bucket)
}
return tab
}
@@ -217,7 +181,7 @@ func (tab *Table) len() (n int) {
func (tab *Table) bumpOrAdd(node NodeID, from *net.UDPAddr) (n *Node) {
b := tab.buckets[logdist(tab.self.ID, node)]
if n = b.bump(node); n == nil {
- n = &Node{ID: node, Addr: from, active: time.Now()}
+ n = newNode(node, from)
if len(b.entries) == bucketSize {
tab.pingReplace(n, b)
} else {
@@ -238,6 +202,7 @@ func (tab *Table) pingReplace(n *Node, b *bucket) {
tab.mutex.Lock()
if len(b.entries) > 0 && b.entries[len(b.entries)-1] == old {
// slide down other entries and put the new one in front.
+ // TODO: insert in correct position to keep the order
copy(b.entries[1:], b.entries)
b.entries[0] = n
}
@@ -312,157 +277,3 @@ func (h *nodesByDistance) push(n *Node, maxElems int) {
h.entries[ix] = n
}
}
-
-// NodeID is a unique identifier for each node.
-// The node identifier is a marshaled elliptic curve public key.
-type NodeID [512 / 8]byte
-
-// NodeID prints as a long hexadecimal number.
-func (n NodeID) String() string {
- return fmt.Sprintf("%#x", n[:])
-}
-
-// The Go syntax representation of a NodeID is a call to HexID.
-func (n NodeID) GoString() string {
- return fmt.Sprintf("HexID(\"%#x\")", n[:])
-}
-
-// HexID converts a hex string to a NodeID.
-// The string may be prefixed with 0x.
-func HexID(in string) (NodeID, error) {
- if strings.HasPrefix(in, "0x") {
- in = in[2:]
- }
- var id NodeID
- b, err := hex.DecodeString(in)
- if err != nil {
- return id, err
- } else if len(b) != len(id) {
- return id, fmt.Errorf("wrong length, need %d hex bytes", len(id))
- }
- copy(id[:], b)
- return id, nil
-}
-
-// MustHexID converts a hex string to a NodeID.
-// It panics if the string is not a valid NodeID.
-func MustHexID(in string) NodeID {
- id, err := HexID(in)
- if err != nil {
- panic(err)
- }
- return id
-}
-
-func PubkeyID(pub *ecdsa.PublicKey) NodeID {
- var id NodeID
- pbytes := elliptic.Marshal(pub.Curve, pub.X, pub.Y)
- if len(pbytes)-1 != len(id) {
- panic(fmt.Errorf("invalid key: need %d bit pubkey, got %d bits", (len(id)+1)*8, len(pbytes)))
- }
- copy(id[:], pbytes[1:])
- return id
-}
-
-// recoverNodeID computes the public key used to sign the
-// given hash from the signature.
-func recoverNodeID(hash, sig []byte) (id NodeID, err error) {
- pubkey, err := secp256k1.RecoverPubkey(hash, sig)
- if err != nil {
- return id, err
- }
- if len(pubkey)-1 != len(id) {
- return id, fmt.Errorf("recovered pubkey has %d bits, want %d bits", len(pubkey)*8, (len(id)+1)*8)
- }
- for i := range id {
- id[i] = pubkey[i+1]
- }
- return id, nil
-}
-
-// distcmp compares the distances a->target and b->target.
-// Returns -1 if a is closer to target, 1 if b is closer to target
-// and 0 if they are equal.
-func distcmp(target, a, b NodeID) int {
- for i := range target {
- da := a[i] ^ target[i]
- db := b[i] ^ target[i]
- if da > db {
- return 1
- } else if da < db {
- return -1
- }
- }
- return 0
-}
-
-// table of leading zero counts for bytes [0..255]
-var lzcount = [256]int{
- 8, 7, 6, 6, 5, 5, 5, 5,
- 4, 4, 4, 4, 4, 4, 4, 4,
- 3, 3, 3, 3, 3, 3, 3, 3,
- 3, 3, 3, 3, 3, 3, 3, 3,
- 2, 2, 2, 2, 2, 2, 2, 2,
- 2, 2, 2, 2, 2, 2, 2, 2,
- 2, 2, 2, 2, 2, 2, 2, 2,
- 2, 2, 2, 2, 2, 2, 2, 2,
- 1, 1, 1, 1, 1, 1, 1, 1,
- 1, 1, 1, 1, 1, 1, 1, 1,
- 1, 1, 1, 1, 1, 1, 1, 1,
- 1, 1, 1, 1, 1, 1, 1, 1,
- 1, 1, 1, 1, 1, 1, 1, 1,
- 1, 1, 1, 1, 1, 1, 1, 1,
- 1, 1, 1, 1, 1, 1, 1, 1,
- 1, 1, 1, 1, 1, 1, 1, 1,
- 0, 0, 0, 0, 0, 0, 0, 0,
- 0, 0, 0, 0, 0, 0, 0, 0,
- 0, 0, 0, 0, 0, 0, 0, 0,
- 0, 0, 0, 0, 0, 0, 0, 0,
- 0, 0, 0, 0, 0, 0, 0, 0,
- 0, 0, 0, 0, 0, 0, 0, 0,
- 0, 0, 0, 0, 0, 0, 0, 0,
- 0, 0, 0, 0, 0, 0, 0, 0,
- 0, 0, 0, 0, 0, 0, 0, 0,
- 0, 0, 0, 0, 0, 0, 0, 0,
- 0, 0, 0, 0, 0, 0, 0, 0,
- 0, 0, 0, 0, 0, 0, 0, 0,
- 0, 0, 0, 0, 0, 0, 0, 0,
- 0, 0, 0, 0, 0, 0, 0, 0,
- 0, 0, 0, 0, 0, 0, 0, 0,
- 0, 0, 0, 0, 0, 0, 0, 0,
-}
-
-// logdist returns the logarithmic distance between a and b, log2(a ^ b).
-func logdist(a, b NodeID) int {
- lz := 0
- for i := range a {
- x := a[i] ^ b[i]
- if x == 0 {
- lz += 8
- } else {
- lz += lzcount[x]
- break
- }
- }
- return len(a)*8 - lz
-}
-
-// randomID returns a random NodeID such that logdist(a, b) == n
-func randomID(a NodeID, n int) (b NodeID) {
- if n == 0 {
- return a
- }
- // flip bit at position n, fill the rest with random bits
- b = a
- pos := len(a) - n/8 - 1
- bit := byte(0x01) << (byte(n%8) - 1)
- if bit == 0 {
- pos++
- bit = 0x80
- }
- b[pos] = a[pos]&^bit | ^a[pos]&bit // TODO: randomize end bits
- for i := pos + 1; i < len(a); i++ {
- b[i] = byte(rand.Intn(255))
- }
- return b
-}
diff --git a/p2p/discover/table_test.go b/p2p/discover/table_test.go
index 90f89b147..9b05ce7f4 100644
--- a/p2p/discover/table_test.go
+++ b/p2p/discover/table_test.go
@@ -4,7 +4,6 @@ import (
"crypto/ecdsa"
"errors"
"fmt"
- "math/big"
"math/rand"
"net"
"reflect"
@@ -15,107 +14,6 @@ import (
"github.com/ethereum/go-ethereum/crypto"
)
-var (
- quickrand = rand.New(rand.NewSource(time.Now().Unix()))
- quickcfg = &quick.Config{MaxCount: 5000, Rand: quickrand}
-)
-
-func TestHexID(t *testing.T) {
- ref := NodeID{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 128, 106, 217, 182, 31, 165, 174, 1, 67, 7, 235, 220, 150, 66, 83, 173, 205, 159, 44, 10, 57, 42, 161, 26, 188}
- id1 := MustHexID("0x000000000000000000000000000000000000000000000000000000000000000000000000000000806ad9b61fa5ae014307ebdc964253adcd9f2c0a392aa11abc")
- id2 := MustHexID("000000000000000000000000000000000000000000000000000000000000000000000000000000806ad9b61fa5ae014307ebdc964253adcd9f2c0a392aa11abc")
-
- if id1 != ref {
- t.Errorf("wrong id1\ngot %v\nwant %v", id1[:], ref[:])
- }
- if id2 != ref {
- t.Errorf("wrong id2\ngot %v\nwant %v", id2[:], ref[:])
- }
-}
-
-func TestNodeID_recover(t *testing.T) {
- prv := newkey()
- hash := make([]byte, 32)
- sig, err := crypto.Sign(hash, prv)
- if err != nil {
- t.Fatalf("signing error: %v", err)
- }
-
- pub := PubkeyID(&prv.PublicKey)
- recpub, err := recoverNodeID(hash, sig)
- if err != nil {
- t.Fatalf("recovery error: %v", err)
- }
- if pub != recpub {
- t.Errorf("recovered wrong pubkey:\ngot: %v\nwant: %v", recpub, pub)
- }
-}
-
-func TestNodeID_distcmp(t *testing.T) {
- distcmpBig := func(target, a, b NodeID) int {
- tbig := new(big.Int).SetBytes(target[:])
- abig := new(big.Int).SetBytes(a[:])
- bbig := new(big.Int).SetBytes(b[:])
- return new(big.Int).Xor(tbig, abig).Cmp(new(big.Int).Xor(tbig, bbig))
- }
- if err := quick.CheckEqual(distcmp, distcmpBig, quickcfg); err != nil {
- t.Error(err)
- }
-}
-
-// the random tests is likely to miss the case where they're equal.
-func TestNodeID_distcmpEqual(t *testing.T) {
- base := NodeID{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}
- x := NodeID{15, 14, 13, 12, 11, 10, 9, 8, 7, 6, 5, 4, 3, 2, 1, 0}
- if distcmp(base, x, x) != 0 {
- t.Errorf("distcmp(base, x, x) != 0")
- }
-}
-
-func TestNodeID_logdist(t *testing.T) {
- logdistBig := func(a, b NodeID) int {
- abig, bbig := new(big.Int).SetBytes(a[:]), new(big.Int).SetBytes(b[:])
- return new(big.Int).Xor(abig, bbig).BitLen()
- }
- if err := quick.CheckEqual(logdist, logdistBig, quickcfg); err != nil {
- t.Error(err)
- }
-}
-
-// the random tests is likely to miss the case where they're equal.
-func TestNodeID_logdistEqual(t *testing.T) {
- x := NodeID{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}
- if logdist(x, x) != 0 {
- t.Errorf("logdist(x, x) != 0")
- }
-}
-
-func TestNodeID_randomID(t *testing.T) {
- // we don't use quick.Check here because its output isn't
- // very helpful when the test fails.
- for i := 0; i < quickcfg.MaxCount; i++ {
- a := gen(NodeID{}, quickrand).(NodeID)
- dist := quickrand.Intn(len(NodeID{}) * 8)
- result := randomID(a, dist)
- actualdist := logdist(result, a)
-
- if dist != actualdist {
- t.Log("a: ", a)
- t.Log("result:", result)
- t.Fatalf("#%d: distance of result is %d, want %d", i, actualdist, dist)
- }
- }
-}
-
-func (NodeID) Generate(rand *rand.Rand, size int) reflect.Value {
- var id NodeID
- m := rand.Intn(len(id))
- for i := len(id) - 1; i > m; i-- {
- id[i] = byte(rand.Uint32())
- }
- return reflect.ValueOf(id)
-}
-
func TestTable_bumpOrAddPingReplace(t *testing.T) {
pingC := make(pingC)
tab := newTable(pingC, NodeID{}, &net.UDPAddr{})
@@ -123,7 +21,7 @@ func TestTable_bumpOrAddPingReplace(t *testing.T) {
// this bumpOrAdd should not replace the last node
// because the node replies to ping.
- new := tab.bumpOrAdd(randomID(tab.self.ID, 200), nil)
+ new := tab.bumpOrAdd(randomID(tab.self.ID, 200), &net.UDPAddr{})
pinged := <-pingC
if pinged != last.ID {
@@ -149,7 +47,7 @@ func TestTable_bumpOrAddPingTimeout(t *testing.T) {
// this bumpOrAdd should replace the last node
// because the node does not reply to ping.
- new := tab.bumpOrAdd(randomID(tab.self.ID, 200), nil)
+ new := tab.bumpOrAdd(randomID(tab.self.ID, 200), &net.UDPAddr{})
// wait for async bucket update. damn. this needs to go away.
time.Sleep(2 * time.Millisecond)
@@ -329,19 +227,17 @@ type findnodeOracle struct {
}
func (t findnodeOracle) findnode(n *Node, target NodeID) ([]*Node, error) {
- t.t.Logf("findnode query at dist %d", n.Addr.Port)
+ t.t.Logf("findnode query at dist %d", n.DiscPort)
// current log distance is encoded in port number
var result []*Node
- switch port := n.Addr.Port; port {
+ switch n.DiscPort {
case 0:
panic("query to node at distance 0")
- case 1:
- result = append(result, &Node{ID: t.target, Addr: &net.UDPAddr{Port: 0}})
default:
// TODO: add more randomness to distances
- port--
+ next := n.DiscPort - 1
for i := 0; i < bucketSize; i++ {
- result = append(result, &Node{ID: randomID(t.target, port), Addr: &net.UDPAddr{Port: port}})
+ result = append(result, &Node{ID: randomID(t.target, next), DiscPort: next})
}
}
return result, nil
diff --git a/p2p/discover/udp.go b/p2p/discover/udp.go
index 449139c1d..67e349aa4 100644
--- a/p2p/discover/udp.go
+++ b/p2p/discover/udp.go
@@ -69,6 +69,12 @@ type (
}
)
+type rpcNode struct {
+ IP string
+ Port uint16
+ ID NodeID
+}
+
// udp implements the RPC protocol.
type udp struct {
conn *net.UDPConn
@@ -121,7 +127,7 @@ func ListenUDP(priv *ecdsa.PrivateKey, laddr string) (*Table, error) {
return nil, err
}
net.Table = newTable(net, PubkeyID(&priv.PublicKey), realaddr)
- log.Debugf("Listening on %v, my ID %x\n", realaddr, net.self.ID[:])
+ log.Debugf("Listening, %v\n", net.self)
return net.Table, nil
}
@@ -159,8 +165,8 @@ func (t *udp) ping(e *Node) error {
// TODO: maybe check for ReplyTo field in callback to measure RTT
errc := t.pending(e.ID, pongPacket, func(interface{}) bool { return true })
t.send(e, pingPacket, ping{
- IP: t.self.Addr.String(),
- Port: uint16(t.self.Addr.Port),
+ IP: t.self.IP.String(),
+ Port: uint16(t.self.TCPPort),
Expiration: uint64(time.Now().Add(expiration).Unix()),
})
return <-errc
@@ -176,7 +182,7 @@ func (t *udp) findnode(to *Node, target NodeID) ([]*Node, error) {
for i := 0; i < len(reply.Nodes); i++ {
nreceived++
n := reply.Nodes[i]
- if validAddr(n.Addr) && n.ID != t.self.ID {
+ if n.ID != t.self.ID && n.isValid() {
nodes = append(nodes, n)
}
}
@@ -191,10 +197,6 @@ func (t *udp) findnode(to *Node, target NodeID) ([]*Node, error) {
return nodes, err
}
-func validAddr(a *net.UDPAddr) bool {
- return !a.IP.IsMulticast() && !a.IP.IsUnspecified() && a.Port != 0
-}
-
// pending adds a reply callback to the pending reply queue.
// see the documentation of type pending for a detailed explanation.
func (t *udp) pending(id NodeID, ptype byte, callback func(interface{}) bool) <-chan error {
@@ -302,8 +304,9 @@ func (t *udp) send(to *Node, ptype byte, req interface{}) error {
// the future.
copy(packet, crypto.Sha3(packet[macSize:]))
- log.DebugDetailf(">>> %v %T %v\n", to.Addr, req, req)
- if _, err = t.conn.WriteToUDP(packet, to.Addr); err != nil {
+ toaddr := &net.UDPAddr{IP: to.IP, Port: to.DiscPort}
+ log.DebugDetailf(">>> %v %T %v\n", toaddr, req, req)
+ if _, err = t.conn.WriteToUDP(packet, toaddr); err != nil {
log.DebugDetailln("UDP send failed:", err)
}
return err
@@ -365,11 +368,14 @@ func (req *ping) handle(t *udp, from *net.UDPAddr, fromID NodeID, mac []byte) er
return errExpired
}
t.mutex.Lock()
- // Note: we're ignoring the provided IP/Port right now.
- e := t.bumpOrAdd(fromID, from)
+ // Note: we're ignoring the provided IP address right now
+ n := t.bumpOrAdd(fromID, from)
+ if req.Port != 0 {
+ n.TCPPort = int(req.Port)
+ }
t.mutex.Unlock()
- t.send(e, pongPacket, pong{
+ t.send(n, pongPacket, pong{
ReplyTok: mac,
Expiration: uint64(time.Now().Add(expiration).Unix()),
})
diff --git a/p2p/discover/udp_test.go b/p2p/discover/udp_test.go
index 8ed6ec9c0..94680d6fc 100644
--- a/p2p/discover/udp_test.go
+++ b/p2p/discover/udp_test.go
@@ -4,6 +4,7 @@ import (
logpkg "log"
"net"
"os"
+ "reflect"
"testing"
"time"
@@ -11,7 +12,7 @@ import (
)
func init() {
- logger.AddLogSystem(logger.NewStdLogSystem(os.Stdout, logpkg.LstdFlags, logger.DebugLevel))
+ logger.AddLogSystem(logger.NewStdLogSystem(os.Stdout, logpkg.LstdFlags, logger.ErrorLevel))
}
func TestUDP_ping(t *testing.T) {
@@ -52,7 +53,7 @@ func TestUDP_findnode(t *testing.T) {
defer n1.Close()
defer n2.Close()
- entry := &Node{ID: NodeID{1}, Addr: &net.UDPAddr{IP: net.IP{1, 2, 3, 4}, Port: 15}}
+ entry := MustParseNode("enode://9d8a19597e312ef32d76e6b4903bb43d7bcd892d17b769d30b404bd3a4c2dca6c86184b17d0fdeeafe3b01e0e912d990ddc853db3f325d5419f31446543c30be@127.0.0.1:54194")
n2.add([]*Node{entry})
target := randomID(n1.self.ID, 100)
@@ -60,7 +61,7 @@ func TestUDP_findnode(t *testing.T) {
if len(result) != 1 {
t.Fatalf("wrong number of results: got %d, want 1", len(result))
}
- if result[0].ID != entry.ID {
+ if !reflect.DeepEqual(result[0], entry) {
t.Errorf("wrong result: got %v, want %v", result[0], entry)
}
}
@@ -103,8 +104,10 @@ func TestUDP_findnodeMultiReply(t *testing.T) {
nodes := make([]*Node, bucketSize)
for i := range nodes {
nodes[i] = &Node{
- Addr: &net.UDPAddr{IP: net.IP{1, 2, 3, 4}, Port: i + 1},
- ID: randomID(n2.self.ID, i+1),
+ IP: net.IP{1, 2, 3, 4},
+ DiscPort: i + 1,
+ TCPPort: i + 1,
+ ID: randomID(n2.self.ID, i+1),
}
}