aboutsummaryrefslogtreecommitdiffstats
path: root/p2p
diff options
context:
space:
mode:
Diffstat (limited to 'p2p')
-rw-r--r--p2p/protocols/protocol.go89
-rw-r--r--p2p/protocols/protocol_test.go4
-rw-r--r--p2p/testing/protocoltester.go21
3 files changed, 101 insertions, 13 deletions
diff --git a/p2p/protocols/protocol.go b/p2p/protocols/protocol.go
index d5c0375ac..615f74b56 100644
--- a/p2p/protocols/protocol.go
+++ b/p2p/protocols/protocol.go
@@ -29,6 +29,8 @@ devp2p subprotocols by abstracting away code standardly shared by protocols.
package protocols
import (
+ "bufio"
+ "bytes"
"context"
"fmt"
"io"
@@ -39,6 +41,10 @@ import (
"github.com/ethereum/go-ethereum/log"
"github.com/ethereum/go-ethereum/metrics"
"github.com/ethereum/go-ethereum/p2p"
+ "github.com/ethereum/go-ethereum/rlp"
+ "github.com/ethereum/go-ethereum/swarm/spancontext"
+ "github.com/ethereum/go-ethereum/swarm/tracing"
+ opentracing "github.com/opentracing/opentracing-go"
)
// error codes used by this protocol scheme
@@ -109,6 +115,13 @@ func errorf(code int, format string, params ...interface{}) *Error {
}
}
+// WrappedMsg is used to propagate marshalled context alongside message payloads
+type WrappedMsg struct {
+ Context []byte
+ Size uint32
+ Payload []byte
+}
+
// Spec is a protocol specification including its name and version as well as
// the types of messages which are exchanged
type Spec struct {
@@ -201,7 +214,7 @@ func NewPeer(p *p2p.Peer, rw p2p.MsgReadWriter, spec *Spec) *Peer {
// the handler argument is a function which is called for each message received
// from the remote peer, a returned error causes the loop to exit
// resulting in disconnection
-func (p *Peer) Run(handler func(msg interface{}) error) error {
+func (p *Peer) Run(handler func(ctx context.Context, msg interface{}) error) error {
for {
if err := p.handleIncoming(handler); err != nil {
if err != io.EOF {
@@ -225,14 +238,47 @@ func (p *Peer) Drop(err error) {
// message off to the peer
// this low level call will be wrapped by libraries providing routed or broadcast sends
// but often just used to forward and push messages to directly connected peers
-func (p *Peer) Send(msg interface{}) error {
+func (p *Peer) Send(ctx context.Context, msg interface{}) error {
defer metrics.GetOrRegisterResettingTimer("peer.send_t", nil).UpdateSince(time.Now())
metrics.GetOrRegisterCounter("peer.send", nil).Inc(1)
+
+ var b bytes.Buffer
+ if tracing.Enabled {
+ writer := bufio.NewWriter(&b)
+
+ tracer := opentracing.GlobalTracer()
+
+ sctx := spancontext.FromContext(ctx)
+
+ if sctx != nil {
+ err := tracer.Inject(
+ sctx,
+ opentracing.Binary,
+ writer)
+ if err != nil {
+ return err
+ }
+ }
+
+ writer.Flush()
+ }
+
+ r, err := rlp.EncodeToBytes(msg)
+ if err != nil {
+ return err
+ }
+
+ wmsg := WrappedMsg{
+ Context: b.Bytes(),
+ Size: uint32(len(r)),
+ Payload: r,
+ }
+
code, found := p.spec.GetCode(msg)
if !found {
return errorf(ErrInvalidMsgType, "%v", code)
}
- return p2p.Send(p.rw, code, msg)
+ return p2p.Send(p.rw, code, wmsg)
}
// handleIncoming(code)
@@ -243,7 +289,7 @@ func (p *Peer) Send(msg interface{}) error {
// * checks for out-of-range message codes,
// * handles decoding with reflection,
// * call handlers as callbacks
-func (p *Peer) handleIncoming(handle func(msg interface{}) error) error {
+func (p *Peer) handleIncoming(handle func(ctx context.Context, msg interface{}) error) error {
msg, err := p.rw.ReadMsg()
if err != nil {
return err
@@ -255,11 +301,38 @@ func (p *Peer) handleIncoming(handle func(msg interface{}) error) error {
return errorf(ErrMsgTooLong, "%v > %v", msg.Size, p.spec.MaxMsgSize)
}
+ // unmarshal wrapped msg, which might contain context
+ var wmsg WrappedMsg
+ err = msg.Decode(&wmsg)
+ if err != nil {
+ log.Error(err.Error())
+ return err
+ }
+
+ ctx := context.Background()
+
+ // if tracing is enabled and the context coming within the request is
+ // not empty, try to unmarshal it
+ if tracing.Enabled && len(wmsg.Context) > 0 {
+ var sctx opentracing.SpanContext
+
+ tracer := opentracing.GlobalTracer()
+ sctx, err = tracer.Extract(
+ opentracing.Binary,
+ bytes.NewReader(wmsg.Context))
+ if err != nil {
+ log.Error(err.Error())
+ return err
+ }
+
+ ctx = spancontext.WithContext(ctx, sctx)
+ }
+
val, ok := p.spec.NewMsg(msg.Code)
if !ok {
return errorf(ErrInvalidMsgCode, "%v", msg.Code)
}
- if err := msg.Decode(val); err != nil {
+ if err := rlp.DecodeBytes(wmsg.Payload, val); err != nil {
return errorf(ErrDecode, "<= %v: %v", msg, err)
}
@@ -268,7 +341,7 @@ func (p *Peer) handleIncoming(handle func(msg interface{}) error) error {
// which the handler is supposed to cast to the appropriate type
// it is entirely safe not to check the cast in the handler since the handler is
// chosen based on the proper type in the first place
- if err := handle(val); err != nil {
+ if err := handle(ctx, val); err != nil {
return errorf(ErrHandler, "(msg code %v): %v", msg.Code, err)
}
return nil
@@ -288,14 +361,14 @@ func (p *Peer) Handshake(ctx context.Context, hs interface{}, verify func(interf
return nil, errorf(ErrHandshake, "unknown handshake message type: %T", hs)
}
errc := make(chan error, 2)
- handle := func(msg interface{}) error {
+ handle := func(ctx context.Context, msg interface{}) error {
rhs = msg
if verify != nil {
return verify(rhs)
}
return nil
}
- send := func() { errc <- p.Send(hs) }
+ send := func() { errc <- p.Send(ctx, hs) }
receive := func() { errc <- p.handleIncoming(handle) }
go func() {
diff --git a/p2p/protocols/protocol_test.go b/p2p/protocols/protocol_test.go
index aaae7502b..11df8ff39 100644
--- a/p2p/protocols/protocol_test.go
+++ b/p2p/protocols/protocol_test.go
@@ -104,7 +104,7 @@ func newProtocol(pp *p2ptest.TestPeerPool) func(*p2p.Peer, p2p.MsgReadWriter) er
return fmt.Errorf("handshake mismatch remote %v > local %v", rmhs.C, lhs.C)
}
- handle := func(msg interface{}) error {
+ handle := func(ctx context.Context, msg interface{}) error {
switch msg := msg.(type) {
case *protoHandshake:
@@ -116,7 +116,7 @@ func newProtocol(pp *p2ptest.TestPeerPool) func(*p2p.Peer, p2p.MsgReadWriter) er
return fmt.Errorf("handshake mismatch remote %v > local %v", rhs.C, lhs.C)
}
lhs.C += rhs.C
- return peer.Send(lhs)
+ return peer.Send(ctx, lhs)
case *kill:
// demonstrates use of peerPool, killing another peer connection as a response to a message
diff --git a/p2p/testing/protocoltester.go b/p2p/testing/protocoltester.go
index 636613c57..c99578fe0 100644
--- a/p2p/testing/protocoltester.go
+++ b/p2p/testing/protocoltester.go
@@ -180,7 +180,8 @@ func (m *mockNode) Run(peer *p2p.Peer, rw p2p.MsgReadWriter) error {
for {
select {
case trig := <-m.trigger:
- m.err <- p2p.Send(rw, trig.Code, trig.Msg)
+ wmsg := Wrap(trig.Msg)
+ m.err <- p2p.Send(rw, trig.Code, wmsg)
case exps := <-m.expect:
m.err <- expectMsgs(rw, exps)
case <-m.stop:
@@ -220,7 +221,7 @@ func expectMsgs(rw p2p.MsgReadWriter, exps []Expect) error {
}
var found bool
for i, exp := range exps {
- if exp.Code == msg.Code && bytes.Equal(actualContent, mustEncodeMsg(exp.Msg)) {
+ if exp.Code == msg.Code && bytes.Equal(actualContent, mustEncodeMsg(Wrap(exp.Msg))) {
if matched[i] {
return fmt.Errorf("message #%d received two times", i)
}
@@ -235,7 +236,7 @@ func expectMsgs(rw p2p.MsgReadWriter, exps []Expect) error {
if matched[i] {
continue
}
- expected = append(expected, fmt.Sprintf("code %d payload %x", exp.Code, mustEncodeMsg(exp.Msg)))
+ expected = append(expected, fmt.Sprintf("code %d payload %x", exp.Code, mustEncodeMsg(Wrap(exp.Msg))))
}
return fmt.Errorf("unexpected message code %d payload %x, expected %s", msg.Code, actualContent, strings.Join(expected, " or "))
}
@@ -267,3 +268,17 @@ func mustEncodeMsg(msg interface{}) []byte {
}
return contentEnc
}
+
+type WrappedMsg struct {
+ Context []byte
+ Size uint32
+ Payload []byte
+}
+
+func Wrap(msg interface{}) interface{} {
+ data, _ := rlp.EncodeToBytes(msg)
+ return &WrappedMsg{
+ Size: uint32(len(data)),
+ Payload: data,
+ }
+}