aboutsummaryrefslogtreecommitdiffstats
path: root/p2p/testing/protocolsession.go
blob: e3ec41ad67c73b9c7d0f91ed1bdbadd630ba849f (plain) (blame)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
// Copyright 2017 The go-ethereum Authors
// This file is part of the go-ethereum library.
//
// The go-ethereum library is free software: you can redistribute it and/or modify
// it under the terms of the GNU Lesser General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// The go-ethereum library is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU Lesser General Public License for more details.
//
// You should have received a copy of the GNU Lesser General Public License
// along with the go-ethereum library. If not, see <http://www.gnu.org/licenses/>.

package testing

import (
    "errors"
    "fmt"
    "sync"
    "time"

    "github.com/ethereum/go-ethereum/log"
    "github.com/ethereum/go-ethereum/p2p"
    "github.com/ethereum/go-ethereum/p2p/discover"
    "github.com/ethereum/go-ethereum/p2p/simulations/adapters"
)

var errTimedOut = errors.New("timed out")

// ProtocolSession is a quasi simulation of a pivot node running
// a service and a number of dummy peers that can send (trigger) or
// receive (expect) messages
type ProtocolSession struct {
    Server  *p2p.Server
    IDs     []discover.NodeID
    adapter *adapters.SimAdapter
    events  chan *p2p.PeerEvent
}

// Exchange is the basic units of protocol tests
// the triggers and expects in the arrays are run immediately and asynchronously
// thus one cannot have multiple expects for the SAME peer with DIFFERENT message types
// because it's unpredictable which expect will receive which message
// (with expect #1 and #2, messages might be sent #2 and #1, and both expects will complain about wrong message code)
// an exchange is defined on a session
type Exchange struct {
    Label    string
    Triggers []Trigger
    Expects  []Expect
    Timeout  time.Duration
}

// Trigger is part of the exchange, incoming message for the pivot node
// sent by a peer
type Trigger struct {
    Msg     interface{}     // type of message to be sent
    Code    uint64          // code of message is given
    Peer    discover.NodeID // the peer to send the message to
    Timeout time.Duration   // timeout duration for the sending
}

// Expect is part of an exchange, outgoing message from the pivot node
// received by a peer
type Expect struct {
    Msg     interface{}     // type of message to expect
    Code    uint64          // code of message is now given
    Peer    discover.NodeID // the peer that expects the message
    Timeout time.Duration   // timeout duration for receiving
}

// Disconnect represents a disconnect event, used and checked by TestDisconnected
type Disconnect struct {
    Peer  discover.NodeID // discconnected peer
    Error error           // disconnect reason
}

// trigger sends messages from peers
func (s *ProtocolSession) trigger(trig Trigger) error {
    simNode, ok := s.adapter.GetNode(trig.Peer)
    if !ok {
        return fmt.Errorf("trigger: peer %v does not exist (1- %v)", trig.Peer, len(s.IDs))
    }
    mockNode, ok := simNode.Services()[0].(*mockNode)
    if !ok {
        return fmt.Errorf("trigger: peer %v is not a mock", trig.Peer)
    }

    errc := make(chan error)

    go func() {
        log.Trace(fmt.Sprintf("trigger %v (%v)....", trig.Msg, trig.Code))
        errc <- mockNode.Trigger(&trig)
        log.Trace(fmt.Sprintf("triggered %v (%v)", trig.Msg, trig.Code))
    }()

    t := trig.Timeout
    if t == time.Duration(0) {
        t = 1000 * time.Millisecond
    }
    select {
    case err := <-errc:
        return err
    case <-time.After(t):
        return fmt.Errorf("timout expecting %v to send to peer %v", trig.Msg, trig.Peer)
    }
}

