diff options
Diffstat (limited to 'common')
-rw-r--r-- | common/hexutil/json.go | 32 | ||||
-rw-r--r-- | common/hexutil/json_test.go | 35 | ||||
-rw-r--r-- | common/math/big.go | 22 | ||||
-rw-r--r-- | common/math/big_test.go | 13 | ||||
-rw-r--r-- | common/math/integer.go | 23 | ||||
-rw-r--r-- | common/math/integer_test.go | 11 | ||||
-rw-r--r-- | common/types.go | 44 |
7 files changed, 149 insertions, 31 deletions
diff --git a/common/hexutil/json.go b/common/hexutil/json.go index 23393ed2c..1bc1d014c 100644 --- a/common/hexutil/json.go +++ b/common/hexutil/json.go @@ -51,7 +51,7 @@ func (b *Bytes) UnmarshalJSON(input []byte) error { // UnmarshalText implements encoding.TextUnmarshaler. func (b *Bytes) UnmarshalText(input []byte) error { - raw, err := checkText(input) + raw, err := checkText(input, true) if err != nil { return err } @@ -73,7 +73,28 @@ func (b Bytes) String() string { // determines the required input length. This function is commonly used to implement the // UnmarshalText method for fixed-size types. func UnmarshalFixedText(typname string, input, out []byte) error { - raw, err := checkText(input) + raw, err := checkText(input, true) + if err != nil { + return err + } + if len(raw)/2 != len(out) { + return fmt.Errorf("hex string has length %d, want %d for %s", len(raw), len(out)*2, typname) + } + // Pre-verify syntax before modifying out. + for _, b := range raw { + if decodeNibble(b) == badNibble { + return ErrSyntax + } + } + hex.Decode(out, raw) + return nil +} + +// UnmarshalFixedUnprefixedText decodes the input as a string with optional 0x prefix. The +// length of out determines the required input length. This function is commonly used to +// implement the UnmarshalText method for fixed-size types. +func UnmarshalFixedUnprefixedText(typname string, input, out []byte) error { + raw, err := checkText(input, false) if err != nil { return err } @@ -243,14 +264,15 @@ func bytesHave0xPrefix(input []byte) bool { return len(input) >= 2 && input[0] == '0' && (input[1] == 'x' || input[1] == 'X') } -func checkText(input []byte) ([]byte, error) { +func checkText(input []byte, wantPrefix bool) ([]byte, error) { if len(input) == 0 { return nil, nil // empty strings are allowed } - if !bytesHave0xPrefix(input) { + if bytesHave0xPrefix(input) { + input = input[2:] + } else if wantPrefix { return nil, ErrMissingPrefix } - input = input[2:] if len(input)%2 != 0 { return nil, ErrOddLength } diff --git a/common/hexutil/json_test.go b/common/hexutil/json_test.go index af7f44915..e4e827491 100644 --- a/common/hexutil/json_test.go +++ b/common/hexutil/json_test.go @@ -337,3 +337,38 @@ func TestUnmarshalUint(t *testing.T) { } } } + +func TestUnmarshalFixedUnprefixedText(t *testing.T) { + tests := []struct { + input string + want []byte + wantErr error + }{ + {input: "0x2", wantErr: ErrOddLength}, + {input: "2", wantErr: ErrOddLength}, + {input: "4444", wantErr: errors.New("hex string has length 4, want 8 for x")}, + {input: "4444", wantErr: errors.New("hex string has length 4, want 8 for x")}, + // check that output is not modified for partially correct input + {input: "444444gg", wantErr: ErrSyntax, want: []byte{0, 0, 0, 0}}, + {input: "0x444444gg", wantErr: ErrSyntax, want: []byte{0, 0, 0, 0}}, + // valid inputs + {input: "44444444", want: []byte{0x44, 0x44, 0x44, 0x44}}, + {input: "0x44444444", want: []byte{0x44, 0x44, 0x44, 0x44}}, + } + + for _, test := range tests { + out := make([]byte, 4) + err := UnmarshalFixedUnprefixedText("x", []byte(test.input), out) + switch { + case err == nil && test.wantErr != nil: + t.Errorf("%q: got no error, expected %q", test.input, test.wantErr) + case err != nil && test.wantErr == nil: + t.Errorf("%q: unexpected error %q", test.input, err) + case err != nil && err.Error() != test.wantErr.Error(): + t.Errorf("%q: error mismatch: got %q, want %q", test.input, err, test.wantErr) + } + if test.want != nil && !bytes.Equal(out, test.want) { + t.Errorf("%q: output mismatch: got %x, want %x", test.input, out, test.want) + } + } +} diff --git a/common/math/big.go b/common/math/big.go index 704ca40a9..5255a88e9 100644 --- a/common/math/big.go +++ b/common/math/big.go @@ -18,6 +18,7 @@ package math import ( + "fmt" "math/big" ) @@ -35,6 +36,27 @@ const ( wordBytes = wordBits / 8 ) +// HexOrDecimal256 marshals big.Int as hex or decimal. +type HexOrDecimal256 big.Int + +// UnmarshalText implements encoding.TextUnmarshaler. +func (i *HexOrDecimal256) UnmarshalText(input []byte) error { + bigint, ok := ParseBig256(string(input)) + if !ok { + return fmt.Errorf("invalid hex or decimal integer %q", input) + } + *i = HexOrDecimal256(*bigint) + return nil +} + +// MarshalText implements encoding.TextMarshaler. +func (i *HexOrDecimal256) MarshalText() ([]byte, error) { + if i == nil { + return []byte("0x0"), nil + } + return []byte(fmt.Sprintf("%#x", (*big.Int)(i))), nil +} + // ParseBig256 parses s as a 256 bit integer in decimal or hexadecimal syntax. // Leading zeros are accepted. The empty string parses as zero. func ParseBig256(s string) (*big.Int, bool) { diff --git a/common/math/big_test.go b/common/math/big_test.go index 6eb13f4f1..deff25465 100644 --- a/common/math/big_test.go +++ b/common/math/big_test.go @@ -23,7 +23,7 @@ import ( "testing" ) -func TestParseBig256(t *testing.T) { +func TestHexOrDecimal256(t *testing.T) { tests := []struct { input string num *big.Int @@ -47,13 +47,14 @@ func TestParseBig256(t *testing.T) { {"115792089237316195423570985008687907853269984665640564039457584007913129639936", nil, false}, } for _, test := range tests { - num, ok := ParseBig256(test.input) - if ok != test.ok { - t.Errorf("ParseBig(%q) -> ok = %t, want %t", test.input, ok, test.ok) + var num HexOrDecimal256 + err := num.UnmarshalText([]byte(test.input)) + if (err == nil) != test.ok { + t.Errorf("ParseBig(%q) -> (err == nil) == %t, want %t", test.input, err == nil, test.ok) continue } - if num != nil && test.num != nil && num.Cmp(test.num) != 0 { - t.Errorf("ParseBig(%q) -> %d, want %d", test.input, num, test.num) + if test.num != nil && (*big.Int)(&num).Cmp(test.num) != 0 { + t.Errorf("ParseBig(%q) -> %d, want %d", test.input, (*big.Int)(&num), test.num) } } } diff --git a/common/math/integer.go b/common/math/integer.go index a3eeee27e..7eff4d3b0 100644 --- a/common/math/integer.go +++ b/common/math/integer.go @@ -16,7 +16,10 @@ package math -import "strconv" +import ( + "fmt" + "strconv" +) const ( // Integer limit values. @@ -34,6 +37,24 @@ const ( MaxUint64 = 1<<64 - 1 ) +// HexOrDecimal64 marshals uint64 as hex or decimal. +type HexOrDecimal64 uint64 + +// UnmarshalText implements encoding.TextUnmarshaler. +func (i *HexOrDecimal64) UnmarshalText(input []byte) error { + int, ok := ParseUint64(string(input)) + if !ok { + return fmt.Errorf("invalid hex or decimal integer %q", input) + } + *i = HexOrDecimal64(int) + return nil +} + +// MarshalText implements encoding.TextMarshaler. +func (i HexOrDecimal64) MarshalText() ([]byte, error) { + return []byte(fmt.Sprintf("%#x", uint64(i))), nil +} + // ParseUint64 parses s as an integer in decimal or hexadecimal syntax. // Leading zeros are accepted. The empty string parses as zero. func ParseUint64(s string) (uint64, bool) { diff --git a/common/math/integer_test.go b/common/math/integer_test.go index 05bba221f..b31c7c26c 100644 --- a/common/math/integer_test.go +++ b/common/math/integer_test.go @@ -65,7 +65,7 @@ func TestOverflow(t *testing.T) { } } -func TestParseUint64(t *testing.T) { +func TestHexOrDecimal64(t *testing.T) { tests := []struct { input string num uint64 @@ -88,12 +88,13 @@ func TestParseUint64(t *testing.T) { {"18446744073709551617", 0, false}, } for _, test := range tests { - num, ok := ParseUint64(test.input) - if ok != test.ok { - t.Errorf("ParseUint64(%q) -> ok = %t, want %t", test.input, ok, test.ok) + var num HexOrDecimal64 + err := num.UnmarshalText([]byte(test.input)) + if (err == nil) != test.ok { + t.Errorf("ParseUint64(%q) -> (err == nil) = %t, want %t", test.input, err == nil, test.ok) continue } - if ok && num != test.num { + if err == nil && uint64(num) != test.num { t.Errorf("ParseUint64(%q) -> %d, want %d", test.input, num, test.num) } } diff --git a/common/types.go b/common/types.go index 9c50beb13..05288bf46 100644 --- a/common/types.go +++ b/common/types.go @@ -17,6 +17,7 @@ package common import ( + "encoding/hex" "fmt" "math/big" "math/rand" @@ -30,13 +31,8 @@ const ( AddressLength = 20 ) -type ( - // Hash represents the 32 byte Keccak256 hash of arbitrary data. - Hash [HashLength]byte - - // Address represents the 20 byte address of an Ethereum account. - Address [AddressLength]byte -) +// Hash represents the 32 byte Keccak256 hash of arbitrary data. +type Hash [HashLength]byte func BytesToHash(b []byte) Hash { var h Hash @@ -113,7 +109,24 @@ func EmptyHash(h Hash) bool { return h == Hash{} } +// UnprefixedHash allows marshaling a Hash without 0x prefix. +type UnprefixedHash Hash + +// UnmarshalText decodes the hash from hex. The 0x prefix is optional. +func (h *UnprefixedHash) UnmarshalText(input []byte) error { + return hexutil.UnmarshalFixedUnprefixedText("UnprefixedHash", input, h[:]) +} + +// MarshalText encodes the hash as hex. +func (h UnprefixedHash) MarshalText() ([]byte, error) { + return []byte(hex.EncodeToString(h[:])), nil +} + /////////// Address + +// Address represents the 20 byte address of an Ethereum account. +type Address [AddressLength]byte + func BytesToAddress(b []byte) Address { var a Address a.SetBytes(b) @@ -181,12 +194,15 @@ func (a *Address) UnmarshalText(input []byte) error { return hexutil.UnmarshalFixedText("Address", input, a[:]) } -// PP Pretty Prints a byte slice in the following format: -// hex(value[:4])...(hex[len(value)-4:]) -func PP(value []byte) string { - if len(value) <= 8 { - return Bytes2Hex(value) - } +// UnprefixedHash allows marshaling an Address without 0x prefix. +type UnprefixedAddress Address + +// UnmarshalText decodes the address from hex. The 0x prefix is optional. +func (a *UnprefixedAddress) UnmarshalText(input []byte) error { + return hexutil.UnmarshalFixedUnprefixedText("UnprefixedAddress", input, a[:]) +} - return fmt.Sprintf("%x...%x", value[:4], value[len(value)-4]) +// MarshalText encodes the address as hex. +func (a UnprefixedAddress) MarshalText() ([]byte, error) { + return []byte(hex.EncodeToString(a[:])), nil } |