Skip to content

Commit

Permalink
fix(sqlbuilder): added WithWhere and NewWhere (#35)
Browse files Browse the repository at this point in the history
  • Loading branch information
cnlangzi authored Apr 18, 2024
1 parent c2510fe commit 7457db4
Show file tree
Hide file tree
Showing 5 changed files with 104 additions and 13 deletions.
6 changes: 6 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,12 @@ All notable changes to this project will be documented in this file.
The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),
and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html).



## [1.4.4] - 2014-04-19
### Added
- added `NewWhere` and `WithWhere` (#35)

## [1.4.3] - 2014-04-12
### Fixes
- fixed close issue when it fails to build prepareStmt (#33)
Expand Down
12 changes: 7 additions & 5 deletions db.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ var (
ErrMissingDHT = errors.New("sqle: missing_dht")
)

// DB represents a database connection pool with sharding support.
type DB struct {
*Context
_ noCopy //nolint: unused
Expand All @@ -23,6 +24,7 @@ type DB struct {
dbs []*Context
}

// Open creates a new DB instance with the provided database connections.
func Open(dbs ...*sql.DB) *DB {
d := &DB{
Context: &Context{
Expand All @@ -47,7 +49,7 @@ func Open(dbs ...*sql.DB) *DB {
return d
}

// Add dynamically scale out DB with new databases
// Add dynamically scales out the DB with new databases.
func (db *DB) Add(dbs ...*sql.DB) {
db.Lock()
defer db.Unlock()
Expand All @@ -65,30 +67,31 @@ func (db *DB) Add(dbs ...*sql.DB) {
}
}

// On select database from shardid.ID
// On selects the database context based on the shardid ID.
func (db *DB) On(id shardid.ID) *Context {
db.mu.RLock()
defer db.mu.RUnlock()

return db.dbs[int(id.DatabaseID)]
}

// NewDHT create new DTH with databases
// NewDHT creates a new DHT (Distributed Hash Table) with the specified databases.
func (db *DB) NewDHT(name string, dbs ...int) {
db.mu.Lock()
defer db.mu.Unlock()

db.dhts[name] = shardid.NewDHT(dbs...)
}

// GetDHT returns the DHT (Distributed Hash Table) with the specified name.
func (db *DB) GetDHT(name string) *shardid.DHT {
db.mu.RLock()
defer db.mu.RUnlock()

return db.dhts[name]
}

// OnDHT select database from DHT
// OnDHT selects the database context based on the DHT (Distributed Hash Table) key.
func (db *DB) OnDHT(key string, names ...string) (*Context, error) {
db.mu.RLock()
defer db.mu.RUnlock()
Expand All @@ -107,7 +110,6 @@ func (db *DB) OnDHT(key string, names ...string) (*Context, error) {

if err != nil {
return nil, err

}
return db.dbs[cur], nil
}
58 changes: 50 additions & 8 deletions sqlbuilder.go
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
// Package sqle provides a SQLBuilder for constructing SQL statements in a programmatic way.
// It allows you to build SELECT, INSERT, UPDATE, and DELETE statements with ease.
package sqle

import (
Expand All @@ -9,16 +11,19 @@ import (
)

var (
// ErrInvalidParamVariable is an error that is returned when an invalid parameter variable is encountered.
ErrInvalidParamVariable = errors.New("sqle: invalid param variable")
)

var (
DefaultSQLQuote = "`"
// DefaultSQLQuote is the default character used to escape column names in UPDATE and INSERT statements.
DefaultSQLQuote = "`"

// DefaultSQLParameterize is the default function used to parameterize values in SQL statements.
DefaultSQLParameterize = func(name string, index int) string {
return "?"
}
)

// Builder is a SQL query builder that allows you to construct SQL statements.
type Builder struct {
stmt strings.Builder
inputs map[string]string
Expand All @@ -29,8 +34,8 @@ type Builder struct {
Parameterize func(name string, index int) string
}

// New creates a new instance of the Builder with the given initial command(s).
func New(cmd ...string) *Builder {

b := &Builder{
inputs: make(map[string]string),
params: make(map[string]any),
Expand All @@ -48,11 +53,13 @@ func New(cmd ...string) *Builder {
return b
}

// Input sets the value of an input variable in the Builder.
func (b *Builder) Input(name, value string) *Builder {
b.inputs[name] = value
return b
}

// Inputs sets multiple input variables in the Builder.
func (b *Builder) Inputs(v map[string]string) *Builder {
for n, v := range v {
b.Input(n, v)
Expand All @@ -61,11 +68,13 @@ func (b *Builder) Inputs(v map[string]string) *Builder {
return b
}

// Param sets the value of a parameter variable in the Builder.
func (b *Builder) Param(name string, value any) *Builder {
b.params[name] = value
return b
}

// Params sets multiple parameter variables in the Builder.
func (b *Builder) Params(v map[string]any) *Builder {
for n, v := range v {
b.Param(n, v)
Expand All @@ -74,11 +83,15 @@ func (b *Builder) Params(v map[string]any) *Builder {
return b
}

// If sets a condition that determines whether the subsequent SQL command should be executed.
// If the predicate is false, the command is skipped.
func (b *Builder) If(predicate bool) *Builder {
b.shouldSkip = !predicate
return b
}

// SQL appends the given SQL command to the Builder's statement.
// If the Builder's shouldSkip flag is set, the command is skipped.
func (b *Builder) SQL(cmd string) *Builder {
if b.shouldSkip {
b.shouldSkip = false
Expand All @@ -91,10 +104,12 @@ func (b *Builder) SQL(cmd string) *Builder {
return b
}

// String returns the SQL statement constructed by the Builder.
func (b *Builder) String() string {
return b.stmt.String()
}

// Build constructs the final SQL statement and returns it along with the parameter values.
func (b *Builder) Build() (string, []any, error) {
tz := Tokenize(b.stmt.String())

Expand Down Expand Up @@ -131,6 +146,23 @@ func (b *Builder) Build() (string, []any, error) {

}

// WithWhere adds the input and parameter values from the given WhereBuilder to the current Builder
// and sets the WHERE clause of the SQL statement to the string representation of the WhereBuilder's statement.
// It returns the modified WhereBuilder.
func (b *Builder) WithWhere(wb *WhereBuilder) *WhereBuilder {
for k, v := range wb.inputs {
b.Input(k, v)
}

for k, v := range wb.params {
b.Param(k, v)
}

return b.Where(strings.TrimSpace(wb.stmt.String()))
}

// Where starts a new WhereBuilder and adds the given conditions to the current query builder.
// Returns the new WhereBuilder.
func (b *Builder) Where(cmd ...string) *WhereBuilder {
wb := &WhereBuilder{Builder: b}

Expand All @@ -146,6 +178,7 @@ func (b *Builder) Where(cmd ...string) *WhereBuilder {
return wb
}

// quoteColumn escapes the given column name using the Builder's Quote character.
func (b *Builder) quoteColumn(c string) string {
if strings.ContainsAny(c, "(") || strings.ContainsAny(c, " ") || strings.ContainsAny(c, "as") {
return c
Expand All @@ -154,13 +187,17 @@ func (b *Builder) quoteColumn(c string) string {
}
}

// Update starts a new UpdateBuilder and sets the table to update.
// Returns the new UpdateBuilder.
func (b *Builder) Update(table string) *UpdateBuilder {
b.SQL("UPDATE ").SQL(b.Quote).SQL(table).SQL(b.Quote).SQL(" SET ")
return &UpdateBuilder{
Builder: b,
}
}

// Insert starts a new InsertBuilder and sets the table to insert into.
// Returns the new InsertBuilder.
func (b *Builder) Insert(table string) *InsertBuilder {
return &InsertBuilder{
b: b,
Expand All @@ -169,6 +206,9 @@ func (b *Builder) Insert(table string) *InsertBuilder {
}
}

// Select adds a SELECT statement to the current query builder.
// If no columns are specified, it selects all columns using "*".
// Returns the current query builder.
func (b *Builder) Select(table string, columns ...string) *Builder {
b.SQL("SELECT")

Expand All @@ -189,18 +229,23 @@ func (b *Builder) Select(table string, columns ...string) *Builder {
return b
}

// Delete adds a DELETE statement to the current query builder.
// Returns the current query builder.
func (b *Builder) Delete(table string) *Builder {
b.SQL("DELETE FROM ").SQL(b.Quote).SQL(table).SQL(b.Quote)

return b
}

// On sets the "rotate" input variable to the given shard ID's rotate name.
// Returns the current query builder.
func (b *Builder) On(id shardid.ID) *Builder {
return b.Input("rotate", id.RotateName())
}

// sortColumns sorts the columns in the given map and returns them as a slice.
// It also allows customization of column names using BuilderOptions.
func sortColumns(m map[string]any, opts ...BuilderOption) []string {

bo := &BuilderOptions{}
for _, opt := range opts {
opt(bo)
Expand All @@ -209,15 +254,13 @@ func sortColumns(m map[string]any, opts ...BuilderOption) []string {
hasCustomizedColumns := len(bo.Columns) > 0

for n, v := range m {

name := n

if bo.ToName != nil {
name = bo.ToName(name)
if name != n {
m[name] = v
}

}

if !hasCustomizedColumns {
Expand All @@ -230,5 +273,4 @@ func sortColumns(m map[string]any, opts ...BuilderOption) []string {
}

return bo.Columns

}
26 changes: 26 additions & 0 deletions sqlbuilder_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -158,6 +158,32 @@ func TestBuilder(t *testing.T) {
require.Equal(t, now, vars[1])
},
},
{
name: "build_with_where",
build: func() *Builder {
b := New().Select("<prefix>orders")

wb := NewWhere().And("cancelled>={now}").
If(true).SQL("AND", "id={order_id}").
SQL("AND", "created>={now}")

wb.Input("prefix", "prefix_")
wb.Param("order_id", 123456).Param("now", now)

b.WithWhere(wb)

return b
},
assert: func(t *testing.T, b *Builder) {
s, vars, err := b.Build()
require.NoError(t, err)
require.Equal(t, "SELECT * FROM `prefix_orders` WHERE cancelled>=? AND id=? AND created>=?", s)
require.Len(t, vars, 3)
require.Equal(t, now, vars[0])
require.Equal(t, 123456, vars[1])
require.Equal(t, now, vars[2])
},
},
{
name: "build_update",
build: func() *Builder {
Expand Down
15 changes: 15 additions & 0 deletions sqlbuilder_where.go
Original file line number Diff line number Diff line change
@@ -1,24 +1,38 @@
package sqle

// WhereBuilder is a struct that represents a SQL WHERE clause builder.
type WhereBuilder struct {
*Builder
written bool
shouldSkip bool
}

// NewWhere creates a new instance of WhereBuilder.
func NewWhere() *WhereBuilder {
return &WhereBuilder{
Builder: New(),
}
}

// If sets a condition to skip the subsequent SQL statements.
// If the predicate is false, the subsequent SQL statements will be skipped.
func (wb *WhereBuilder) If(predicate bool) *WhereBuilder {
wb.shouldSkip = !predicate
return wb
}

// And appends an AND operator and a command to the SQL statement.
func (wb *WhereBuilder) And(cmd string) *WhereBuilder {
return wb.SQL("AND", cmd)
}

// Or appends an OR operator and a command to the SQL statement.
func (wb *WhereBuilder) Or(cmd string) *WhereBuilder {
return wb.SQL("OR", cmd)
}

// SQL appends an operator and a command to the SQL statement.
// The operator is only appended if the command is not empty.
func (wb *WhereBuilder) SQL(op string, cmd string) *WhereBuilder {
if wb.shouldSkip {
wb.shouldSkip = false
Expand All @@ -40,6 +54,7 @@ func (wb *WhereBuilder) SQL(op string, cmd string) *WhereBuilder {
return wb
}

// End returns the underlying Builder instance.
func (wb *WhereBuilder) End() *Builder {
return wb.Builder
}

0 comments on commit 7457db4

Please sign in to comment.