From 91ef419cd95a615b1e0f9905ce41bbeccb021081 Mon Sep 17 00:00:00 2001 From: Lz Date: Fri, 19 Apr 2024 10:32:15 +0800 Subject: [PATCH] fix(query): added LimitedResult/LimitedQuery --- limited_query.go | 25 ++++++++++++++++ limited_result.go | 7 +++++ query.go | 37 +++++++++++++++++------- query_option.go | 6 ++-- queryer.go | 16 ++++++++--- queryer_mapr.go | 25 ++++++++++------ sqlbuilder.go | 35 ++--------------------- sqlbuilder_orderby.go | 21 ++++++++++++++ sqlbuilder_orderby_test.go | 14 +++++++++ sqlbuilder_where.go | 58 ++++++++++++++++++++++++++++++-------- token.go | 14 +++++++++ 11 files changed, 188 insertions(+), 70 deletions(-) create mode 100644 limited_query.go create mode 100644 limited_result.go diff --git a/limited_query.go b/limited_query.go new file mode 100644 index 0000000..11b70c2 --- /dev/null +++ b/limited_query.go @@ -0,0 +1,25 @@ +package sqle + +// LimitedQueryOption is a function type that modifies a LimitedQuery. +type LimitedQueryOption func(q *LimitedQuery) + +// LimitedQuery represents a query with pagination and ordering options. +type LimitedQuery struct { + Offset int64 // PageIndex represents the index of the page to retrieve. + Limit int64 // PageSize represents the number of items per page. + OrderBy *OrderByBuilder // OrderBy represents the ordering of the query. +} + +// WithPageSize is a LimitedQueryOption that sets the page size of the LimitedQuery. +func WithPageSize(size int64) LimitedQueryOption { + return func(q *LimitedQuery) { + q.Limit = size + } +} + +// WithOrderBy is a LimitedQueryOption that sets the ordering of the LimitedQuery. +func WithOrderBy(ob *OrderByBuilder) LimitedQueryOption { + return func(q *LimitedQuery) { + q.OrderBy = ob + } +} diff --git a/limited_result.go b/limited_result.go new file mode 100644 index 0000000..708127b --- /dev/null +++ b/limited_result.go @@ -0,0 +1,7 @@ +package sqle + +// LimitedResult represents a limited result set with items and total count. +type LimitedResult[T any] struct { + Items []T `json:"items,omitempty"` // Items contains the limited result items. + Total int64 `json:"total,omitempty"` // Total represents the total count. +} diff --git a/query.go b/query.go index 4903398..8a97b03 100644 --- a/query.go +++ b/query.go @@ -14,12 +14,14 @@ func (e *Errors) Error() string { } type Query[T any] struct { - db *DB - queryer Queryer[T] - tables []string + db *DB + queryer Queryer[T] + withRotatedTables []string } -// NewQuery create a Query +// NewQuery creates a new Query instance. +// It takes a *DB as the first argument and optional QueryOption functions as the rest. +// It returns a pointer to the created Query instance. func NewQuery[T any](db *DB, options ...QueryOption[T]) *Query[T] { q := &Query[T]{ db: db, @@ -31,8 +33,8 @@ func NewQuery[T any](db *DB, options ...QueryOption[T]) *Query[T] { } } - if q.tables == nil { - q.tables = []string{""} + if q.withRotatedTables == nil { + q.withRotatedTables = []string{""} } if q.queryer == nil { @@ -44,18 +46,33 @@ func NewQuery[T any](db *DB, options ...QueryOption[T]) *Query[T] { return q } +// First executes the query and returns the first result. +// It takes a context.Context and a *Builder as arguments. +// It returns the result of type T and an error, if any. func (q *Query[T]) First(ctx context.Context, b *Builder) (T, error) { - return q.queryer.First(ctx, q.tables, b) + return q.queryer.First(ctx, q.withRotatedTables, b) } +// Count executes the query and returns the number of results. +// It takes a context.Context and a *Builder as arguments. +// It returns the count as an integer and an error, if any. func (q *Query[T]) Count(ctx context.Context, b *Builder) (int, error) { - return q.queryer.Count(ctx, q.tables, b) + return q.queryer.Count(ctx, q.withRotatedTables, b) } +// Query executes the query and returns all the results. +// It takes a context.Context, a *Builder, and a comparison function as arguments. +// The comparison function is used to sort the results. +// It returns a slice of results of type T and an error, if any. func (q *Query[T]) Query(ctx context.Context, b *Builder, less func(i, j T) bool) ([]T, error) { - return q.queryer.Query(ctx, q.tables, b, less) + return q.queryer.Query(ctx, q.withRotatedTables, b, less) } +// QueryLimit executes the query and returns a limited number of results. +// It takes a context.Context, a *Builder, a comparison function, and a limit as arguments. +// The comparison function is used to sort the results. +// The limit specifies the maximum number of results to return. +// It returns a slice of results of type T and an error, if any. func (q *Query[T]) QueryLimit(ctx context.Context, b *Builder, less func(i, j T) bool, limit int) ([]T, error) { - return q.queryer.QueryLimit(ctx, q.tables, b, less, limit) + return q.queryer.QueryLimit(ctx, q.withRotatedTables, b, less, limit) } diff --git a/query_option.go b/query_option.go index 4fe94ef..9639c63 100644 --- a/query_option.go +++ b/query_option.go @@ -11,7 +11,7 @@ type QueryOption[T any] func(q *Query[T]) func WithMonths[T any](start, end time.Time) QueryOption[T] { return func(q *Query[T]) { for t := start; !t.After(end); t = t.AddDate(0, 1, 0) { - q.tables = append(q.tables, shardid.FormatMonth(t)) + q.withRotatedTables = append(q.withRotatedTables, shardid.FormatMonth(t)) } } } @@ -19,7 +19,7 @@ func WithMonths[T any](start, end time.Time) QueryOption[T] { func WithWeeks[T any](start, end time.Time) QueryOption[T] { return func(q *Query[T]) { for t := start; !t.After(end); t = t.AddDate(0, 0, 7) { - q.tables = append(q.tables, shardid.FormatWeek(t)) + q.withRotatedTables = append(q.withRotatedTables, shardid.FormatWeek(t)) } } } @@ -27,7 +27,7 @@ func WithWeeks[T any](start, end time.Time) QueryOption[T] { func WithDays[T any](start, end time.Time) QueryOption[T] { return func(q *Query[T]) { for t := start; !t.After(end); t = t.AddDate(0, 0, 1) { - q.tables = append(q.tables, shardid.FormatDay(t)) + q.withRotatedTables = append(q.withRotatedTables, shardid.FormatDay(t)) } } } diff --git a/queryer.go b/queryer.go index 3038e05..2766867 100644 --- a/queryer.go +++ b/queryer.go @@ -2,9 +2,17 @@ package sqle import "context" +// Queryer is a query provider interface that defines methods for querying data. type Queryer[T any] interface { - First(ctx context.Context, tables []string, b *Builder) (T, error) - Count(ctx context.Context, tables []string, b *Builder) (int, error) - Query(ctx context.Context, tables []string, b *Builder, less func(i, j T) bool) ([]T, error) - QueryLimit(ctx context.Context, tables []string, b *Builder, less func(i, j T) bool, limit int) ([]T, error) + // First retrieves the first result that matches the query criteria. + First(ctx context.Context, rotatedTables []string, b *Builder) (T, error) + + // Count returns the number of results that match the query criteria. + Count(ctx context.Context, rotatedTables []string, b *Builder) (int, error) + + // Query retrieves all results that match the query criteria and sorts them using less function if it is provided. + Query(ctx context.Context, rotatedTables []string, b *Builder, less func(i, j T) bool) ([]T, error) + + // QueryLimit retrieves a limited number of results that match the query criteria, sorts them using the provided less function, and limits the number of results to the specified limit. + QueryLimit(ctx context.Context, rotatedTables []string, b *Builder, less func(i, j T) bool, limit int) ([]T, error) } diff --git a/queryer_mapr.go b/queryer_mapr.go index 92d62fb..25fd6cd 100644 --- a/queryer_mapr.go +++ b/queryer_mapr.go @@ -9,12 +9,13 @@ import ( "github.com/yaitoo/async" ) -// MapR Map/Reduce Query +// MapR is a Map/Reduce Query Provider based on databases. type MapR[T any] struct { dbs []*Context } -func (q *MapR[T]) First(ctx context.Context, tables []string, b *Builder) (T, error) { +// First executes the query and returns the first result. +func (q *MapR[T]) First(ctx context.Context, rotatedTables []string, b *Builder) (T, error) { var it T b.Input("rotate", "") // lazy replace on async.Wait query, args, err := b.Build() @@ -24,7 +25,7 @@ func (q *MapR[T]) First(ctx context.Context, tables []string, b *Builder) (T, er w := async.New[T]() - for _, r := range tables { + for _, r := range rotatedTables { qr := strings.ReplaceAll(query, "", r) for _, db := range q.dbs { w.Add(func(db *Context, qr string) func(context.Context) (T, error) { @@ -44,7 +45,9 @@ func (q *MapR[T]) First(ctx context.Context, tables []string, b *Builder) (T, er d, _, err := w.WaitAny(ctx) return d, err } -func (q *MapR[T]) Count(ctx context.Context, tables []string, b *Builder) (int, error) { + +// Count executes the query and returns the count of results. +func (q *MapR[T]) Count(ctx context.Context, rotatedTables []string, b *Builder) (int, error) { b.Input("rotate", "") // lazy replace on async.Wait query, args, err := b.Build() if err != nil { @@ -53,7 +56,7 @@ func (q *MapR[T]) Count(ctx context.Context, tables []string, b *Builder) (int, w := async.New[int]() - for _, r := range tables { + for _, r := range rotatedTables { qr := strings.ReplaceAll(query, "", r) for _, db := range q.dbs { w.Add(func(db *Context, qr string) func(context.Context) (int, error) { @@ -84,7 +87,9 @@ func (q *MapR[T]) Count(ctx context.Context, tables []string, b *Builder) (int, return total, nil } -func (q *MapR[T]) Query(ctx context.Context, tables []string, b *Builder, less func(i, j T) bool) ([]T, error) { + +// Query executes the query and returns a list of results. +func (q *MapR[T]) Query(ctx context.Context, rotatedTables []string, b *Builder, less func(i, j T) bool) ([]T, error) { b.Input("rotate", "") // lazy replace on async.Wait query, args, err := b.Build() @@ -94,7 +99,7 @@ func (q *MapR[T]) Query(ctx context.Context, tables []string, b *Builder, less f w := async.New[[]T]() - for _, r := range tables { + for _, r := range rotatedTables { qr := strings.ReplaceAll(query, "", r) for _, db := range q.dbs { w.Add(func(db *Context, qr string) func(context.Context) ([]T, error) { @@ -138,13 +143,15 @@ func (q *MapR[T]) Query(ctx context.Context, tables []string, b *Builder, less f return list, nil } -func (q *MapR[T]) QueryLimit(ctx context.Context, tables []string, b *Builder, less func(i, j T) bool, limit int) ([]T, error) { + +// QueryLimit executes the query and returns a limited list of results. +func (q *MapR[T]) QueryLimit(ctx context.Context, rotatedTables []string, b *Builder, less func(i, j T) bool, limit int) ([]T, error) { if limit > 0 { b.SQL(" LIMIT " + strconv.Itoa(limit*len(q.dbs))) } - list, err := q.Query(ctx, tables, b, less) + list, err := q.Query(ctx, rotatedTables, b, less) if err != nil { return nil, err } diff --git a/sqlbuilder.go b/sqlbuilder.go index 9b0c9de..2770820 100644 --- a/sqlbuilder.go +++ b/sqlbuilder.go @@ -146,38 +146,6 @@ func (b *Builder) Build() (string, []any, error) { } -// WithWhere adds the input and parameter values from the given WhereBuilder to the current Builder -// and sets the WHERE clause of the SQL statement to the string representation of the WhereBuilder's statement. -// It returns the modified WhereBuilder. -func (b *Builder) WithWhere(wb *WhereBuilder) *WhereBuilder { - for k, v := range wb.inputs { - b.Input(k, v) - } - - for k, v := range wb.params { - b.Param(k, v) - } - - return b.Where(strings.TrimSpace(wb.stmt.String())) -} - -// Where starts a new WhereBuilder and adds the given conditions to the current query builder. -// Returns the new WhereBuilder. -func (b *Builder) Where(cmd ...string) *WhereBuilder { - wb := &WhereBuilder{Builder: b} - - b.stmt.WriteString(" WHERE") - for _, it := range cmd { - if it != "" { - wb.written = true - b.stmt.WriteString(" ") - b.stmt.WriteString(it) - } - } - - return wb -} - // quoteColumn escapes the given column name using the Builder's Quote character. func (b *Builder) quoteColumn(c string) string { if strings.ContainsAny(c, "(") || strings.ContainsAny(c, " ") || strings.ContainsAny(c, "as") { @@ -243,7 +211,8 @@ func (b *Builder) On(id shardid.ID) *Builder { return b.Input("rotate", id.RotateName()) } -// sortColumns sorts the columns in the given map and returns them as a slice. +// sortColumns sorts the columns in the given map and returns them as a pre-sorted columns slice. +// It helps PrepareStmt works with sql statement as less as possible. // It also allows customization of column names using BuilderOptions. func sortColumns(m map[string]any, opts ...BuilderOption) []string { bo := &BuilderOptions{} diff --git a/sqlbuilder_orderby.go b/sqlbuilder_orderby.go index 771ecd8..4f5d91c 100644 --- a/sqlbuilder_orderby.go +++ b/sqlbuilder_orderby.go @@ -11,6 +11,27 @@ type OrderByBuilder struct { allowedColumns []string } +// NewOrderBy creates a new instance of the OrderByBuilder. +// It takes a variadic parameter `allowedColumns` which specifies the columns that are allowed to be used in the ORDER BY clause. +func NewOrderBy(allowedColumns ...string) *OrderByBuilder { + return &OrderByBuilder{ + Builder: New(), + allowedColumns: allowedColumns, + } +} + +// WithOrderBy sets the order by clause for the SQL query. +// It takes an instance of the OrderByBuilder and adds the allowed columns to the Builder's order list. +// It also appends the SQL string representation of the OrderByBuilder to the Builder's SQL string. +// It returns a new instance of the OrderByBuilder. +func (b *Builder) WithOrderBy(ob *OrderByBuilder) *OrderByBuilder { + n := b.Order(ob.allowedColumns...) + + b.SQL(ob.String()) + + return n +} + // Order create an OrderByBuilder with allowed columns to prevent sql injection. NB: any input is allowed if it is not provided func (b *Builder) Order(allowedColumns ...string) *OrderByBuilder { ob := &OrderByBuilder{ diff --git a/sqlbuilder_orderby_test.go b/sqlbuilder_orderby_test.go index e3c31f1..f4042a8 100644 --- a/sqlbuilder_orderby_test.go +++ b/sqlbuilder_orderby_test.go @@ -49,6 +49,20 @@ func TestOrderByBuilder(t *testing.T) { }, wanted: "SELECT * FROM users ORDER BY created_at DESC, id ASC, updated_at ASC", }, + { + name: "with_order_should_work", + build: func() *Builder { + b := New("SELECT * FROM users") + + ob := NewOrderBy("id", "created_at", "updated_at", "age") + ob.By("created_at desc, id, name asc, updated_at asc, age invalid_by, unsafe_asc, unsafe_desc desc") + + b.WithOrderBy(ob) + + return b + }, + wanted: "SELECT * FROM users ORDER BY created_at DESC, id ASC, updated_at ASC", + }, } for _, test := range tests { diff --git a/sqlbuilder_where.go b/sqlbuilder_where.go index 92c4d05..5f8f4df 100644 --- a/sqlbuilder_where.go +++ b/sqlbuilder_where.go @@ -1,5 +1,7 @@ package sqle +import "strings" + // WhereBuilder is a struct that represents a SQL WHERE clause builder. type WhereBuilder struct { *Builder @@ -14,6 +16,41 @@ func NewWhere() *WhereBuilder { } } +// Where adds a WHERE clause to the SQL statement. +// It takes one or more criteria strings as arguments. +// Each criteria string represents a condition in the WHERE clause. +// If a criteria string is empty, it will be ignored. +// Returns a *WhereBuilder that can be used to further build the SQL statement. +func (b *Builder) Where(criteria ...string) *WhereBuilder { + wb := &WhereBuilder{Builder: b} + + b.stmt.WriteString(" WHERE") + for _, it := range criteria { + if it != "" { + wb.written = true + b.stmt.WriteString(" ") + b.stmt.WriteString(it) + } + } + + return wb +} + +// WithWhere adds the input and parameter values from the given WhereBuilder to the current Builder +// and sets the WHERE clause of the SQL statement to the string representation of the WhereBuilder's statement. +// It returns the modified WhereBuilder. +func (b *Builder) WithWhere(wb *WhereBuilder) *WhereBuilder { + for k, v := range wb.inputs { + b.Input(k, v) + } + + for k, v := range wb.params { + b.Param(k, v) + } + + return b.Where(strings.TrimSpace(wb.stmt.String())) +} + // If sets a condition to skip the subsequent SQL statements. // If the predicate is false, the subsequent SQL statements will be skipped. func (wb *WhereBuilder) If(predicate bool) *WhereBuilder { @@ -21,25 +58,24 @@ func (wb *WhereBuilder) If(predicate bool) *WhereBuilder { return wb } -// And appends an AND operator and a command to the SQL statement. -func (wb *WhereBuilder) And(cmd string) *WhereBuilder { - return wb.SQL("AND", cmd) +// And adds an AND condition to the WHERE clause. +func (wb *WhereBuilder) And(criteria string) *WhereBuilder { + return wb.SQL("AND", criteria) } -// Or appends an OR operator and a command to the SQL statement. -func (wb *WhereBuilder) Or(cmd string) *WhereBuilder { - return wb.SQL("OR", cmd) +// Or adds an OR condition to the WHERE clause. +func (wb *WhereBuilder) Or(criteria string) *WhereBuilder { + return wb.SQL("OR", criteria) } -// SQL appends an operator and a command to the SQL statement. -// The operator is only appended if the command is not empty. -func (wb *WhereBuilder) SQL(op string, cmd string) *WhereBuilder { +// SQL adds a condition to the WHERE clause with the specified operator. +func (wb *WhereBuilder) SQL(op string, criteria string) *WhereBuilder { if wb.shouldSkip { wb.shouldSkip = false return wb } - if cmd != "" { + if criteria != "" { // first condition, op expression should not be written if wb.written { wb.Builder.stmt.WriteString(" ") @@ -48,7 +84,7 @@ func (wb *WhereBuilder) SQL(op string, cmd string) *WhereBuilder { wb.written = true wb.Builder.stmt.WriteString(" ") - wb.Builder.stmt.WriteString(cmd) + wb.Builder.stmt.WriteString(criteria) } return wb diff --git a/token.go b/token.go index d422b49..4d5950c 100644 --- a/token.go +++ b/token.go @@ -1,5 +1,6 @@ package sqle +// TokenType represents the type of a token. type TokenType uint const ( @@ -8,34 +9,47 @@ const ( ParamToken TokenType = 2 ) +// Token is an interface that represents a SQL token. type Token interface { Type() TokenType String() string } +// Text represents a text token. type Text string +// Type returns the type of the token. func (t Text) Type() TokenType { return TextToken } + +// String returns the string representation of the token. func (t Text) String() string { return string(t) } +// Input represents an input token. type Input string +// Type returns the type of the token. func (t Input) Type() TokenType { return InputToken } + +// String returns the string representation of the token. func (t Input) String() string { return string(t) } +// Param represents a parameter token. type Param string +// Type returns the type of the token. func (t Param) Type() TokenType { return ParamToken } + +// String returns the string representation of the token. func (t Param) String() string { return string(t) }