// expect checks an expectation of a message sent out by the pivot node
func (s *ProtocolSession) expect(exps []Expect) error {
    // construct a map of expectations for each node
    peerExpects := make(map[discover.NodeID][]Expect)
    for _, exp := range exps {
        if exp.Msg == nil {
            return errors.New("no message to expect")
        }
        peerExpects[exp.Peer] = append(peerExpects[exp.Peer], exp)
    }

    // construct a map of mockNodes for each node
    mockNodes := make(map[discover.NodeID]*mockNode)
    for nodeID := range peerExpects {
        simNode, ok := s.adapter.GetNode(nodeID)
        if !ok {
            return fmt.Errorf("trigger: peer %v does not exist (1- %v)", nodeID, len(s.IDs))
        }
        mockNode, ok := simNode.Services()[0].(*mockNode)
        if !ok {
            return fmt.Errorf("trigger: peer %v is not a mock", nodeID)
        }
        mockNodes[nodeID] = mockNode
    }

    // done chanell cancels all created goroutines when function returns
    done := make(chan struct{})
    defer close(done)
    // errc catches the first error from
    errc := make(chan error)

    wg := &sync.WaitGroup{}
    wg.Add(len(mockNodes))
    for nodeID, mockNode := range mockNodes {
        nodeID := nodeID
        mockNode := mockNode
        go func() {
            defer wg.Done()

            // Sum all Expect timeouts to give the maximum
            // time for all expectations to finish.
            // mockNode.Expect checks all received messages against
            // a list of expected messages and timeout for each
            // of them can not be checked separately.
            var t time.Duration
            for _, exp := range peerExpects[nodeID] {
                if exp.Timeout == time.Duration(0) {
                    t += 2000 * time.Millisecond
                } else {
                    t += exp.Timeout
                }
            }
            alarm := time.NewTimer(t)
            defer alarm.Stop()

            // expectErrc is used to check if error returned
            // from mockNode.Expect is not nil and to send it to
            // errc only in that case.
            // done channel will be closed when function
            expectErrc := make(chan error)
            go func() {
                select {
                case expectErrc <- mockNode.Expect(peerExpects[nodeID]...):
                case <-done:
                case <-alarm.C:
                }
            }()

            select {
            case err := <-expectErrc:
                if err != nil {
                    select {
                    case errc <- err:
                    case <-done:
                    case <-alarm.C:
                        errc <- errTimedOut
                    }
                }
            case <-done:
            case <-alarm.C:
                errc <- errTimedOut
            }

        }()
    }

    go func() {
        wg.Wait()
        // close errc when all goroutines finish to return nill err from errc
        close(errc)
    }()

    return <-errc
}

// TestExchanges tests a series of exchanges against the session
func (s *ProtocolSession) TestExchanges(exchanges ...Exchange) error {
    for i, e := range exchanges {
        if err := s.testExchange(e); err != nil {
            return fmt.Errorf("exchange #%d %q: %v", i, e.Label, err)
        }
        log.Trace(fmt.Sprintf("exchange #%d %q: run successfully", i, e.Label))
    }
    return nil
}

// testExchange tests a single Exchange.
// Default timeout value is 2 seconds.
func (s *ProtocolSession) testExchange(e Exchange) error {
    errc := make(chan error)
    done := make(chan struct{})
    defer close(done)

    go func() {
        for _, trig := range e.Triggers {
            err := s.trigger(trig)
            if err != nil {
                errc <- err
                return
            }
        }

        select {
        case errc <- s.expect(e.Expects):
        case <-done:
        }
    }()

    // time out globally or finish when all expectations satisfied
    t := e.Timeout
    if t == 0 {
        t = 2000 * time.Millisecond
    }
    alarm := time.NewTimer(t)
    select {
    case err := <-errc:
        return err
    case <-alarm.C:
        return errTimedOut
    }
}

// TestDisconnected tests the disconnections given as arguments
// the disconnect structs describe what disconnect error is expected on which peer
func (s *ProtocolSession) TestDisconnected(disconnects ...*Disconnect) error {
    expects := make(map[discover.NodeID]error)
    for _, disconnect := range disconnects {
        expects[disconnect.Peer] = disconnect.Error
    }

    timeout := time.After(time.Second)
    for len(expects) > 0 {
        select {
        case event := <-s.events:
            if event.Type != p2p.PeerEventTypeDrop {
                continue
            }
            expectErr, ok := expects[event.Peer]
            if !ok {
                continue
            }

            if !(expectErr == nil && event.Error == "" || expectErr != nil && expectErr.Error() == event.Error) {
                return fmt.Errorf("unexpected error on peer %v. expected '%v', got '%v'", event.Peer, expectErr, event.Error)
            }
            delete(expects, event.Peer)
        case <-timeout:
            return fmt.Errorf("timed out waiting for peers to disconnect")
        }
    }
    return nil
}