From 0d3af0a0fae643a94bae979fc8ce9bfe652617c9 Mon Sep 17 00:00:00 2001 From: John Behm Date: Fri, 5 Jul 2024 21:41:58 +0200 Subject: [PATCH] remove testify dependency, use own test utilities --- browser/browser_test.go | 16 +-- compression/huffman_test.go | 4 +- compression/packer_test.go | 174 ++++++++++++++++++++++++- compression/unpacker.go | 102 ++++++++++++++- compression/varint_test.go | 6 +- go.mod | 12 +- go.sum | 10 -- internal/testutils/require/compare.go | 101 +++++++++++++++ internal/testutils/require/fail.go | 31 +++++ internal/testutils/require/format.go | 71 +++++++++++ internal/testutils/require/helpers.go | 67 ++++++++++ internal/testutils/require/require.go | 176 ++++++++++++++++++++++++++ 12 files changed, 730 insertions(+), 40 deletions(-) create mode 100644 internal/testutils/require/compare.go create mode 100644 internal/testutils/require/fail.go create mode 100644 internal/testutils/require/format.go create mode 100644 internal/testutils/require/helpers.go create mode 100644 internal/testutils/require/require.go diff --git a/browser/browser_test.go b/browser/browser_test.go index cefebec..757bdfe 100644 --- a/browser/browser_test.go +++ b/browser/browser_test.go @@ -5,7 +5,7 @@ import ( "time" "github.com/jxsl13/twapi/browser" - "github.com/stretchr/testify/require" + "github.com/jxsl13/twapi/internal/testutils/require" ) func TestGetServerAddresses(t *testing.T) { @@ -14,12 +14,8 @@ func TestGetServerAddresses(t *testing.T) { start := time.Now() u, err := browser.GetServerAddresses() diff := time.Since(start) - if err != nil { - t.Fatal(err) - } - if len(u) == 0 { - t.Errorf("found %d server addresses in %d milliseconds", len(u), diff.Milliseconds()) - } + require.NoError(t, err) + require.NotZero(t, len(u), "found %d server addresses in %d milliseconds", len(u), diff.Milliseconds()) } func TestGetServerInfos(t *testing.T) { @@ -28,9 +24,7 @@ func TestGetServerInfos(t *testing.T) { start := time.Now() u, err := browser.GetServerInfos() diff := time.Since(start) - if err != nil { - t.Fatal(err) - } + require.NoError(t, err) t.Logf("found %d server infos in %d milliseconds", len(u), diff.Milliseconds()) } @@ -41,7 +35,7 @@ func TestServerInfoOfSingleServer(t *testing.T) { u, err := browser.GetServerInfosOf(SimplyzCatch) diff := time.Since(start) require.NoError(t, err) - require.Len(t, u, 1) + require.Len(t, 1, u) t.Logf("found %d server infos in %d milliseconds", len(u), diff.Milliseconds()) } diff --git a/compression/huffman_test.go b/compression/huffman_test.go index 24ccded..ba88caa 100644 --- a/compression/huffman_test.go +++ b/compression/huffman_test.go @@ -7,8 +7,8 @@ import ( "testing" "github.com/jxsl13/twapi/compression" + "github.com/jxsl13/twapi/internal/testutils/require" "github.com/jxsl13/twapi/protocol" - "github.com/stretchr/testify/require" ) func FuzzNewHuffman(f *testing.F) { @@ -76,7 +76,7 @@ func TestHuffmanCompress(t *testing.T) { compressed := make([]byte, 1500) n, err := h.Compress(src, compressed) require.NoError(t, err) - require.Greater(t, n, 0) + require.Greater(t, 0, n) } func TestHuffmanCompressDecompress(t *testing.T) { diff --git a/compression/packer_test.go b/compression/packer_test.go index 6909c89..0dcb1f7 100644 --- a/compression/packer_test.go +++ b/compression/packer_test.go @@ -5,7 +5,7 @@ import ( "sync" "testing" - "github.com/stretchr/testify/require" + "github.com/jxsl13/twapi/internal/testutils/require" ) func TestReset(t *testing.T) { @@ -99,5 +99,177 @@ func TestPackerAndUnpacker(t *testing.T) { } wg.Wait() +} + +// rest + +func TestUnpackRest(t *testing.T) { + u := NewUnpacker([]byte{0x01, 0xff, 0xaa}) + + { + got, err := u.NextInt() + require.NoError(t, err) + require.Equal(t, 1, got) + } + + { + want := []byte{0xff, 0xaa} + got := u.Bytes() + require.Equal(t, want, got) + } +} + +func TestUnpackClientInfo(t *testing.T) { + require := require.New(t) + u := NewUnpacker([]byte{ + 0x24, 0x00, 0x01, 0x00, 0x67, 0x6f, 0x70, 0x68, 0x65, 0x72, 0x00, + 0x00, 0x40, 0x67, 0x72, 0x65, 0x65, 0x6e, 0x73, 0x77, 0x61, 0x72, + 0x64, 0x00, 0x64, 0x75, 0x6f, 0x64, 0x6f, 0x6e, 0x6e, 0x79, 0x00, + 0x00, 0x73, 0x74, 0x61, 0x6e, 0x64, 0x61, 0x72, 0x64, 0x00, 0x73, + 0x74, 0x61, 0x6e, 0x64, 0x61, 0x72, 0x64, 0x00, 0x73, 0x74, 0x61, + 0x6e, 0x64, 0x61, 0x72, 0x64, 0x00, 0x01, 0x01, 0x00, 0x00, 0x00, + 0x00, 0x80, 0xfc, 0xaf, 0x05, 0xeb, 0x83, 0xd0, 0x0a, 0x80, 0xfe, + 0x07, 0x80, 0xfe, 0x07, 0x80, 0xfe, 0x07, 0x80, 0xfe, 0x07, 0x00, + }) + + { + // message id + want := 36 + got, err := u.NextInt() + require.NoError(err) + require.Equal(want, got) + + // client id + want = 0 + got, err = u.NextInt() + require.NoError(err) + require.Equal(want, got) + + _, err = u.NextBool() // Local bool + require.NoError(err) + _, err = u.NextInt() // Team int + require.NoError(err) + } + + { + // name + want := "gopher" + got, err := u.NextString() + require.NoError(err) + require.Equal(want, got) + + // clan + want = "" + got, err = u.NextString() + require.NoError(err) + require.Equal(want, got) + + } + + { + // country + want := -1 + got, err := u.NextInt() + require.NoError(err) + require.Equal(want, got) + } + + { + // body + want := "greensward" + got, err := u.NextString() + require.NoError(err) + require.Equal(want, got) + } +} + +// unpack with state + +func TestUnpackSimpleInts(t *testing.T) { + require := require.New(t) + u := NewUnpacker([]byte{0x01, 0x02, 0x03, 0x0f}) + + want := 1 + got, err := u.NextInt() + require.NoError(err) + require.Equal(want, got) + want = 2 + got, err = u.NextInt() + require.NoError(err) + require.Equal(want, got) + + want = 3 + got, err = u.NextInt() + require.NoError(err) + require.Equal(want, got) + + want = 15 + got, err = u.NextInt() + require.NoError(err) + require.Equal(want, got) +} + +func TestUnpackString(t *testing.T) { + require := require.New(t) + u := NewUnpacker([]byte{'f', 'o', 'o', 0x00}) + + want := "foo" + got, err := u.NextString() + require.NoError(err) + require.Equal(want, got) +} + +func TestUnpackTwoStrings(t *testing.T) { + require := require.New(t) + u := NewUnpacker([]byte{'f', 'o', 'o', 0x00, 'b', 'a', 'r', 0x00}) + + want := "foo" + got, err := u.NextString() + require.NoError(err) + require.Equal(want, got) + + want = "bar" + got, err = u.NextString() + require.NoError(err) + require.Equal(want, got) +} + +func TestUnpackMixed(t *testing.T) { + require := require.New(t) + u := NewUnpacker([]byte{0x0F, 0x0F, 'f', 'o', 'o', 0x00, 'b', 'a', 'r', 0x00, 0x01}) + + // ints + { + want := 15 + got, err := u.NextInt() + require.NoError(err) + require.Equal(want, got) + + want = 15 + got, err = u.NextInt() + require.NoError(err) + require.Equal(want, got) + } + + // strings + { + want := "foo" + got, err := u.NextString() + require.NoError(err) + require.Equal(want, got) + + want = "bar" + got, err = u.NextString() + require.NoError(err) + require.Equal(want, got) + } + + // ints + { + want := 1 + got, err := u.NextInt() + require.NoError(err) + require.Equal(want, got) + } } diff --git a/compression/unpacker.go b/compression/unpacker.go index 75c0842..d4e0198 100644 --- a/compression/unpacker.go +++ b/compression/unpacker.go @@ -7,6 +7,12 @@ import ( "io" ) +const ( + Sanitize SanitizeKind = 1 + SanitizeCC SanitizeKind = 2 + SanitizeSkipWhitespaces SanitizeKind = 4 +) + var ( // ErrNoDataToUnpack is returned if the compressed array does not have sufficient data to unpack ErrNoDataToUnpack = fmt.Errorf("%w: no data", io.EOF) @@ -14,10 +20,15 @@ var ( // ErrNotAString if no separator after a string is found, the string cannot be unpacked, as there is no string ErrNotAString = errors.New("could not unpack string: terminator not found") + // ErrNotABool is returned when the data is neither 0x00 nor 0x01 + ErrNotABool = errors.New("could not unpack bool: invald value") + // ErrNotEnoughData is used when the user tries to retrieve more data with NextBytes() than there is available. ErrNotEnoughData = errors.New("trying to read more data than available") ) +type SanitizeKind int + // NewUnpacker constructs a new Unpacker func NewUnpacker(data []byte) *Unpacker { return &Unpacker{data} @@ -29,6 +40,8 @@ type Unpacker struct { } // Reset resets the underlying byte slice to a new slice +// The slice that is passed to this method should not be used +// as the ownership has been passed to the unpacker. func (u *Unpacker) Reset(b []byte) { u.buffer = b } @@ -38,6 +51,23 @@ func (u *Unpacker) Size() int { return len(u.buffer) } +func (u *Unpacker) NextBool() (bool, error) { + if len(u.buffer) == 0 { + return false, ErrNoDataToUnpack + } + var b bool + switch u.buffer[0] { + case 0x00: + b = false + case 0x01: + b = true + default: + return false, ErrNotABool + } + u.buffer = u.buffer[1:] + return b, nil +} + // NextInt unpacks the next integer func (u *Unpacker) NextInt() (i int, err error) { i, n := Varint(u.buffer) @@ -50,8 +80,8 @@ func (u *Unpacker) NextInt() (i int, err error) { return i, nil } -// NextString unpacks the next string from the message -func (u *Unpacker) NextString() (s string, err error) { +// NextRawString unpacks the next string from the message without sanitizig it. +func (u *Unpacker) NextRawString() (s string, err error) { if len(u.buffer) == 0 { return "", ErrNoDataToUnpack } @@ -91,7 +121,7 @@ func (u *Unpacker) NextByte() (b byte, err error) { return result, nil } -// Bytes returns the not yet used bytes. +// Bytes returns the not yet consumed bytes. // This operation consumes the buffer leaving it empty func (u *Unpacker) Bytes() []byte { if len(u.buffer) == 0 { @@ -103,3 +133,69 @@ func (u *Unpacker) Bytes() []byte { u.buffer = u.buffer[:0] return result } + +// first byte of the current buffer +func (u *Unpacker) peekByte(offset int) (byte, error) { + if len(u.buffer) < offset+1 { + return 0, ErrNotEnoughData + } + return u.buffer[offset], nil +} + +func (u *Unpacker) NextSanitizedString(sanitizeType SanitizeKind) (string, error) { + + i := bytes.IndexByte(u.buffer, StringTerminator) + if i < 0 { + return "", ErrNotAString + } + + var ( + // reduce slice reallocations by approximating the size + // real size might be less due to sanitization + result = make([]byte, 0, i) + skipping = sanitizeType&SanitizeSkipWhitespaces != 0 + err error + b byte + index = -1 + ) + + for { + index++ + + b, err = u.peekByte(index) + if err != nil { + return "", err + } + if b == StringTerminator { + break + } + + if skipping { + if b == ' ' || b == '\t' || b == '\n' { + continue + } + skipping = false + } + + if sanitizeType&SanitizeCC != 0 { + if b < 32 { + b = ' ' + } + } else if sanitizeType&Sanitize != 0 { + if b < 32 && !(b == '\r') && !(b == '\n') && !(b == '\t') { + b = ' ' + } + } + + result = append(result, b) + } + + u.buffer = u.buffer[index+1:] + return string(result), nil +} + +// NextString unpacks the next string from the message +// and sanitizes it by replacing control characters with spaces. +func (u *Unpacker) NextString() (string, error) { + return u.NextSanitizedString(Sanitize) +} diff --git a/compression/varint_test.go b/compression/varint_test.go index 3d2bc16..40e451d 100644 --- a/compression/varint_test.go +++ b/compression/varint_test.go @@ -6,7 +6,7 @@ import ( "math" "testing" - "github.com/stretchr/testify/require" + "github.com/jxsl13/twapi/internal/testutils/require" ) func varIntWriteRead(t *testing.T, inNumber int, expectedBytes int) { @@ -16,7 +16,7 @@ func varIntWriteRead(t *testing.T, inNumber int, expectedBytes int) { written := PutVarint(buf, inNumber) require.Equal(expectedBytes, written) out, read := Varint(buf) - require.GreaterOrEqual(read, 1, "read must be at least 0") + require.GreaterOrEqual(1, read, "read must be at least 0") require.Equal(inNumber, out, "out == in") require.Equal(written, read, "read == written") // buf := buf[:written] @@ -50,7 +50,7 @@ func TestVarintExtensive(t *testing.T) { written = PutVarint(buf, in) out, read = Varint(buf) ) - require.GreaterOrEqual(read, 1, "read must be at least 1") + require.GreaterOrEqual(1, read, "read must be at least 1") require.Equal(in, out, "in/out") require.Equal(written, read, "written/read") } diff --git a/go.mod b/go.mod index 7266709..47581fa 100644 --- a/go.mod +++ b/go.mod @@ -2,14 +2,6 @@ module github.com/jxsl13/twapi go 1.21.6 -require ( - github.com/reiver/go-telnet v0.0.0-20180421082511-9ff0b2ab096e - github.com/stretchr/testify v1.9.0 -) +require github.com/reiver/go-telnet v0.0.0-20180421082511-9ff0b2ab096e -require ( - github.com/davecgh/go-spew v1.1.1 // indirect - github.com/pmezard/go-difflib v1.0.0 // indirect - github.com/reiver/go-oi v1.0.0 // indirect - gopkg.in/yaml.v3 v3.0.1 // indirect -) +require github.com/reiver/go-oi v1.0.0 // indirect diff --git a/go.sum b/go.sum index b71b7fe..4b7b7a4 100644 --- a/go.sum +++ b/go.sum @@ -1,14 +1,4 @@ -github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= -github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= -github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= -github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= github.com/reiver/go-oi v1.0.0 h1:nvECWD7LF+vOs8leNGV/ww+F2iZKf3EYjYZ527turzM= github.com/reiver/go-oi v1.0.0/go.mod h1:RrDBct90BAhoDTxB1fenZwfykqeGvhI6LsNfStJoEkI= github.com/reiver/go-telnet v0.0.0-20180421082511-9ff0b2ab096e h1:quuzZLi72kkJjl+f5AQ93FMcadG19WkS7MO6TXFOSas= github.com/reiver/go-telnet v0.0.0-20180421082511-9ff0b2ab096e/go.mod h1:+5vNVvEWwEIx86DB9Ke/+a5wBI464eDRo3eF0LcfpWg= -github.com/stretchr/testify v1.9.0 h1:HtqpIVDClZ4nwg75+f6Lvsy/wHu+3BoSGCbBAcpTsTg= -github.com/stretchr/testify v1.9.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY= -gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM= -gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= -gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= -gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= diff --git a/internal/testutils/require/compare.go b/internal/testutils/require/compare.go new file mode 100644 index 0000000..a1b3eff --- /dev/null +++ b/internal/testutils/require/compare.go @@ -0,0 +1,101 @@ +package require + +import ( + "cmp" + "fmt" + "reflect" + "testing" +) + +func GreaterOrEqual(t *testing.T, expected, actual any, msgAndArgs ...any) { + t.Helper() + + if compare(t, expected, actual) < 0 { + FailNow(t, fmt.Sprintf("expected: %v to be greater or equal to: %v", expected, actual), msgAndArgs...) + } +} + +func Greater(t *testing.T, expected, actual any, msgAndArgs ...any) { + t.Helper() + + if compare(t, expected, actual) <= 0 { + FailNow(t, fmt.Sprintf("expected: %v to be greater than: %v", expected, actual), msgAndArgs...) + } +} + +func LessOrEqual(t *testing.T, expected, actual any, msgAndArgs ...any) { + t.Helper() + + if compare(t, expected, actual) > 0 { + FailNow(t, fmt.Sprintf("expected: %v to be less or equal to: %v", expected, actual), msgAndArgs...) + } +} + +func Less(t *testing.T, expected, actual any, msgAndArgs ...any) { + t.Helper() + + if compare(t, expected, actual) >= 0 { + FailNow(t, fmt.Sprintf("expected: %v to be less than: %v", expected, actual), msgAndArgs...) + } +} + +func compare(t *testing.T, expected, actual any) int { + t.Helper() + + e := reflect.ValueOf(expected) + a := reflect.ValueOf(actual) + + if e.Kind() != a.Kind() { + FailNow(t, "type mismatch: expected %T, got %T", expected, actual) + } + + if !e.Comparable() { + FailNow(t, "expected value is not comparable") + } + + if !a.Comparable() { + FailNow(t, "actual value is not comparable") + } + + if e.Kind() != a.Kind() { + FailNow(t, "type mismatch: expected %T, got %T", expected, actual) + } + + switch e.Kind() { + case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: + { + ev := e.Convert(reflect.TypeOf(int64(0))).Interface().(int64) + av := a.Convert(reflect.TypeOf(int64(0))).Interface().(int64) + return cmp.Compare(av, ev) + } + + case reflect.Uint8, reflect.Uint, reflect.Uint16, reflect.Uint32, reflect.Uint64: + { + ev := e.Convert(reflect.TypeOf(uint64(0))).Interface().(uint64) + av := a.Convert(reflect.TypeOf(uint64(0))).Interface().(uint64) + return cmp.Compare(av, ev) + } + + case reflect.Float32, reflect.Float64: + { + ev := e.Convert(reflect.TypeOf(float64(0))).Interface().(float64) + av := a.Convert(reflect.TypeOf(float64(0))).Interface().(float64) + return cmp.Compare(av, ev) + } + case reflect.String: + { + ev := e.Convert(reflect.TypeOf(string(""))).Interface().(string) + av := a.Convert(reflect.TypeOf(string(""))).Interface().(string) + return cmp.Compare(av, ev) + } + case reflect.Uintptr: + { + ev := e.Convert(reflect.TypeOf(uintptr(0))).Interface().(uintptr) + av := a.Convert(reflect.TypeOf(uintptr(0))).Interface().(uintptr) + return cmp.Compare(av, ev) + } + } + + FailNow(t, "type not supported: %T", expected) + return 0 // should not be reached +} diff --git a/internal/testutils/require/fail.go b/internal/testutils/require/fail.go new file mode 100644 index 0000000..1da48f2 --- /dev/null +++ b/internal/testutils/require/fail.go @@ -0,0 +1,31 @@ +package require + +import ( + "strings" + "testing" +) + +func FailNow(t *testing.T, errMsg string, msgAndArgs ...any) { + t.Helper() + + labeledMessages := labeledMessages{ + { + label: "Error Trace", + message: strings.Join(CallStack(), "\n\t\t\t"), + }, + { + label: "Error", + message: errMsg, + }, + } + + message := msgOrFmtMsg(msgAndArgs...) + if len(message) > 0 { + labeledMessages = append(labeledMessages, labeledMessage{ + label: "Message", + message: message, + }) + } + + t.Fatal(labeledMessages.String()) +} diff --git a/internal/testutils/require/format.go b/internal/testutils/require/format.go new file mode 100644 index 0000000..fb4caeb --- /dev/null +++ b/internal/testutils/require/format.go @@ -0,0 +1,71 @@ +package require + +import ( + "bufio" + "fmt" + "strings" +) + +func msgOrFmtMsg(msgAndArgs ...any) string { + if len(msgAndArgs) == 0 || msgAndArgs == nil { + return "" + } + if len(msgAndArgs) == 1 { + msg := msgAndArgs[0] + if msgAsStr, ok := msg.(string); ok { + return msgAsStr + } + return fmt.Sprintf("%+v", msg) + } + if len(msgAndArgs) > 1 { + return fmt.Sprintf(msgAndArgs[0].(string), msgAndArgs[1:]...) + } + return "" +} + +type labeledMessage struct { + label string + message string +} + +type labeledMessages []labeledMessage + +func (lm labeledMessages) String() string { + longestLabel := 0 + numLabels := len(lm) + msgSizeTotal := 0 + for _, v := range lm { + if len(v.label) > longestLabel { + longestLabel = len(v.label) + } + msgSizeTotal += len(v.message) + } + + var sb strings.Builder + sb.Grow(msgSizeTotal + numLabels*(longestLabel+8)) + sb.WriteString("\n") + + for _, v := range lm { + sb.WriteString("\t") + sb.WriteString(v.label) + sb.WriteString(":") + sb.WriteString(strings.Repeat(" ", longestLabel-len(v.label))) + sb.WriteString("\t") + + // indent lines + for i, scanner := 0, bufio.NewScanner(strings.NewReader(v.message)); scanner.Scan(); i++ { + // no need to align first line because it starts at the correct location (after the label) + if i != 0 { + // append alignLen+1 spaces to align with "{{longestLabel}}:" before adding tab + sb.WriteString("\n\t") + sb.WriteString(strings.Repeat(" ", longestLabel+1)) + sb.WriteString("\t") + } + // write line + sb.WriteString(scanner.Text()) + } + sb.WriteString("\n") + } + + return sb.String() +} diff --git a/internal/testutils/require/helpers.go b/internal/testutils/require/helpers.go new file mode 100644 index 0000000..3fe65c1 --- /dev/null +++ b/internal/testutils/require/helpers.go @@ -0,0 +1,67 @@ +package require + +import ( + "fmt" + "runtime" + "strings" +) + +/* CallStack is necessary because the assert functions use the testing object +internally, causing it to print the file:line of the assert method, rather than where +the problem actually occurred in calling code.*/ + +// CallStack returns an array of strings containing the file and line number +// of each stack frame leading from the current test to the assert call that +// failed. +func CallStack() []string { + + var ( + pc uintptr + ok bool + file string + line int + name string + ) + + callers := []string{} + for i := 0; ; i++ { + pc, file, line, ok = runtime.Caller(i) + if !ok { + break + } + + if file == "" { + break + } + + f := runtime.FuncForPC(pc) + if f == nil { + break + } + name = f.Name() + + // testing.tRunner is the standard library function that calls tests. + if name == "testing.tRunner" { + break + } + + parts := strings.Split(file, "/") + if len(parts) > 1 { + dir := parts[len(parts)-2] + if dir != "require" { + callers = append(callers, fmt.Sprintf("%s:%d", file, line)) + } + } + + // Drop this package + segments := strings.Split(name, ".") + name = segments[len(segments)-1] + if strings.HasPrefix(name, "Test") || + strings.HasPrefix(name, "Benchmark") || + strings.HasPrefix(name, "Example") { + break + } + } + + return callers +} diff --git a/internal/testutils/require/require.go b/internal/testutils/require/require.go new file mode 100644 index 0000000..7bb9f73 --- /dev/null +++ b/internal/testutils/require/require.go @@ -0,0 +1,176 @@ +package require + +import ( + "errors" + "fmt" + "reflect" + "testing" +) + +func New(t *testing.T) *Require { + return &Require{ + t: t, + } +} + +type Require struct { + t *testing.T +} + +func (r *Require) Equal(expected, actual any, msgAndArgs ...any) { + r.t.Helper() + Equal(r.t, expected, actual, msgAndArgs...) +} + +func (r *Require) NoError(err error, msgAndArgs ...any) { + r.t.Helper() + NoError(r.t, err, msgAndArgs...) +} + +func (r *Require) Error(err error, msgAndArgs ...any) { + r.t.Helper() + Error(r.t, err, msgAndArgs...) +} + +func (r *Require) ErrorIs(expected, actual error, msgAndArgs ...any) { + r.t.Helper() + ErrorIs(r.t, expected, actual, msgAndArgs...) +} + +func (r *Require) NotNil(a any, msgAndArgs ...any) { + r.t.Helper() + NotNil(r.t, a, msgAndArgs...) +} + +func (r *Require) Nil(a any, msgAndArgs ...any) { + r.t.Helper() + Nil(r.t, a, msgAndArgs...) +} + +func (r *Require) GreaterOrEqual(expected, actual any, msgAndArgs ...any) { + r.t.Helper() + GreaterOrEqual(r.t, expected, actual, msgAndArgs...) +} +func (r *Require) Greater(expected, actual any, msgAndArgs ...any) { + r.t.Helper() + Greater(r.t, expected, actual, msgAndArgs...) +} +func (r *Require) LessOrEqual(expected, actual any, msgAndArgs ...any) { + r.t.Helper() + LessOrEqual(r.t, expected, actual, msgAndArgs...) +} +func (r *Require) Less(expected, actual any, msgAndArgs ...any) { + r.t.Helper() + Less(r.t, expected, actual, msgAndArgs...) +} + +func (r *Require) Zero(a any, msgAndArgs ...any) { + r.t.Helper() + Zero(r.t, a, msgAndArgs...) +} + +func (r *Require) NotZero(a any, msgAndArgs ...any) { + r.t.Helper() + NotZero(r.t, a, msgAndArgs...) +} + +func (r *Require) Len(expected int, s any, msgAndArgs ...any) { + r.t.Helper() + Len(r.t, expected, s, msgAndArgs...) +} + +func Equal(t *testing.T, expected, actual any, msgAndArgs ...any) { + t.Helper() + if !reflect.DeepEqual(expected, actual) { + FailNow(t, fmt.Sprintf("expected: %v, got: %v", expected, actual), msgAndArgs...) + } +} + +func NoError(t *testing.T, err error, msgAndArgs ...any) { + t.Helper() + if err != nil { + FailNow(t, fmt.Sprintf("expected no error, got: %v", err), msgAndArgs...) + } +} + +func Error(t *testing.T, err error, msgAndArgs ...any) { + t.Helper() + if err != nil { + return + } + FailNow(t, "expected error, got nil", msgAndArgs...) +} + +func ErrorIs(t *testing.T, expected, actual error, msgAndArgs ...any) { + t.Helper() + if errors.Is(actual, expected) { + return + } + FailNow(t, fmt.Sprintf("expected error: %v, got: %v", expected, actual), msgAndArgs...) +} + +func NotNil(t *testing.T, a any, msgAndArgs ...any) { + t.Helper() + if a != nil { + return + } + FailNow(t, "expected not nil, got nil", msgAndArgs...) +} + +func Nil(t *testing.T, a any, msgAndArgs ...any) { + t.Helper() + if a != nil { + FailNow(t, "expected nil, got not %v", append([]any{a}, msgAndArgs...)...) + } +} + +func True(t *testing.T, b bool, msgAndArgs ...any) { + t.Helper() + if b { + return + } + FailNow(t, "expected true, got false", msgAndArgs...) +} + +func False(t *testing.T, b bool, msgAndArgs ...any) { + t.Helper() + if !b { + return + } + FailNow(t, "expected false, got true", msgAndArgs...) +} + +func Zero(t *testing.T, a any, msgAndArgs ...any) { + t.Helper() + if reflect.ValueOf(a).IsZero() { + return + } + FailNow(t, fmt.Sprintf("expected zero value, got %v", a), msgAndArgs...) +} + +func NotZero(t *testing.T, a any, msgAndArgs ...any) { + t.Helper() + if !reflect.ValueOf(a).IsZero() { + return + } + FailNow(t, "expected not zero value, got zero", msgAndArgs...) +} + +func Len(t *testing.T, expected int, s any, msgAndArgs ...any) { + t.Helper() + vs := reflect.ValueOf(s) + + if vs.Kind() == reflect.Pointer { + vs = vs.Elem() + } + + switch vs.Kind() { + case reflect.Array, reflect.Chan, reflect.Map, reflect.Slice, reflect.String: + actual := vs.Len() + if expected != actual { + FailNow(t, fmt.Sprintf("expected length %d, got %d", expected, actual), msgAndArgs...) + } + default: + FailNow(t, fmt.Sprintf("expected array, chan, map, slice or string, got %T", s), msgAndArgs...) + } +}