From d3f056bd68fb6a8e9ffa3938d5404c6a209e0313 Mon Sep 17 00:00:00 2001
From: Ferenc Szabo <frncmx@gmail.com>
Date: Fri, 21 Sep 2018 12:56:43 +0200
Subject: swarm/network/stream: fix DoS invalid hash length (#927)

---
 swarm/network/stream/messages.go      | 14 +++++--
 swarm/network/stream/streamer_test.go | 77 ++++++++++++++++++++++++++++++++---
 2 files changed, 82 insertions(+), 9 deletions(-)

(limited to 'swarm/network')

diff --git a/swarm/network/stream/messages.go b/swarm/network/stream/messages.go
index 2e1a81e82..62c46b120 100644
--- a/swarm/network/stream/messages.go
+++ b/swarm/network/stream/messages.go
@@ -26,7 +26,7 @@ import (
 	bv "github.com/ethereum/go-ethereum/swarm/network/bitvector"
 	"github.com/ethereum/go-ethereum/swarm/spancontext"
 	"github.com/ethereum/go-ethereum/swarm/storage"
-	opentracing "github.com/opentracing/opentracing-go"
+	"github.com/opentracing/opentracing-go"
 )
 
 var syncBatchTimeout = 30 * time.Second
@@ -195,10 +195,16 @@ func (p *Peer) handleOfferedHashesMsg(ctx context.Context, req *OfferedHashesMsg
 	if err != nil {
 		return err
 	}
+
 	hashes := req.Hashes
-	want, err := bv.New(len(hashes) / HashSize)
+	lenHashes := len(hashes)
+	if lenHashes%HashSize != 0 {
+		return fmt.Errorf("error invalid hashes length (len: %v)", lenHashes)
+	}
+
+	want, err := bv.New(lenHashes / HashSize)
 	if err != nil {
-		return fmt.Errorf("error initiaising bitvector of length %v: %v", len(hashes)/HashSize, err)
+		return fmt.Errorf("error initiaising bitvector of length %v: %v", lenHashes/HashSize, err)
 	}
 
 	ctr := 0
@@ -206,7 +212,7 @@ func (p *Peer) handleOfferedHashesMsg(ctx context.Context, req *OfferedHashesMsg
 	ctx, cancel := context.WithTimeout(ctx, syncBatchTimeout)
 
 	ctx = context.WithValue(ctx, "source", p.ID().String())
-	for i := 0; i < len(hashes); i += HashSize {
+	for i := 0; i < lenHashes; i += HashSize {
 		hash := hashes[i : i+HashSize]
 
 		if wait := c.NeedData(ctx, hash); wait != nil {
diff --git a/swarm/network/stream/streamer_test.go b/swarm/network/stream/streamer_test.go
index 06e96b9a9..dba4e7ea3 100644
--- a/swarm/network/stream/streamer_test.go
+++ b/swarm/network/stream/streamer_test.go
@@ -19,6 +19,7 @@ package stream
 import (
 	"bytes"
 	"context"
+	"errors"
 	"testing"
 	"time"
 
@@ -55,11 +56,12 @@ func TestStreamerRequestSubscription(t *testing.T) {
 }
 
 var (
-	hash0     = sha3.Sum256([]byte{0})
-	hash1     = sha3.Sum256([]byte{1})
-	hash2     = sha3.Sum256([]byte{2})
-	hashesTmp = append(hash0[:], hash1[:]...)
-	hashes    = append(hashesTmp, hash2[:]...)
+	hash0         = sha3.Sum256([]byte{0})
+	hash1         = sha3.Sum256([]byte{1})
+	hash2         = sha3.Sum256([]byte{2})
+	hashesTmp     = append(hash0[:], hash1[:]...)
+	hashes        = append(hashesTmp, hash2[:]...)
+	corruptHashes = append(hashes[:40])
 )
 
 type testClient struct {
@@ -459,6 +461,71 @@ func TestStreamerUpstreamSubscribeLiveAndHistory(t *testing.T) {
 	}
 }
 
+func TestStreamerDownstreamCorruptHashesMsgExchange(t *testing.T) {
+	tester, streamer, _, teardown, err := newStreamerTester(t)
+	defer teardown()
+	if err != nil {
+		t.Fatal(err)
+	}
+
+	stream := NewStream("foo", "", true)
+
+	var tc *testClient
+
+	streamer.RegisterClientFunc("foo", func(p *Peer, t string, live bool) (Client, error) {
+		tc = newTestClient(t)
+		return tc, nil
+	})
+
+	peerID := tester.IDs[0]
+
+	err = streamer.Subscribe(peerID, stream, NewRange(5, 8), Top)
+	if err != nil {
+		t.Fatalf("Expected no error, got %v", err)
+	}
+
+	err = tester.TestExchanges(p2ptest.Exchange{
+		Label: "Subscribe message",
+		Expects: []p2ptest.Expect{
+			{
+				Code: 4,
+				Msg: &SubscribeMsg{
+					Stream:   stream,
+					History:  NewRange(5, 8),
+					Priority: Top,
+				},
+				Peer: peerID,
+			},
+		},
+	},
+		p2ptest.Exchange{
+			Label: "Corrupt offered hash message",
+			Triggers: []p2ptest.Trigger{
+				{
+					Code: 1,
+					Msg: &OfferedHashesMsg{
+						HandoverProof: &HandoverProof{
+							Handover: &Handover{},
+						},
+						Hashes: corruptHashes,
+						From:   5,
+						To:     8,
+						Stream: stream,
+					},
+					Peer: peerID,
+				},
+			},
+		})
+	if err != nil {
+		t.Fatal(err)
+	}
+
+	expectedError := errors.New("Message handler error: (msg code 1): error invalid hashes length (len: 40)")
+	if err := tester.TestDisconnected(&p2ptest.Disconnect{Peer: tester.IDs[0], Error: expectedError}); err != nil {
+		t.Fatal(err)
+	}
+}
+
 func TestStreamerDownstreamOfferedHashesMsgExchange(t *testing.T) {
 	tester, streamer, _, teardown, err := newStreamerTester(t)
 	defer teardown()
-- 
cgit 


From 3f7acbbeb929bc3a2a3073bae15977ec69761bab Mon Sep 17 00:00:00 2001
From: Balint Gabor <balint.g@gmail.com>
Date: Tue, 25 Sep 2018 17:35:54 +0200
Subject: swarm: prevent forever running retrieve request loops

---
 swarm/network/fetcher.go         | 24 ++++++++++++------
 swarm/network/fetcher_test.go    | 54 ++++++++++++++++++++++++++++++----------
 swarm/network/stream/delivery.go |  6 ++++-
 swarm/network/stream/stream.go   |  2 +-
 4 files changed, 64 insertions(+), 22 deletions(-)

(limited to 'swarm/network')

diff --git a/swarm/network/fetcher.go b/swarm/network/fetcher.go
index 5b4b61c7e..6aed57e22 100644
--- a/swarm/network/fetcher.go
+++ b/swarm/network/fetcher.go
@@ -32,6 +32,8 @@ var searchTimeout = 1 * time.Second
 // Also used in stream delivery.
 var RequestTimeout = 10 * time.Second
 
+var maxHopCount uint8 = 20 // maximum number of forwarded requests (hops), to make sure requests are not forwarded forever in peer loops
+
 type RequestFunc func(context.Context, *Request) (*enode.ID, chan struct{}, error)
 
 // Fetcher is created when a chunk is not found locally. It starts a request handler loop once and
@@ -44,7 +46,7 @@ type Fetcher struct {
 	protoRequestFunc RequestFunc     // request function fetcher calls to issue retrieve request for a chunk
 	addr             storage.Address // the address of the chunk to be fetched
 	offerC           chan *enode.ID  // channel of sources (peer node id strings)
-	requestC         chan struct{}
+	requestC         chan uint8      // channel for incoming requests (with the hopCount value in it)
 	skipCheck        bool
 }
 
@@ -53,6 +55,7 @@ type Request struct {
 	Source      *enode.ID       // nodeID of peer to request from (can be nil)
 	SkipCheck   bool            // whether to offer the chunk first or deliver directly
 	peersToSkip *sync.Map       // peers not to request chunk from (only makes sense if source is nil)
+	HopCount    uint8           // number of forwarded requests (hops)
 }
 
 // NewRequest returns a new instance of Request based on chunk address skip check and
@@ -113,7 +116,7 @@ func NewFetcher(addr storage.Address, rf RequestFunc, skipCheck bool) *Fetcher {
 		addr:             addr,
 		protoRequestFunc: rf,
 		offerC:           make(chan *enode.ID),
-		requestC:         make(chan struct{}),
+		requestC:         make(chan uint8),
 		skipCheck:        skipCheck,
 	}
 }
@@ -136,7 +139,7 @@ func (f *Fetcher) Offer(ctx context.Context, source *enode.ID) {
 }
 
 // Request is called when an upstream peer request the chunk as part of `RetrieveRequestMsg`, or from a local request through FileStore, and the node does not have the chunk locally.
-func (f *Fetcher) Request(ctx context.Context) {
+func (f *Fetcher) Request(ctx context.Context, hopCount uint8) {
 	// First we need to have this select to make sure that we return if context is done
 	select {
 	case <-ctx.Done():
@@ -144,10 +147,15 @@ func (f *Fetcher) Request(ctx context.Context) {
 	default:
 	}
 
+	if hopCount >= maxHopCount {
+		log.Debug("fetcher request hop count limit reached", "hops", hopCount)
+		return
+	}
+
 	// This select alone would not guarantee that we return of context is done, it could potentially
 	// push to offerC instead if offerC is available (see number 2 in https://golang.org/ref/spec#Select_statements)
 	select {
-	case f.requestC <- struct{}{}:
+	case f.requestC <- hopCount + 1:
 	case <-ctx.Done():
 	}
 }
@@ -161,6 +169,7 @@ func (f *Fetcher) run(ctx context.Context, peers *sync.Map) {
 		waitC     <-chan time.Time // timer channel
 		sources   []*enode.ID      // known sources, ie. peers that offered the chunk
 		requested bool             // true if the chunk was actually requested
+		hopCount  uint8
 	)
 	gone := make(chan *enode.ID) // channel to signal that a peer we requested from disconnected
 
@@ -183,7 +192,7 @@ func (f *Fetcher) run(ctx context.Context, peers *sync.Map) {
 			doRequest = requested
 
 		// incoming request
-		case <-f.requestC:
+		case hopCount = <-f.requestC:
 			log.Trace("new request", "request addr", f.addr)
 			// 2) chunk is requested, set requested flag
 			// launch a request iff none been launched yet
@@ -213,7 +222,7 @@ func (f *Fetcher) run(ctx context.Context, peers *sync.Map) {
 		// need to issue a new request
 		if doRequest {
 			var err error
-			sources, err = f.doRequest(ctx, gone, peers, sources)
+			sources, err = f.doRequest(ctx, gone, peers, sources, hopCount)
 			if err != nil {
 				log.Info("unable to request", "request addr", f.addr, "err", err)
 			}
@@ -251,7 +260,7 @@ func (f *Fetcher) run(ctx context.Context, peers *sync.Map) {
 // * the peer's address is added to the set of peers to skip
 // * the peer's address is removed from prospective sources, and
 // * a go routine is started that reports on the gone channel if the peer is disconnected (or terminated their streamer)
-func (f *Fetcher) doRequest(ctx context.Context, gone chan *enode.ID, peersToSkip *sync.Map, sources []*enode.ID) ([]*enode.ID, error) {
+func (f *Fetcher) doRequest(ctx context.Context, gone chan *enode.ID, peersToSkip *sync.Map, sources []*enode.ID, hopCount uint8) ([]*enode.ID, error) {
 	var i int
 	var sourceID *enode.ID
 	var quit chan struct{}
@@ -260,6 +269,7 @@ func (f *Fetcher) doRequest(ctx context.Context, gone chan *enode.ID, peersToSki
 		Addr:        f.addr,
 		SkipCheck:   f.skipCheck,
 		peersToSkip: peersToSkip,
+		HopCount:    hopCount,
 	}
 
 	foundSource := false
diff --git a/swarm/network/fetcher_test.go b/swarm/network/fetcher_test.go
index b2316b097..3a926f475 100644
--- a/swarm/network/fetcher_test.go
+++ b/swarm/network/fetcher_test.go
@@ -33,7 +33,7 @@ type mockRequester struct {
 	// requests []Request
 	requestC  chan *Request   // when a request is coming it is pushed to requestC
 	waitTimes []time.Duration // with waitTimes[i] you can define how much to wait on the ith request (optional)
-	ctr       int             //counts the number of requests
+	count     int             //counts the number of requests
 	quitC     chan struct{}
 }
 
@@ -47,9 +47,9 @@ func newMockRequester(waitTimes ...time.Duration) *mockRequester {
 
 func (m *mockRequester) doRequest(ctx context.Context, request *Request) (*enode.ID, chan struct{}, error) {
 	waitTime := time.Duration(0)
-	if m.ctr < len(m.waitTimes) {
-		waitTime = m.waitTimes[m.ctr]
-		m.ctr++
+	if m.count < len(m.waitTimes) {
+		waitTime = m.waitTimes[m.count]
+		m.count++
 	}
 	time.Sleep(waitTime)
 	m.requestC <- request
@@ -83,7 +83,7 @@ func TestFetcherSingleRequest(t *testing.T) {
 	go fetcher.run(ctx, peersToSkip)
 
 	rctx := context.Background()
-	fetcher.Request(rctx)
+	fetcher.Request(rctx, 0)
 
 	select {
 	case request := <-requester.requestC:
@@ -100,6 +100,11 @@ func TestFetcherSingleRequest(t *testing.T) {
 			t.Fatalf("request.peersToSkip does not contain peer returned by the request function")
 		}
 
+		// hopCount in the forwarded request should be incremented
+		if request.HopCount != 1 {
+			t.Fatalf("Expected request.HopCount 1 got %v", request.HopCount)
+		}
+
 		// fetch should trigger a request, if it doesn't happen in time, test should fail
 	case <-time.After(200 * time.Millisecond):
 		t.Fatalf("fetch timeout")
@@ -123,7 +128,7 @@ func TestFetcherCancelStopsFetcher(t *testing.T) {
 	rctx, rcancel := context.WithTimeout(ctx, 100*time.Millisecond)
 	defer rcancel()
 	// we call Request with an active context
-	fetcher.Request(rctx)
+	fetcher.Request(rctx, 0)
 
 	// fetcher should not initiate request, we can only check by waiting a bit and making sure no request is happening
 	select {
@@ -151,7 +156,7 @@ func TestFetcherCancelStopsRequest(t *testing.T) {
 	rcancel()
 
 	// we call Request with a cancelled context
-	fetcher.Request(rctx)
+	fetcher.Request(rctx, 0)
 
 	// fetcher should not initiate request, we can only check by waiting a bit and making sure no request is happening
 	select {
@@ -162,7 +167,7 @@ func TestFetcherCancelStopsRequest(t *testing.T) {
 
 	// if there is another Request with active context, there should be a request, because the fetcher itself is not cancelled
 	rctx = context.Background()
-	fetcher.Request(rctx)
+	fetcher.Request(rctx, 0)
 
 	select {
 	case <-requester.requestC:
@@ -200,7 +205,7 @@ func TestFetcherOfferUsesSource(t *testing.T) {
 
 	// call Request after the Offer
 	rctx = context.Background()
-	fetcher.Request(rctx)
+	fetcher.Request(rctx, 0)
 
 	// there should be exactly 1 request coming from fetcher
 	var request *Request
@@ -241,7 +246,7 @@ func TestFetcherOfferAfterRequestUsesSourceFromContext(t *testing.T) {
 
 	// call Request first
 	rctx := context.Background()
-	fetcher.Request(rctx)
+	fetcher.Request(rctx, 0)
 
 	// there should be a request coming from fetcher
 	var request *Request
@@ -296,7 +301,7 @@ func TestFetcherRetryOnTimeout(t *testing.T) {
 
 	// call the fetch function with an active context
 	rctx := context.Background()
-	fetcher.Request(rctx)
+	fetcher.Request(rctx, 0)
 
 	// after 100ms the first request should be initiated
 	time.Sleep(100 * time.Millisecond)
@@ -338,7 +343,7 @@ func TestFetcherFactory(t *testing.T) {
 
 	fetcher := fetcherFactory.New(context.Background(), addr, peersToSkip)
 
-	fetcher.Request(context.Background())
+	fetcher.Request(context.Background(), 0)
 
 	// check if the created fetchFunction really starts a fetcher and initiates a request
 	select {
@@ -368,7 +373,7 @@ func TestFetcherRequestQuitRetriesRequest(t *testing.T) {
 	go fetcher.run(ctx, peersToSkip)
 
 	rctx := context.Background()
-	fetcher.Request(rctx)
+	fetcher.Request(rctx, 0)
 
 	select {
 	case <-requester.requestC:
@@ -457,3 +462,26 @@ func TestRequestSkipPeerPermanent(t *testing.T) {
 		t.Errorf("peer not skipped")
 	}
 }
+
+func TestFetcherMaxHopCount(t *testing.T) {
+	requester := newMockRequester()
+	addr := make([]byte, 32)
+	fetcher := NewFetcher(addr, requester.doRequest, true)
+
+	ctx, cancel := context.WithCancel(context.Background())
+	defer cancel()
+
+	peersToSkip := &sync.Map{}
+
+	go fetcher.run(ctx, peersToSkip)
+
+	rctx := context.Background()
+	fetcher.Request(rctx, maxHopCount)
+
+	// if hopCount is already at max no request should be initiated
+	select {
+	case <-requester.requestC:
+		t.Fatalf("cancelled fetcher initiated request")
+	case <-time.After(200 * time.Millisecond):
+	}
+}
diff --git a/swarm/network/stream/delivery.go b/swarm/network/stream/delivery.go
index 431136ab1..c2adb1009 100644
--- a/swarm/network/stream/delivery.go
+++ b/swarm/network/stream/delivery.go
@@ -128,6 +128,7 @@ func (s *SwarmChunkServer) GetData(ctx context.Context, key []byte) ([]byte, err
 type RetrieveRequestMsg struct {
 	Addr      storage.Address
 	SkipCheck bool
+	HopCount  uint8
 }
 
 func (d *Delivery) handleRetrieveRequestMsg(ctx context.Context, sp *Peer, req *RetrieveRequestMsg) error {
@@ -148,7 +149,9 @@ func (d *Delivery) handleRetrieveRequestMsg(ctx context.Context, sp *Peer, req *
 
 	var cancel func()
 	// TODO: do something with this hardcoded timeout, maybe use TTL in the future
-	ctx, cancel = context.WithTimeout(context.WithValue(ctx, "peer", sp.ID().String()), network.RequestTimeout)
+	ctx = context.WithValue(ctx, "peer", sp.ID().String())
+	ctx = context.WithValue(ctx, "hopcount", req.HopCount)
+	ctx, cancel = context.WithTimeout(ctx, network.RequestTimeout)
 
 	go func() {
 		select {
@@ -247,6 +250,7 @@ func (d *Delivery) RequestFromPeers(ctx context.Context, req *network.Request) (
 	err := sp.SendPriority(ctx, &RetrieveRequestMsg{
 		Addr:      req.Addr,
 		SkipCheck: req.SkipCheck,
+		HopCount:  req.HopCount,
 	}, Top)
 	if err != nil {
 		return nil, nil, err
diff --git a/swarm/network/stream/stream.go b/swarm/network/stream/stream.go
index ea7cce8cb..65b8dff5a 100644
--- a/swarm/network/stream/stream.go
+++ b/swarm/network/stream/stream.go
@@ -639,7 +639,7 @@ func (c *clientParams) clientCreated() {
 // Spec is the spec of the streamer protocol
 var Spec = &protocols.Spec{
 	Name:       "stream",
-	Version:    6,
+	Version:    7,
 	MaxMsgSize: 10 * 1024 * 1024,
 	Messages: []interface{}{
 		UnsubscribeMsg{},
-- 
cgit