diff --git a/compression/varint.go b/compression/varint.go index 04fd83a..177743f 100644 --- a/compression/varint.go +++ b/compression/varint.go @@ -14,7 +14,7 @@ const ( // PutVarint encodes an int32 into buf and returns the number of bytes written. // If the buffer is too small, PutVarint will panic. -// Format: ESDDDDDD EDDDDDDD EDD... Extended, Data, Sign +// Format: ESDDDDDD EDDDDDDD EDD... Extended, Sign, Data // E: is next byte part of the current integer // S: Sign of integer // Data, Integer bits that follow the sign @@ -91,12 +91,17 @@ func Varint(buf []byte) (i int, n int) { for i, b := range buf { index++ // overflow check - if i == maxAllowedLen-1 && b >= 0b10000000 { + // 1 sign bit + 6 data bits = 7 bits + // 7 bits * 4 bytes = 28 bits + // 7 + 28 = 35 bits, 3 too many + // last byte can only have 4 bits + if i == maxAllowedLen-1 && b > 0b00001111 { return 0, -(i + 1) } value |= int(b&0b01111111) << (6 + 7*i) if b < 0b10000000 { + // no extend bit set break } } @@ -145,9 +150,14 @@ func ReadVarint(r io.ByteReader) (int, error) { } return value, nil } - index++ - if i == maxAllowedLen-1 && b >= 0b10000000 { + + // overflow check + // 1 sign bit + 6 data bits = 7 bits + // 7 bits * 4 bytes = 28 bits + // 7 + 28 = 35 bits, 3 too many + // last byte can only have 4 bits + if i == maxAllowedLen-1 && b > 0b00001111 { return 0, errors.New("overflow due to invalid last byte") } diff --git a/compression/varint_test.go b/compression/varint_test.go index 723bb6a..61e269b 100644 --- a/compression/varint_test.go +++ b/compression/varint_test.go @@ -3,6 +3,7 @@ package compression import ( "bytes" "io" + "math" "testing" "github.com/stretchr/testify/require" @@ -15,12 +16,14 @@ func varIntWriteRead(t *testing.T, inNumber int, expectedBytes int) { written := PutVarint(buf, inNumber) require.Equal(expectedBytes, written) out, read := Varint(buf) + require.GreaterOrEqual(read, 1, "read must be at least 0") require.Equal(inNumber, out, "out == in") require.Equal(written, read, "read == written") // buf := buf[:written] } -func TestVarint(t *testing.T) { +func TestVarintBoundaries(t *testing.T) { + t.Parallel() varIntWriteRead(t, 63, 1) varIntWriteRead(t, 64, 2) @@ -30,6 +33,15 @@ func TestVarint(t *testing.T) { varIntWriteRead(t, 134217728-1, 4) // 2^(6+7+7+7) -1 varIntWriteRead(t, 134217728, 5) // 2^(6+7+7+7) + // int32 boundaries + varIntWriteRead(t, math.MaxInt32, 5) // 2^31 -1 = 2147483647 + varIntWriteRead(t, math.MinInt32, 5) // -2^31 = -2147483648 + +} + +func TestVarintExtensive(t *testing.T) { + t.Parallel() + require := require.New(t) for in := -20_000_000; in < 20_000_000; in++ { var ( @@ -38,26 +50,9 @@ func TestVarint(t *testing.T) { written = PutVarint(buf, in) out, read = Varint(buf) ) + require.GreaterOrEqual(0, read, "read must be at least 0") require.Equal(in, out, "in/out") require.Equal(written, read, "written/read") - require.GreaterOrEqual(read, 0, "read must be at least 0") - } -} - -func TestReadVarint(t *testing.T) { - require := require.New(t) - - buf := []byte{} - - for in := -2_000_000; in < 2_000_000; in++ { - buf = AppendVarint(buf, in) - } - - b := bytes.NewBuffer(buf) - for in := -2_000_000; in < 2_000_000; in++ { - out, err := ReadVarint(b) - require.NoError(err) - require.Equal(in, out, "out != in") } }