aboutsummaryrefslogtreecommitdiffstats
path: root/common/hexutil
diff options
context:
space:
mode:
authorFelix Lange <fjl@twurst.com>2017-03-08 06:19:27 +0800
committerFelix Lange <fjl@twurst.com>2017-03-23 22:58:42 +0800
commitb4547a560b861e2e5463bf6fed6d61958c4e9411 (patch)
tree9b7e2171f4e8b9671ee3b67c5c591b440a2c0a94 /common/hexutil
parent04fa6a374499dcefeb3f854c4cf6cfcdfb6c8c76 (diff)
downloaddexon-b4547a560b861e2e5463bf6fed6d61958c4e9411.tar.gz
dexon-b4547a560b861e2e5463bf6fed6d61958c4e9411.tar.zst
dexon-b4547a560b861e2e5463bf6fed6d61958c4e9411.zip
common/hexutil: add UnmarshalFixedUnprefixedText
Diffstat (limited to 'common/hexutil')
-rw-r--r--common/hexutil/json.go32
-rw-r--r--common/hexutil/json_test.go35
2 files changed, 62 insertions, 5 deletions
diff --git a/common/hexutil/json.go b/common/hexutil/json.go
index 23393ed2c..1bc1d014c 100644
--- a/common/hexutil/json.go
+++ b/common/hexutil/json.go
@@ -51,7 +51,7 @@ func (b *Bytes) UnmarshalJSON(input []byte) error {
// UnmarshalText implements encoding.TextUnmarshaler.
func (b *Bytes) UnmarshalText(input []byte) error {
- raw, err := checkText(input)
+ raw, err := checkText(input, true)
if err != nil {
return err
}
@@ -73,7 +73,28 @@ func (b Bytes) String() string {
// determines the required input length. This function is commonly used to implement the
// UnmarshalText method for fixed-size types.
func UnmarshalFixedText(typname string, input, out []byte) error {
- raw, err := checkText(input)
+ raw, err := checkText(input, true)
+ if err != nil {
+ return err
+ }
+ if len(raw)/2 != len(out) {
+ return fmt.Errorf("hex string has length %d, want %d for %s", len(raw), len(out)*2, typname)
+ }
+ // Pre-verify syntax before modifying out.
+ for _, b := range raw {
+ if decodeNibble(b) == badNibble {
+ return ErrSyntax
+ }
+ }
+ hex.Decode(out, raw)
+ return nil
+}
+
+// UnmarshalFixedUnprefixedText decodes the input as a string with optional 0x prefix. The
+// length of out determines the required input length. This function is commonly used to
+// implement the UnmarshalText method for fixed-size types.
+func UnmarshalFixedUnprefixedText(typname string, input, out []byte) error {
+ raw, err := checkText(input, false)
if err != nil {
return err
}
@@ -243,14 +264,15 @@ func bytesHave0xPrefix(input []byte) bool {
return len(input) >= 2 && input[0] == '0' && (input[1] == 'x' || input[1] == 'X')
}
-func checkText(input []byte) ([]byte, error) {
+func checkText(input []byte, wantPrefix bool) ([]byte, error) {
if len(input) == 0 {
return nil, nil // empty strings are allowed
}
- if !bytesHave0xPrefix(input) {
+ if bytesHave0xPrefix(input) {
+ input = input[2:]
+ } else if wantPrefix {
return nil, ErrMissingPrefix
}
- input = input[2:]
if len(input)%2 != 0 {
return nil, ErrOddLength
}
diff --git a/common/hexutil/json_test.go b/common/hexutil/json_test.go
index af7f44915..e4e827491 100644
--- a/common/hexutil/json_test.go
+++ b/common/hexutil/json_test.go
@@ -337,3 +337,38 @@ func TestUnmarshalUint(t *testing.T) {
}
}
}
+
+func TestUnmarshalFixedUnprefixedText(t *testing.T) {
+ tests := []struct {
+ input string
+ want []byte
+ wantErr error
+ }{
+ {input: "0x2", wantErr: ErrOddLength},
+ {input: "2", wantErr: ErrOddLength},
+ {input: "4444", wantErr: errors.New("hex string has length 4, want 8 for x")},
+ {input: "4444", wantErr: errors.New("hex string has length 4, want 8 for x")},
+ // check that output is not modified for partially correct input
+ {input: "444444gg", wantErr: ErrSyntax, want: []byte{0, 0, 0, 0}},
+ {input: "0x444444gg", wantErr: ErrSyntax, want: []byte{0, 0, 0, 0}},
+ // valid inputs
+ {input: "44444444", want: []byte{0x44, 0x44, 0x44, 0x44}},
+ {input: "0x44444444", want: []byte{0x44, 0x44, 0x44, 0x44}},
+ }
+
+ for _, test := range tests {
+ out := make([]byte, 4)
+ err := UnmarshalFixedUnprefixedText("x", []byte(test.input), out)
+ switch {
+ case err == nil && test.wantErr != nil:
+ t.Errorf("%q: got no error, expected %q", test.input, test.wantErr)
+ case err != nil && test.wantErr == nil:
+ t.Errorf("%q: unexpected error %q", test.input, err)
+ case err != nil && err.Error() != test.wantErr.Error():
+ t.Errorf("%q: error mismatch: got %q, want %q", test.input, err, test.wantErr)
+ }
+ if test.want != nil && !bytes.Equal(out, test.want) {
+ t.Errorf("%q: output mismatch: got %x, want %x", test.input, out, test.want)
+ }
+ }
+}