diff options
author | Felix Lange <fjl@twurst.com> | 2014-12-12 18:38:42 +0800 |
---|---|---|
committer | Felix Lange <fjl@twurst.com> | 2014-12-12 18:40:02 +0800 |
commit | e28c60caf9a31669451124a9add2b9036bec1e73 (patch) | |
tree | 345d075dfabd629ac22cce150317116491a97b7f | |
parent | 9423401d73562b2a55398559a9d21af75210d955 (diff) | |
download | go-tangerine-e28c60caf9a31669451124a9add2b9036bec1e73.tar.gz go-tangerine-e28c60caf9a31669451124a9add2b9036bec1e73.tar.zst go-tangerine-e28c60caf9a31669451124a9add2b9036bec1e73.zip |
p2p: improve and test eofSignal
-rw-r--r-- | p2p/peer.go | 17 | ||||
-rw-r--r-- | p2p/peer_test.go | 56 |
2 files changed, 68 insertions, 5 deletions
diff --git a/p2p/peer.go b/p2p/peer.go index 893ba86d7..86c4d7ab5 100644 --- a/p2p/peer.go +++ b/p2p/peer.go @@ -300,7 +300,7 @@ func (p *Peer) dispatch(msg Msg, protoDone chan struct{}) (wait bool, err error) proto.in <- msg } else { wait = true - pr := &eofSignal{msg.Payload, protoDone} + pr := &eofSignal{msg.Payload, int64(msg.Size), protoDone} msg.Payload = pr proto.in <- msg } @@ -438,18 +438,25 @@ func (rw *proto) ReadMsg() (Msg, error) { return msg, nil } -// eofSignal wraps a reader with eof signaling. -// the eof channel is closed when the wrapped reader -// reaches EOF. +// eofSignal wraps a reader with eof signaling. the eof channel is +// closed when the wrapped reader returns an error or when count bytes +// have been read. +// type eofSignal struct { wrapped io.Reader + count int64 eof chan<- struct{} } +// note: when using eofSignal to detect whether a message payload +// has been read, Read might not be called for zero sized messages. + func (r *eofSignal) Read(buf []byte) (int, error) { n, err := r.wrapped.Read(buf) - if err != nil { + r.count -= int64(n) + if (err != nil || r.count <= 0) && r.eof != nil { r.eof <- struct{}{} // tell Peer that msg has been consumed + r.eof = nil } return n, err } diff --git a/p2p/peer_test.go b/p2p/peer_test.go index d9640292f..f7759786e 100644 --- a/p2p/peer_test.go +++ b/p2p/peer_test.go @@ -4,6 +4,7 @@ import ( "bufio" "bytes" "encoding/hex" + "io" "io/ioutil" "net" "reflect" @@ -237,3 +238,58 @@ func TestNewPeer(t *testing.T) { // Should not hang. p.Disconnect(DiscAlreadyConnected) } + +func TestEOFSignal(t *testing.T) { + rb := make([]byte, 10) + + // empty reader + eof := make(chan struct{}, 1) + sig := &eofSignal{new(bytes.Buffer), 0, eof} + if n, err := sig.Read(rb); n != 0 || err != io.EOF { + t.Errorf("Read returned unexpected values: (%v, %v)", n, err) + } + select { + case <-eof: + default: + t.Error("EOF chan not signaled") + } + + // count before error + eof = make(chan struct{}, 1) + sig = &eofSignal{bytes.NewBufferString("aaaaaaaa"), 4, eof} + if n, err := sig.Read(rb); n != 8 || err != nil { + t.Errorf("Read returned unexpected values: (%v, %v)", n, err) + } + select { + case <-eof: + default: + t.Error("EOF chan not signaled") + } + + // error before count + eof = make(chan struct{}, 1) + sig = &eofSignal{bytes.NewBufferString("aaaa"), 999, eof} + if n, err := sig.Read(rb); n != 4 || err != nil { + t.Errorf("Read returned unexpected values: (%v, %v)", n, err) + } + if n, err := sig.Read(rb); n != 0 || err != io.EOF { + t.Errorf("Read returned unexpected values: (%v, %v)", n, err) + } + select { + case <-eof: + default: + t.Error("EOF chan not signaled") + } + + // no signal if neither occurs + eof = make(chan struct{}, 1) + sig = &eofSignal{bytes.NewBufferString("aaaaaaaaaaaaaaaaaaaaa"), 999, eof} + if n, err := sig.Read(rb); n != 10 || err != nil { + t.Errorf("Read returned unexpected values: (%v, %v)", n, err) + } + select { + case <-eof: + t.Error("unexpected EOF signal") + default: + } +} |