aboutsummaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
-rw-r--r--rlp/decode.go43
-rw-r--r--rlp/decode_test.go11
2 files changed, 40 insertions, 14 deletions
diff --git a/rlp/decode.go b/rlp/decode.go
index 6952ecaea..0c660426f 100644
--- a/rlp/decode.go
+++ b/rlp/decode.go
@@ -820,6 +820,16 @@ func (s *Stream) Kind() (kind Kind, size uint64, err error) {
func (s *Stream) readKind() (kind Kind, size uint64, err error) {
b, err := s.readByte()
if err != nil {
+ if len(s.stack) == 0 {
+ // At toplevel, Adjust the error to actual EOF. io.EOF is
+ // used by callers to determine when to stop decoding.
+ switch err {
+ case io.ErrUnexpectedEOF:
+ err = io.EOF
+ case ErrValueTooLarge:
+ err = io.EOF
+ }
+ }
return 0, 0, err
}
s.byteval = 0
@@ -876,9 +886,6 @@ func (s *Stream) readUint(size byte) (uint64, error) {
return 0, nil
case 1:
b, err := s.readByte()
- if err == io.EOF {
- err = io.ErrUnexpectedEOF
- }
return uint64(b), err
default:
start := int(8 - size)
@@ -899,10 +906,9 @@ func (s *Stream) readUint(size byte) (uint64, error) {
}
func (s *Stream) readFull(buf []byte) (err error) {
- if s.limited && s.remaining < uint64(len(buf)) {
- return ErrValueTooLarge
+ if err := s.willRead(uint64(len(buf))); err != nil {
+ return err
}
- s.willRead(uint64(len(buf)))
var nn, n int
for n < len(buf) && err == nil {
nn, err = s.r.Read(buf[n:])
@@ -915,23 +921,32 @@ func (s *Stream) readFull(buf []byte) (err error) {
}
func (s *Stream) readByte() (byte, error) {
- if s.limited && s.remaining == 0 {
- return 0, io.EOF
+ if err := s.willRead(1); err != nil {
+ return 0, err
}
- s.willRead(1)
b, err := s.r.ReadByte()
- if len(s.stack) > 0 && err == io.EOF {
+ if err == io.EOF {
err = io.ErrUnexpectedEOF
}
return b, err
}
-func (s *Stream) willRead(n uint64) {
+func (s *Stream) willRead(n uint64) error {
s.kind = -1 // rearm Kind
- if s.limited {
- s.remaining -= n
- }
+
if len(s.stack) > 0 {
+ // check list overflow
+ tos := s.stack[len(s.stack)-1]
+ if n > tos.size-tos.pos {
+ return ErrElemTooLarge
+ }
s.stack[len(s.stack)-1].pos += n
}
+ if s.limited {
+ if n > s.remaining {
+ return ErrValueTooLarge
+ }
+ s.remaining -= n
+ }
+ return nil
}
diff --git a/rlp/decode_test.go b/rlp/decode_test.go
index d07520bd0..ae65346a9 100644
--- a/rlp/decode_test.go
+++ b/rlp/decode_test.go
@@ -119,6 +119,10 @@ func TestStreamErrors(t *testing.T) {
{"8158", calls{"Uint", "Uint"}, nil, io.EOF},
{"C0", calls{"List", "ListEnd", "List"}, nil, io.EOF},
+ {"", calls{"List"}, withoutInputLimit, io.EOF},
+ {"8158", calls{"Uint", "Uint"}, withoutInputLimit, io.EOF},
+ {"C0", calls{"List", "ListEnd", "List"}, withoutInputLimit, io.EOF},
+
// Input limit errors.
{"81", calls{"Bytes"}, nil, ErrValueTooLarge},
{"81", calls{"Uint"}, nil, ErrValueTooLarge},
@@ -426,6 +430,13 @@ var decodeTests = []decodeTest{
ptr: new([]io.Reader),
error: "rlp: type io.Reader is not RLP-serializable",
},
+
+ // fuzzer crashes
+ {
+ input: "c330f9c030f93030ce3030303030303030bd303030303030",
+ ptr: new(interface{}),
+ error: "rlp: element is larger than containing list",
+ },
}
func uintp(i uint) *uint { return &i }