diff options
Diffstat (limited to 'eth/fetcher/fetcher_test.go')
-rw-r--r-- | eth/fetcher/fetcher_test.go | 67 |
1 files changed, 60 insertions, 7 deletions
diff --git a/eth/fetcher/fetcher_test.go b/eth/fetcher/fetcher_test.go index d594d830c..b9f0f36a5 100644 --- a/eth/fetcher/fetcher_test.go +++ b/eth/fetcher/fetcher_test.go @@ -399,15 +399,15 @@ func TestDistantDiscarding(t *testing.T) { // Tests that a peer is unable to use unbounded memory with sending infinite // block announcements to a node, but that even in the face of such an attack, // the fetcher remains operational. -func TestAnnounceMemoryExhaustionAttack(t *testing.T) { +func TestHashMemoryExhaustionAttack(t *testing.T) { tester := newTester() // Create a valid chain and an infinite junk chain - hashes := createHashes(announceLimit+2*maxQueueDist, knownHash) + hashes := createHashes(hashLimit+2*maxQueueDist, knownHash) blocks := createBlocksFromHashes(hashes) valid := tester.makeFetcher(blocks) - attack := createHashes(announceLimit+2*maxQueueDist, unknownHash) + attack := createHashes(hashLimit+2*maxQueueDist, unknownHash) attacker := tester.makeFetcher(nil) // Feed the tester a huge hashset from the attacker, and a limited from the valid peer @@ -417,8 +417,8 @@ func TestAnnounceMemoryExhaustionAttack(t *testing.T) { } tester.fetcher.Notify("attacker", attack[i], time.Now().Add(arriveTimeout/2), attacker) } - if len(tester.fetcher.announced) != announceLimit+maxQueueDist { - t.Fatalf("queued announce count mismatch: have %d, want %d", len(tester.fetcher.announced), announceLimit+maxQueueDist) + if len(tester.fetcher.announced) != hashLimit+maxQueueDist { + t.Fatalf("queued announce count mismatch: have %d, want %d", len(tester.fetcher.announced), hashLimit+maxQueueDist) } // Wait for synchronisation to complete and check success for the valid peer time.Sleep(2 * arriveTimeout) @@ -431,10 +431,63 @@ func TestAnnounceMemoryExhaustionAttack(t *testing.T) { tester.fetcher.Notify("valid", hashes[i], time.Now().Add(time.Millisecond), valid) i-- } - time.Sleep(256 * time.Millisecond) + time.Sleep(500 * time.Millisecond) } - time.Sleep(256 * time.Millisecond) + time.Sleep(500 * time.Millisecond) if imported := len(tester.blocks); imported != len(hashes) { t.Fatalf("fully synchronised block mismatch: have %v, want %v", imported, len(hashes)) } } + +// Tests that blocks sent to the fetcher (either through propagation or via hash +// announces and retrievals) don't pile up indefinitely, exhausting available +// system memory. +func TestBlockMemoryExhaustionAttack(t *testing.T) { + tester := newTester() + + // Create a valid chain and a batch of dangling (but in range) blocks + hashes := createHashes(blockLimit, knownHash) + blocks := createBlocksFromHashes(hashes) + + attack := make(map[common.Hash]*types.Block) + for i := 0; i < 16; i++ { + hashes := createHashes(maxQueueDist-1, unknownHash) + blocks := createBlocksFromHashes(hashes) + for _, hash := range hashes[:maxQueueDist-2] { + attack[hash] = blocks[hash] + } + } + // Try to feed all the attacker blocks make sure only a limited batch is accepted + for _, block := range attack { + tester.fetcher.Enqueue("attacker", block) + } + time.Sleep(100 * time.Millisecond) + if queued := tester.fetcher.queue.Size(); queued != blockLimit { + t.Fatalf("queued block count mismatch: have %d, want %d", queued, blockLimit) + } + // Queue up a batch of valid blocks, and check that a new peer is allowed to do so + for i := 0; i < maxQueueDist-1; i++ { + tester.fetcher.Enqueue("valid", blocks[hashes[len(hashes)-3-i]]) + } + time.Sleep(100 * time.Millisecond) + if queued := tester.fetcher.queue.Size(); queued != blockLimit+maxQueueDist-1 { + t.Fatalf("queued block count mismatch: have %d, want %d", queued, blockLimit+maxQueueDist-1) + } + // Insert the missing piece (and sanity check the import) + tester.fetcher.Enqueue("valid", blocks[hashes[len(hashes)-2]]) + time.Sleep(500 * time.Millisecond) + if imported := len(tester.blocks); imported != maxQueueDist+1 { + t.Fatalf("synchronised block mismatch: have %v, want %v", imported, maxQueueDist+1) + } + // Insert the remaining blocks in chunks to ensure clean DOS protection + for i := maxQueueDist; i < len(hashes)-1; i++ { + tester.fetcher.Enqueue("valid", blocks[hashes[len(hashes)-2-i]]) + if i%maxQueueDist == 0 { + time.Sleep(500 * time.Millisecond) + } + } + time.Sleep(500 * time.Millisecond) + if imported := len(tester.blocks); imported != len(hashes) { + t.Fatalf("synchronised block mismatch: have %v, want %v", imported, len(hashes)) + } +} |