From e402e1dc2e72df2a433b984caeaba771085b2b66 Mon Sep 17 00:00:00 2001 From: Taylor Gerring Date: Thu, 2 Apr 2015 13:17:55 +0200 Subject: New args types with stricter checking --- rpc/api.go | 12 ++++---- rpc/args.go | 88 +++++++++++++++++++++++++++++++++++++++----------------- rpc/args_test.go | 12 +++++++- 3 files changed, 79 insertions(+), 33 deletions(-) (limited to 'rpc') diff --git a/rpc/api.go b/rpc/api.go index 940b80758..478ca8752 100644 --- a/rpc/api.go +++ b/rpc/api.go @@ -108,15 +108,15 @@ func (api *EthereumApi) GetRequestReply(req *RpcRequest, reply *interface{}) err count := api.xethAtStateNum(args.BlockNumber).TxCountAt(args.Address) *reply = common.ToHex(big.NewInt(int64(count)).Bytes()) case "eth_getBlockTransactionCountByHash": - args := new(GetBlockByHashArgs) + args := new(HashArgs) if err := json.Unmarshal(req.Params, &args); err != nil { return err } - block := NewBlockRes(api.xeth().EthBlockByHash(args.BlockHash), false) + block := NewBlockRes(api.xeth().EthBlockByHash(args.Hash), false) *reply = common.ToHex(big.NewInt(int64(len(block.Transactions))).Bytes()) case "eth_getBlockTransactionCountByNumber": - args := new(GetBlockByNumberArgs) + args := new(BlockNumArg) if err := json.Unmarshal(req.Params, &args); err != nil { return err } @@ -124,16 +124,16 @@ func (api *EthereumApi) GetRequestReply(req *RpcRequest, reply *interface{}) err block := NewBlockRes(api.xeth().EthBlockByNumber(args.BlockNumber), false) *reply = common.ToHex(big.NewInt(int64(len(block.Transactions))).Bytes()) case "eth_getUncleCountByBlockHash": - args := new(GetBlockByHashArgs) + args := new(HashArgs) if err := json.Unmarshal(req.Params, &args); err != nil { return err } - block := api.xeth().EthBlockByHash(args.BlockHash) + block := api.xeth().EthBlockByHash(args.Hash) br := NewBlockRes(block, false) *reply = common.ToHex(big.NewInt(int64(len(br.Uncles))).Bytes()) case "eth_getUncleCountByBlockNumber": - args := new(GetBlockByNumberArgs) + args := new(BlockNumArg) if err := json.Unmarshal(req.Params, &args); err != nil { return err } diff --git a/rpc/args.go b/rpc/args.go index 220daf960..b43c465c0 100644 --- a/rpc/args.go +++ b/rpc/args.go @@ -108,8 +108,8 @@ func (args *GetBlockByHashArgs) UnmarshalJSON(b []byte) (err error) { return NewDecodeParamError(err.Error()) } - if len(obj) < 1 { - return NewInsufficientParamsError(len(obj), 1) + if len(obj) < 2 { + return NewInsufficientParamsError(len(obj), 2) } argstr, ok := obj[0].(string) @@ -118,9 +118,7 @@ func (args *GetBlockByHashArgs) UnmarshalJSON(b []byte) (err error) { } args.BlockHash = argstr - if len(obj) > 1 { - args.IncludeTxs = obj[1].(bool) - } + args.IncludeTxs = obj[1].(bool) return nil } @@ -136,8 +134,8 @@ func (args *GetBlockByNumberArgs) UnmarshalJSON(b []byte) (err error) { return NewDecodeParamError(err.Error()) } - if len(obj) < 1 { - return NewInsufficientParamsError(len(obj), 1) + if len(obj) < 2 { + return NewInsufficientParamsError(len(obj), 2) } if v, ok := obj[0].(float64); ok { @@ -148,9 +146,7 @@ func (args *GetBlockByNumberArgs) UnmarshalJSON(b []byte) (err error) { return NewInvalidTypeError("blockNumber", "not a number or string") } - if len(obj) > 1 { - args.IncludeTxs = obj[1].(bool) - } + args.IncludeTxs = obj[1].(bool) return nil } @@ -496,6 +492,27 @@ func (args *GetDataArgs) UnmarshalJSON(b []byte) (err error) { return nil } +type BlockNumArg struct { + BlockNumber int64 +} + +func (args *BlockNumArg) UnmarshalJSON(b []byte) (err error) { + var obj []interface{} + if err := json.Unmarshal(b, &obj); err != nil { + return NewDecodeParamError(err.Error()) + } + + if len(obj) < 1 { + return NewInsufficientParamsError(len(obj), 1) + } + + if err := blockHeight(obj[0], &args.BlockNumber); err != nil { + return err + } + + return nil +} + type BlockNumIndexArgs struct { BlockNumber int64 Index int64 @@ -507,21 +524,42 @@ func (args *BlockNumIndexArgs) UnmarshalJSON(b []byte) (err error) { return NewDecodeParamError(err.Error()) } - if len(obj) < 1 { - return NewInsufficientParamsError(len(obj), 1) + if len(obj) < 2 { + return NewInsufficientParamsError(len(obj), 2) } if err := blockHeight(obj[0], &args.BlockNumber); err != nil { return err } - if len(obj) > 1 { - arg1, ok := obj[1].(string) - if !ok { - return NewInvalidTypeError("index", "not a string") - } - args.Index = common.Big(arg1).Int64() + arg1, ok := obj[1].(string) + if !ok { + return NewInvalidTypeError("index", "not a string") } + args.Index = common.Big(arg1).Int64() + + return nil +} + +type HashArgs struct { + Hash string +} + +func (args *HashArgs) UnmarshalJSON(b []byte) (err error) { + var obj []interface{} + if err := json.Unmarshal(b, &obj); err != nil { + return NewDecodeParamError(err.Error()) + } + + if len(obj) < 1 { + return NewInsufficientParamsError(len(obj), 1) + } + + arg0, ok := obj[0].(string) + if !ok { + return NewInvalidTypeError("hash", "not a string") + } + args.Hash = arg0 return nil } @@ -537,8 +575,8 @@ func (args *HashIndexArgs) UnmarshalJSON(b []byte) (err error) { return NewDecodeParamError(err.Error()) } - if len(obj) < 1 { - return NewInsufficientParamsError(len(obj), 1) + if len(obj) < 2 { + return NewInsufficientParamsError(len(obj), 2) } arg0, ok := obj[0].(string) @@ -547,13 +585,11 @@ func (args *HashIndexArgs) UnmarshalJSON(b []byte) (err error) { } args.Hash = arg0 - if len(obj) > 1 { - arg1, ok := obj[1].(string) - if !ok { - return NewInvalidTypeError("index", "not a string") - } - args.Index = common.Big(arg1).Int64() + arg1, ok := obj[1].(string) + if !ok { + return NewInvalidTypeError("index", "not a string") } + args.Index = common.Big(arg1).Int64() return nil } diff --git a/rpc/args_test.go b/rpc/args_test.go index 3635882c0..f00899b79 100644 --- a/rpc/args_test.go +++ b/rpc/args_test.go @@ -225,7 +225,7 @@ func TestGetBlockByHashArgsHashInt(t *testing.T) { input := `[8]` args := new(GetBlockByHashArgs) - str := ExpectInvalidTypeError(json.Unmarshal([]byte(input), &args)) + str := ExpectInsufficientParamsError(json.Unmarshal([]byte(input), &args)) if len(str) > 0 { t.Error(str) } @@ -281,6 +281,16 @@ func TestGetBlockByNumberEmpty(t *testing.T) { } } +func TestGetBlockByNumberShort(t *testing.T) { + input := `["0xbbb"]` + + args := new(GetBlockByNumberArgs) + str := ExpectInsufficientParamsError(json.Unmarshal([]byte(input), &args)) + if len(str) > 0 { + t.Error(str) + } +} + func TestGetBlockByNumberBool(t *testing.T) { input := `[true, true]` -- cgit