aboutsummaryrefslogtreecommitdiffstats
path: root/eth/downloader/downloader_test.go
diff options
context:
space:
mode:
Diffstat (limited to 'eth/downloader/downloader_test.go')
-rw-r--r--eth/downloader/downloader_test.go64
1 files changed, 49 insertions, 15 deletions
diff --git a/eth/downloader/downloader_test.go b/eth/downloader/downloader_test.go
index 60dcc06cd..d55664314 100644
--- a/eth/downloader/downloader_test.go
+++ b/eth/downloader/downloader_test.go
@@ -23,25 +23,26 @@ func createHashes(start, amount int) (hashes []common.Hash) {
for i := range hashes[:len(hashes)-1] {
binary.BigEndian.PutUint64(hashes[i][:8], uint64(i+2))
}
-
return
}
-func createBlock(i int, prevHash, hash common.Hash) *types.Block {
+func createBlock(i int, parent, hash common.Hash) *types.Block {
header := &types.Header{Number: big.NewInt(int64(i))}
block := types.NewBlockWithHeader(header)
block.HeaderHash = hash
- block.ParentHeaderHash = prevHash
+ block.ParentHeaderHash = parent
return block
}
func createBlocksFromHashes(hashes []common.Hash) map[common.Hash]*types.Block {
blocks := make(map[common.Hash]*types.Block)
-
- for i, hash := range hashes {
- blocks[hash] = createBlock(len(hashes)-i, knownHash, hash)
+ for i := 0; i < len(hashes); i++ {
+ parent := knownHash
+ if i < len(hashes)-1 {
+ parent = hashes[i+1]
+ }
+ blocks[hashes[i]] = createBlock(len(hashes)-i, parent, hashes[i])
}
-
return blocks
}
@@ -136,6 +137,7 @@ func (dl *downloadTester) getHashes(head common.Hash) error {
hashes := make([]common.Hash, 0, maxHashFetch)
for i, hash := range dl.hashes {
if hash == head {
+ i++
for len(hashes) < cap(hashes) && i < len(dl.hashes) {
hashes = append(hashes, dl.hashes[i])
i++
@@ -144,9 +146,11 @@ func (dl *downloadTester) getHashes(head common.Hash) error {
}
}
// Delay delivery a bit to allow attacks to unfold
- time.Sleep(time.Millisecond)
-
- dl.downloader.DeliverHashes(dl.activePeerId, hashes)
+ id := dl.activePeerId
+ go func() {
+ time.Sleep(time.Millisecond)
+ dl.downloader.DeliverHashes(id, hashes)
+ }()
return nil
}
@@ -424,12 +428,15 @@ func TestInvalidHashOrderAttack(t *testing.T) {
hashes := createHashes(0, 4*blockCacheLimit)
blocks := createBlocksFromHashes(hashes)
+ chunk1 := make([]common.Hash, blockCacheLimit)
+ chunk2 := make([]common.Hash, blockCacheLimit)
+ copy(chunk1, hashes[blockCacheLimit:2*blockCacheLimit])
+ copy(chunk2, hashes[2*blockCacheLimit:3*blockCacheLimit])
+
reverse := make([]common.Hash, len(hashes))
copy(reverse, hashes)
-
- for i := len(hashes) / 4; i < 2*len(hashes)/4; i++ {
- reverse[i], reverse[len(hashes)-i-1] = reverse[len(hashes)-i-1], reverse[i]
- }
+ copy(reverse[2*blockCacheLimit:], chunk1)
+ copy(reverse[blockCacheLimit:], chunk2)
// Try and sync with the malicious node and check that it fails
tester := newTester(t, reverse, blocks)
@@ -453,7 +460,6 @@ func TestMadeupHashChainAttack(t *testing.T) {
// Create a long chain of hashes without backing blocks
hashes := createHashes(0, 1024*blockCacheLimit)
- hashes = hashes[:len(hashes)-1]
// Try and sync with the malicious node and check that it fails
tester := newTester(t, hashes, nil)
@@ -462,3 +468,31 @@ func TestMadeupHashChainAttack(t *testing.T) {
t.Fatalf("synchronisation error mismatch: have %v, want %v", err, ErrCrossCheckFailed)
}
}
+
+// Tests that if a malicious peer makes up a random block chain, and tried to
+// push indefinitely, it actually gets caught with it.
+func TestMadeupBlockChainAttack(t *testing.T) {
+ blockTTL = 100 * time.Millisecond
+ crossCheckCycle = 25 * time.Millisecond
+
+ // Create a long chain of blocks and simulate an invalid chain by dropping every second
+ hashes := createHashes(0, 32*blockCacheLimit)
+ blocks := createBlocksFromHashes(hashes)
+
+ gapped := make([]common.Hash, len(hashes)/2)
+ for i := 0; i < len(gapped); i++ {
+ gapped[i] = hashes[2*i]
+ }
+ // Try and sync with the malicious node and check that it fails
+ tester := newTester(t, gapped, blocks)
+ tester.newPeer("attack", big.NewInt(10000), gapped[0])
+ if _, err := tester.syncTake("attack", gapped[0]); err != ErrCrossCheckFailed {
+ t.Fatalf("synchronisation error mismatch: have %v, want %v", err, ErrCrossCheckFailed)
+ }
+ // Ensure that a valid chain can still pass sync
+ tester.hashes = hashes
+ tester.newPeer("valid", big.NewInt(20000), hashes[0])
+ if _, err := tester.syncTake("valid", hashes[0]); err != nil {
+ t.Fatalf("failed to synchronise blocks: %v", err)
+ }
+}