diff --git a/go.mod b/go.mod index 2a7ba53..f774f66 100644 --- a/go.mod +++ b/go.mod @@ -2,4 +2,9 @@ module github.com/radulucut/search go 1.22.1 -require github.com/google/go-cmp v0.6.0 +require ( + github.com/davecgh/go-spew v1.1.1 // indirect + github.com/pmezard/go-difflib v1.0.0 // indirect + github.com/stretchr/testify v1.9.0 + gopkg.in/yaml.v3 v3.0.1 // indirect +) diff --git a/go.sum b/go.sum index 5a8d551..60ce688 100644 --- a/go.sum +++ b/go.sum @@ -1,2 +1,10 @@ -github.com/google/go-cmp v0.6.0 h1:ofyhxvXcZhMsU5ulbFiLKl/XBFqE1GSq7atu8tAmTRI= -github.com/google/go-cmp v0.6.0/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY= +github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= +github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= +github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= +github.com/stretchr/testify v1.9.0 h1:HtqpIVDClZ4nwg75+f6Lvsy/wHu+3BoSGCbBAcpTsTg= +github.com/stretchr/testify v1.9.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY= +gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM= +gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= +gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= +gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= diff --git a/search.go b/search.go index 1bf80e2..e198855 100644 --- a/search.go +++ b/search.go @@ -1,7 +1,8 @@ package search import ( - "sort" + "math" + "slices" "sync" ) @@ -55,24 +56,30 @@ type itemScore struct { score int } +type SearchOptions struct { + Query string + Limit int + Offset int + Ignore []int64 +} + // Search finds the most similar items to the given query. // limit is the maximum number of items to return. // ignore is a list of item ids to ignore. -func (e *Engine) Search(query string, limit int, ignore []int64) []int64 { +func (e *Engine) Search(opts SearchOptions) []int64 { var ignoreMap map[int64]struct{} hasIgnore := false - if len(ignore) != 0 { + if len(opts.Ignore) != 0 { hasIgnore = true ignoreMap = make(map[int64]struct{}) - for i := range ignore { - ignoreMap[ignore[i]] = struct{}{} + for i := range opts.Ignore { + ignoreMap[opts.Ignore[i]] = struct{}{} } } - - q := e.tokenize(query) + q := e.tokenize(opts.Query) e.RLock() defer e.RUnlock() - scores := make([]itemScore, 0) + scores := make([]*itemScore, 0) for id := range e.items { if hasIgnore { if _, ok := ignoreMap[id]; ok { @@ -83,18 +90,26 @@ func (e *Engine) Search(query string, limit int, ignore []int64) []int64 { if score == -1 { continue } - scores = append(scores, itemScore{id: id, score: score}) + scores = append(scores, &itemScore{id: id, score: score}) } - sort.Slice(scores, func(i, j int) bool { - if scores[i].score == scores[j].score { - return scores[i].id > scores[j].id - } else { - return scores[i].score < scores[j].score + slices.SortFunc(scores, func(a, b *itemScore) int { + if a.score < b.score { + return -1 + } + if a.score > b.score { + return 1 + } + if a.id > b.id { + return -1 + } + if a.id < b.id { + return 1 } + return 0 }) - limit = min(limit, len(scores)) + limit := min(opts.Offset+opts.Limit, len(scores)) res := make([]int64, 0, limit) - for i := 0; i < limit; i++ { + for i := opts.Offset; i < limit; i++ { res = append(res, scores[i].id) } return res @@ -104,7 +119,7 @@ func (e *Engine) score(q, b [][]rune) int { var score int skip := true for i := range q { - best := (1<<63 - 1) + best := math.MaxInt for j := range b { best = min(best, LevenshteinDistance(q[i], b[j])) } diff --git a/search_test.go b/search_test.go index bbdd5d9..04023d4 100644 --- a/search_test.go +++ b/search_test.go @@ -3,7 +3,7 @@ package search import ( "testing" - "github.com/google/go-cmp/cmp" + "github.com/stretchr/testify/assert" ) type Book struct { @@ -45,41 +45,51 @@ func Test_Engine(t *testing.T) { {"spânzuraţilor", []int64{2}}, {"amintiri din copilărie", []int64{8, 11, 10, 5, 15}}, {"xyz zyx", []int64{}}, + {"din", []int64{11, 8, 15, 14, 13}}, } for _, test := range tests { t.Run(test.query, func(t *testing.T) { - actual := engine.Search(test.query, 5, nil) - if diff := cmp.Diff(test.expected, actual); diff != "" { - t.Errorf("mismatch (-want +got):\n%s", diff) - } + actual := engine.Search(SearchOptions{Query: test.query, Limit: 5}) + assert.Equal(t, test.expected, actual) }) } + t.Run("offset", func(t *testing.T) { + actual := engine.Search(SearchOptions{ + Query: "de", + Limit: 5, + Offset: 5, + Ignore: []int64{15}, + }) + assert.Equal(t, []int64{9, 8, 7, 6, 5}, actual) + }) + t.Run("Ignore ids", func(t *testing.T) { - actual := engine.Search("maitreyi", 5, []int64{4}) - expected := []int64{} - if diff := cmp.Diff(expected, actual); diff != "" { - t.Errorf("mismatch (-want +got):\n%s", diff) - } + actual := engine.Search(SearchOptions{ + Query: "maitreyi", + Limit: 5, + Ignore: []int64{4}, + }) + assert.ElementsMatch(t, []int64{}, actual) }) engine.SetItem(16, "Ciocoii vechi și noi de Nicolae Filimon") t.Run("SetItem", func(t *testing.T) { - actual := engine.Search("Ciocoii vechi", 5, nil) - expected := []int64{16} - if diff := cmp.Diff(expected, actual); diff != "" { - t.Errorf("mismatch (-want +got):\n%s", diff) - } + actual := engine.Search(SearchOptions{ + Query: "Ciocoii vechi", + Limit: 5, + }) + assert.ElementsMatch(t, []int64{16}, actual) }) engine.DeleteItem(7) t.Run("DeleteItem", func(t *testing.T) { - actual := engine.Search("Moara", 5, nil) - expected := []int64{} - if diff := cmp.Diff(expected, actual); diff != "" { - t.Errorf("mismatch (-want +got):\n%s", diff) - } + actual := engine.Search(SearchOptions{ + Query: "Moara", + Limit: 5, + }) + assert.ElementsMatch(t, []int64{}, actual) }) } @@ -94,9 +104,7 @@ func Test_Tokenize(t *testing.T) { {'4'}, {'a', 'a', 'a', 'a', 'i', 'i', 's', 's', 's', 's', 't', 't', 't', 't'}, } - if diff := cmp.Diff(expected, tokens); diff != "" { - t.Errorf("mismatch (-want +got):\n%s", diff) - } + assert.Equal(t, expected, tokens) } func Test_LevenshteinDistance(t *testing.T) { @@ -121,9 +129,7 @@ func Test_LevenshteinDistance(t *testing.T) { } for _, test := range tests { t.Run("LevenshteinDistance", func(t *testing.T) { - if diff := cmp.Diff(test.expected, LevenshteinDistance(test.a, test.b)); diff != "" { - t.Errorf("mismatch (-want +got):\n%s", diff) - } + assert.Equal(t, test.expected, LevenshteinDistance(test.a, test.b)) }) } }