diff options
Diffstat (limited to 'rlp/decode.go')
-rw-r--r-- | rlp/decode.go | 109 |
1 files changed, 91 insertions, 18 deletions
diff --git a/rlp/decode.go b/rlp/decode.go index 3b5617475..ca9252575 100644 --- a/rlp/decode.go +++ b/rlp/decode.go @@ -9,6 +9,7 @@ import ( "io" "math/big" "reflect" + "strings" ) var ( @@ -70,14 +71,22 @@ type Decoder interface { // Non-empty interface types are not supported, nor are booleans, // signed integers, floating point numbers, maps, channels and // functions. +// +// Note that Decode does not set an input limit for all readers +// and may be vulnerable to panics cause by huge value sizes. If +// you need an input limit, use +// +// NewStream(r, limit).Decode(val) func Decode(r io.Reader, val interface{}) error { - return NewStream(r).Decode(val) + // TODO: this could use a Stream from a pool. + return NewStream(r, 0).Decode(val) } // DecodeBytes parses RLP data from b into val. // Please see the documentation of Decode for the decoding rules. func DecodeBytes(b []byte, val interface{}) error { - return NewStream(bytes.NewReader(b)).Decode(val) + // TODO: this could use a Stream from a pool. + return NewStream(bytes.NewReader(b), uint64(len(b))).Decode(val) } type decodeError struct { @@ -470,10 +479,12 @@ var ( ErrExpectedList = errors.New("rlp: expected List") ErrCanonInt = errors.New("rlp: expected Int") ErrElemTooLarge = errors.New("rlp: element is larger than containing list") + ErrValueTooLarge = errors.New("rlp: value size exceeds available input length") // internal errors - errNotInList = errors.New("rlp: call of ListEnd outside of any list") - errNotAtEOL = errors.New("rlp: call of ListEnd not positioned at EOL") + errNotInList = errors.New("rlp: call of ListEnd outside of any list") + errNotAtEOL = errors.New("rlp: call of ListEnd not positioned at EOL") + errUintOverflow = errors.New("rlp: uint overflow") ) // ByteReader must be implemented by any input reader for a Stream. It @@ -496,7 +507,13 @@ type ByteReader interface { // // Stream is not safe for concurrent use. type Stream struct { - r ByteReader + r ByteReader + + // number of bytes remaining to be read from r. + remaining uint64 + limited bool + + // auxiliary buffer for integer decoding uintbuf []byte kind Kind // kind of value ahead @@ -507,12 +524,26 @@ type Stream struct { type listpos struct{ pos, size uint64 } -// NewStream creates a new stream reading from r. -// If r does not implement ByteReader, the Stream will -// introduce its own buffering. -func NewStream(r io.Reader) *Stream { +// NewStream creates a new decoding stream reading from r. +// +// If r implements the ByteReader interface, Stream will +// not introduce any buffering. +// +// For non-toplevel values, Stream returns ErrElemTooLarge +// for values that do not fit into the enclosing list. +// +// Stream supports an optional input limit. If a limit is set, the +// size of any toplevel value will be checked against the remaining +// input length. Stream operations that encounter a value exceeding +// the remaining input length will return ErrValueTooLarge. The limit +// can be set by passing a non-zero value for inputLimit. +// +// If r is a bytes.Reader or strings.Reader, the input limit is set to +// the length of r's underlying data unless an explicit limit is +// provided. +func NewStream(r io.Reader, inputLimit uint64) *Stream { s := new(Stream) - s.Reset(r) + s.Reset(r, inputLimit) return s } @@ -520,7 +551,7 @@ func NewStream(r io.Reader) *Stream { // at an encoded list of the given length. func NewListStream(r io.Reader, len uint64) *Stream { s := new(Stream) - s.Reset(r) + s.Reset(r, len) s.kind = List s.size = len return s @@ -574,8 +605,6 @@ func (s *Stream) Raw() ([]byte, error) { return buf, nil } -var errUintOverflow = errors.New("rlp: uint overflow") - // Uint reads an RLP string of up to 8 bytes and returns its contents // as an unsigned integer. If the input does not contain an RLP string, the // returned error will be ErrExpectedString. @@ -667,14 +696,36 @@ func (s *Stream) Decode(val interface{}) error { } // Reset discards any information about the current decoding context -// and starts reading from r. If r does not also implement ByteReader, -// Stream will do its own buffering. -func (s *Stream) Reset(r io.Reader) { +// and starts reading from r. This method is meant to facilitate reuse +// of a preallocated Stream across many decoding operations. +// +// If r does not also implement ByteReader, Stream will do its own +// buffering. +func (s *Stream) Reset(r io.Reader, inputLimit uint64) { + if inputLimit > 0 { + s.remaining = inputLimit + s.limited = true + } else { + // Attempt to automatically discover + // the limit when reading from a byte slice. + switch br := r.(type) { + case *bytes.Reader: + s.remaining = uint64(br.Len()) + s.limited = true + case *strings.Reader: + s.remaining = uint64(br.Len()) + s.limited = true + default: + s.limited = false + } + } + // Wrap r with a buffer if it doesn't have one. bufr, ok := r.(ByteReader) if !ok { bufr = bufio.NewReader(r) } s.r = bufr + // Reset the decoding context. s.stack = s.stack[:0] s.size = 0 s.kind = -1 @@ -700,6 +751,8 @@ func (s *Stream) Kind() (kind Kind, size uint64, err error) { tos = &s.stack[len(s.stack)-1] } if s.kind < 0 { + // don't read further if we're at the end of the + // innermost list. if tos != nil && tos.pos == tos.size { return 0, 0, EOL } @@ -709,8 +762,19 @@ func (s *Stream) Kind() (kind Kind, size uint64, err error) { } s.kind, s.size = kind, size } - if tos != nil && tos.pos+s.size > tos.size { - return 0, 0, ErrElemTooLarge + // Make sure size is reasonable. This is done always + // so Kind returns the same error when called multiple times. + if tos == nil { + // At toplevel, check that the value is smaller + // than the remaining input length. + if s.limited && s.size > s.remaining { + return 0, 0, ErrValueTooLarge + } + } else { + // Inside a list, check that the value doesn't overflow the list. + if tos.pos+s.size > tos.size { + return 0, 0, ErrElemTooLarge + } } return s.kind, s.size, nil } @@ -778,6 +842,9 @@ func (s *Stream) readUint(size byte) (uint64, error) { } func (s *Stream) readFull(buf []byte) (err error) { + if s.limited && s.remaining < uint64(len(buf)) { + return ErrValueTooLarge + } s.willRead(uint64(len(buf))) var nn, n int for n < len(buf) && err == nil { @@ -791,6 +858,9 @@ func (s *Stream) readFull(buf []byte) (err error) { } func (s *Stream) readByte() (byte, error) { + if s.limited && s.remaining == 0 { + return 0, io.EOF + } s.willRead(1) b, err := s.r.ReadByte() if len(s.stack) > 0 && err == io.EOF { @@ -801,6 +871,9 @@ func (s *Stream) readByte() (byte, error) { func (s *Stream) willRead(n uint64) { s.kind = -1 // rearm Kind + if s.limited { + s.remaining -= n + } if len(s.stack) > 0 { s.stack[len(s.stack)-1].pos += n } |