Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update join #6757

Closed
wants to merge 3 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions callbacks/callbacks.go
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@ func RegisterDefaultCallbacks(db *gorm.DB, config *Config) {
deleteCallback.Register("gorm:before_delete", BeforeDelete)
deleteCallback.Register("gorm:delete_before_associations", DeleteBeforeAssociations)
deleteCallback.Register("gorm:delete", Delete(config))
deleteCallback.Register("gorm:preload", Preload)
deleteCallback.Register("gorm:after_delete", AfterDelete)
deleteCallback.Match(enableTransaction).Register("gorm:commit_or_rollback_transaction", CommitOrRollbackTransaction)
deleteCallback.Clauses = config.DeleteClauses
Expand All @@ -68,6 +69,7 @@ func RegisterDefaultCallbacks(db *gorm.DB, config *Config) {
updateCallback.Register("gorm:before_update", BeforeUpdate)
updateCallback.Register("gorm:save_before_associations", SaveBeforeAssociations(false))
updateCallback.Register("gorm:update", Update(config))
updateCallback.Register("gorm:preload", Preload)
updateCallback.Register("gorm:save_after_associations", SaveAfterAssociations(false))
updateCallback.Register("gorm:after_update", AfterUpdate)
updateCallback.Match(enableTransaction).Register("gorm:commit_or_rollback_transaction", CommitOrRollbackTransaction)
Expand Down
158 changes: 4 additions & 154 deletions callbacks/query.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,9 @@ import (
"fmt"
"reflect"
"sort"
"strings"

"gorm.io/gorm"
"gorm.io/gorm/clause"
"gorm.io/gorm/schema"
"gorm.io/gorm/utils"
)

func Query(db *gorm.DB) {
Expand Down Expand Up @@ -104,157 +101,8 @@ func BuildQuerySQL(db *gorm.DB) {
}

if len(db.Statement.Joins) != 0 || len(fromClause.Joins) != 0 {
if len(db.Statement.Selects) == 0 && len(db.Statement.Omits) == 0 && db.Statement.Schema != nil {
clauseSelect.Columns = make([]clause.Column, len(db.Statement.Schema.DBNames))
for idx, dbName := range db.Statement.Schema.DBNames {
clauseSelect.Columns[idx] = clause.Column{Table: db.Statement.Table, Name: dbName}
}
}

specifiedRelationsName := make(map[string]interface{})
for _, join := range db.Statement.Joins {
if db.Statement.Schema != nil {
var isRelations bool // is relations or raw sql
var relations []*schema.Relationship
relation, ok := db.Statement.Schema.Relationships.Relations[join.Name]
if ok {
isRelations = true
relations = append(relations, relation)
} else {
// handle nested join like "Manager.Company"
nestedJoinNames := strings.Split(join.Name, ".")
if len(nestedJoinNames) > 1 {
isNestedJoin := true
gussNestedRelations := make([]*schema.Relationship, 0, len(nestedJoinNames))
currentRelations := db.Statement.Schema.Relationships.Relations
for _, relname := range nestedJoinNames {
// incomplete match, only treated as raw sql
if relation, ok = currentRelations[relname]; ok {
gussNestedRelations = append(gussNestedRelations, relation)
currentRelations = relation.FieldSchema.Relationships.Relations
} else {
isNestedJoin = false
break
}
}

if isNestedJoin {
isRelations = true
relations = gussNestedRelations
}
}
}

if isRelations {
genJoinClause := func(joinType clause.JoinType, parentTableName string, relation *schema.Relationship) clause.Join {
tableAliasName := relation.Name
if parentTableName != clause.CurrentTable {
tableAliasName = utils.NestedRelationName(parentTableName, tableAliasName)
}

columnStmt := gorm.Statement{
Table: tableAliasName, DB: db, Schema: relation.FieldSchema,
Selects: join.Selects, Omits: join.Omits,
}

selectColumns, restricted := columnStmt.SelectAndOmitColumns(false, false)
for _, s := range relation.FieldSchema.DBNames {
if v, ok := selectColumns[s]; (ok && v) || (!ok && !restricted) {
clauseSelect.Columns = append(clauseSelect.Columns, clause.Column{
Table: tableAliasName,
Name: s,
Alias: utils.NestedRelationName(tableAliasName, s),
})
}
}

exprs := make([]clause.Expression, len(relation.References))
for idx, ref := range relation.References {
if ref.OwnPrimaryKey {
exprs[idx] = clause.Eq{
Column: clause.Column{Table: parentTableName, Name: ref.PrimaryKey.DBName},
Value: clause.Column{Table: tableAliasName, Name: ref.ForeignKey.DBName},
}
} else {
if ref.PrimaryValue == "" {
exprs[idx] = clause.Eq{
Column: clause.Column{Table: parentTableName, Name: ref.ForeignKey.DBName},
Value: clause.Column{Table: tableAliasName, Name: ref.PrimaryKey.DBName},
}
} else {
exprs[idx] = clause.Eq{
Column: clause.Column{Table: tableAliasName, Name: ref.ForeignKey.DBName},
Value: ref.PrimaryValue,
}
}
}
}

{
onStmt := gorm.Statement{Table: tableAliasName, DB: db, Clauses: map[string]clause.Clause{}}
for _, c := range relation.FieldSchema.QueryClauses {
onStmt.AddClause(c)
}

if join.On != nil {
onStmt.AddClause(join.On)
}

if cs, ok := onStmt.Clauses["WHERE"]; ok {
if where, ok := cs.Expression.(clause.Where); ok {
where.Build(&onStmt)

if onSQL := onStmt.SQL.String(); onSQL != "" {
vars := onStmt.Vars
for idx, v := range vars {
bindvar := strings.Builder{}
onStmt.Vars = vars[0 : idx+1]
db.Dialector.BindVarTo(&bindvar, &onStmt, v)
onSQL = strings.Replace(onSQL, bindvar.String(), "?", 1)
}

exprs = append(exprs, clause.Expr{SQL: onSQL, Vars: vars})
}
}
}
}

return clause.Join{
Type: joinType,
Table: clause.Table{Name: relation.FieldSchema.Table, Alias: tableAliasName},
ON: clause.Where{Exprs: exprs},
}
}

parentTableName := clause.CurrentTable
for _, rel := range relations {
// joins table alias like "Manager, Company, Manager__Company"
nestedAlias := utils.NestedRelationName(parentTableName, rel.Name)
if _, ok := specifiedRelationsName[nestedAlias]; !ok {
fromClause.Joins = append(fromClause.Joins, genJoinClause(join.JoinType, parentTableName, rel))
specifiedRelationsName[nestedAlias] = nil
}

if parentTableName != clause.CurrentTable {
parentTableName = utils.NestedRelationName(parentTableName, rel.Name)
} else {
parentTableName = rel.Name
}
}
} else {
fromClause.Joins = append(fromClause.Joins, clause.Join{
Expression: clause.NamedExpr{SQL: join.Name, Vars: join.Conds},
})
}
} else {
fromClause.Joins = append(fromClause.Joins, clause.Join{
Expression: clause.NamedExpr{SQL: join.Name, Vars: join.Conds},
})
}
}

fromClause.Joins = append(fromClause.Joins, gorm.GenJoinClauses(db, &clauseSelect)...)
db.Statement.AddClause(fromClause)
db.Statement.Joins = nil
} else {
db.Statement.AddClauseIfNotExists(clause.From{})
}
Expand Down Expand Up @@ -286,7 +134,9 @@ func Preload(db *gorm.DB) {
})

if err := preloadDB.Statement.Parse(db.Statement.Dest); err != nil {
return
if err := preloadDB.Statement.Parse(db.Statement.Model); err != nil {
return
}
}
preloadDB.Statement.ReflectValue = db.Statement.ReflectValue
preloadDB.Statement.Unscoped = db.Statement.Unscoped
Expand Down
4 changes: 3 additions & 1 deletion callbacks/update.go
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,9 @@ func Update(config *Config) func(db *gorm.DB) {

if db.Statement.SQL.Len() == 0 {
db.Statement.SQL.Grow(180)
db.Statement.AddClauseIfNotExists(clause.Update{})

gorm.CreateUpdateClause(db.Statement)

if _, ok := db.Statement.Clauses["SET"]; !ok {
if set := ConvertToAssignments(db.Statement); len(set) != 0 {
defer delete(db.Statement.Clauses, "SET")
Expand Down
6 changes: 6 additions & 0 deletions clause/update.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package clause
type Update struct {
Modifier string
Table Table
Joins []Join
}

// Name update clause name
Expand All @@ -22,6 +23,11 @@ func (update Update) Build(builder Builder) {
} else {
builder.WriteQuoted(update.Table)
}

for _, join := range update.Joins {
builder.WriteByte(' ')
join.Build(builder)
}
}

// MergeClause merge update clause
Expand Down
Loading
Loading