diff options
author | obscuren <geffobscura@gmail.com> | 2014-12-05 23:27:37 +0800 |
---|---|---|
committer | obscuren <geffobscura@gmail.com> | 2014-12-05 23:27:37 +0800 |
commit | 195b2d2ebdfe1e794c096c9b6c9fe91e9df7af1d (patch) | |
tree | 9b1a96d7c65c4dcb5515eddf6c2b03e7e6b0d397 /rlp | |
parent | 710360bab61178cf7fbc52213ec4c612be37ad18 (diff) | |
parent | 384b8c75f07b6811aa3012ad52a44844b3ab6e52 (diff) | |
download | go-tangerine-195b2d2ebdfe1e794c096c9b6c9fe91e9df7af1d.tar.gz go-tangerine-195b2d2ebdfe1e794c096c9b6c9fe91e9df7af1d.tar.zst go-tangerine-195b2d2ebdfe1e794c096c9b6c9fe91e9df7af1d.zip |
Merge branch 'fjl-feature/p2p-protocol-interface' into poc8
Diffstat (limited to 'rlp')
-rw-r--r-- | rlp/decode.go | 81 | ||||
-rw-r--r-- | rlp/decode_test.go | 89 |
2 files changed, 135 insertions, 35 deletions
diff --git a/rlp/decode.go b/rlp/decode.go index 96d912f56..7d95af02b 100644 --- a/rlp/decode.go +++ b/rlp/decode.go @@ -1,6 +1,7 @@ package rlp import ( + "bufio" "encoding/binary" "errors" "fmt" @@ -24,8 +25,9 @@ type Decoder interface { DecodeRLP(*Stream) error } -// Decode parses RLP-encoded data from r and stores the result -// in the value pointed to by val. Val must be a non-nil pointer. +// Decode parses RLP-encoded data from r and stores the result in the +// value pointed to by val. Val must be a non-nil pointer. If r does +// not implement ByteReader, Decode will do its own buffering. // // Decode uses the following type-dependent decoding rules: // @@ -66,10 +68,19 @@ type Decoder interface { // // Non-empty interface types are not supported, nor are bool, float32, // float64, maps, channel types and functions. -func Decode(r ByteReader, val interface{}) error { +func Decode(r io.Reader, val interface{}) error { return NewStream(r).Decode(val) } +type decodeError struct { + msg string + typ reflect.Type +} + +func (err decodeError) Error() string { + return fmt.Sprintf("rlp: %s for %v", err.msg, err.typ) +} + func makeNumDecoder(typ reflect.Type) decoder { kind := typ.Kind() switch { @@ -83,8 +94,11 @@ func makeNumDecoder(typ reflect.Type) decoder { } func decodeInt(s *Stream, val reflect.Value) error { - num, err := s.uint(val.Type().Bits()) - if err != nil { + typ := val.Type() + num, err := s.uint(typ.Bits()) + if err == errUintOverflow { + return decodeError{"input string too long", typ} + } else if err != nil { return err } val.SetInt(int64(num)) @@ -92,8 +106,11 @@ func decodeInt(s *Stream, val reflect.Value) error { } func decodeUint(s *Stream, val reflect.Value) error { - num, err := s.uint(val.Type().Bits()) - if err != nil { + typ := val.Type() + num, err := s.uint(typ.Bits()) + if err == errUintOverflow { + return decodeError{"input string too big", typ} + } else if err != nil { return err } val.SetUint(num) @@ -175,7 +192,7 @@ func decodeList(s *Stream, val reflect.Value, elemdec decoder, maxelem int) erro i := 0 for { if i > maxelem { - return fmt.Errorf("rlp: input List has more than %d elements", maxelem) + return decodeError{"input list has too many elements", val.Type()} } if val.Kind() == reflect.Slice { // grow slice if necessary @@ -226,8 +243,6 @@ func decodeByteSlice(s *Stream, val reflect.Value) error { return err } -var errStringDoesntFitArray = errors.New("rlp: string value doesn't fit into target array") - func decodeByteArray(s *Stream, val reflect.Value) error { kind, size, err := s.Kind() if err != nil { @@ -236,14 +251,14 @@ func decodeByteArray(s *Stream, val reflect.Value) error { switch kind { case Byte: if val.Len() == 0 { - return errStringDoesntFitArray + return decodeError{"input string too big", val.Type()} } bv, _ := s.Uint() val.Index(0).SetUint(bv) zero(val, 1) case String: if uint64(val.Len()) < size { - return errStringDoesntFitArray + return decodeError{"input string too big", val.Type()} } slice := val.Slice(0, int(size)).Interface().([]byte) if err := s.readFull(slice); err != nil { @@ -293,7 +308,7 @@ func makeStructDecoder(typ reflect.Type) (decoder, error) { } } if err = s.ListEnd(); err == errNotAtEOL { - err = errors.New("rlp: input List has too many elements") + err = decodeError{"input list has too many elements", typ} } return err } @@ -432,8 +447,23 @@ type Stream struct { type listpos struct{ pos, size uint64 } -func NewStream(r ByteReader) *Stream { - return &Stream{r: r, uintbuf: make([]byte, 8), kind: -1} +// NewStream creates a new stream reading from r. +// If r does not implement ByteReader, the Stream will +// introduce its own buffering. +func NewStream(r io.Reader) *Stream { + s := new(Stream) + s.Reset(r) + return s +} + +// NewListStream creates a new stream that pretends to be positioned +// at an encoded list of the given length. +func NewListStream(r io.Reader, len uint64) *Stream { + s := new(Stream) + s.Reset(r) + s.kind = List + s.size = len + return s } // Bytes reads an RLP string and returns its contents as a byte slice. @@ -459,6 +489,8 @@ func (s *Stream) Bytes() ([]byte, error) { } } +var errUintOverflow = errors.New("rlp: uint overflow") + // Uint reads an RLP string of up to 8 bytes and returns its contents // as an unsigned integer. If the input does not contain an RLP string, the // returned error will be ErrExpectedString. @@ -477,7 +509,7 @@ func (s *Stream) uint(maxbits int) (uint64, error) { return uint64(s.byteval), nil case String: if size > uint64(maxbits/8) { - return 0, fmt.Errorf("rlp: string is larger than %d bits", maxbits) + return 0, errUintOverflow } return s.readUint(byte(size)) default: @@ -543,6 +575,23 @@ func (s *Stream) Decode(val interface{}) error { return info.decoder(s, rval.Elem()) } +// Reset discards any information about the current decoding context +// and starts reading from r. If r does not also implement ByteReader, +// Stream will do its own buffering. +func (s *Stream) Reset(r io.Reader) { + bufr, ok := r.(ByteReader) + if !ok { + bufr = bufio.NewReader(r) + } + s.r = bufr + s.stack = s.stack[:0] + s.size = 0 + s.kind = -1 + if s.uintbuf == nil { + s.uintbuf = make([]byte, 8) + } +} + // Kind returns the kind and size of the next value in the // input stream. // diff --git a/rlp/decode_test.go b/rlp/decode_test.go index eb1618299..3b60234dd 100644 --- a/rlp/decode_test.go +++ b/rlp/decode_test.go @@ -3,7 +3,6 @@ package rlp import ( "bytes" "encoding/hex" - "errors" "fmt" "io" "math/big" @@ -54,6 +53,24 @@ func TestStreamKind(t *testing.T) { } } +func TestNewListStream(t *testing.T) { + ls := NewListStream(bytes.NewReader(unhex("0101010101")), 3) + if k, size, err := ls.Kind(); k != List || size != 3 || err != nil { + t.Errorf("Kind() returned (%v, %d, %v), expected (List, 3, nil)", k, size, err) + } + if size, err := ls.List(); size != 3 || err != nil { + t.Errorf("List() returned (%d, %v), expected (3, nil)", size, err) + } + for i := 0; i < 3; i++ { + if val, err := ls.Uint(); val != 1 || err != nil { + t.Errorf("Uint() returned (%d, %v), expected (1, nil)", val, err) + } + } + if err := ls.ListEnd(); err != nil { + t.Errorf("ListEnd() returned %v, expected (3, nil)", err) + } +} + func TestStreamErrors(t *testing.T) { type calls []string tests := []struct { @@ -69,7 +86,7 @@ func TestStreamErrors(t *testing.T) { {"81", calls{"Bytes"}, io.ErrUnexpectedEOF}, {"81", calls{"Uint"}, io.ErrUnexpectedEOF}, {"BFFFFFFFFFFFFFFF", calls{"Bytes"}, io.ErrUnexpectedEOF}, - {"89000000000000000001", calls{"Uint"}, errors.New("rlp: string is larger than 64 bits")}, + {"89000000000000000001", calls{"Uint"}, errUintOverflow}, {"00", calls{"List"}, ErrExpectedList}, {"80", calls{"List"}, ErrExpectedList}, {"C0", calls{"List", "Uint"}, EOL}, @@ -163,7 +180,7 @@ type decodeTest struct { input string ptr interface{} value interface{} - error error + error string } type simplestruct struct { @@ -196,8 +213,8 @@ var decodeTests = []decodeTest{ {input: "820505", ptr: new(uint32), value: uint32(0x0505)}, {input: "83050505", ptr: new(uint32), value: uint32(0x050505)}, {input: "8405050505", ptr: new(uint32), value: uint32(0x05050505)}, - {input: "850505050505", ptr: new(uint32), error: errors.New("rlp: string is larger than 32 bits")}, - {input: "C0", ptr: new(uint32), error: ErrExpectedString}, + {input: "850505050505", ptr: new(uint32), error: "rlp: input string too big for uint32"}, + {input: "C0", ptr: new(uint32), error: ErrExpectedString.Error()}, // slices {input: "C0", ptr: new([]int), value: []int{}}, @@ -206,7 +223,7 @@ var decodeTests = []decodeTest{ // arrays {input: "C0", ptr: new([5]int), value: [5]int{}}, {input: "C50102030405", ptr: new([5]int), value: [5]int{1, 2, 3, 4, 5}}, - {input: "C6010203040506", ptr: new([5]int), error: errors.New("rlp: input List has more than 5 elements")}, + {input: "C6010203040506", ptr: new([5]int), error: "rlp: input list has too many elements for [5]int"}, // byte slices {input: "01", ptr: new([]byte), value: []byte{1}}, @@ -214,7 +231,7 @@ var decodeTests = []decodeTest{ {input: "8D6162636465666768696A6B6C6D", ptr: new([]byte), value: []byte("abcdefghijklm")}, {input: "C0", ptr: new([]byte), value: []byte{}}, {input: "C3010203", ptr: new([]byte), value: []byte{1, 2, 3}}, - {input: "C3820102", ptr: new([]byte), error: errors.New("rlp: string is larger than 8 bits")}, + {input: "C3820102", ptr: new([]byte), error: "rlp: input string too big for uint8"}, // byte arrays {input: "01", ptr: new([5]byte), value: [5]byte{1}}, @@ -222,9 +239,9 @@ var decodeTests = []decodeTest{ {input: "850102030405", ptr: new([5]byte), value: [5]byte{1, 2, 3, 4, 5}}, {input: "C0", ptr: new([5]byte), value: [5]byte{}}, {input: "C3010203", ptr: new([5]byte), value: [5]byte{1, 2, 3, 0, 0}}, - {input: "C3820102", ptr: new([5]byte), error: errors.New("rlp: string is larger than 8 bits")}, - {input: "86010203040506", ptr: new([5]byte), error: errStringDoesntFitArray}, - {input: "850101", ptr: new([5]byte), error: io.ErrUnexpectedEOF}, + {input: "C3820102", ptr: new([5]byte), error: "rlp: input string too big for uint8"}, + {input: "86010203040506", ptr: new([5]byte), error: "rlp: input string too big for [5]uint8"}, + {input: "850101", ptr: new([5]byte), error: io.ErrUnexpectedEOF.Error()}, // byte array reuse (should be zeroed) {input: "850102030405", ptr: &sharedByteArray, value: [5]byte{1, 2, 3, 4, 5}}, @@ -237,25 +254,25 @@ var decodeTests = []decodeTest{ // zero sized byte arrays {input: "80", ptr: new([0]byte), value: [0]byte{}}, {input: "C0", ptr: new([0]byte), value: [0]byte{}}, - {input: "01", ptr: new([0]byte), error: errStringDoesntFitArray}, - {input: "8101", ptr: new([0]byte), error: errStringDoesntFitArray}, + {input: "01", ptr: new([0]byte), error: "rlp: input string too big for [0]uint8"}, + {input: "8101", ptr: new([0]byte), error: "rlp: input string too big for [0]uint8"}, // strings {input: "00", ptr: new(string), value: "\000"}, {input: "8D6162636465666768696A6B6C6D", ptr: new(string), value: "abcdefghijklm"}, - {input: "C0", ptr: new(string), error: ErrExpectedString}, + {input: "C0", ptr: new(string), error: ErrExpectedString.Error()}, // big ints {input: "01", ptr: new(*big.Int), value: big.NewInt(1)}, {input: "89FFFFFFFFFFFFFFFFFF", ptr: new(*big.Int), value: veryBigInt}, {input: "10", ptr: new(big.Int), value: *big.NewInt(16)}, // non-pointer also works - {input: "C0", ptr: new(*big.Int), error: ErrExpectedString}, + {input: "C0", ptr: new(*big.Int), error: ErrExpectedString.Error()}, // structs {input: "C0", ptr: new(simplestruct), value: simplestruct{0, ""}}, {input: "C105", ptr: new(simplestruct), value: simplestruct{5, ""}}, {input: "C50583343434", ptr: new(simplestruct), value: simplestruct{5, "444"}}, - {input: "C3010101", ptr: new(simplestruct), error: errors.New("rlp: input List has too many elements")}, + {input: "C3010101", ptr: new(simplestruct), error: "rlp: input list has too many elements for rlp.simplestruct"}, { input: "C501C302C103", ptr: new(recstruct), @@ -286,20 +303,20 @@ var decodeTests = []decodeTest{ func intp(i int) *int { return &i } -func TestDecode(t *testing.T) { +func runTests(t *testing.T, decode func([]byte, interface{}) error) { for i, test := range decodeTests { input, err := hex.DecodeString(test.input) if err != nil { t.Errorf("test %d: invalid hex input %q", i, test.input) continue } - err = Decode(bytes.NewReader(input), test.ptr) - if err != nil && test.error == nil { + err = decode(input, test.ptr) + if err != nil && test.error == "" { t.Errorf("test %d: unexpected Decode error: %v\ndecoding into %T\ninput %q", i, err, test.ptr, test.input) continue } - if test.error != nil && fmt.Sprint(err) != fmt.Sprint(test.error) { + if test.error != "" && fmt.Sprint(err) != test.error { t.Errorf("test %d: Decode error mismatch\ngot %v\nwant %v\ndecoding into %T\ninput %q", i, err, test.error, test.ptr, test.input) continue @@ -312,6 +329,40 @@ func TestDecode(t *testing.T) { } } +func TestDecodeWithByteReader(t *testing.T) { + runTests(t, func(input []byte, into interface{}) error { + return Decode(bytes.NewReader(input), into) + }) +} + +// dumbReader reads from a byte slice but does not +// implement ReadByte. +type dumbReader []byte + +func (r *dumbReader) Read(buf []byte) (n int, err error) { + if len(*r) == 0 { + return 0, io.EOF + } + n = copy(buf, *r) + *r = (*r)[n:] + return n, nil +} + +func TestDecodeWithNonByteReader(t *testing.T) { + runTests(t, func(input []byte, into interface{}) error { + r := dumbReader(input) + return Decode(&r, into) + }) +} + +func TestDecodeStreamReset(t *testing.T) { + s := NewStream(nil) + runTests(t, func(input []byte, into interface{}) error { + s.Reset(bytes.NewReader(input)) + return s.Decode(into) + }) +} + type testDecoder struct{ called bool } func (t *testDecoder) DecodeRLP(s *Stream) error { |