diff options
Diffstat (limited to 'p2p/peer_test.go')
-rw-r--r-- | p2p/peer_test.go | 297 |
1 files changed, 162 insertions, 135 deletions
diff --git a/p2p/peer_test.go b/p2p/peer_test.go index 4ee88f112..68c9910a2 100644 --- a/p2p/peer_test.go +++ b/p2p/peer_test.go @@ -1,15 +1,17 @@ package p2p import ( - "bufio" "bytes" - "encoding/hex" - "io" + "fmt" "io/ioutil" "net" "reflect" + "sort" "testing" "time" + + "github.com/ethereum/go-ethereum/p2p/discover" + "github.com/ethereum/go-ethereum/rlp" ) var discard = Protocol{ @@ -28,17 +30,13 @@ var discard = Protocol{ }, } -func testPeer(protos []Protocol) (net.Conn, *Peer, <-chan error) { +func testPeer(noHandshake bool, protos []Protocol) (*frameRW, *Peer, <-chan DiscReason) { conn1, conn2 := net.Pipe() - peer := newPeer(conn1, protos, nil) - peer.ourID = &peerId{} - peer.pubkeyHook = func(*peerAddr) error { return nil } - errc := make(chan error, 1) - go func() { - _, err := peer.loop() - errc <- err - }() - return conn2, peer, errc + peer := newPeer(conn1, protos, "name", &discover.NodeID{}, &discover.NodeID{}) + peer.noHandshake = noHandshake + errc := make(chan DiscReason, 1) + go func() { errc <- peer.run() }() + return newFrameRW(conn2, msgWriteTimeout), peer, errc } func TestPeerProtoReadMsg(t *testing.T) { @@ -49,31 +47,28 @@ func TestPeerProtoReadMsg(t *testing.T) { Name: "a", Length: 5, Run: func(peer *Peer, rw MsgReadWriter) error { - msg, err := rw.ReadMsg() - if err != nil { - t.Errorf("read error: %v", err) + if err := expectMsg(rw, 2, []uint{1}); err != nil { + t.Error(err) } - if msg.Code != 2 { - t.Errorf("incorrect msg code %d relayed to protocol", msg.Code) - } - data, err := ioutil.ReadAll(msg.Payload) - if err != nil { - t.Errorf("payload read error: %v", err) + if err := expectMsg(rw, 3, []uint{2}); err != nil { + t.Error(err) } - expdata, _ := hex.DecodeString("0183303030") - if !bytes.Equal(expdata, data) { - t.Errorf("incorrect msg data %x", data) + if err := expectMsg(rw, 4, []uint{3}); err != nil { + t.Error(err) } close(done) return nil }, } - net, peer, errc := testPeer([]Protocol{proto}) - defer net.Close() + rw, peer, errc := testPeer(true, []Protocol{proto}) + defer rw.Close() peer.startSubprotocols([]Cap{proto.cap()}) - writeMsg(net, NewMsg(18, 1, "000")) + EncodeMsg(rw, baseProtocolLength+2, 1) + EncodeMsg(rw, baseProtocolLength+3, 2) + EncodeMsg(rw, baseProtocolLength+4, 3) + select { case <-done: case err := <-errc: @@ -105,11 +100,11 @@ func TestPeerProtoReadLargeMsg(t *testing.T) { }, } - net, peer, errc := testPeer([]Protocol{proto}) - defer net.Close() + rw, peer, errc := testPeer(true, []Protocol{proto}) + defer rw.Close() peer.startSubprotocols([]Cap{proto.cap()}) - writeMsg(net, NewMsg(18, make([]byte, msgsize))) + EncodeMsg(rw, 18, make([]byte, msgsize)) select { case <-done: case err := <-errc: @@ -135,32 +130,20 @@ func TestPeerProtoEncodeMsg(t *testing.T) { return nil }, } - net, peer, _ := testPeer([]Protocol{proto}) - defer net.Close() + rw, peer, _ := testPeer(true, []Protocol{proto}) + defer rw.Close() peer.startSubprotocols([]Cap{proto.cap()}) - bufr := bufio.NewReader(net) - msg, err := readMsg(bufr) - if err != nil { - t.Errorf("read error: %v", err) - } - if msg.Code != 17 { - t.Errorf("incorrect message code: got %d, expected %d", msg.Code, 17) - } - var data []string - if err := msg.Decode(&data); err != nil { - t.Errorf("payload decode error: %v", err) - } - if !reflect.DeepEqual(data, []string{"foo", "bar"}) { - t.Errorf("payload RLP mismatch, got %#v, want %#v", data, []string{"foo", "bar"}) + if err := expectMsg(rw, 17, []string{"foo", "bar"}); err != nil { + t.Error(err) } } -func TestPeerWrite(t *testing.T) { +func TestPeerWriteForBroadcast(t *testing.T) { defer testlog(t).detach() - net, peer, peerErr := testPeer([]Protocol{discard}) - defer net.Close() + rw, peer, peerErr := testPeer(true, []Protocol{discard}) + defer rw.Close() peer.startSubprotocols([]Cap{discard.cap()}) // test write errors @@ -176,18 +159,13 @@ func TestPeerWrite(t *testing.T) { // setup for reading the message on the other end read := make(chan struct{}) go func() { - bufr := bufio.NewReader(net) - msg, err := readMsg(bufr) - if err != nil { - t.Errorf("read error: %v", err) - } else if msg.Code != 16 { - t.Errorf("wrong code, got %d, expected %d", msg.Code, 16) + if err := expectMsg(rw, 16, nil); err != nil { + t.Error() } - msg.Discard() close(read) }() - // test succcessful write + // test successful write if err := peer.writeProtoMsg("discard", NewMsg(0)); err != nil { t.Errorf("expect no error for known protocol: %v", err) } @@ -198,104 +176,153 @@ func TestPeerWrite(t *testing.T) { } } -func TestPeerActivity(t *testing.T) { - // shorten inactivityTimeout while this test is running - oldT := inactivityTimeout - defer func() { inactivityTimeout = oldT }() - inactivityTimeout = 20 * time.Millisecond +func TestPeerPing(t *testing.T) { + defer testlog(t).detach() - net, peer, peerErr := testPeer([]Protocol{discard}) - defer net.Close() - peer.startSubprotocols([]Cap{discard.cap()}) + rw, _, _ := testPeer(true, nil) + defer rw.Close() + if err := EncodeMsg(rw, pingMsg); err != nil { + t.Fatal(err) + } + if err := expectMsg(rw, pongMsg, nil); err != nil { + t.Error(err) + } +} - sub := peer.activity.Subscribe(time.Time{}) - defer sub.Unsubscribe() +func TestPeerDisconnect(t *testing.T) { + defer testlog(t).detach() - for i := 0; i < 6; i++ { - writeMsg(net, NewMsg(16)) - select { - case <-sub.Chan(): - case <-time.After(inactivityTimeout / 2): - t.Fatal("no event within ", inactivityTimeout/2) - case err := <-peerErr: - t.Fatal("peer error", err) - } + rw, _, disc := testPeer(true, nil) + defer rw.Close() + if err := EncodeMsg(rw, discMsg, DiscQuitting); err != nil { + t.Fatal(err) } - - select { - case <-time.After(inactivityTimeout * 2): - case <-sub.Chan(): - t.Fatal("got activity event while connection was inactive") - case err := <-peerErr: - t.Fatal("peer error", err) + if err := expectMsg(rw, discMsg, []interface{}{DiscRequested}); err != nil { + t.Error(err) + } + rw.Close() // make test end faster + if reason := <-disc; reason != DiscRequested { + t.Errorf("run returned wrong reason: got %v, want %v", reason, DiscRequested) } } -func TestNewPeer(t *testing.T) { - caps := []Cap{{"foo", 2}, {"bar", 3}} - id := &peerId{} - p := NewPeer(id, caps) - if !reflect.DeepEqual(p.Caps(), caps) { - t.Errorf("Caps mismatch: got %v, expected %v", p.Caps(), caps) +func TestPeerHandshake(t *testing.T) { + defer testlog(t).detach() + + // remote has two matching protocols: a and c + remote := NewPeer(randomID(), "", []Cap{{"a", 1}, {"b", 999}, {"c", 3}}) + remoteID := randomID() + remote.ourID = &remoteID + remote.ourName = "remote peer" + + start := make(chan string) + stop := make(chan struct{}) + run := func(p *Peer, rw MsgReadWriter) error { + name := rw.(*proto).name + if name != "a" && name != "c" { + t.Errorf("protocol %q should not be started", name) + } else { + start <- name + } + <-stop + return nil } - if p.Identity() != id { - t.Errorf("Identity mismatch: got %v, expected %v", p.Identity(), id) + protocols := []Protocol{ + {Name: "a", Version: 1, Length: 1, Run: run}, + {Name: "b", Version: 2, Length: 1, Run: run}, + {Name: "c", Version: 3, Length: 1, Run: run}, + {Name: "d", Version: 4, Length: 1, Run: run}, } - // Should not hang. - p.Disconnect(DiscAlreadyConnected) -} + rw, p, disc := testPeer(false, protocols) + p.remoteID = remote.ourID + defer rw.Close() -func TestEOFSignal(t *testing.T) { - rb := make([]byte, 10) + // run the handshake + remoteProtocols := []Protocol{protocols[0], protocols[2]} + if err := writeProtocolHandshake(rw, "remote peer", remoteID, remoteProtocols); err != nil { + t.Fatalf("handshake write error: %v", err) + } + if err := readProtocolHandshake(remote, rw); err != nil { + t.Fatalf("handshake read error: %v", err) + } - // empty reader - eof := make(chan struct{}, 1) - sig := &eofSignal{new(bytes.Buffer), 0, eof} - if n, err := sig.Read(rb); n != 0 || err != io.EOF { - t.Errorf("Read returned unexpected values: (%v, %v)", n, err) + // check that all protocols have been started + var started []string + for i := 0; i < 2; i++ { + select { + case name := <-start: + started = append(started, name) + case <-time.After(100 * time.Millisecond): + } } - select { - case <-eof: - default: - t.Error("EOF chan not signaled") + sort.Strings(started) + if !reflect.DeepEqual(started, []string{"a", "c"}) { + t.Errorf("wrong protocols started: %v", started) } - // count before error - eof = make(chan struct{}, 1) - sig = &eofSignal{bytes.NewBufferString("aaaaaaaa"), 4, eof} - if n, err := sig.Read(rb); n != 8 || err != nil { - t.Errorf("Read returned unexpected values: (%v, %v)", n, err) + // check that metadata has been set + if p.ID() != remoteID { + t.Errorf("peer has wrong node ID: got %v, want %v", p.ID(), remoteID) } - select { - case <-eof: - default: - t.Error("EOF chan not signaled") + if p.Name() != remote.ourName { + t.Errorf("peer has wrong node name: got %q, want %q", p.Name(), remote.ourName) } - // error before count - eof = make(chan struct{}, 1) - sig = &eofSignal{bytes.NewBufferString("aaaa"), 999, eof} - if n, err := sig.Read(rb); n != 4 || err != nil { - t.Errorf("Read returned unexpected values: (%v, %v)", n, err) + close(stop) + expectMsg(rw, discMsg, nil) + t.Logf("disc reason: %v", <-disc) +} + +func TestNewPeer(t *testing.T) { + name := "nodename" + caps := []Cap{{"foo", 2}, {"bar", 3}} + id := randomID() + p := NewPeer(id, name, caps) + if p.ID() != id { + t.Errorf("ID mismatch: got %v, expected %v", p.ID(), id) } - if n, err := sig.Read(rb); n != 0 || err != io.EOF { - t.Errorf("Read returned unexpected values: (%v, %v)", n, err) + if p.Name() != name { + t.Errorf("Name mismatch: got %v, expected %v", p.Name(), name) } - select { - case <-eof: - default: - t.Error("EOF chan not signaled") + if !reflect.DeepEqual(p.Caps(), caps) { + t.Errorf("Caps mismatch: got %v, expected %v", p.Caps(), caps) } - // no signal if neither occurs - eof = make(chan struct{}, 1) - sig = &eofSignal{bytes.NewBufferString("aaaaaaaaaaaaaaaaaaaaa"), 999, eof} - if n, err := sig.Read(rb); n != 10 || err != nil { - t.Errorf("Read returned unexpected values: (%v, %v)", n, err) + p.Disconnect(DiscAlreadyConnected) // Should not hang +} + +// expectMsg reads a message from r and verifies that its +// code and encoded RLP content match the provided values. +// If content is nil, the payload is discarded and not verified. +func expectMsg(r MsgReader, code uint64, content interface{}) error { + msg, err := r.ReadMsg() + if err != nil { + return err } - select { - case <-eof: - t.Error("unexpected EOF signal") - default: + if msg.Code != code { + return fmt.Errorf("message code mismatch: got %d, expected %d", msg.Code, code) + } + if content == nil { + return msg.Discard() + } else { + contentEnc, err := rlp.EncodeToBytes(content) + if err != nil { + panic("content encode error: " + err.Error()) + } + // skip over list header in encoded value. this is temporary. + contentEncR := bytes.NewReader(contentEnc) + if k, _, err := rlp.NewStream(contentEncR).Kind(); k != rlp.List || err != nil { + panic("content must encode as RLP list") + } + contentEnc = contentEnc[len(contentEnc)-contentEncR.Len():] + + actualContent, err := ioutil.ReadAll(msg.Payload) + if err != nil { + return err + } + if !bytes.Equal(actualContent, contentEnc) { + return fmt.Errorf("message payload mismatch:\ngot: %x\nwant: %x", actualContent, contentEnc) + } } + return nil } |