aboutsummaryrefslogtreecommitdiffstats
path: root/p2p/protocol_test.go
blob: ce25b3e1b541ccff6c760bbcbb8b0ce8616c86a4 (plain) (blame)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
package p2p

import (
    "fmt"
    "net"
    "reflect"
    "testing"

    "github.com/ethereum/go-ethereum/crypto"
)

type peerId struct {
    pubkey []byte
}

func (self *peerId) String() string {
    return fmt.Sprintf("test peer %x", self.Pubkey()[:4])
}

func (self *peerId) Pubkey() (pubkey []byte) {
    pubkey = self.pubkey
    if len(pubkey) == 0 {
        pubkey = crypto.GenerateNewKeyPair().PublicKey
        self.pubkey = pubkey
    }
    return
}

func newTestPeer() (peer *Peer) {
    peer = NewPeer(&peerId{}, []Cap{})
    peer.pubkeyHook = func(*peerAddr) error { return nil }
    peer.ourID = &peerId{}
    peer.listenAddr = &peerAddr{}
    peer.otherPeers = func() []*Peer { return nil }
    return
}

func TestBaseProtocolPeers(t *testing.T) {
    cannedPeerList := []*peerAddr{
        {IP: net.ParseIP("1.2.3.4"), Port: 2222, Pubkey: []byte{}},
        {IP: net.ParseIP("5.6.7.8"), Port: 3333, Pubkey: []byte{}},
    }
    var ownAddr *peerAddr = &peerAddr{IP: net.ParseIP("1.3.5.7"), Port: 1111, Pubkey: []byte{}}
    rw1, rw2 := MsgPipe()
    // run matcher, close pipe when addresses have arrived
    addrChan := make(chan *peerAddr, len(cannedPeerList))
    go func() {
        for _, want := range cannedPeerList {
            got := <-addrChan
            t.Logf("got peer: %+v", got)
            if !reflect.DeepEqual(want, got) {
                t.Errorf("mismatch:  got %#v, want %#v", got, want)
            }
        }
        close(addrChan)
        var own []*peerAddr
        var got *peerAddr
        for got = range addrChan {
            own = append(own, got)
        }
        if len(own) != 1 || !reflect.DeepEqual(ownAddr, own[0]) {
            t.Errorf("mismatch: peers own address is incorrectly or not given, got %v, want %#v", ownAddr)
        }
        rw2.Close()
    }()
    // run first peer
    peer1 := newTestPeer()
    peer1.ourListenAddr = ownAddr
    peer1.otherPeers = func() []*Peer {
        pl := make([]*Peer, len(cannedPeerList))
        for i, addr := range cannedPeerList {
            pl[i] = &Peer{listenAddr: addr}
        }
        return pl
    }
    go runBaseProtocol(peer1, rw1)
    // run second peer
    peer2 := newTestPeer()
    peer2.newPeerAddr = addrChan // feed peer suggestions into matcher
    if err := runBaseProtocol(peer2, rw2); err != ErrPipeClosed {
        t.Errorf("peer2 terminated with unexpected error: %v", err)
    }
}

func TestBaseProtocolDisconnect(t *testing.T) {
    peer := NewPeer(&peerId{}, nil)
    peer.ourID = &peerId{}
    peer.pubkeyHook = func(*peerAddr) error { return nil }

    rw1, rw2 := MsgPipe()
    done := make(chan struct{})
    go func() {
        if err := expectMsg(rw2, handshakeMsg); err != nil {
            t.Error(err)
        }
        err := rw2.EncodeMsg(handshakeMsg,
            baseProtocolVersion,
            "",
            []interface{}{},
            0,
            make([]byte, 64),
        )
        if err != nil {
            t.Error(err)
        }
        if err := expectMsg(rw2, getPeersMsg); err != nil {
            t.Error(err)
        }
        if err := rw2.EncodeMsg(discMsg, DiscQuitting); err != nil {
            t.Error(err)
        }

        close(done)
    }()

    if err := runBaseProtocol(peer, rw1); err == nil {
        t.Errorf("base protocol returned without error")
    } else if reason, ok := err.(discRequestedError); !ok || reason != DiscQuitting {
        t.Errorf("base protocol returned wrong error: %v", err)
    }
    <-done
}

func expectMsg(r MsgReader, code uint64) error {
    msg, err := r.ReadMsg()
    if err != nil {
        return err
    }
    if err := msg.Discard(); err != nil {
        return err
    }
    if msg.Code != code {
        return fmt.Errorf("wrong message code: got %d, expected %d", msg.Code, code)
    }
    return nil
}