diff options
Diffstat (limited to 'light/odr_test.go')
-rw-r--r-- | light/odr_test.go | 164 |
1 files changed, 81 insertions, 83 deletions
diff --git a/light/odr_test.go b/light/odr_test.go index 576e3abc9..544b64eff 100644 --- a/light/odr_test.go +++ b/light/odr_test.go @@ -86,11 +86,11 @@ func (odr *testOdr) Retrieve(ctx context.Context, req OdrRequest) error { return nil } -type odrTestFn func(ctx context.Context, db ethdb.Database, bc *core.BlockChain, lc *LightChain, bhash common.Hash) []byte +type odrTestFn func(ctx context.Context, db ethdb.Database, bc *core.BlockChain, lc *LightChain, bhash common.Hash) ([]byte, error) -func TestOdrGetBlockLes1(t *testing.T) { testChainOdr(t, 1, 1, odrGetBlock) } +func TestOdrGetBlockLes1(t *testing.T) { testChainOdr(t, 1, odrGetBlock) } -func odrGetBlock(ctx context.Context, db ethdb.Database, bc *core.BlockChain, lc *LightChain, bhash common.Hash) []byte { +func odrGetBlock(ctx context.Context, db ethdb.Database, bc *core.BlockChain, lc *LightChain, bhash common.Hash) ([]byte, error) { var block *types.Block if bc != nil { block = bc.GetBlockByHash(bhash) @@ -98,15 +98,15 @@ func odrGetBlock(ctx context.Context, db ethdb.Database, bc *core.BlockChain, lc block, _ = lc.GetBlockByHash(ctx, bhash) } if block == nil { - return nil + return nil, nil } rlp, _ := rlp.EncodeToBytes(block) - return rlp + return rlp, nil } -func TestOdrGetReceiptsLes1(t *testing.T) { testChainOdr(t, 1, 1, odrGetReceipts) } +func TestOdrGetReceiptsLes1(t *testing.T) { testChainOdr(t, 1, odrGetReceipts) } -func odrGetReceipts(ctx context.Context, db ethdb.Database, bc *core.BlockChain, lc *LightChain, bhash common.Hash) []byte { +func odrGetReceipts(ctx context.Context, db ethdb.Database, bc *core.BlockChain, lc *LightChain, bhash common.Hash) ([]byte, error) { var receipts types.Receipts if bc != nil { receipts = core.GetBlockReceipts(db, bhash, core.GetBlockNumber(db, bhash)) @@ -114,43 +114,37 @@ func odrGetReceipts(ctx context.Context, db ethdb.Database, bc *core.BlockChain, receipts, _ = GetBlockReceipts(ctx, lc.Odr(), bhash, core.GetBlockNumber(db, bhash)) } if receipts == nil { - return nil + return nil, nil } rlp, _ := rlp.EncodeToBytes(receipts) - return rlp + return rlp, nil } -func TestOdrAccountsLes1(t *testing.T) { testChainOdr(t, 1, 1, odrAccounts) } +func TestOdrAccountsLes1(t *testing.T) { testChainOdr(t, 1, odrAccounts) } -func odrAccounts(ctx context.Context, db ethdb.Database, bc *core.BlockChain, lc *LightChain, bhash common.Hash) []byte { +func odrAccounts(ctx context.Context, db ethdb.Database, bc *core.BlockChain, lc *LightChain, bhash common.Hash) ([]byte, error) { dummyAddr := common.HexToAddress("1234567812345678123456781234567812345678") acc := []common.Address{testBankAddress, acc1Addr, acc2Addr, dummyAddr} + var st *state.StateDB + if bc == nil { + header := lc.GetHeaderByHash(bhash) + st = NewState(ctx, header, lc.Odr()) + } else { + header := bc.GetHeaderByHash(bhash) + st, _ = state.New(header.Root, state.NewDatabase(db)) + } + var res []byte for _, addr := range acc { - if bc != nil { - header := bc.GetHeaderByHash(bhash) - st, err := state.New(header.Root, db) - if err == nil { - bal := st.GetBalance(addr) - rlp, _ := rlp.EncodeToBytes(bal) - res = append(res, rlp...) - } - } else { - header := lc.GetHeaderByHash(bhash) - st := NewLightState(StateTrieID(header), lc.Odr()) - bal, err := st.GetBalance(ctx, addr) - if err == nil { - rlp, _ := rlp.EncodeToBytes(bal) - res = append(res, rlp...) - } - } + bal := st.GetBalance(addr) + rlp, _ := rlp.EncodeToBytes(bal) + res = append(res, rlp...) } - - return res + return res, st.Error() } -func TestOdrContractCallLes1(t *testing.T) { testChainOdr(t, 1, 2, odrContractCall) } +func TestOdrContractCallLes1(t *testing.T) { testChainOdr(t, 1, odrContractCall) } type callmsg struct { types.Message @@ -158,50 +152,42 @@ type callmsg struct { func (callmsg) CheckNonce() bool { return false } -func odrContractCall(ctx context.Context, db ethdb.Database, bc *core.BlockChain, lc *LightChain, bhash common.Hash) []byte { +func odrContractCall(ctx context.Context, db ethdb.Database, bc *core.BlockChain, lc *LightChain, bhash common.Hash) ([]byte, error) { data := common.Hex2Bytes("60CD26850000000000000000000000000000000000000000000000000000000000000000") - config := params.TestChainConfig var res []byte for i := 0; i < 3; i++ { data[35] = byte(i) - if bc != nil { - header := bc.GetHeaderByHash(bhash) - statedb, err := state.New(header.Root, db) - if err == nil { - from := statedb.GetOrNewStateObject(testBankAddress) - from.SetBalance(math.MaxBig256) - - msg := callmsg{types.NewMessage(from.Address(), &testContractAddr, 0, new(big.Int), big.NewInt(1000000), new(big.Int), data, false)} - context := core.NewEVMContext(msg, header, bc, nil) - vmenv := vm.NewEVM(context, statedb, config, vm.Config{}) - - gp := new(core.GasPool).AddGas(math.MaxBig256) - ret, _, _ := core.ApplyMessage(vmenv, msg, gp) - res = append(res, ret...) - } + var ( + st *state.StateDB + header *types.Header + chain core.ChainContext + ) + if bc == nil { + chain = lc + header = lc.GetHeaderByHash(bhash) + st = NewState(ctx, header, lc.Odr()) } else { - header := lc.GetHeaderByHash(bhash) - state := NewLightState(StateTrieID(header), lc.Odr()) - vmstate := NewVMState(ctx, state) - from, err := state.GetOrNewStateObject(ctx, testBankAddress) - if err == nil { - from.SetBalance(math.MaxBig256) - - msg := callmsg{types.NewMessage(from.Address(), &testContractAddr, 0, new(big.Int), big.NewInt(1000000), new(big.Int), data, false)} - context := core.NewEVMContext(msg, header, lc, nil) - vmenv := vm.NewEVM(context, vmstate, config, vm.Config{}) - gp := new(core.GasPool).AddGas(math.MaxBig256) - ret, _, _ := core.ApplyMessage(vmenv, msg, gp) - if vmstate.Error() == nil { - res = append(res, ret...) - } - } + chain = bc + header = bc.GetHeaderByHash(bhash) + st, _ = state.New(header.Root, state.NewDatabase(db)) + } + + // Perform read-only call. + st.SetBalance(testBankAddress, math.MaxBig256) + msg := callmsg{types.NewMessage(testBankAddress, &testContractAddr, 0, new(big.Int), big.NewInt(1000000), new(big.Int), data, false)} + context := core.NewEVMContext(msg, header, chain, nil) + vmenv := vm.NewEVM(context, st, config, vm.Config{}) + gp := new(core.GasPool).AddGas(math.MaxBig256) + ret, _, _ := core.ApplyMessage(vmenv, msg, gp) + res = append(res, ret...) + if st.Error() != nil { + return res, st.Error() } } - return res + return res, nil } func testChainGen(i int, block *core.BlockGen) { @@ -245,7 +231,7 @@ func testChainGen(i int, block *core.BlockGen) { } } -func testChainOdr(t *testing.T, protocol int, expFail uint64, fn odrTestFn) { +func testChainOdr(t *testing.T, protocol int, fn odrTestFn) { var ( evmux = new(event.TypeMux) sdb, _ = ethdb.NewMemDatabase() @@ -258,46 +244,58 @@ func testChainOdr(t *testing.T, protocol int, expFail uint64, fn odrTestFn) { blockchain, _ := core.NewBlockChain(sdb, params.TestChainConfig, ethash.NewFullFaker(), evmux, vm.Config{}) gchain, _ := core.GenerateChain(params.TestChainConfig, genesis, sdb, 4, testChainGen) if _, err := blockchain.InsertChain(gchain); err != nil { - panic(err) + t.Fatal(err) } odr := &testOdr{sdb: sdb, ldb: ldb} - lightchain, _ := NewLightChain(odr, params.TestChainConfig, ethash.NewFullFaker(), evmux) + lightchain, err := NewLightChain(odr, params.TestChainConfig, ethash.NewFullFaker(), evmux) + if err != nil { + t.Fatal(err) + } headers := make([]*types.Header, len(gchain)) for i, block := range gchain { headers[i] = block.Header() } if _, err := lightchain.InsertHeaderChain(headers, 1); err != nil { - panic(err) + t.Fatal(err) } - test := func(expFail uint64) { + test := func(expFail int) { for i := uint64(0); i <= blockchain.CurrentHeader().Number.Uint64(); i++ { bhash := core.GetCanonicalHash(sdb, i) - b1 := fn(NoOdr, sdb, blockchain, nil, bhash) + b1, err := fn(NoOdr, sdb, blockchain, nil, bhash) + if err != nil { + t.Fatalf("error in full-node test for block %d: %v", i, err) + } ctx, cancel := context.WithTimeout(context.Background(), 200*time.Millisecond) defer cancel() - b2 := fn(ctx, ldb, nil, lightchain, bhash) + + exp := i < uint64(expFail) + b2, err := fn(ctx, ldb, nil, lightchain, bhash) + if err != nil && exp { + t.Errorf("error in ODR test for block %d: %v", i, err) + } eq := bytes.Equal(b1, b2) - exp := i < expFail if exp && !eq { - t.Errorf("odr mismatch") - } - if !exp && eq { - t.Errorf("unexpected odr match") + t.Errorf("ODR test output for block %d doesn't match full node", i) } } } - odr.disable = true // expect retrievals to fail (except genesis block) without a les peer - test(expFail) - odr.disable = false - // expect all retrievals to pass - test(5) + t.Log("checking without ODR") odr.disable = true + test(1) + + // expect all retrievals to pass with ODR enabled + t.Log("checking with ODR") + odr.disable = false + test(len(gchain)) + // still expect all retrievals to pass, now data should be cached locally - test(5) + t.Log("checking without ODR, should be cached") + odr.disable = true + test(len(gchain)) } |