aboutsummaryrefslogtreecommitdiffstats
path: root/p2p/connection_test.go
diff options
context:
space:
mode:
Diffstat (limited to 'p2p/connection_test.go')
-rw-r--r--p2p/connection_test.go222
1 files changed, 222 insertions, 0 deletions
diff --git a/p2p/connection_test.go b/p2p/connection_test.go
new file mode 100644
index 000000000..76ee8021c
--- /dev/null
+++ b/p2p/connection_test.go
@@ -0,0 +1,222 @@
+package p2p
+
+import (
+ "bytes"
+ "fmt"
+ "io"
+ "net"
+ "testing"
+ "time"
+)
+
+type TestNetworkConnection struct {
+ in chan []byte
+ current []byte
+ Out [][]byte
+ addr net.Addr
+}
+
+func NewTestNetworkConnection(addr net.Addr) *TestNetworkConnection {
+ return &TestNetworkConnection{
+ in: make(chan []byte),
+ current: []byte{},
+ Out: [][]byte{},
+ addr: addr,
+ }
+}
+
+func (self *TestNetworkConnection) In(latency time.Duration, packets ...[]byte) {
+ time.Sleep(latency)
+ for _, s := range packets {
+ self.in <- s
+ }
+}
+
+func (self *TestNetworkConnection) Read(buff []byte) (n int, err error) {
+ if len(self.current) == 0 {
+ select {
+ case self.current = <-self.in:
+ default:
+ return 0, io.EOF
+ }
+ }
+ length := len(self.current)
+ if length > len(buff) {
+ copy(buff[:], self.current[:len(buff)])
+ self.current = self.current[len(buff):]
+ return len(buff), nil
+ } else {
+ copy(buff[:length], self.current[:])
+ self.current = []byte{}
+ return length, io.EOF
+ }
+}
+
+func (self *TestNetworkConnection) Write(buff []byte) (n int, err error) {
+ self.Out = append(self.Out, buff)
+ fmt.Printf("net write %v\n%v\n", len(self.Out), buff)
+ return len(buff), nil
+}
+
+func (self *TestNetworkConnection) Close() (err error) {
+ return
+}
+
+func (self *TestNetworkConnection) LocalAddr() (addr net.Addr) {
+ return
+}
+
+func (self *TestNetworkConnection) RemoteAddr() (addr net.Addr) {
+ return self.addr
+}
+
+func (self *TestNetworkConnection) SetDeadline(t time.Time) (err error) {
+ return
+}
+
+func (self *TestNetworkConnection) SetReadDeadline(t time.Time) (err error) {
+ return
+}
+
+func (self *TestNetworkConnection) SetWriteDeadline(t time.Time) (err error) {
+ return
+}
+
+func setupConnection() (*Connection, *TestNetworkConnection) {
+ addr := &TestAddr{"test:30303"}
+ net := NewTestNetworkConnection(addr)
+ conn := NewConnection(net, NewPeerErrorChannel())
+ conn.Open()
+ return conn, net
+}
+
+func TestReadingNilPacket(t *testing.T) {
+ conn, net := setupConnection()
+ go net.In(0, []byte{})
+ // time.Sleep(10 * time.Millisecond)
+ select {
+ case packet := <-conn.Read():
+ t.Errorf("read %v", packet)
+ case err := <-conn.Error():
+ t.Errorf("incorrect error %v", err)
+ default:
+ }
+ conn.Close()
+}
+
+func TestReadingShortPacket(t *testing.T) {
+ conn, net := setupConnection()
+ go net.In(0, []byte{0})
+ select {
+ case packet := <-conn.Read():
+ t.Errorf("read %v", packet)
+ case err := <-conn.Error():
+ if err.Code != PacketTooShort {
+ t.Errorf("incorrect error %v, expected %v", err.Code, PacketTooShort)
+ }
+ }
+ conn.Close()
+}
+
+func TestReadingInvalidPacket(t *testing.T) {
+ conn, net := setupConnection()
+ go net.In(0, []byte{1, 0, 0, 0, 0, 0, 0, 0})
+ select {
+ case packet := <-conn.Read():
+ t.Errorf("read %v", packet)
+ case err := <-conn.Error():
+ if err.Code != MagicTokenMismatch {
+ t.Errorf("incorrect error %v, expected %v", err.Code, MagicTokenMismatch)
+ }
+ }
+ conn.Close()
+}
+
+func TestReadingInvalidPayload(t *testing.T) {
+ conn, net := setupConnection()
+ go net.In(0, []byte{34, 64, 8, 145, 0, 0, 0, 2, 0})
+ select {
+ case packet := <-conn.Read():
+ t.Errorf("read %v", packet)
+ case err := <-conn.Error():
+ if err.Code != PayloadTooShort {
+ t.Errorf("incorrect error %v, expected %v", err.Code, PayloadTooShort)
+ }
+ }
+ conn.Close()
+}
+
+func TestReadingEmptyPayload(t *testing.T) {
+ conn, net := setupConnection()
+ go net.In(0, []byte{34, 64, 8, 145, 0, 0, 0, 0})
+ time.Sleep(10 * time.Millisecond)
+ select {
+ case packet := <-conn.Read():
+ t.Errorf("read %v", packet)
+ default:
+ }
+ select {
+ case err := <-conn.Error():
+ code := err.Code
+ if code != EmptyPayload {
+ t.Errorf("incorrect error, expected EmptyPayload, got %v", code)
+ }
+ default:
+ t.Errorf("no error, expected EmptyPayload")
+ }
+ conn.Close()
+}
+
+func TestReadingCompletePacket(t *testing.T) {
+ conn, net := setupConnection()
+ go net.In(0, []byte{34, 64, 8, 145, 0, 0, 0, 1, 1})
+ time.Sleep(10 * time.Millisecond)
+ select {
+ case packet := <-conn.Read():
+ if bytes.Compare(packet, []byte{1}) != 0 {
+ t.Errorf("incorrect payload read")
+ }
+ case err := <-conn.Error():
+ t.Errorf("incorrect error %v", err)
+ default:
+ t.Errorf("nothing read")
+ }
+ conn.Close()
+}
+
+func TestReadingTwoCompletePackets(t *testing.T) {
+ conn, net := setupConnection()
+ go net.In(0, []byte{34, 64, 8, 145, 0, 0, 0, 1, 0, 34, 64, 8, 145, 0, 0, 0, 1, 1})
+
+ for i := 0; i < 2; i++ {
+ time.Sleep(10 * time.Millisecond)
+ select {
+ case packet := <-conn.Read():
+ if bytes.Compare(packet, []byte{byte(i)}) != 0 {
+ t.Errorf("incorrect payload read")
+ }
+ case err := <-conn.Error():
+ t.Errorf("incorrect error %v", err)
+ default:
+ t.Errorf("nothing read")
+ }
+ }
+ conn.Close()
+}
+
+func TestWriting(t *testing.T) {
+ conn, net := setupConnection()
+ conn.Write() <- []byte{0}
+ time.Sleep(10 * time.Millisecond)
+ if len(net.Out) == 0 {
+ t.Errorf("no output")
+ } else {
+ out := net.Out[0]
+ if bytes.Compare(out, []byte{34, 64, 8, 145, 0, 0, 0, 1, 0}) != 0 {
+ t.Errorf("incorrect packet %v", out)
+ }
+ }
+ conn.Close()
+}
+
+// hello packet with client id ABC: 0x22 40 08 91 00 00 00 08 84 00 00 00 43414243