Skip to content

Commit

Permalink
fix(sqlbuilder): fixed WithWhere/WithOrderBy for empty builder (#39)
Browse files Browse the repository at this point in the history
  • Loading branch information
cnlangzi authored Apr 20, 2024
1 parent 665c28e commit d72e331
Show file tree
Hide file tree
Showing 4 changed files with 71 additions and 22 deletions.
47 changes: 27 additions & 20 deletions sqlbuilder_orderby.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,12 @@ import (
"strings"
)

// OrderByBuilder represents a SQL ORDER BY clause builder.
// It is used to construct ORDER BY clauses for SQL queries.
type OrderByBuilder struct {
*Builder
isWritten bool
allowedColumns []string
*Builder // The underlying SQL query builder.
written bool // Indicates if the ORDER BY clause has been written.
allowedColumns []string // The list of allowed columns for ordering.
}

// NewOrderBy creates a new instance of the OrderByBuilder.
Expand Down Expand Up @@ -43,8 +45,6 @@ func (b *Builder) Order(allowedColumns ...string) *OrderByBuilder {
allowedColumns: allowedColumns,
}

b.SQL(" ORDER BY ")

return ob
}

Expand Down Expand Up @@ -89,29 +89,36 @@ func (ob *OrderByBuilder) By(raw string) *OrderByBuilder {
// ByAsc order by ascending with columns
func (ob *OrderByBuilder) ByAsc(columns ...string) *OrderByBuilder {
for _, c := range columns {
if ob.isAllowed(c) {
if ob.isWritten {
ob.Builder.SQL(", ").SQL(c).SQL(" ASC")
} else {
ob.Builder.SQL(c).SQL(" ASC")
ob.isWritten = true
}
}
ob.add(c, " ASC")
}
return ob
}

// ByDesc order by descending with columns
func (ob *OrderByBuilder) ByDesc(columns ...string) *OrderByBuilder {
for _, c := range columns {
if ob.isAllowed(c) {
if ob.isWritten {
ob.Builder.SQL(", ").SQL(c).SQL(" DESC")
} else {
ob.Builder.SQL(c).SQL(" DESC")
ob.isWritten = true
ob.add(c, " DESC")
}
return ob
}

// add adds a column and its sorting direction to the OrderByBuilder.
// It checks if the column is allowed and appends it to the SQL query.
// If the column has already been written, it appends a comma before adding the column.
// If it's the first column being added, it appends "ORDER BY" before adding the column.
func (ob *OrderByBuilder) add(col, direction string) {
if ob.isAllowed(col) {
if ob.written {
ob.Builder.SQL(", ").SQL(col).SQL(direction)
} else {
// only write once
if !ob.written {
ob.Builder.SQL(" ORDER BY ")
}

ob.Builder.SQL(col).SQL(direction)

ob.written = true
}
}
return ob
}
16 changes: 16 additions & 0 deletions sqlbuilder_orderby_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,22 @@ func TestOrderByBuilder(t *testing.T) {
},
wanted: "SELECT * FROM users ORDER BY id ASC, created_at DESC, updated_at ASC",
},
{
name: "with_empty_order_by_should_work",
build: func() *Builder {
b := New("SELECT * FROM users")

ob := NewOrderBy("age").
ByAsc("id", "name").
ByDesc("created_at", "unsafe_input").
ByAsc("updated_at")

b.WithOrderBy(ob)

return b
},
wanted: "SELECT * FROM users",
},
}

for _, test := range tests {
Expand Down
20 changes: 20 additions & 0 deletions sqlbuilder_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -207,6 +207,26 @@ func TestBuilder(t *testing.T) {
require.Equal(t, now, vars[2])
},
},
{
name: "build_with_empty_where",
build: func() *Builder {
b := New().Select("orders")
b.Where().
If(false).SQL("AND", "cancelled>={now}").
If(false).SQL("AND", "id={order_id}")
b.Param("order_id", 123456)
b.Param("now", now)

b.WithWhere(nil)
return b
},
assert: func(t *testing.T, b *Builder) {
s, vars, err := b.Build()
require.NoError(t, err)
require.Equal(t, "SELECT * FROM `orders`", s)
require.Len(t, vars, 0)
},
},
{
name: "build_update",
build: func() *Builder {
Expand Down
10 changes: 8 additions & 2 deletions sqlbuilder_where.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,10 +24,14 @@ func NewWhere() *WhereBuilder {
func (b *Builder) Where(criteria ...string) *WhereBuilder {
wb := &WhereBuilder{Builder: b}

b.stmt.WriteString(" WHERE")
for _, it := range criteria {
if it != "" {

if !wb.written {
b.stmt.WriteString(" WHERE")
}
wb.written = true

b.stmt.WriteString(" ")
b.stmt.WriteString(it)
}
Expand All @@ -52,7 +56,7 @@ func (b *Builder) WithWhere(wb *WhereBuilder) *WhereBuilder {
b.Param(k, v)
}

return b.Where(strings.TrimSpace(wb.stmt.String()))
return b.Where(strings.TrimPrefix(wb.stmt.String(), " WHERE "))
}

// If sets a condition to skip the subsequent SQL statements.
Expand Down Expand Up @@ -84,6 +88,8 @@ func (wb *WhereBuilder) SQL(op string, criteria string) *WhereBuilder {
if wb.written {
wb.Builder.stmt.WriteString(" ")
wb.Builder.stmt.WriteString(op)
} else {
wb.Builder.stmt.WriteString(" WHERE")
}

wb.written = true
Expand Down

0 comments on commit d72e331

Please sign in to comment.