aboutsummaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorFelix Lange <fjl@twurst.com>2014-11-04 20:21:44 +0800
committerFelix Lange <fjl@twurst.com>2014-11-22 04:52:45 +0800
commitf38052c499c1fee61423efeddb1f52677f1442e9 (patch)
tree6cc4c4e9739d61edeba9dc62781b2ebdeb0faf11
parent8cf9ed0ea588e97f2baf0f834248727e8fbca18f (diff)
downloaddexon-f38052c499c1fee61423efeddb1f52677f1442e9.tar.gz
dexon-f38052c499c1fee61423efeddb1f52677f1442e9.tar.zst
dexon-f38052c499c1fee61423efeddb1f52677f1442e9.zip
p2p: rework protocol API
-rw-r--r--p2p/connection.go275
-rw-r--r--p2p/connection_test.go222
-rw-r--r--p2p/message.go201
-rw-r--r--p2p/message_test.go75
-rw-r--r--p2p/messenger.go353
-rw-r--r--p2p/messenger_test.go224
-rw-r--r--p2p/peer.go29
-rw-r--r--p2p/peer_error.go10
-rw-r--r--p2p/peer_error_handler.go31
-rw-r--r--p2p/peer_error_handler_test.go2
-rw-r--r--p2p/peer_test.go170
-rw-r--r--p2p/protocol.go353
-rw-r--r--p2p/server.go150
-rw-r--r--p2p/server_test.go204
14 files changed, 1017 insertions, 1282 deletions
diff --git a/p2p/connection.go b/p2p/connection.go
deleted file mode 100644
index be366235d..000000000
--- a/p2p/connection.go
+++ /dev/null
@@ -1,275 +0,0 @@
-package p2p
-
-import (
- "bytes"
- // "fmt"
- "net"
- "time"
-
- "github.com/ethereum/go-ethereum/ethutil"
-)
-
-type Connection struct {
- conn net.Conn
- // conn NetworkConnection
- timeout time.Duration
- in chan []byte
- out chan []byte
- err chan *PeerError
- closingIn chan chan bool
- closingOut chan chan bool
-}
-
-// const readBufferLength = 2 //for testing
-
-const readBufferLength = 1440
-const partialsQueueSize = 10
-const maxPendingQueueSize = 1
-const defaultTimeout = 500
-
-var magicToken = []byte{34, 64, 8, 145}
-
-func (self *Connection) Open() {
- go self.startRead()
- go self.startWrite()
-}
-
-func (self *Connection) Close() {
- self.closeIn()
- self.closeOut()
-}
-
-func (self *Connection) closeIn() {
- errc := make(chan bool)
- self.closingIn <- errc
- <-errc
-}
-
-func (self *Connection) closeOut() {
- errc := make(chan bool)
- self.closingOut <- errc
- <-errc
-}
-
-func NewConnection(conn net.Conn, errchan chan *PeerError) *Connection {
- return &Connection{
- conn: conn,
- timeout: defaultTimeout,
- in: make(chan []byte),
- out: make(chan []byte),
- err: errchan,
- closingIn: make(chan chan bool, 1),
- closingOut: make(chan chan bool, 1),
- }
-}
-
-func (self *Connection) Read() <-chan []byte {
- return self.in
-}
-
-func (self *Connection) Write() chan<- []byte {
- return self.out
-}
-
-func (self *Connection) Error() <-chan *PeerError {
- return self.err
-}
-
-func (self *Connection) startRead() {
- payloads := make(chan []byte)
- done := make(chan *PeerError)
- pending := [][]byte{}
- var head []byte
- var wait time.Duration // initally 0 (no delay)
- read := time.After(wait * time.Millisecond)
-
- for {
- // if pending empty, nil channel blocks
- var in chan []byte
- if len(pending) > 0 {
- in = self.in // enable send case
- head = pending[0]
- } else {
- in = nil
- }
-
- select {
- case <-read:
- go self.read(payloads, done)
- case err := <-done:
- if err == nil { // no error but nothing to read
- if len(pending) < maxPendingQueueSize {
- wait = 100
- } else if wait == 0 {
- wait = 100
- } else {
- wait = 2 * wait
- }
- } else {
- self.err <- err // report error
- wait = 100
- }
- read = time.After(wait * time.Millisecond)
- case payload := <-payloads:
- pending = append(pending, payload)
- if len(pending) < maxPendingQueueSize {
- wait = 0
- } else {
- wait = 100
- }
- read = time.After(wait * time.Millisecond)
- case in <- head:
- pending = pending[1:]
- case errc := <-self.closingIn:
- errc <- true
- close(self.in)
- return
- }
-
- }
-}
-
-func (self *Connection) startWrite() {
- pending := [][]byte{}
- done := make(chan *PeerError)
- writing := false
- for {
- if len(pending) > 0 && !writing {
- writing = true
- go self.write(pending[0], done)
- }
- select {
- case payload := <-self.out:
- pending = append(pending, payload)
- case err := <-done:
- if err == nil {
- pending = pending[1:]
- writing = false
- } else {
- self.err <- err // report error
- }
- case errc := <-self.closingOut:
- errc <- true
- close(self.out)
- return
- }
- }
-}
-
-func pack(payload []byte) (packet []byte) {
- length := ethutil.NumberToBytes(uint32(len(payload)), 32)
- // return error if too long?
- // Write magic token and payload length (first 8 bytes)
- packet = append(magicToken, length...)
- packet = append(packet, payload...)
- return
-}
-
-func avoidPanic(done chan *PeerError) {
- if rec := recover(); rec != nil {
- err := NewPeerError(MiscError, " %v", rec)
- logger.Debugln(err)
- done <- err
- }
-}
-
-func (self *Connection) write(payload []byte, done chan *PeerError) {
- defer avoidPanic(done)
- var err *PeerError
- _, ok := self.conn.Write(pack(payload))
- if ok != nil {
- err = NewPeerError(WriteError, " %v", ok)
- logger.Debugln(err)
- }
- done <- err
-}
-
-func (self *Connection) read(payloads chan []byte, done chan *PeerError) {
- //defer avoidPanic(done)
-
- partials := make(chan []byte, partialsQueueSize)
- errc := make(chan *PeerError)
- go self.readPartials(partials, errc)
-
- packet := []byte{}
- length := 8
- start := true
- var err *PeerError
-out:
- for {
- // appends partials read via connection until packet is
- // - either parseable (>=8bytes)
- // - or complete (payload fully consumed)
- for len(packet) < length {
- partial, ok := <-partials
- if !ok { // partials channel is closed
- err = <-errc
- if err == nil && len(packet) > 0 {
- if start {
- err = NewPeerError(PacketTooShort, "%v", packet)
- } else {
- err = NewPeerError(PayloadTooShort, "%d < %d", len(packet), length)
- }
- }
- break out
- }
- packet = append(packet, partial...)
- }
- if start {
- // at least 8 bytes read, can validate packet
- if bytes.Compare(magicToken, packet[:4]) != 0 {
- err = NewPeerError(MagicTokenMismatch, " received %v", packet[:4])
- break
- }
- length = int(ethutil.BytesToNumber(packet[4:8]))
- packet = packet[8:]
-
- if length > 0 {
- start = false // now consuming payload
- } else { //penalize peer but read on
- self.err <- NewPeerError(EmptyPayload, "")
- length = 8
- }
- } else {
- // packet complete (payload fully consumed)
- payloads <- packet[:length]
- packet = packet[length:] // resclice packet
- start = true
- length = 8
- }
- }
-
- // this stops partials read via the connection, should we?
- //if err != nil {
- // select {
- // case errc <- err
- // default:
- //}
- done <- err
-}
-
-func (self *Connection) readPartials(partials chan []byte, errc chan *PeerError) {
- defer close(partials)
- for {
- // Give buffering some time
- self.conn.SetReadDeadline(time.Now().Add(self.timeout * time.Millisecond))
- buffer := make([]byte, readBufferLength)
- // read partial from connection
- bytesRead, err := self.conn.Read(buffer)
- if err == nil || err.Error() == "EOF" {
- if bytesRead > 0 {
- partials <- buffer[:bytesRead]
- }
- if err != nil && err.Error() == "EOF" {
- break
- }
- } else {
- // unexpected error, report to errc
- err := NewPeerError(ReadError, " %v", err)
- logger.Debugln(err)
- errc <- err
- return // will close partials channel
- }
- }
- close(errc)
-}
diff --git a/p2p/connection_test.go b/p2p/connection_test.go
deleted file mode 100644
index 76ee8021c..000000000
--- a/p2p/connection_test.go
+++ /dev/null
@@ -1,222 +0,0 @@
-package p2p
-
-import (
- "bytes"
- "fmt"
- "io"
- "net"
- "testing"
- "time"
-)
-
-type TestNetworkConnection struct {
- in chan []byte
- current []byte
- Out [][]byte
- addr net.Addr
-}
-
-func NewTestNetworkConnection(addr net.Addr) *TestNetworkConnection {
- return &TestNetworkConnection{
- in: make(chan []byte),
- current: []byte{},
- Out: [][]byte{},
- addr: addr,
- }
-}
-
-func (self *TestNetworkConnection) In(latency time.Duration, packets ...[]byte) {
- time.Sleep(latency)
- for _, s := range packets {
- self.in <- s
- }
-}
-
-func (self *TestNetworkConnection) Read(buff []byte) (n int, err error) {
- if len(self.current) == 0 {
- select {
- case self.current = <-self.in:
- default:
- return 0, io.EOF
- }
- }
- length := len(self.current)
- if length > len(buff) {
- copy(buff[:], self.current[:len(buff)])
- self.current = self.current[len(buff):]
- return len(buff), nil
- } else {
- copy(buff[:length], self.current[:])
- self.current = []byte{}
- return length, io.EOF
- }
-}
-
-func (self *TestNetworkConnection) Write(buff []byte) (n int, err error) {
- self.Out = append(self.Out, buff)
- fmt.Printf("net write %v\n%v\n", len(self.Out), buff)
- return len(buff), nil
-}
-
-func (self *TestNetworkConnection) Close() (err error) {
- return
-}
-
-func (self *TestNetworkConnection) LocalAddr() (addr net.Addr) {
- return
-}
-
-func (self *TestNetworkConnection) RemoteAddr() (addr net.Addr) {
- return self.addr
-}
-
-func (self *TestNetworkConnection) SetDeadline(t time.Time) (err error) {
- return
-}
-
-func (self *TestNetworkConnection) SetReadDeadline(t time.Time) (err error) {
- return
-}
-
-func (self *TestNetworkConnection) SetWriteDeadline(t time.Time) (err error) {
- return
-}
-
-func setupConnection() (*Connection, *TestNetworkConnection) {
- addr := &TestAddr{"test:30303"}
- net := NewTestNetworkConnection(addr)
- conn := NewConnection(net, NewPeerErrorChannel())
- conn.Open()
- return conn, net
-}
-
-func TestReadingNilPacket(t *testing.T) {
- conn, net := setupConnection()
- go net.In(0, []byte{})
- // time.Sleep(10 * time.Millisecond)
- select {
- case packet := <-conn.Read():
- t.Errorf("read %v", packet)
- case err := <-conn.Error():
- t.Errorf("incorrect error %v", err)
- default:
- }
- conn.Close()
-}
-
-func TestReadingShortPacket(t *testing.T) {
- conn, net := setupConnection()
- go net.In(0, []byte{0})
- select {
- case packet := <-conn.Read():
- t.Errorf("read %v", packet)
- case err := <-conn.Error():
- if err.Code != PacketTooShort {
- t.Errorf("incorrect error %v, expected %v", err.Code, PacketTooShort)
- }
- }
- conn.Close()
-}
-
-func TestReadingInvalidPacket(t *testing.T) {
- conn, net := setupConnection()
- go net.In(0, []byte{1, 0, 0, 0, 0, 0, 0, 0})
- select {
- case packet := <-conn.Read():
- t.Errorf("read %v", packet)
- case err := <-conn.Error():
- if err.Code != MagicTokenMismatch {
- t.Errorf("incorrect error %v, expected %v", err.Code, MagicTokenMismatch)
- }
- }
- conn.Close()
-}
-
-func TestReadingInvalidPayload(t *testing.T) {
- conn, net := setupConnection()
- go net.In(0, []byte{34, 64, 8, 145, 0, 0, 0, 2, 0})
- select {
- case packet := <-conn.Read():
- t.Errorf("read %v", packet)
- case err := <-conn.Error():
- if err.Code != PayloadTooShort {
- t.Errorf("incorrect error %v, expected %v", err.Code, PayloadTooShort)
- }
- }
- conn.Close()
-}
-
-func TestReadingEmptyPayload(t *testing.T) {
- conn, net := setupConnection()
- go net.In(0, []byte{34, 64, 8, 145, 0, 0, 0, 0})
- time.Sleep(10 * time.Millisecond)
- select {
- case packet := <-conn.Read():
- t.Errorf("read %v", packet)
- default:
- }
- select {
- case err := <-conn.Error():
- code := err.Code
- if code != EmptyPayload {
- t.Errorf("incorrect error, expected EmptyPayload, got %v", code)
- }
- default:
- t.Errorf("no error, expected EmptyPayload")
- }
- conn.Close()
-}
-
-func TestReadingCompletePacket(t *testing.T) {
- conn, net := setupConnection()
- go net.In(0, []byte{34, 64, 8, 145, 0, 0, 0, 1, 1})
- time.Sleep(10 * time.Millisecond)
- select {
- case packet := <-conn.Read():
- if bytes.Compare(packet, []byte{1}) != 0 {
- t.Errorf("incorrect payload read")
- }
- case err := <-conn.Error():
- t.Errorf("incorrect error %v", err)
- default:
- t.Errorf("nothing read")
- }
- conn.Close()
-}
-
-func TestReadingTwoCompletePackets(t *testing.T) {
- conn, net := setupConnection()
- go net.In(0, []byte{34, 64, 8, 145, 0, 0, 0, 1, 0, 34, 64, 8, 145, 0, 0, 0, 1, 1})
-
- for i := 0; i < 2; i++ {
- time.Sleep(10 * time.Millisecond)
- select {
- case packet := <-conn.Read():
- if bytes.Compare(packet, []byte{byte(i)}) != 0 {
- t.Errorf("incorrect payload read")
- }
- case err := <-conn.Error():
- t.Errorf("incorrect error %v", err)
- default:
- t.Errorf("nothing read")
- }
- }
- conn.Close()
-}
-
-func TestWriting(t *testing.T) {
- conn, net := setupConnection()
- conn.Write() <- []byte{0}
- time.Sleep(10 * time.Millisecond)
- if len(net.Out) == 0 {
- t.Errorf("no output")
- } else {
- out := net.Out[0]
- if bytes.Compare(out, []byte{34, 64, 8, 145, 0, 0, 0, 1, 0}) != 0 {
- t.Errorf("incorrect packet %v", out)
- }
- }
- conn.Close()
-}
-
-// hello packet with client id ABC: 0x22 40 08 91 00 00 00 08 84 00 00 00 43414243
diff --git a/p2p/message.go b/p2p/message.go
index 446e74dff..366cff5d7 100644
--- a/p2p/message.go
+++ b/p2p/message.go
@@ -1,75 +1,174 @@
package p2p
import (
- // "fmt"
+ "bytes"
+ "encoding/binary"
+ "fmt"
+ "io"
+ "io/ioutil"
+ "math/big"
+
"github.com/ethereum/go-ethereum/ethutil"
)
-type MsgCode uint8
+type MsgCode uint64
+// Msg defines the structure of a p2p message.
+//
+// Note that a Msg can only be sent once since the Payload reader is
+// consumed during sending. It is not possible to create a Msg and
+// send it any number of times. If you want to reuse an encoded
+// structure, encode the payload into a byte array and create a
+// separate Msg with a bytes.Reader as Payload for each send.
type Msg struct {
- code MsgCode // this is the raw code as per adaptive msg code scheme
- data *ethutil.Value
- encoded []byte
+ Code MsgCode
+ Size uint32 // size of the paylod
+ Payload io.Reader
}
-func (self *Msg) Code() MsgCode {
- return self.code
+// NewMsg creates an RLP-encoded message with the given code.
+func NewMsg(code MsgCode, params ...interface{}) Msg {
+ buf := new(bytes.Buffer)
+ for _, p := range params {
+ buf.Write(ethutil.Encode(p))
+ }
+ return Msg{Code: code, Size: uint32(buf.Len()), Payload: buf}
}
-func (self *Msg) Data() *ethutil.Value {
- return self.data
+func encodePayload(params ...interface{}) []byte {
+ buf := new(bytes.Buffer)
+ for _, p := range params {
+ buf.Write(ethutil.Encode(p))
+ }
+ return buf.Bytes()
}
-func NewMsg(code MsgCode, params ...interface{}) (msg *Msg, err error) {
-
- // // data := [][]interface{}{}
- // data := []interface{}{}
- // for _, value := range params {
- // if encodable, ok := value.(ethutil.RlpEncodeDecode); ok {
- // data = append(data, encodable.RlpValue())
- // } else if raw, ok := value.([]interface{}); ok {
- // data = append(data, raw)
- // } else {
- // // data = append(data, interface{}(raw))
- // err = fmt.Errorf("Unable to encode object of type %T", value)
- // return
- // }
- // }
- return &Msg{
- code: code,
- data: ethutil.NewValue(interface{}(params)),
- }, nil
+// Data returns the decoded RLP payload items in a message.
+func (msg Msg) Data() (*ethutil.Value, error) {
+ // TODO: avoid copying when we have a better RLP decoder
+ buf := new(bytes.Buffer)
+ var s []interface{}
+ if _, err := buf.ReadFrom(msg.Payload); err != nil {
+ return nil, err
+ }
+ for buf.Len() > 0 {
+ s = append(s, ethutil.DecodeWithReader(buf))
+ }
+ return ethutil.NewValue(s), nil
+}
+
+// Discard reads any remaining payload data into a black hole.
+func (msg Msg) Discard() error {
+ _, err := io.Copy(ioutil.Discard, msg.Payload)
+ return err
+}
+
+var magicToken = []byte{34, 64, 8, 145}
+
+func writeMsg(w io.Writer, msg Msg) error {
+ // TODO: handle case when Size + len(code) + len(listhdr) overflows uint32
+ code := ethutil.Encode(uint32(msg.Code))
+ listhdr := makeListHeader(msg.Size + uint32(len(code)))
+ payloadLen := uint32(len(listhdr)) + uint32(len(code)) + msg.Size
+
+ start := make([]byte, 8)
+ copy(start, magicToken)
+ binary.BigEndian.PutUint32(start[4:], payloadLen)
+
+ for _, b := range [][]byte{start, listhdr, code} {
+ if _, err := w.Write(b); err != nil {
+ return err
+ }
+ }
+ _, err := io.CopyN(w, msg.Payload, int64(msg.Size))
+ return err
}
-func NewMsgFromBytes(encoded []byte) (msg *Msg, err error) {
- value := ethutil.NewValueFromBytes(encoded)
- // Type of message
- code := value.Get(0).Uint()
- // Actual data
- data := value.SliceFrom(1)
-
- msg = &Msg{
- code: MsgCode(code),
- data: data,
- // data: ethutil.NewValue(data),
- encoded: encoded,
+func makeListHeader(length uint32) []byte {
+ if length < 56 {
+ return []byte{byte(length + 0xc0)}
}
- return
+ enc := big.NewInt(int64(length)).Bytes()
+ lenb := byte(len(enc)) + 0xf7
+ return append([]byte{lenb}, enc...)
}
-func (self *Msg) Decode(offset MsgCode) {
- self.code = self.code - offset
+type byteReader interface {
+ io.Reader
+ io.ByteReader
}
-// encode takes an offset argument to implement adaptive message coding
-// the encoded message is memoized to make msgs relayed to several peers more efficient
-func (self *Msg) Encode(offset MsgCode) (res []byte) {
- if len(self.encoded) == 0 {
- res = ethutil.NewValue(append([]interface{}{byte(self.code + offset)}, self.data.Slice()...)).Encode()
- self.encoded = res
+// readMsg reads a message header.
+func readMsg(r byteReader) (msg Msg, err error) {
+ // read magic and payload size
+ start := make([]byte, 8)
+ if _, err = io.ReadFull(r, start); err != nil {
+ return msg, NewPeerError(ReadError, "%v", err)
+ }
+ if !bytes.HasPrefix(start, magicToken) {
+ return msg, NewPeerError(MagicTokenMismatch, "got %x, want %x", start[:4], magicToken)
+ }
+ size := binary.BigEndian.Uint32(start[4:])
+
+ // decode start of RLP message to get the message code
+ _, hdrlen, err := readListHeader(r)
+ if err != nil {
+ return msg, err
+ }
+ code, codelen, err := readMsgCode(r)
+ if err != nil {
+ return msg, err
+ }
+
+ rlpsize := size - hdrlen - codelen
+ return Msg{
+ Code: code,
+ Size: rlpsize,
+ Payload: io.LimitReader(r, int64(rlpsize)),
+ }, nil
+}
+
+// readListHeader reads an RLP list header from r.
+func readListHeader(r byteReader) (len uint64, hdrlen uint32, err error) {
+ b, err := r.ReadByte()
+ if err != nil {
+ return 0, 0, err
+ }
+ if b < 0xC0 {
+ return 0, 0, fmt.Errorf("expected list start byte >= 0xC0, got %x", b)
+ } else if b < 0xF7 {
+ len = uint64(b - 0xc0)
+ hdrlen = 1
} else {
- res = self.encoded
+ lenlen := b - 0xF7
+ lenbuf := make([]byte, 8)
+ if _, err := io.ReadFull(r, lenbuf[8-lenlen:]); err != nil {
+ return 0, 0, err
+ }
+ len = binary.BigEndian.Uint64(lenbuf)
+ hdrlen = 1 + uint32(lenlen)
+ }
+ return len, hdrlen, nil
+}
+
+// readUint reads an RLP-encoded unsigned integer from r.
+func readMsgCode(r byteReader) (code MsgCode, codelen uint32, err error) {
+ b, err := r.ReadByte()
+ if err != nil {
+ return 0, 0, err
+ }
+ if b < 0x80 {
+ return MsgCode(b), 1, nil
+ } else if b < 0x89 { // max length for uint64 is 8 bytes
+ codelen = uint32(b - 0x80)
+ if codelen == 0 {
+ return 0, 1, nil
+ }
+ buf := make([]byte, 8)
+ if _, err := io.ReadFull(r, buf[8-codelen:]); err != nil {
+ return 0, 0, err
+ }
+ return MsgCode(binary.BigEndian.Uint64(buf)), codelen, nil
}
- return
+ return 0, 0, fmt.Errorf("bad RLP type for message code: %x", b)
}
diff --git a/p2p/message_test.go b/p2p/message_test.go
index e9d46f2c3..1edabc4e7 100644
--- a/p2p/message_test.go
+++ b/p2p/message_test.go
@@ -1,38 +1,67 @@
package p2p
import (
+ "bytes"
+ "io/ioutil"
"testing"
+
+ "github.com/ethereum/go-ethereum/ethutil"
)
func TestNewMsg(t *testing.T) {
- msg, _ := NewMsg(3, 1, "000")
- if msg.Code() != 3 {
- t.Errorf("incorrect code %v", msg.Code())
+ msg := NewMsg(3, 1, "000")
+ if msg.Code != 3 {
+ t.Errorf("incorrect code %d, want %d", msg.Code)
}
- data0 := msg.Data().Get(0).Uint()
- data1 := string(msg.Data().Get(1).Bytes())
- if data0 != 1 {
- t.Errorf("incorrect data %v", data0)
+ if msg.Size != 5 {
+ t.Errorf("incorrect size %d, want %d", msg.Size, 5)
}
- if data1 != "000" {
- t.Errorf("incorrect data %v", data1)
+ pl, _ := ioutil.ReadAll(msg.Payload)
+ expect := []byte{0x01, 0x83, 0x30, 0x30, 0x30}
+ if !bytes.Equal(pl, expect) {
+ t.Errorf("incorrect payload content, got %x, want %x", pl, expect)
}
}
func TestEncodeDecodeMsg(t *testing.T) {
- msg, _ := NewMsg(3, 1, "000")
- encoded := msg.Encode(3)
- msg, _ = NewMsgFromBytes(encoded)
- msg.Decode(3)
- if msg.Code() != 3 {
- t.Errorf("incorrect code %v", msg.Code())
- }
- data0 := msg.Data().Get(0).Uint()
- data1 := msg.Data().Get(1).Str()
- if data0 != 1 {
- t.Errorf("incorrect data %v", data0)
- }
- if data1 != "000" {
- t.Errorf("incorrect data %v", data1)
+ msg := NewMsg(3, 1, "000")
+ buf := new(bytes.Buffer)
+ if err := writeMsg(buf, msg); err != nil {
+ t.Fatalf("encodeMsg error: %v", err)
+ }
+
+ t.Logf("encoded: %x", buf.Bytes())
+
+ decmsg, err := readMsg(buf)
+ if err != nil {
+ t.Fatalf("readMsg error: %v", err)
+ }
+ if decmsg.Code != 3 {
+ t.Errorf("incorrect code %d, want %d", decmsg.Code, 3)
+ }
+ if decmsg.Size != 5 {
+ t.Errorf("incorrect size %d, want %d", decmsg.Size, 5)
+ }
+ data, err := decmsg.Data()
+ if err != nil {
+ t.Fatalf("first payload item decode error: %v", err)
+ }
+ if v := data.Get(0).Uint(); v != 1 {
+ t.Errorf("incorrect data[0]: got %v, expected %d", v, 1)
+ }
+ if v := data.Get(1).Str(); v != "000" {
+ t.Errorf("incorrect data[1]: got %q, expected %q", v, "000")
+ }
+}
+
+func TestDecodeRealMsg(t *testing.T) {
+ data := ethutil.Hex2Bytes("2240089100000080f87e8002b5457468657265756d282b2b292f5065657220536572766572204f6e652f76302e372e382f52656c656173652f4c696e75782f672b2bc082765fb84086dd80b7aefd6a6d2e3b93f4f300a86bfb6ef7bdc97cb03f793db6bb")
+ msg, err := readMsg(bytes.NewReader(data))
+ if err != nil {
+ t.Fatalf("unexpected error: %v", err)
+ }
+
+ if msg.Code != 0 {
+ t.Errorf("incorrect code %d, want %d", msg.Code, 0)
}
}
diff --git a/p2p/messenger.go b/p2p/messenger.go
index d42ba1720..7375ecc07 100644
--- a/p2p/messenger.go
+++ b/p2p/messenger.go
@@ -1,220 +1,221 @@
package p2p
import (
+ "bufio"
+ "bytes"
"fmt"
+ "io"
+ "io/ioutil"
+ "net"
"sync"
"time"
)
-const (
- handlerTimeout = 1000
-)
+type Handlers map[string]func() Protocol
-type Handlers map[string](func(p *Peer) Protocol)
-
-type Messenger struct {
- conn *Connection
- peer *Peer
- handlers Handlers
- protocolLock sync.RWMutex
- protocols []Protocol
- offsets []MsgCode // offsets for adaptive message idss
- protocolTable map[string]int
- quit chan chan bool
- err chan *PeerError
- pulse chan bool
-}
-
-func NewMessenger(peer *Peer, conn *Connection, errchan chan *PeerError, handlers Handlers) *Messenger {
- baseProtocol := NewBaseProtocol(peer)
- return &Messenger{
- conn: conn,
- peer: peer,
- offsets: []MsgCode{baseProtocol.Offset()},
- handlers: handlers,
- protocols: []Protocol{baseProtocol},
- protocolTable: make(map[string]int),
- err: errchan,
- pulse: make(chan bool, 1),
- quit: make(chan chan bool, 1),
- }
+type proto struct {
+ in chan Msg
+ maxcode, offset MsgCode
+ messenger *messenger
}
-func (self *Messenger) Start() {
- self.conn.Open()
- go self.messenger()
- self.protocolLock.RLock()
- defer self.protocolLock.RUnlock()
- self.protocols[0].Start()
+func (rw *proto) WriteMsg(msg Msg) error {
+ if msg.Code >= rw.maxcode {
+ return NewPeerError(InvalidMsgCode, "not handled")
+ }
+ return rw.messenger.writeMsg(msg)
}
-func (self *Messenger) Stop() {
- // close pulse to stop ping pong monitoring
- close(self.pulse)
- self.protocolLock.RLock()
- defer self.protocolLock.RUnlock()
- for _, protocol := range self.protocols {
- protocol.Stop() // could be parallel
+func (rw *proto) ReadMsg() (Msg, error) {
+ msg, ok := <-rw.in
+ if !ok {
+ return msg, io.EOF
}
- q := make(chan bool)
- self.quit <- q
- <-q
- self.conn.Close()
+ return msg, nil
}
-func (self *Messenger) messenger() {
- in := self.conn.Read()
- for {
- select {
- case payload, ok := <-in:
- //dispatches message to the protocol asynchronously
- if ok {
- go self.handle(payload)
- } else {
- return
- }
- case q := <-self.quit:
- q <- true
- return
- }
- }
+// eofSignal is used to 'lend' the network connection
+// to a protocol. when the protocol's read loop has read the
+// whole payload, the done channel is closed.
+type eofSignal struct {
+ wrapped io.Reader
+ eof chan struct{}
}
-// handles each message by dispatching to the appropriate protocol
-// using adaptive message codes
-// this function is started as a separate go routine for each message
-// it waits for the protocol response
-// then encodes and sends outgoing messages to the connection's write channel
-func (self *Messenger) handle(payload []byte) {
- // send ping to heartbeat channel signalling time of last message
- // select {
- // case self.pulse <- true:
- // default:
- // }
- self.pulse <- true
- // initialise message from payload
- msg, err := NewMsgFromBytes(payload)
+func (r *eofSignal) Read(buf []byte) (int, error) {
+ n, err := r.wrapped.Read(buf)
if err != nil {
- self.err <- NewPeerError(MiscError, " %v", err)
- return
+ close(r.eof) // tell messenger that msg has been consumed
}
- // retrieves protocol based on message Code
- protocol, offset, peerErr := self.getProtocol(msg.Code())
- if err != nil {
- self.err <- peerErr
- return
+ return n, err
+}
+
+// messenger represents a message-oriented peer connection.
+// It keeps track of the set of protocols understood
+// by the remote peer.
+type messenger struct {
+ peer *Peer
+ handlers Handlers
+
+ // the mutex protects the connection
+ // so only one protocol can write at a time.
+ writeMu sync.Mutex
+ conn net.Conn
+ bufconn *bufio.ReadWriter
+
+ protocolLock sync.RWMutex
+ protocols map[string]*proto
+ offsets map[MsgCode]*proto
+ protoWG sync.WaitGroup
+
+ err chan error
+ pulse chan bool
+}
+
+func newMessenger(peer *Peer, conn net.Conn, errchan chan error, handlers Handlers) *messenger {
+ return &messenger{
+ conn: conn,
+ bufconn: bufio.NewReadWriter(bufio.NewReader(conn), bufio.NewWriter(conn)),
+ peer: peer,
+ handlers: handlers,
+ protocols: make(map[string]*proto),
+ err: errchan,
+ pulse: make(chan bool, 1),
}
- // reset message code based on adaptive offset
- msg.Decode(offset)
- // dispatches
- response := make(chan *Msg)
- go protocol.HandleIn(msg, response)
- // protocol reponse timeout to prevent leaks
- timer := time.After(handlerTimeout * time.Millisecond)
+}
+
+func (m *messenger) Start() {
+ m.protocols[""] = m.startProto(0, "", &baseProtocol{})
+ go m.readLoop()
+}
+
+func (m *messenger) Stop() {
+ m.conn.Close()
+ m.protoWG.Wait()
+}
+
+const (
+ // maximum amount of time allowed for reading a message
+ msgReadTimeout = 5 * time.Second
+
+ // messages smaller than this many bytes will be read at
+ // once before passing them to a protocol.
+ wholePayloadSize = 64 * 1024
+)
+
+func (m *messenger) readLoop() {
+ defer m.closeProtocols()
for {
- select {
- case outgoing, ok := <-response:
- // we check if response channel is not closed
- if ok {
- self.conn.Write() <- outgoing.Encode(offset)
- } else {
+ m.conn.SetReadDeadline(time.Now().Add(msgReadTimeout))
+ msg, err := readMsg(m.bufconn)
+ if err != nil {
+ m.err <- err
+ return
+ }
+ // send ping to heartbeat channel signalling time of last message
+ m.pulse <- true
+ proto, err := m.getProto(msg.Code)
+ if err != nil {
+ m.err <- err
+ return
+ }
+ msg.Code -= proto.offset
+ if msg.Size <= wholePayloadSize {
+ // optimization: msg is small enough, read all
+ // of it and move on to the next message
+ buf, err := ioutil.ReadAll(msg.Payload)
+ if err != nil {
+ m.err <- err
return
}
- case <-timer:
- return
+ msg.Payload = bytes.NewReader(buf)
+ proto.in <- msg
+ } else {
+ pr := &eofSignal{msg.Payload, make(chan struct{})}
+ msg.Payload = pr
+ proto.in <- msg
+ <-pr.eof
}
}
}
-// negotiated protocols
-// stores offsets needed for adaptive message id scheme
-
-// based on offsets set at handshake
-// get the right protocol to handle the message
-func (self *Messenger) getProtocol(code MsgCode) (Protocol, MsgCode, *PeerError) {
- self.protocolLock.RLock()
- defer self.protocolLock.RUnlock()
- base := MsgCode(0)
- for index, offset := range self.offsets {
- if code < offset {
- return self.protocols[index], base, nil
- }
- base = offset
+func (m *messenger) closeProtocols() {
+ m.protocolLock.RLock()
+ for _, p := range m.protocols {
+ close(p.in)
}
- return nil, MsgCode(0), NewPeerError(InvalidMsgCode, " %v", code)
+ m.protocolLock.RUnlock()
}
-func (self *Messenger) PingPong(timeout time.Duration, gracePeriod time.Duration, pingCallback func(), timeoutCallback func()) {
- fmt.Printf("pingpong keepalive started at %v", time.Now())
+func (m *messenger) startProto(offset MsgCode, name string, impl Protocol) *proto {
+ proto := &proto{
+ in: make(chan Msg),
+ offset: offset,
+ maxcode: impl.Offset(),
+ messenger: m,
+ }
+ m.protoWG.Add(1)
+ go func() {
+ if err := impl.Start(m.peer, proto); err != nil && err != io.EOF {
+ logger.Errorf("protocol %q error: %v\n", name, err)
+ m.err <- err
+ }
+ m.protoWG.Done()
+ }()
+ return proto
+}
- timer := time.After(timeout)
- pinged := false
- for {
- select {
- case _, ok := <-self.pulse:
- if ok {
- pinged = false
- timer = time.After(timeout)
- } else {
- // pulse is closed, stop monitoring
- return
- }
- case <-timer:
- if pinged {
- fmt.Printf("timeout at %v", time.Now())
- timeoutCallback()
- return
- } else {
- fmt.Printf("pinged at %v", time.Now())
- pingCallback()
- timer = time.After(gracePeriod)
- pinged = true
- }
+// getProto finds the protocol responsible for handling
+// the given message code.
+func (m *messenger) getProto(code MsgCode) (*proto, error) {
+ m.protocolLock.RLock()
+ defer m.protocolLock.RUnlock()
+ for _, proto := range m.protocols {
+ if code >= proto.offset && code < proto.offset+proto.maxcode {
+ return proto, nil
}
}
+ return nil, NewPeerError(InvalidMsgCode, "%d", code)
}
-func (self *Messenger) AddProtocols(protocols []string) {
- self.protocolLock.Lock()
- defer self.protocolLock.Unlock()
- i := len(self.offsets)
- offset := self.offsets[i-1]
+// setProtocols starts all subprotocols shared with the
+// remote peer. the protocols must be sorted alphabetically.
+func (m *messenger) setRemoteProtocols(protocols []string) {
+ m.protocolLock.Lock()
+ defer m.protocolLock.Unlock()
+ offset := baseProtocolOffset
for _, name := range protocols {
- protocolFunc, ok := self.handlers[name]
- if ok {
- protocol := protocolFunc(self.peer)
- self.protocolTable[name] = i
- i++
- offset += protocol.Offset()
- fmt.Println("offset ", name, offset)
-
- self.offsets = append(self.offsets, offset)
- self.protocols = append(self.protocols, protocol)
- protocol.Start()
- } else {
- fmt.Println("no ", name)
- // protocol not handled
+ protocolFunc, ok := m.handlers[name]
+ if !ok {
+ continue // not handled
}
+ inst := protocolFunc()
+ m.protocols[name] = m.startProto(offset, name, inst)
+ offset += inst.Offset()
}
}
-func (self *Messenger) Write(protocol string, msg *Msg) error {
- self.protocolLock.RLock()
- defer self.protocolLock.RUnlock()
- i := 0
- offset := MsgCode(0)
- if len(protocol) > 0 {
- var ok bool
- i, ok = self.protocolTable[protocol]
- if !ok {
- return fmt.Errorf("protocol %v not handled by peer", protocol)
- }
- offset = self.offsets[i-1]
+// writeProtoMsg sends the given message on behalf of the given named protocol.
+func (m *messenger) writeProtoMsg(protoName string, msg Msg) error {
+ m.protocolLock.RLock()
+ proto, ok := m.protocols[protoName]
+ m.protocolLock.RUnlock()
+ if !ok {
+ return fmt.Errorf("protocol %s not handled by peer", protoName)
}
- handler := self.protocols[i]
- // checking if protocol status/caps allows the message to be sent out
- if handler.HandleOut(msg) {
- self.conn.Write() <- msg.Encode(offset)
+ if msg.Code >= proto.maxcode {
+ return NewPeerError(InvalidMsgCode, "code %x is out of range for protocol %q", msg.Code, protoName)
+ }
+ msg.Code += proto.offset
+ return m.writeMsg(msg)
+}
+
+// writeMsg writes a message to the connection.
+func (m *messenger) writeMsg(msg Msg) error {
+ m.writeMu.Lock()
+ defer m.writeMu.Unlock()
+ if err := writeMsg(m.bufconn, msg); err != nil {
+ return err
}
- return nil
+ return m.bufconn.Flush()
}
diff --git a/p2p/messenger_test.go b/p2p/messenger_test.go
index 472d74515..f10469e2f 100644
--- a/p2p/messenger_test.go
+++ b/p2p/messenger_test.go
@@ -1,147 +1,157 @@
package p2p
import (
- // "fmt"
- "bytes"
+ "bufio"
+ "fmt"
+ "io"
+ "log"
+ "net"
+ "os"
+ "reflect"
"testing"
"time"
"github.com/ethereum/go-ethereum/ethutil"
)
-func setupMessenger(handlers Handlers) (*TestNetworkConnection, chan *PeerError, *Messenger) {
- errchan := NewPeerErrorChannel()
- addr := &TestAddr{"test:30303"}
- net := NewTestNetworkConnection(addr)
- conn := NewConnection(net, errchan)
- mess := NewMessenger(nil, conn, errchan, handlers)
- mess.Start()
- return net, errchan, mess
+func init() {
+ ethlog.AddLogSystem(ethlog.NewStdLogSystem(os.Stdout, log.LstdFlags, ethlog.DebugLevel))
}
-type TestProtocol struct {
- Msgs []*Msg
+func setupMessenger(handlers Handlers) (net.Conn, *Peer, *messenger) {
+ conn1, conn2 := net.Pipe()
+ id := NewSimpleClientIdentity("test", "0", "0", "public key")
+ server := New(nil, conn1.LocalAddr(), id, handlers, 10, NewBlacklist())
+ peer := server.addPeer(conn1, conn1.RemoteAddr(), true, 0)
+ return conn2, peer, peer.messenger
}
-func (self *TestProtocol) Start() {
-}
-
-func (self *TestProtocol) Stop() {
-}
-
-func (self *TestProtocol) Offset() MsgCode {
- return MsgCode(5)
+func performTestHandshake(r *bufio.Reader, w io.Writer) error {
+ // read remote handshake
+ msg, err := readMsg(r)
+ if err != nil {
+ return fmt.Errorf("read error: %v", err)
+ }
+ if msg.Code != handshakeMsg {
+ return fmt.Errorf("first message should be handshake, got %x", msg.Code)
+ }
+ if err := msg.Discard(); err != nil {
+ return err
+ }
+ // send empty handshake
+ pubkey := make([]byte, 64)
+ msg = NewMsg(handshakeMsg, p2pVersion, "testid", nil, 9999, pubkey)
+ return writeMsg(w, msg)
}
-func (self *TestProtocol) HandleIn(msg *Msg, response chan *Msg) {
- self.Msgs = append(self.Msgs, msg)
- close(response)
+type testMsg struct {
+ code MsgCode
+ data *ethutil.Value
}
-func (self *TestProtocol) HandleOut(msg *Msg) bool {
- if msg.Code() > 3 {
- return false
- } else {
- return true
- }
+type testProto struct {
+ recv chan testMsg
}
-func (self *TestProtocol) Name() string {
- return "a"
-}
+func (*testProto) Offset() MsgCode { return 5 }
-func Packet(offset MsgCode, code MsgCode, params ...interface{}) []byte {
- msg, _ := NewMsg(code, params...)
- encoded := msg.Encode(offset)
- packet := []byte{34, 64, 8, 145}
- packet = append(packet, ethutil.NumberToBytes(uint32(len(encoded)), 32)...)
- return append(packet, encoded...)
+func (tp *testProto) Start(peer *Peer, rw MsgReadWriter) error {
+ return MsgLoop(rw, 1024, func(code MsgCode, data *ethutil.Value) error {
+ logger.Debugf("testprotocol got msg: %d\n", code)
+ tp.recv <- testMsg{code, data}
+ return nil
+ })
}
func TestRead(t *testing.T) {
- handlers := make(Handlers)
- testProtocol := &TestProtocol{Msgs: []*Msg{}}
- handlers["a"] = func(p *Peer) Protocol { return testProtocol }
- net, _, mess := setupMessenger(handlers)
- mess.AddProtocols([]string{"a"})
- defer mess.Stop()
- wait := 1 * time.Millisecond
- packet := Packet(16, 1, uint32(1), "000")
- go net.In(0, packet)
- time.Sleep(wait)
- if len(testProtocol.Msgs) != 1 {
- t.Errorf("msg not relayed to correct protocol")
- } else {
- if testProtocol.Msgs[0].Code() != 1 {
- t.Errorf("incorrect msg code relayed to protocol")
+ testProtocol := &testProto{make(chan testMsg)}
+ handlers := Handlers{"a": func() Protocol { return testProtocol }}
+ net, peer, mess := setupMessenger(handlers)
+ bufr := bufio.NewReader(net)
+ defer peer.Stop()
+ if err := performTestHandshake(bufr, net); err != nil {
+ t.Fatalf("handshake failed: %v", err)
+ }
+
+ mess.setRemoteProtocols([]string{"a"})
+ writeMsg(net, NewMsg(17, uint32(1), "000"))
+ select {
+ case msg := <-testProtocol.recv:
+ if msg.code != 1 {
+ t.Errorf("incorrect msg code %d relayed to protocol", msg.code)
+ }
+ expdata := []interface{}{1, []byte{0x30, 0x30, 0x30}}
+ if !reflect.DeepEqual(msg.data.Slice(), expdata) {
+ t.Errorf("incorrect msg data %#v", msg.data.Slice())
}
+ case <-time.After(2 * time.Second):
+ t.Errorf("receive timeout")
}
}
-func TestWrite(t *testing.T) {
+func TestWriteProtoMsg(t *testing.T) {
handlers := make(Handlers)
- testProtocol := &TestProtocol{Msgs: []*Msg{}}
- handlers["a"] = func(p *Peer) Protocol { return testProtocol }
- net, _, mess := setupMessenger(handlers)
- mess.AddProtocols([]string{"a"})
- defer mess.Stop()
- wait := 1 * time.Millisecond
- msg, _ := NewMsg(3, uint32(1), "000")
- err := mess.Write("b", msg)
- if err == nil {
- t.Errorf("expect error for unknown protocol")
+ testProtocol := &testProto{recv: make(chan testMsg, 1)}
+ handlers["a"] = func() Protocol { return testProtocol }
+ net, peer, mess := setupMessenger(handlers)
+ defer peer.Stop()
+ bufr := bufio.NewReader(net)
+ if err := performTestHandshake(bufr, net); err != nil {
+ t.Fatalf("handshake failed: %v", err)
}
- err = mess.Write("a", msg)
- if err != nil {
- t.Errorf("expect no error for known protocol: %v", err)
- } else {
- time.Sleep(wait)
- if len(net.Out) != 1 {
- t.Errorf("msg not written")
+ mess.setRemoteProtocols([]string{"a"})
+
+ // test write errors
+ if err := mess.writeProtoMsg("b", NewMsg(3)); err == nil {
+ t.Errorf("expected error for unknown protocol, got nil")
+ }
+ if err := mess.writeProtoMsg("a", NewMsg(8)); err == nil {
+ t.Errorf("expected error for out-of-range msg code, got nil")
+ } else if perr, ok := err.(*PeerError); !ok || perr.Code != InvalidMsgCode {
+ t.Errorf("wrong error for out-of-range msg code, got %#v")
+ }
+
+ // test succcessful write
+ read, readerr := make(chan Msg), make(chan error)
+ go func() {
+ if msg, err := readMsg(bufr); err != nil {
+ readerr <- err
} else {
- out := net.Out[0]
- packet := Packet(16, 3, uint32(1), "000")
- if bytes.Compare(out, packet) != 0 {
- t.Errorf("incorrect packet %v", out)
- }
+ read <- msg
+ }
+ }()
+ if err := mess.writeProtoMsg("a", NewMsg(3)); err != nil {
+ t.Errorf("expect no error for known protocol: %v", err)
+ }
+ select {
+ case msg := <-read:
+ if msg.Code != 19 {
+ t.Errorf("wrong code, got %d, expected %d", msg.Code, 19)
}
+ msg.Discard()
+ case err := <-readerr:
+ t.Errorf("read error: %v", err)
}
}
func TestPulse(t *testing.T) {
- net, _, mess := setupMessenger(make(Handlers))
- defer mess.Stop()
- ping := false
- timeout := false
- pingTimeout := 10 * time.Millisecond
- gracePeriod := 200 * time.Millisecond
- go mess.PingPong(pingTimeout, gracePeriod, func() { ping = true }, func() { timeout = true })
- net.In(0, Packet(0, 1))
- if ping {
- t.Errorf("ping sent too early")
- }
- time.Sleep(pingTimeout + 100*time.Millisecond)
- if !ping {
- t.Errorf("no ping sent after timeout")
- }
- if timeout {
- t.Errorf("timeout too early")
+ net, peer, _ := setupMessenger(nil)
+ defer peer.Stop()
+ bufr := bufio.NewReader(net)
+ if err := performTestHandshake(bufr, net); err != nil {
+ t.Fatalf("handshake failed: %v", err)
}
- ping = false
- net.In(0, Packet(0, 1))
- time.Sleep(pingTimeout + 100*time.Millisecond)
- if !ping {
- t.Errorf("no ping sent after timeout")
- }
- if timeout {
- t.Errorf("timeout too early")
+
+ before := time.Now()
+ msg, err := readMsg(bufr)
+ if err != nil {
+ t.Fatalf("read error: %v", err)
}
- ping = false
- time.Sleep(gracePeriod)
- if ping {
- t.Errorf("ping called twice")
+ after := time.Now()
+ if msg.Code != pingMsg {
+ t.Errorf("expected ping message, got %x", msg.Code)
}
- if !timeout {
- t.Errorf("no timeout after grace period")
+ if d := after.Sub(before); d < pingTimeout {
+ t.Errorf("ping sent too early after %v, expected at least %v", d, pingTimeout)
}
}
diff --git a/p2p/peer.go b/p2p/peer.go
index f4b68a007..34b6152a3 100644
--- a/p2p/peer.go
+++ b/p2p/peer.go
@@ -7,7 +7,6 @@ import (
)
type Peer struct {
- // quit chan chan bool
Inbound bool // inbound (via listener) or outbound (via dialout)
Address net.Addr
Host []byte
@@ -15,24 +14,12 @@ type Peer struct {
Pubkey []byte
Id string
Caps []string
- peerErrorChan chan *PeerError
- messenger *Messenger
+ peerErrorChan chan error
+ messenger *messenger
peerErrorHandler *PeerErrorHandler
server *Server
}
-func (self *Peer) Messenger() *Messenger {
- return self.messenger
-}
-
-func (self *Peer) PeerErrorChan() chan *PeerError {
- return self.peerErrorChan
-}
-
-func (self *Peer) Server() *Server {
- return self.server
-}
-
func NewPeer(conn net.Conn, address net.Addr, inbound bool, server *Server) *Peer {
peerErrorChan := NewPeerErrorChannel()
host, port, _ := net.SplitHostPort(address.String())
@@ -45,9 +32,8 @@ func NewPeer(conn net.Conn, address net.Addr, inbound bool, server *Server) *Pee
peerErrorChan: peerErrorChan,
server: server,
}
- connection := NewConnection(conn, peerErrorChan)
- peer.messenger = NewMessenger(peer, connection, peerErrorChan, server.Handlers())
- peer.peerErrorHandler = NewPeerErrorHandler(address, server.PeerDisconnect(), peerErrorChan, server.Blacklist())
+ peer.messenger = newMessenger(peer, conn, peerErrorChan, server.Handlers())
+ peer.peerErrorHandler = NewPeerErrorHandler(address, server.PeerDisconnect(), peerErrorChan)
return peer
}
@@ -61,8 +47,8 @@ func (self *Peer) String() string {
return fmt.Sprintf("%v:%v (%s) v%v %v", self.Host, self.Port, kind, self.Id, self.Caps)
}
-func (self *Peer) Write(protocol string, msg *Msg) error {
- return self.messenger.Write(protocol, msg)
+func (self *Peer) Write(protocol string, msg Msg) error {
+ return self.messenger.writeProtoMsg(protocol, msg)
}
func (self *Peer) Start() {
@@ -73,9 +59,6 @@ func (self *Peer) Start() {
func (self *Peer) Stop() {
self.peerErrorHandler.Stop()
self.messenger.Stop()
- // q := make(chan bool)
- // self.quit <- q
- // <-q
}
func (p *Peer) Encode() []interface{} {
diff --git a/p2p/peer_error.go b/p2p/peer_error.go
index de921878a..f3ef98d98 100644
--- a/p2p/peer_error.go
+++ b/p2p/peer_error.go
@@ -9,10 +9,9 @@ type ErrorCode int
const errorChanCapacity = 10
const (
- PacketTooShort = iota
+ PacketTooLong = iota
PayloadTooShort
MagicTokenMismatch
- EmptyPayload
ReadError
WriteError
MiscError
@@ -31,10 +30,9 @@ const (
)
var errorToString = map[ErrorCode]string{
- PacketTooShort: "Packet too short",
+ PacketTooLong: "Packet too long",
PayloadTooShort: "Payload too short",
MagicTokenMismatch: "Magic token mismatch",
- EmptyPayload: "Empty payload",
ReadError: "Read error",
WriteError: "Write error",
MiscError: "Misc error",
@@ -71,6 +69,6 @@ func (self *PeerError) Error() string {
return self.message
}
-func NewPeerErrorChannel() chan *PeerError {
- return make(chan *PeerError, errorChanCapacity)
+func NewPeerErrorChannel() chan error {
+ return make(chan error, errorChanCapacity)
}
diff --git a/p2p/peer_error_handler.go b/p2p/peer_error_handler.go
index ca6cae4db..47dcd14ff 100644
--- a/p2p/peer_error_handler.go
+++ b/p2p/peer_error_handler.go
@@ -18,17 +18,15 @@ type PeerErrorHandler struct {
address net.Addr
peerDisconnect chan DisconnectRequest
severity int
- peerErrorChan chan *PeerError
- blacklist Blacklist
+ errc chan error
}
-func NewPeerErrorHandler(address net.Addr, peerDisconnect chan DisconnectRequest, peerErrorChan chan *PeerError, blacklist Blacklist) *PeerErrorHandler {
+func NewPeerErrorHandler(address net.Addr, peerDisconnect chan DisconnectRequest, errc chan error) *PeerErrorHandler {
return &PeerErrorHandler{
quit: make(chan chan bool),
address: address,
peerDisconnect: peerDisconnect,
- peerErrorChan: peerErrorChan,
- blacklist: blacklist,
+ errc: errc,
}
}
@@ -45,10 +43,10 @@ func (self *PeerErrorHandler) Stop() {
func (self *PeerErrorHandler) listen() {
for {
select {
- case peerError, ok := <-self.peerErrorChan:
+ case err, ok := <-self.errc:
if ok {
- logger.Debugf("error %v\n", peerError)
- go self.handle(peerError)
+ logger.Debugf("error %v\n", err)
+ go self.handle(err)
} else {
return
}
@@ -59,8 +57,12 @@ func (self *PeerErrorHandler) listen() {
}
}
-func (self *PeerErrorHandler) handle(peerError *PeerError) {
+func (self *PeerErrorHandler) handle(err error) {
reason := DiscReason(' ')
+ peerError, ok := err.(*PeerError)
+ if !ok {
+ peerError = NewPeerError(MiscError, " %v", err)
+ }
switch peerError.Code {
case P2PVersionMismatch:
reason = DiscIncompatibleVersion
@@ -68,11 +70,11 @@ func (self *PeerErrorHandler) handle(peerError *PeerError) {
reason = DiscInvalidIdentity
case PubkeyForbidden:
reason = DiscUselessPeer
- case InvalidMsgCode, PacketTooShort, PayloadTooShort, MagicTokenMismatch, EmptyPayload, ProtocolBreach:
+ case InvalidMsgCode, PacketTooLong, PayloadTooShort, MagicTokenMismatch, ProtocolBreach:
reason = DiscProtocolError
case PingTimeout:
reason = DiscReadTimeout
- case WriteError, MiscError:
+ case ReadError, WriteError, MiscError:
reason = DiscNetworkError
case InvalidGenesis, InvalidNetworkId, InvalidProtocolVersion:
reason = DiscSubprotocolError
@@ -92,10 +94,5 @@ func (self *PeerErrorHandler) handle(peerError *PeerError) {
}
func (self *PeerErrorHandler) getSeverity(peerError *PeerError) int {
- switch peerError.Code {
- case ReadError:
- return 4 //tolerate 3 :)
- default:
- return 1
- }
+ return 1
}
diff --git a/p2p/peer_error_handler_test.go b/p2p/peer_error_handler_test.go
index 790a7443b..b93252f6a 100644
--- a/p2p/peer_error_handler_test.go
+++ b/p2p/peer_error_handler_test.go
@@ -11,7 +11,7 @@ func TestPeerErrorHandler(t *testing.T) {
address := &net.TCPAddr{IP: net.IP([]byte{1, 2, 3, 4}), Port: 30303}
peerDisconnect := make(chan DisconnectRequest)
peerErrorChan := NewPeerErrorChannel()
- peh := NewPeerErrorHandler(address, peerDisconnect, peerErrorChan, NewBlacklist())
+ peh := NewPeerErrorHandler(address, peerDisconnect, peerErrorChan)
peh.Start()
defer peh.Stop()
for i := 0; i < 11; i++ {
diff --git a/p2p/peer_test.go b/p2p/peer_test.go
index c37540bef..da62cc380 100644
--- a/p2p/peer_test.go
+++ b/p2p/peer_test.go
@@ -1,96 +1,90 @@
package p2p
-import (
- "bytes"
- "fmt"
- // "net"
- "testing"
- "time"
-)
+// "net"
-func TestPeer(t *testing.T) {
- handlers := make(Handlers)
- testProtocol := &TestProtocol{Msgs: []*Msg{}}
- handlers["aaa"] = func(p *Peer) Protocol { return testProtocol }
- handlers["ccc"] = func(p *Peer) Protocol { return testProtocol }
- addr := &TestAddr{"test:30"}
- conn := NewTestNetworkConnection(addr)
- _, server := SetupTestServer(handlers)
- server.Handshake()
- peer := NewPeer(conn, addr, true, server)
- // peer.Messenger().AddProtocols([]string{"aaa", "ccc"})
- peer.Start()
- defer peer.Stop()
- time.Sleep(2 * time.Millisecond)
- if len(conn.Out) != 1 {
- t.Errorf("handshake not sent")
- } else {
- out := conn.Out[0]
- packet := Packet(0, HandshakeMsg, P2PVersion, []byte(peer.server.identity.String()), []interface{}{peer.server.protocols}, peer.server.port, peer.server.identity.Pubkey()[1:])
- if bytes.Compare(out, packet) != 0 {
- t.Errorf("incorrect handshake packet %v != %v", out, packet)
- }
- }
+// func TestPeer(t *testing.T) {
+// handlers := make(Handlers)
+// testProtocol := &TestProtocol{recv: make(chan testMsg)}
+// handlers["aaa"] = func(p *Peer) Protocol { return testProtocol }
+// handlers["ccc"] = func(p *Peer) Protocol { return testProtocol }
+// addr := &TestAddr{"test:30"}
+// conn := NewTestNetworkConnection(addr)
+// _, server := SetupTestServer(handlers)
+// server.Handshake()
+// peer := NewPeer(conn, addr, true, server)
+// // peer.Messenger().AddProtocols([]string{"aaa", "ccc"})
+// peer.Start()
+// defer peer.Stop()
+// time.Sleep(2 * time.Millisecond)
+// if len(conn.Out) != 1 {
+// t.Errorf("handshake not sent")
+// } else {
+// out := conn.Out[0]
+// packet := Packet(0, HandshakeMsg, P2PVersion, []byte(peer.server.identity.String()), []interface{}{peer.server.protocols}, peer.server.port, peer.server.identity.Pubkey()[1:])
+// if bytes.Compare(out, packet) != 0 {
+// t.Errorf("incorrect handshake packet %v != %v", out, packet)
+// }
+// }
- packet := Packet(0, HandshakeMsg, P2PVersion, []byte("peer"), []interface{}{"bbb", "aaa", "ccc"}, 30, []byte("0000000000000000000000000000000000000000000000000000000000000000"))
- conn.In(0, packet)
- time.Sleep(10 * time.Millisecond)
+// packet := Packet(0, HandshakeMsg, P2PVersion, []byte("peer"), []interface{}{"bbb", "aaa", "ccc"}, 30, []byte("0000000000000000000000000000000000000000000000000000000000000000"))
+// conn.In(0, packet)
+// time.Sleep(10 * time.Millisecond)
- pro, _ := peer.Messenger().protocols[0].(*BaseProtocol)
- if pro.state != handshakeReceived {
- t.Errorf("handshake not received")
- }
- if peer.Port != 30 {
- t.Errorf("port incorrectly set")
- }
- if peer.Id != "peer" {
- t.Errorf("id incorrectly set")
- }
- if string(peer.Pubkey) != "0000000000000000000000000000000000000000000000000000000000000000" {
- t.Errorf("pubkey incorrectly set")
- }
- fmt.Println(peer.Caps)
- if len(peer.Caps) != 3 || peer.Caps[0] != "aaa" || peer.Caps[1] != "bbb" || peer.Caps[2] != "ccc" {
- t.Errorf("protocols incorrectly set")
- }
+// pro, _ := peer.Messenger().protocols[0].(*BaseProtocol)
+// if pro.state != handshakeReceived {
+// t.Errorf("handshake not received")
+// }
+// if peer.Port != 30 {
+// t.Errorf("port incorrectly set")
+// }
+// if peer.Id != "peer" {
+// t.Errorf("id incorrectly set")
+// }
+// if string(peer.Pubkey) != "0000000000000000000000000000000000000000000000000000000000000000" {
+// t.Errorf("pubkey incorrectly set")
+// }
+// fmt.Println(peer.Caps)
+// if len(peer.Caps) != 3 || peer.Caps[0] != "aaa" || peer.Caps[1] != "bbb" || peer.Caps[2] != "ccc" {
+// t.Errorf("protocols incorrectly set")
+// }
- msg, _ := NewMsg(3)
- err := peer.Write("aaa", msg)
- if err != nil {
- t.Errorf("expect no error for known protocol: %v", err)
- } else {
- time.Sleep(1 * time.Millisecond)
- if len(conn.Out) != 2 {
- t.Errorf("msg not written")
- } else {
- out := conn.Out[1]
- packet := Packet(16, 3)
- if bytes.Compare(out, packet) != 0 {
- t.Errorf("incorrect packet %v != %v", out, packet)
- }
- }
- }
+// msg := NewMsg(3)
+// err := peer.Write("aaa", msg)
+// if err != nil {
+// t.Errorf("expect no error for known protocol: %v", err)
+// } else {
+// time.Sleep(1 * time.Millisecond)
+// if len(conn.Out) != 2 {
+// t.Errorf("msg not written")
+// } else {
+// out := conn.Out[1]
+// packet := Packet(16, 3)
+// if bytes.Compare(out, packet) != 0 {
+// t.Errorf("incorrect packet %v != %v", out, packet)
+// }
+// }
+// }
- msg, _ = NewMsg(2)
- err = peer.Write("ccc", msg)
- if err != nil {
- t.Errorf("expect no error for known protocol: %v", err)
- } else {
- time.Sleep(1 * time.Millisecond)
- if len(conn.Out) != 3 {
- t.Errorf("msg not written")
- } else {
- out := conn.Out[2]
- packet := Packet(21, 2)
- if bytes.Compare(out, packet) != 0 {
- t.Errorf("incorrect packet %v != %v", out, packet)
- }
- }
- }
+// msg = NewMsg(2)
+// err = peer.Write("ccc", msg)
+// if err != nil {
+// t.Errorf("expect no error for known protocol: %v", err)
+// } else {
+// time.Sleep(1 * time.Millisecond)
+// if len(conn.Out) != 3 {
+// t.Errorf("msg not written")
+// } else {
+// out := conn.Out[2]
+// packet := Packet(21, 2)
+// if bytes.Compare(out, packet) != 0 {
+// t.Errorf("incorrect packet %v != %v", out, packet)
+// }
+// }
+// }
- err = peer.Write("bbb", msg)
- time.Sleep(1 * time.Millisecond)
- if err == nil {
- t.Errorf("expect error for unknown protocol")
- }
-}
+// err = peer.Write("bbb", msg)
+// time.Sleep(1 * time.Millisecond)
+// if err == nil {
+// t.Errorf("expect error for unknown protocol")
+// }
+// }
diff --git a/p2p/protocol.go b/p2p/protocol.go
index 5d05ced7d..ccc275287 100644
--- a/p2p/protocol.go
+++ b/p2p/protocol.go
@@ -2,43 +2,101 @@ package p2p
import (
"bytes"
- "fmt"
"net"
"sort"
- "sync"
"time"
+
+ "github.com/ethereum/go-ethereum/ethutil"
)
+// Protocol is implemented by P2P subprotocols.
type Protocol interface {
- Start()
- Stop()
- HandleIn(*Msg, chan *Msg)
- HandleOut(*Msg) bool
+ // Start is called when the protocol becomes active.
+ // It should read and write messages from rw.
+ // Messages must be fully consumed.
+ //
+ // The connection is closed when Start returns. It should return
+ // any protocol-level error (such as an I/O error) that is
+ // encountered.
+ Start(peer *Peer, rw MsgReadWriter) error
+
+ // Offset should return the number of message codes
+ // used by the protocol.
Offset() MsgCode
- Name() string
+}
+
+type MsgReader interface {
+ ReadMsg() (Msg, error)
+}
+
+type MsgWriter interface {
+ WriteMsg(Msg) error
+}
+
+// MsgReadWriter is passed to protocols. Protocol implementations can
+// use it to write messages back to a connected peer.
+type MsgReadWriter interface {
+ MsgReader
+ MsgWriter
+}
+
+type MsgHandler func(code MsgCode, data *ethutil.Value) error
+
+// MsgLoop reads messages off the given reader and
+// calls the handler function for each decoded message until
+// it returns an error or the peer connection is closed.
+//
+// If a message is larger than the given maximum size, RunProtocol
+// returns an appropriate error.n
+func MsgLoop(r MsgReader, maxsize uint32, handler MsgHandler) error {
+ for {
+ msg, err := r.ReadMsg()
+ if err != nil {
+ return err
+ }
+ if msg.Size > maxsize {
+ return NewPeerError(InvalidMsg, "size %d exceeds maximum size of %d", msg.Size, maxsize)
+ }
+ value, err := msg.Data()
+ if err != nil {
+ return err
+ }
+ if err := handler(msg.Code, value); err != nil {
+ return err
+ }
+ }
+}
+
+// the ÐΞVp2p base protocol
+type baseProtocol struct {
+ rw MsgReadWriter
+ peer *Peer
+}
+
+type bpMsg struct {
+ code MsgCode
+ data *ethutil.Value
}
const (
- P2PVersion = 0
- pingTimeout = 2
- pingGracePeriod = 2
+ p2pVersion = 0
+ pingTimeout = 2 * time.Second
+ pingGracePeriod = 2 * time.Second
)
const (
- HandshakeMsg = iota
- DiscMsg
- PingMsg
- PongMsg
- GetPeersMsg
- PeersMsg
- offset = 16
+ // message codes
+ handshakeMsg = iota
+ discMsg
+ pingMsg
+ pongMsg
+ getPeersMsg
+ peersMsg
)
-type ProtocolState uint8
-
const (
- nullState = iota
- handshakeReceived
+ baseProtocolOffset MsgCode = 16
+ baseProtocolMaxMsgSize = 500 * 1024
)
type DiscReason byte
@@ -62,7 +120,7 @@ const (
DiscSubprotocolError = 0x10
)
-var discReasonToString = map[DiscReason]string{
+var discReasonToString = [DiscSubprotocolError + 1]string{
DiscRequested: "Disconnect requested",
DiscNetworkError: "Network error",
DiscProtocolError: "Breach of protocol",
@@ -82,197 +140,178 @@ func (d DiscReason) String() string {
if len(discReasonToString) < int(d) {
return "Unknown"
}
-
return discReasonToString[d]
}
-type BaseProtocol struct {
- peer *Peer
- state ProtocolState
- stateLock sync.RWMutex
+func (bp *baseProtocol) Ping() {
}
-func NewBaseProtocol(peer *Peer) *BaseProtocol {
- self := &BaseProtocol{
- peer: peer,
- }
-
- return self
+func (bp *baseProtocol) Offset() MsgCode {
+ return baseProtocolOffset
}
-func (self *BaseProtocol) Start() {
- if self.peer != nil {
- self.peer.Write("", self.peer.Server().Handshake())
- go self.peer.Messenger().PingPong(
- pingTimeout*time.Second,
- pingGracePeriod*time.Second,
- self.Ping,
- self.Timeout,
- )
+func (bp *baseProtocol) Start(peer *Peer, rw MsgReadWriter) error {
+ bp.peer, bp.rw = peer, rw
+
+ // Do the handshake.
+ // TODO: disconnect is valid before handshake, too.
+ rw.WriteMsg(bp.peer.server.handshakeMsg())
+ msg, err := rw.ReadMsg()
+ if err != nil {
+ return err
+ }
+ if msg.Code != handshakeMsg {
+ return NewPeerError(ProtocolBreach, " first message must be handshake")
+ }
+ data, err := msg.Data()
+ if err != nil {
+ return NewPeerError(InvalidMsg, "%v", err)
+ }
+ if err := bp.handleHandshake(data); err != nil {
+ return err
}
-}
-func (self *BaseProtocol) Stop() {
+ msgin := make(chan bpMsg)
+ done := make(chan error, 1)
+ go func() {
+ done <- MsgLoop(rw, baseProtocolMaxMsgSize,
+ func(code MsgCode, data *ethutil.Value) error {
+ msgin <- bpMsg{code, data}
+ return nil
+ })
+ }()
+ return bp.loop(msgin, done)
}
-func (self *BaseProtocol) Ping() {
- msg, _ := NewMsg(PingMsg)
- self.peer.Write("", msg)
+func (bp *baseProtocol) loop(msgin <-chan bpMsg, quit <-chan error) error {
+ logger.Debugf("pingpong keepalive started at %v\n", time.Now())
+ messenger := bp.rw.(*proto).messenger
+ pingTimer := time.NewTimer(pingTimeout)
+ pinged := true
+
+ for {
+ select {
+ case msg := <-msgin:
+ if err := bp.handle(msg.code, msg.data); err != nil {
+ return err
+ }
+ case err := <-quit:
+ return err
+ case <-messenger.pulse:
+ pingTimer.Reset(pingTimeout)
+ pinged = false
+ case <-pingTimer.C:
+ if pinged {
+ return NewPeerError(PingTimeout, "")
+ }
+ logger.Debugf("pinging at %v\n", time.Now())
+ if err := bp.rw.WriteMsg(NewMsg(pingMsg)); err != nil {
+ return NewPeerError(WriteError, "%v", err)
+ }
+ pinged = true
+ pingTimer.Reset(pingTimeout)
+ }
+ }
}
-func (self *BaseProtocol) Timeout() {
- self.peerError(PingTimeout, "")
-}
+func (bp *baseProtocol) handle(code MsgCode, data *ethutil.Value) error {
+ switch code {
+ case handshakeMsg:
+ return NewPeerError(ProtocolBreach, " extra handshake received")
-func (self *BaseProtocol) Name() string {
- return ""
-}
+ case discMsg:
+ logger.Infof("Disconnect requested from peer %v, reason", DiscReason(data.Get(0).Uint()))
+ bp.peer.server.PeerDisconnect() <- DisconnectRequest{
+ addr: bp.peer.Address,
+ reason: DiscRequested,
+ }
-func (self *BaseProtocol) Offset() MsgCode {
- return offset
-}
+ case pingMsg:
+ return bp.rw.WriteMsg(NewMsg(pongMsg))
-func (self *BaseProtocol) CheckState(state ProtocolState) bool {
- self.stateLock.RLock()
- self.stateLock.RUnlock()
- if self.state != state {
- return false
- } else {
- return true
- }
-}
+ case pongMsg:
+ // reply for ping
-func (self *BaseProtocol) HandleIn(msg *Msg, response chan *Msg) {
- if msg.Code() == HandshakeMsg {
- self.handleHandshake(msg)
- } else {
- if !self.CheckState(handshakeReceived) {
- self.peerError(ProtocolBreach, "message code %v not allowed", msg.Code())
- close(response)
- return
- }
- switch msg.Code() {
- case DiscMsg:
- logger.Infof("Disconnect requested from peer %v, reason", DiscReason(msg.Data().Get(0).Uint()))
- self.peer.Server().PeerDisconnect() <- DisconnectRequest{
- addr: self.peer.Address,
- reason: DiscRequested,
- }
- case PingMsg:
- out, _ := NewMsg(PongMsg)
- response <- out
- case PongMsg:
- case GetPeersMsg:
- // Peer asked for list of connected peers
- if out, err := self.peer.Server().PeersMessage(); err != nil {
- response <- out
+ case getPeersMsg:
+ // Peer asked for list of connected peers.
+ peersRLP := bp.peer.server.encodedPeerList()
+ if peersRLP != nil {
+ msg := Msg{
+ Code: peersMsg,
+ Size: uint32(len(peersRLP)),
+ Payload: bytes.NewReader(peersRLP),
}
- case PeersMsg:
- self.handlePeers(msg)
- default:
- self.peerError(InvalidMsgCode, "unknown message code %v", msg.Code())
+ return bp.rw.WriteMsg(msg)
}
- }
- close(response)
-}
-func (self *BaseProtocol) HandleOut(msg *Msg) (allowed bool) {
- // somewhat overly paranoid
- allowed = msg.Code() == HandshakeMsg || msg.Code() == DiscMsg || msg.Code() < self.Offset() && self.CheckState(handshakeReceived)
- return
-}
+ case peersMsg:
+ bp.handlePeers(data)
-func (self *BaseProtocol) peerError(errorCode ErrorCode, format string, v ...interface{}) {
- err := NewPeerError(errorCode, format, v...)
- logger.Warnln(err)
- fmt.Println(self.peer, err)
- if self.peer != nil {
- self.peer.PeerErrorChan() <- err
+ default:
+ return NewPeerError(InvalidMsgCode, "unknown message code %v", code)
}
+ return nil
}
-func (self *BaseProtocol) handlePeers(msg *Msg) {
- it := msg.Data().NewIterator()
+func (bp *baseProtocol) handlePeers(data *ethutil.Value) {
+ it := data.NewIterator()
for it.Next() {
ip := net.IP(it.Value().Get(0).Bytes())
port := it.Value().Get(1).Uint()
address := &net.TCPAddr{IP: ip, Port: int(port)}
- go self.peer.Server().PeerConnect(address)
+ go bp.peer.server.PeerConnect(address)
}
}
-func (self *BaseProtocol) handleHandshake(msg *Msg) {
- self.stateLock.Lock()
- defer self.stateLock.Unlock()
- if self.state != nullState {
- self.peerError(ProtocolBreach, "extra handshake")
- return
- }
-
- c := msg.Data()
-
+func (bp *baseProtocol) handleHandshake(c *ethutil.Value) error {
var (
- p2pVersion = c.Get(0).Uint()
- id = c.Get(1).Str()
- caps = c.Get(2)
- port = c.Get(3).Uint()
- pubkey = c.Get(4).Bytes()
+ remoteVersion = c.Get(0).Uint()
+ id = c.Get(1).Str()
+ caps = c.Get(2)
+ port = c.Get(3).Uint()
+ pubkey = c.Get(4).Bytes()
)
- fmt.Printf("handshake received %v, %v, %v, %v, %v ", p2pVersion, id, caps, port, pubkey)
-
// Check correctness of p2p protocol version
- if p2pVersion != P2PVersion {
- self.peerError(P2PVersionMismatch, "Require protocol %d, received %d\n", P2PVersion, p2pVersion)
- return
+ if remoteVersion != p2pVersion {
+ return NewPeerError(P2PVersionMismatch, "Require protocol %d, received %d\n", p2pVersion, remoteVersion)
}
// Handle the pub key (validation, uniqueness)
if len(pubkey) == 0 {
- self.peerError(PubkeyMissing, "not supplied in handshake.")
- return
+ return NewPeerError(PubkeyMissing, "not supplied in handshake.")
}
if len(pubkey) != 64 {
- self.peerError(PubkeyInvalid, "require 512 bit, got %v", len(pubkey)*8)
- return
+ return NewPeerError(PubkeyInvalid, "require 512 bit, got %v", len(pubkey)*8)
}
- // Self connect detection
- if bytes.Compare(self.peer.Server().ClientIdentity().Pubkey()[1:], pubkey) == 0 {
- self.peerError(PubkeyForbidden, "not allowed to connect to self")
- return
+ // self connect detection
+ if bytes.Compare(bp.peer.server.ClientIdentity().Pubkey()[1:], pubkey) == 0 {
+ return NewPeerError(PubkeyForbidden, "not allowed to connect to bp")
}
// register pubkey on server. this also sets the pubkey on the peer (need lock)
- if err := self.peer.Server().RegisterPubkey(self.peer, pubkey); err != nil {
- self.peerError(PubkeyForbidden, err.Error())
- return
+ if err := bp.peer.server.RegisterPubkey(bp.peer, pubkey); err != nil {
+ return NewPeerError(PubkeyForbidden, err.Error())
}
// check port
- if self.peer.Inbound {
+ if bp.peer.Inbound {
uint16port := uint16(port)
- if self.peer.Port > 0 && self.peer.Port != uint16port {
- self.peerError(PortMismatch, "port mismatch: %v != %v", self.peer.Port, port)
- return
+ if bp.peer.Port > 0 && bp.peer.Port != uint16port {
+ return NewPeerError(PortMismatch, "port mismatch: %v != %v", bp.peer.Port, port)
} else {
- self.peer.Port = uint16port
+ bp.peer.Port = uint16port
}
}
capsIt := caps.NewIterator()
for capsIt.Next() {
cap := capsIt.Value().Str()
- self.peer.Caps = append(self.peer.Caps, cap)
+ bp.peer.Caps = append(bp.peer.Caps, cap)
}
- sort.Strings(self.peer.Caps)
- self.peer.Messenger().AddProtocols(self.peer.Caps)
-
- self.peer.Id = id
-
- self.state = handshakeReceived
-
- //p.ethereum.PushPeer(p)
- // p.ethereum.reactor.Post("peerList", p.ethereum.Peers())
- return
+ sort.Strings(bp.peer.Caps)
+ bp.rw.(*proto).messenger.setRemoteProtocols(bp.peer.Caps)
+ bp.peer.Id = id
+ return nil
}
diff --git a/p2p/server.go b/p2p/server.go
index 91bc4af5c..54d2cde30 100644
--- a/p2p/server.go
+++ b/p2p/server.go
@@ -80,12 +80,12 @@ type Server struct {
quit chan chan bool
peersLock sync.RWMutex
- maxPeers int
- peers []*Peer
- peerSlots chan int
- peersTable map[string]int
- peersMsg *Msg
- peerCount int
+ maxPeers int
+ peers []*Peer
+ peerSlots chan int
+ peersTable map[string]int
+ peerCount int
+ cachedEncodedPeers []byte
peerConnect chan net.Addr
peerDisconnect chan DisconnectRequest
@@ -147,27 +147,6 @@ func (self *Server) ClientIdentity() ClientIdentity {
return self.identity
}
-func (self *Server) PeersMessage() (msg *Msg, err error) {
- // TODO: memoize and reset when peers change
- self.peersLock.RLock()
- defer self.peersLock.RUnlock()
- msg = self.peersMsg
- if msg == nil {
- var peerData []interface{}
- for _, i := range self.peersTable {
- peer := self.peers[i]
- peerData = append(peerData, peer.Encode())
- }
- if len(peerData) == 0 {
- err = fmt.Errorf("no peers")
- } else {
- msg, err = NewMsg(PeersMsg, peerData...)
- self.peersMsg = msg //memoize
- }
- }
- return
-}
-
func (self *Server) Peers() (peers []*Peer) {
self.peersLock.RLock()
defer self.peersLock.RUnlock()
@@ -185,8 +164,6 @@ func (self *Server) PeerCount() int {
return self.peerCount
}
-var getPeersMsg, _ = NewMsg(GetPeersMsg)
-
func (self *Server) PeerConnect(addr net.Addr) {
// TODO: should buffer, filter and uniq
// send GetPeersMsg if not blocking
@@ -209,12 +186,21 @@ func (self *Server) Handlers() Handlers {
return self.handlers
}
-func (self *Server) Broadcast(protocol string, msg *Msg) {
+func (self *Server) Broadcast(protocol string, code MsgCode, data ...interface{}) {
+ var payload []byte
+ if data != nil {
+ payload = encodePayload(data...)
+ }
self.peersLock.RLock()
defer self.peersLock.RUnlock()
for _, peer := range self.peers {
if peer != nil {
- peer.Write(protocol, msg)
+ var msg = Msg{Code: code}
+ if data != nil {
+ msg.Payload = bytes.NewReader(payload)
+ msg.Size = uint32(len(payload))
+ }
+ peer.messenger.writeProtoMsg(protocol, msg)
}
}
}
@@ -296,7 +282,7 @@ FOR:
select {
case slot := <-self.peerSlots:
i++
- fmt.Printf("%v: found slot %v", i, slot)
+ fmt.Printf("%v: found slot %v\n", i, slot)
if i == self.maxPeers {
break FOR
}
@@ -358,70 +344,68 @@ func (self *Server) outboundPeerHandler(dialer Dialer) {
}
// check if peer address already connected
-func (self *Server) connected(address net.Addr) (err error) {
+func (self *Server) isConnected(address net.Addr) bool {
self.peersLock.RLock()
defer self.peersLock.RUnlock()
- // fmt.Printf("address: %v\n", address)
- slot, found := self.peersTable[address.String()]
- if found {
- err = fmt.Errorf("already connected as peer %v (%v)", slot, address)
- }
- return
+ _, found := self.peersTable[address.String()]
+ return found
}
// connect to peer via listener.Accept()
func (self *Server) connectInboundPeer(listener net.Listener, slot int) {
var address net.Addr
conn, err := listener.Accept()
- if err == nil {
- address = conn.RemoteAddr()
- err = self.connected(address)
- if err != nil {
- conn.Close()
- }
- }
if err != nil {
logger.Debugln(err)
self.peerSlots <- slot
- } else {
- fmt.Printf("adding %v\n", address)
- go self.addPeer(conn, address, true, slot)
+ return
+ }
+ address = conn.RemoteAddr()
+ // XXX: this won't work because the remote socket
+ // address does not identify the peer. we should
+ // probably get rid of this check and rely on public
+ // key detection in the base protocol.
+ if self.isConnected(address) {
+ conn.Close()
+ self.peerSlots <- slot
+ return
}
+ fmt.Printf("adding %v\n", address)
+ go self.addPeer(conn, address, true, slot)
}
// connect to peer via dial out
func (self *Server) connectOutboundPeer(dialer Dialer, address net.Addr, slot int) {
- var conn net.Conn
- err := self.connected(address)
- if err == nil {
- conn, err = dialer.Dial(address.Network(), address.String())
+ if self.isConnected(address) {
+ return
}
+ conn, err := dialer.Dial(address.Network(), address.String())
if err != nil {
- logger.Debugln(err)
self.peerSlots <- slot
- } else {
- go self.addPeer(conn, address, false, slot)
+ return
}
+ go self.addPeer(conn, address, false, slot)
}
// creates the new peer object and inserts it into its slot
-func (self *Server) addPeer(conn net.Conn, address net.Addr, inbound bool, slot int) {
+func (self *Server) addPeer(conn net.Conn, address net.Addr, inbound bool, slot int) *Peer {
self.peersLock.Lock()
defer self.peersLock.Unlock()
if self.closed {
fmt.Println("oopsy, not no longer need peer")
conn.Close() //oopsy our bad
self.peerSlots <- slot // release slot
- } else {
- peer := NewPeer(conn, address, inbound, self)
- self.peers[slot] = peer
- self.peersTable[address.String()] = slot
- self.peerCount++
- // reset peersmsg
- self.peersMsg = nil
- fmt.Printf("added peer %v %v (slot %v)\n", address, peer, slot)
- peer.Start()
+ return nil
}
+ logger.Infoln("adding new peer", address)
+ peer := NewPeer(conn, address, inbound, self)
+ self.peers[slot] = peer
+ self.peersTable[address.String()] = slot
+ self.peerCount++
+ self.cachedEncodedPeers = nil
+ fmt.Printf("added peer %v %v (slot %v)\n", address, peer, slot)
+ peer.Start()
+ return peer
}
// removes peer: sending disconnect msg, stop peer, remove rom list/table, release slot
@@ -441,13 +425,12 @@ func (self *Server) removePeer(request DisconnectRequest) {
self.peerCount--
self.peers[slot] = nil
delete(self.peersTable, address.String())
- // reset peersmsg
- self.peersMsg = nil
+ self.cachedEncodedPeers = nil
fmt.Printf("removed peer %v (slot %v)\n", peer, slot)
self.peersLock.Unlock()
// sending disconnect message
- disconnectMsg, _ := NewMsg(DiscMsg, request.reason)
+ disconnectMsg := NewMsg(discMsg, request.reason)
peer.Write("", disconnectMsg)
// be nice and wait
time.Sleep(disconnectGracePeriod * time.Second)
@@ -459,11 +442,32 @@ func (self *Server) removePeer(request DisconnectRequest) {
self.peerSlots <- slot
}
+// encodedPeerList returns an RLP-encoded list of peers.
+// the returned slice will be nil if there are no peers.
+func (self *Server) encodedPeerList() []byte {
+ // TODO: memoize and reset when peers change
+ self.peersLock.RLock()
+ defer self.peersLock.RUnlock()
+ if self.cachedEncodedPeers == nil && self.peerCount > 0 {
+ var peerData []interface{}
+ for _, i := range self.peersTable {
+ peer := self.peers[i]
+ peerData = append(peerData, peer.Encode())
+ }
+ self.cachedEncodedPeers = encodePayload(peerData)
+ }
+ return self.cachedEncodedPeers
+}
+
// fix handshake message to push to peers
-func (self *Server) Handshake() *Msg {
- fmt.Println(self.identity.Pubkey()[1:])
- msg, _ := NewMsg(HandshakeMsg, P2PVersion, []byte(self.identity.String()), []interface{}{self.protocols}, self.port, self.identity.Pubkey()[1:])
- return msg
+func (self *Server) handshakeMsg() Msg {
+ return NewMsg(handshakeMsg,
+ p2pVersion,
+ []byte(self.identity.String()),
+ []interface{}{self.protocols},
+ self.port,
+ self.identity.Pubkey()[1:],
+ )
}
func (self *Server) RegisterPubkey(candidate *Peer, pubkey []byte) error {
diff --git a/p2p/server_test.go b/p2p/server_test.go
index f749cc490..472759231 100644
--- a/p2p/server_test.go
+++ b/p2p/server_test.go
@@ -1,8 +1,8 @@
package p2p
import (
- "bytes"
"fmt"
+ "io"
"net"
"testing"
"time"
@@ -32,6 +32,7 @@ func (self *TestNetwork) Listener(addr net.Addr) (net.Listener, error) {
connections: self.connections,
addr: addr,
max: self.maxinbound,
+ close: make(chan struct{}),
}, nil
}
@@ -76,24 +77,25 @@ type TestListener struct {
addr net.Addr
max int
i int
+ close chan struct{}
}
-func (self *TestListener) Accept() (conn net.Conn, err error) {
+func (self *TestListener) Accept() (net.Conn, error) {
self.i++
if self.i > self.max {
- err = fmt.Errorf("no more")
- } else {
- addr := &TestAddr{fmt.Sprintf("inboundpeer-%d", self.i)}
- tconn := NewTestNetworkConnection(addr)
- key := tconn.RemoteAddr().String()
- self.connections[key] = tconn
- conn = net.Conn(tconn)
- fmt.Printf("accepted connection from: %v \n", addr)
+ <-self.close
+ return nil, io.EOF
}
- return
+ addr := &TestAddr{fmt.Sprintf("inboundpeer-%d", self.i)}
+ tconn := NewTestNetworkConnection(addr)
+ key := tconn.RemoteAddr().String()
+ self.connections[key] = tconn
+ fmt.Printf("accepted connection from: %v \n", addr)
+ return tconn, nil
}
func (self *TestListener) Close() error {
+ close(self.close)
return nil
}
@@ -101,6 +103,86 @@ func (self *TestListener) Addr() net.Addr {
return self.addr
}
+type TestNetworkConnection struct {
+ in chan []byte
+ close chan struct{}
+ current []byte
+ Out [][]byte
+ addr net.Addr
+}
+
+func NewTestNetworkConnection(addr net.Addr) *TestNetworkConnection {
+ return &TestNetworkConnection{
+ in: make(chan []byte),
+ close: make(chan struct{}),
+ current: []byte{},
+ Out: [][]byte{},
+ addr: addr,
+ }
+}
+
+func (self *TestNetworkConnection) In(latency time.Duration, packets ...[]byte) {
+ time.Sleep(latency)
+ for _, s := range packets {
+ self.in <- s
+ }
+}
+
+func (self *TestNetworkConnection) Read(buff []byte) (n int, err error) {
+ if len(self.current) == 0 {
+ var ok bool
+ select {
+ case self.current, ok = <-self.in:
+ if !ok {
+ return 0, io.EOF
+ }
+ case <-self.close:
+ return 0, io.EOF
+ }
+ }
+ length := len(self.current)
+ if length > len(buff) {
+ copy(buff[:], self.current[:len(buff)])
+ self.current = self.current[len(buff):]
+ return len(buff), nil
+ } else {
+ copy(buff[:length], self.current[:])
+ self.current = []byte{}
+ return length, io.EOF
+ }
+}
+
+func (self *TestNetworkConnection) Write(buff []byte) (n int, err error) {
+ self.Out = append(self.Out, buff)
+ fmt.Printf("net write(%d): %x\n", len(self.Out), buff)
+ return len(buff), nil
+}
+
+func (self *TestNetworkConnection) Close() error {
+ close(self.close)
+ return nil
+}
+
+func (self *TestNetworkConnection) LocalAddr() (addr net.Addr) {
+ return
+}
+
+func (self *TestNetworkConnection) RemoteAddr() (addr net.Addr) {
+ return self.addr
+}
+
+func (self *TestNetworkConnection) SetDeadline(t time.Time) (err error) {
+ return
+}
+
+func (self *TestNetworkConnection) SetReadDeadline(t time.Time) (err error) {
+ return
+}
+
+func (self *TestNetworkConnection) SetWriteDeadline(t time.Time) (err error) {
+ return
+}
+
func SetupTestServer(handlers Handlers) (network *TestNetwork, server *Server) {
network = NewTestNetwork(1)
addr := &TestAddr{"test:30303"}
@@ -124,12 +206,10 @@ func TestServerListener(t *testing.T) {
if !ok {
t.Error("not found inbound peer 1")
} else {
- fmt.Printf("out: %v\n", peer1.Out)
if len(peer1.Out) != 2 {
- t.Errorf("not enough messages sent to peer 1: %v ", len(peer1.Out))
+ t.Errorf("wrong number of writes to peer 1: got %d, want %d", len(peer1.Out), 2)
}
}
-
}
func TestServerDialer(t *testing.T) {
@@ -142,65 +222,63 @@ func TestServerDialer(t *testing.T) {
if !ok {
t.Error("not found outbound peer 1")
} else {
- fmt.Printf("out: %v\n", peer1.Out)
if len(peer1.Out) != 2 {
- t.Errorf("not enough messages sent to peer 1: %v ", len(peer1.Out))
+ t.Errorf("wrong number of writes to peer 1: got %d, want %d", len(peer1.Out), 2)
}
}
}
-func TestServerBroadcast(t *testing.T) {
- handlers := make(Handlers)
- testProtocol := &TestProtocol{Msgs: []*Msg{}}
- handlers["aaa"] = func(p *Peer) Protocol { return testProtocol }
- network, server := SetupTestServer(handlers)
- server.Start(true, true)
- server.peerConnect <- &TestAddr{"outboundpeer-1"}
- time.Sleep(10 * time.Millisecond)
- msg, _ := NewMsg(0)
- server.Broadcast("", msg)
- packet := Packet(0, 0)
- time.Sleep(10 * time.Millisecond)
- server.Stop()
- peer1, ok := network.connections["outboundpeer-1"]
- if !ok {
- t.Error("not found outbound peer 1")
- } else {
- fmt.Printf("out: %v\n", peer1.Out)
- if len(peer1.Out) != 3 {
- t.Errorf("not enough messages sent to peer 1: %v ", len(peer1.Out))
- } else {
- if bytes.Compare(peer1.Out[1], packet) != 0 {
- t.Errorf("incorrect broadcast packet %v != %v", peer1.Out[1], packet)
- }
- }
- }
- peer2, ok := network.connections["inboundpeer-1"]
- if !ok {
- t.Error("not found inbound peer 2")
- } else {
- fmt.Printf("out: %v\n", peer2.Out)
- if len(peer1.Out) != 3 {
- t.Errorf("not enough messages sent to peer 2: %v ", len(peer2.Out))
- } else {
- if bytes.Compare(peer2.Out[1], packet) != 0 {
- t.Errorf("incorrect broadcast packet %v != %v", peer2.Out[1], packet)
- }
- }
- }
-}
+// func TestServerBroadcast(t *testing.T) {
+// handlers := make(Handlers)
+// testProtocol := &TestProtocol{Msgs: []*Msg{}}
+// handlers["aaa"] = func(p *Peer) Protocol { return testProtocol }
+// network, server := SetupTestServer(handlers)
+// server.Start(true, true)
+// server.peerConnect <- &TestAddr{"outboundpeer-1"}
+// time.Sleep(10 * time.Millisecond)
+// msg := NewMsg(0)
+// server.Broadcast("", msg)
+// packet := Packet(0, 0)
+// time.Sleep(10 * time.Millisecond)
+// server.Stop()
+// peer1, ok := network.connections["outboundpeer-1"]
+// if !ok {
+// t.Error("not found outbound peer 1")
+// } else {
+// fmt.Printf("out: %v\n", peer1.Out)
+// if len(peer1.Out) != 3 {
+// t.Errorf("not enough messages sent to peer 1: %v ", len(peer1.Out))
+// } else {
+// if bytes.Compare(peer1.Out[1], packet) != 0 {
+// t.Errorf("incorrect broadcast packet %v != %v", peer1.Out[1], packet)
+// }
+// }
+// }
+// peer2, ok := network.connections["inboundpeer-1"]
+// if !ok {
+// t.Error("not found inbound peer 2")
+// } else {
+// fmt.Printf("out: %v\n", peer2.Out)
+// if len(peer1.Out) != 3 {
+// t.Errorf("not enough messages sent to peer 2: %v ", len(peer2.Out))
+// } else {
+// if bytes.Compare(peer2.Out[1], packet) != 0 {
+// t.Errorf("incorrect broadcast packet %v != %v", peer2.Out[1], packet)
+// }
+// }
+// }
+// }
func TestServerPeersMessage(t *testing.T) {
- handlers := make(Handlers)
- _, server := SetupTestServer(handlers)
+ _, server := SetupTestServer(nil)
server.Start(true, true)
defer server.Stop()
server.peerConnect <- &TestAddr{"outboundpeer-1"}
- time.Sleep(10 * time.Millisecond)
- peersMsg, err := server.PeersMessage()
- fmt.Println(peersMsg)
- if err != nil {
- t.Errorf("expect no error, got %v", err)
+ time.Sleep(2000 * time.Millisecond)
+
+ pl := server.encodedPeerList()
+ if pl == nil {
+ t.Errorf("expect non-nil peer list")
}
if c := server.PeerCount(); c != 2 {
t.Errorf("expect 2 peers, got %v", c)