diff --git a/sql/analyzer/stored_procedures.go b/sql/analyzer/stored_procedures.go index c78d9759e1..9bf3a1bf2d 100644 --- a/sql/analyzer/stored_procedures.go +++ b/sql/analyzer/stored_procedures.go @@ -42,58 +42,25 @@ func loadStoredProcedures(ctx *sql.Context, a *Analyzer, n sql.Node, scope *plan allDatabases := a.Catalog.AllDatabases(ctx) for _, database := range allDatabases { - if pdb, ok := database.(sql.StoredProcedureDatabase); ok { - procedures, err := pdb.GetStoredProcedures(ctx) + pdb, ok := database.(sql.StoredProcedureDatabase); + if !ok { + continue + } + procedures, err := pdb.GetStoredProcedures(ctx) + if err != nil { + return nil, err + } + for _, procedure := range procedures { + proc := planbuilder.BuildProcedureHelper(ctx, a.Catalog, database, nil, procedure) + err = scope.Procedures.Register(database.Name(), proc) if err != nil { return nil, err } - - for _, procedure := range procedures { - var procToRegister *plan.Procedure - var parsedProcedure sql.Node - b := planbuilder.New(ctx, a.Catalog, nil, nil) - b.DisableAuth() - b.SetParserOptions(sql.NewSqlModeFromString(procedure.SqlMode).ParserOptions()) - parsedProcedure, _, _, _, err = b.Parse(procedure.CreateStatement, nil, false) - if err != nil { - procToRegister = &plan.Procedure{ - CreateProcedureString: procedure.CreateStatement, - } - procToRegister.ValidationError = err - } else if cp, ok := parsedProcedure.(*plan.CreateProcedure); !ok { - return nil, sql.ErrProcedureCreateStatementInvalid.New(procedure.CreateStatement) - } else { - procToRegister = cp.Procedure - } - - procToRegister.CreatedAt = procedure.CreatedAt - procToRegister.ModifiedAt = procedure.ModifiedAt - - err = scope.Procedures.Register(database.Name(), procToRegister) - if err != nil { - return nil, err - } - } } } return scope, nil } -// analyzeCreateProcedure checks the plan.CreateProcedure and returns a valid plan.Procedure or an error -func analyzeCreateProcedure(ctx *sql.Context, a *Analyzer, cp *plan.CreateProcedure, scope *plan.Scope, sel RuleSelector, qFlags *sql.QueryFlags) (*plan.Procedure, error) { - var analyzedNode sql.Node - var err error - analyzedNode, _, err = analyzeProcedureBodies(ctx, a, cp.Procedure, false, scope, sel, qFlags) - if err != nil { - return nil, err - } - analyzedProc, ok := analyzedNode.(*plan.Procedure) - if !ok { - return nil, fmt.Errorf("analyzed node %T and expected *plan.Procedure", analyzedNode) - } - return analyzedProc, nil -} - func hasProcedureCall(n sql.Node) bool { referencesProcedures := false transform.Inspect(n, func(n sql.Node) bool { @@ -164,9 +131,7 @@ func applyProcedures(ctx *sql.Context, a *Analyzer, n sql.Node, scope *plan.Scop return n, transform.SameTree, nil } - hasProcedureCall := hasProcedureCall(n) - _, isShowCreateProcedure := n.(*plan.ShowCreateProcedure) - if !hasProcedureCall && !isShowCreateProcedure { + if _, isShowCreateProcedure := n.(*plan.ShowCreateProcedure); !hasProcedureCall(n) && !isShowCreateProcedure { return n, transform.SameTree, nil } @@ -197,46 +162,19 @@ func applyProcedures(ctx *sql.Context, a *Analyzer, n sql.Node, scope *plan.Scop return call.WithProcedure(externalProcedure), transform.NewTree, nil } - if spdb, ok := call.Database().(sql.StoredProcedureDatabase); ok { - procedure, ok, err := spdb.GetStoredProcedure(ctx, call.Name) - if err != nil { - return nil, transform.SameTree, err - } - if !ok { - err := sql.ErrStoredProcedureDoesNotExist.New(call.Name) - if call.Database().Name() == "" { - return nil, transform.SameTree, fmt.Errorf("%w; this might be because no database is selected", err) - } - return nil, transform.SameTree, err - } - var parsedProcedure sql.Node - b := planbuilder.New(ctx, a.Catalog, nil, nil) - b.DisableAuth() - b.SetParserOptions(sql.NewSqlModeFromString(procedure.SqlMode).ParserOptions()) - if call.AsOf() != nil { - asOf, err := call.AsOf().Eval(ctx, nil) - if err != nil { - return n, transform.SameTree, err - } - b.ProcCtx().AsOf = asOf - } - b.ProcCtx().DbName = call.Database().Name() - parsedProcedure, _, _, _, err = b.Parse(procedure.CreateStatement, nil, false) - if err != nil { - return nil, transform.SameTree, err - } - cp, ok := parsedProcedure.(*plan.CreateProcedure) - if !ok { - return nil, transform.SameTree, sql.ErrProcedureCreateStatementInvalid.New(procedure.CreateStatement) - } - analyzedProc, err := analyzeCreateProcedure(ctx, a, cp, scope, sel, nil) - if err != nil { - return nil, transform.SameTree, err - } - return call.WithProcedure(analyzedProc), transform.NewTree, nil - } else { + if _, isStoredProcDb := call.Database().(sql.StoredProcedureDatabase); !isStoredProcDb { return nil, transform.SameTree, sql.ErrStoredProceduresNotSupported.New(call.Database().Name()) } + + analyzedNode, _, err := analyzeProcedureBodies(ctx, a, call.Procedure, false, scope, sel, qFlags) + if err != nil { + return nil, transform.SameTree, err + } + analyzedProc, ok := analyzedNode.(*plan.Procedure) + if !ok { + return nil, transform.SameTree, fmt.Errorf("analyzed node %T and expected *plan.Procedure", analyzedNode) + } + return call.WithProcedure(analyzedProc), transform.NewTree, nil }) if err != nil { return nil, transform.SameTree, err diff --git a/sql/plan/call.go b/sql/plan/call.go index 60a6f83008..c15b12ac5f 100644 --- a/sql/plan/call.go +++ b/sql/plan/call.go @@ -39,13 +39,14 @@ var _ sql.Expressioner = (*Call)(nil) var _ Versionable = (*Call)(nil) // NewCall returns a *Call node. -func NewCall(db sql.Database, name string, params []sql.Expression, asOf sql.Expression, catalog sql.Catalog) *Call { +func NewCall(db sql.Database, name string, params []sql.Expression, proc *Procedure, asOf sql.Expression, catalog sql.Catalog) *Call { return &Call{ - db: db, - Name: name, - Params: params, - asOf: asOf, - cat: catalog, + db: db, + Name: name, + Params: params, + Procedure: proc, + asOf: asOf, + cat: catalog, } } diff --git a/sql/plan/ddl_procedure.go b/sql/plan/ddl_procedure.go index ff824959d1..35c43df152 100644 --- a/sql/plan/ddl_procedure.go +++ b/sql/plan/ddl_procedure.go @@ -1,4 +1,4 @@ -// Copyright 2021 Dolthub, Inc. +// Copyright 2021-2025 Dolthub, Inc. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -15,18 +15,16 @@ package plan import ( - "fmt" - "time" - - "github.com/dolthub/go-mysql-server/sql/types" + "github.com/dolthub/go-mysql-server/sql/types" "github.com/dolthub/go-mysql-server/sql" ) type CreateProcedure struct { - *Procedure - ddlNode - BodyString string + ddlNode ddlNode + + StoredProcDetails sql.StoredProcedureDetails + BodyString string } var _ sql.Node = (*CreateProcedure)(nil) @@ -37,48 +35,31 @@ var _ sql.CollationCoercible = (*CreateProcedure)(nil) // NewCreateProcedure returns a *CreateProcedure node. func NewCreateProcedure( db sql.Database, - name, - definer string, - params []ProcedureParam, - createdAt, modifiedAt time.Time, - securityContext ProcedureSecurityContext, - characteristics []Characteristic, - body sql.Node, - comment, createString, bodyString string, + storedProcDetails sql.StoredProcedureDetails, + bodyString string, ) *CreateProcedure { - procedure := NewProcedure( - name, - definer, - params, - securityContext, - comment, - characteristics, - createString, - body, - createdAt, - modifiedAt) return &CreateProcedure{ - Procedure: procedure, - BodyString: bodyString, - ddlNode: ddlNode{db}, + ddlNode: ddlNode{db}, + StoredProcDetails: storedProcDetails, + BodyString: bodyString, } } // Database implements the sql.Databaser interface. func (c *CreateProcedure) Database() sql.Database { - return c.Db + return c.ddlNode.Db } // WithDatabase implements the sql.Databaser interface. func (c *CreateProcedure) WithDatabase(database sql.Database) (sql.Node, error) { cp := *c - cp.Db = database + cp.ddlNode.Db = database return &cp, nil } // Resolved implements the sql.Node interface. func (c *CreateProcedure) Resolved() bool { - return c.ddlNode.Resolved() && c.Procedure.Resolved() + return c.ddlNode.Resolved() } func (c *CreateProcedure) IsReadOnly() bool { @@ -92,22 +73,15 @@ func (c *CreateProcedure) Schema() sql.Schema { // Children implements the sql.Node interface. func (c *CreateProcedure) Children() []sql.Node { - return []sql.Node{c.Procedure} + return []sql.Node{} } // WithChildren implements the sql.Node interface. func (c *CreateProcedure) WithChildren(children ...sql.Node) (sql.Node, error) { - if len(children) != 1 { - return nil, sql.ErrInvalidChildrenNumber.New(c, len(children), 1) - } - procedure, ok := children[0].(*Procedure) - if !ok { - return nil, fmt.Errorf("expected `*Procedure` but got `%T`", children[0]) + if len(children) != 0 { + return nil, sql.ErrInvalidChildrenNumber.New(c, len(children), 0) } - - nc := *c - nc.Procedure = procedure - return &nc, nil + return c, nil } // CollationCoercibility implements the interface sql.CollationCoercible. @@ -117,50 +91,54 @@ func (*CreateProcedure) CollationCoercibility(ctx *sql.Context) (collation sql.C // String implements the sql.Node interface. func (c *CreateProcedure) String() string { - definer := "" - if c.Definer != "" { - definer = fmt.Sprintf(" DEFINER = %s", c.Definer) - } - params := "" - for i, param := range c.Params { - if i > 0 { - params += ", " - } - params += param.String() - } - comment := "" - if c.Comment != "" { - comment = fmt.Sprintf(" COMMENT '%s'", c.Comment) - } - characteristics := "" - for _, characteristic := range c.Characteristics { - characteristics += fmt.Sprintf(" %s", characteristic.String()) - } - return fmt.Sprintf("CREATE%s PROCEDURE %s (%s) %s%s%s %s", - definer, c.Name, params, c.SecurityContext.String(), comment, characteristics, c.Procedure.String()) + // move this logic elsewhere + return "TODO" + //definer := "" + //if c.Procedure.Definer != "" { + // definer = fmt.Sprintf(" DEFINER = %s", c.Procedure.Definer) + //} + //params := "" + //for i, param := range c.Procedure.Params { + // if i > 0 { + // params += ", " + // } + // params += param.String() + //} + //comment := "" + //if c.Procedure.Comment != "" { + // comment = fmt.Sprintf(" COMMENT '%s'", c.Procedure.Comment) + //} + //characteristics := "" + //for _, characteristic := range c.Procedure.Characteristics { + // characteristics += fmt.Sprintf(" %s", characteristic.String()) + //} + //return fmt.Sprintf("CREATE%s PROCEDURE %s (%s) %s%s%s %s", + // definer, c.Procedure.Name, params, c.Procedure.SecurityContext.String(), comment, characteristics, c.Procedure.String()) } // DebugString implements the sql.DebugStringer interface. func (c *CreateProcedure) DebugString() string { - definer := "" - if c.Definer != "" { - definer = fmt.Sprintf(" DEFINER = %s", c.Definer) - } - params := "" - for i, param := range c.Params { - if i > 0 { - params += ", " - } - params += param.String() - } - comment := "" - if c.Comment != "" { - comment = fmt.Sprintf(" COMMENT '%s'", c.Comment) - } - characteristics := "" - for _, characteristic := range c.Characteristics { - characteristics += fmt.Sprintf(" %s", characteristic.String()) - } - return fmt.Sprintf("CREATE%s PROCEDURE %s (%s) %s%s%s %s", - definer, c.Name, params, c.SecurityContext.String(), comment, characteristics, sql.DebugString(c.Procedure)) + // move this logic elsewhere + return "TODO" + //definer := "" + //if c.Procedure.Definer != "" { + // definer = fmt.Sprintf(" DEFINER = %s", c.Procedure.Definer) + //} + //params := "" + //for i, param := range c.Procedure.Params { + // if i > 0 { + // params += ", " + // } + // params += param.String() + //} + //comment := "" + //if c.Procedure.Comment != "" { + // comment = fmt.Sprintf(" COMMENT '%s'", c.Procedure.Comment) + //} + //characteristics := "" + //for _, characteristic := range c.Procedure.Characteristics { + // characteristics += fmt.Sprintf(" %s", characteristic.String()) + //} + //return fmt.Sprintf("CREATE%s PROCEDURE %s (%s) %s%s%s %s", + // definer, c.Procedure.Name, params, c.Procedure.SecurityContext.String(), comment, characteristics, sql.DebugString(c.Procedure)) } diff --git a/sql/planbuilder/create_ddl.go b/sql/planbuilder/create_ddl.go index 47a1f5ae13..874bcb5dfd 100644 --- a/sql/planbuilder/create_ddl.go +++ b/sql/planbuilder/create_ddl.go @@ -16,7 +16,8 @@ package planbuilder import ( "fmt" - "strings" + "github.com/dolthub/go-mysql-server/sql/types" +"strings" "time" "unicode" @@ -25,9 +26,7 @@ import ( "github.com/dolthub/go-mysql-server/sql" "github.com/dolthub/go-mysql-server/sql/expression" "github.com/dolthub/go-mysql-server/sql/plan" - "github.com/dolthub/go-mysql-server/sql/transform" - "github.com/dolthub/go-mysql-server/sql/types" -) + ) func (b *Builder) buildCreateTrigger(inScope *scope, subQuery string, fullQuery string, c *ast.DDL) (outScope *scope) { outScope = inScope.push() @@ -132,12 +131,9 @@ func getCurrentUserForDefiner(ctx *sql.Context, definer string) string { return definer } -func (b *Builder) buildCreateProcedure(inScope *scope, subQuery string, fullQuery string, c *ast.DDL) (outScope *scope) { - b.qFlags.Set(sql.QFlagCreateProcedure) - defer func() { b.qFlags.Unset(sql.QFlagCreateProcedure) }() - +func (b *Builder) buildProcedureParams(procParams []ast.ProcedureParam) []plan.ProcedureParam { var params []plan.ProcedureParam - for _, param := range c.ProcedureSpec.Params { + for _, param := range procParams { var direction plan.ProcedureParamDirection switch param.Direction { case ast.ProcedureParamDirection_In: @@ -161,11 +157,14 @@ func (b *Builder) buildCreateProcedure(inScope *scope, subQuery string, fullQuer Variadic: false, }) } + return params +} +func (b *Builder) buildProcedureCharacteristics(procCharacteristics []ast.Characteristic) ([]plan.Characteristic, plan.ProcedureSecurityContext, string) { var characteristics []plan.Characteristic securityType := plan.ProcedureSecurityContext_Definer // Default Security Context comment := "" - for _, characteristic := range c.ProcedureSpec.Characteristics { + for _, characteristic := range procCharacteristics { switch characteristic.Type { case ast.CharacteristicValue_Comment: comment = characteristic.Comment @@ -192,56 +191,30 @@ func (b *Builder) buildCreateProcedure(inScope *scope, subQuery string, fullQuer b.handleErr(err) } } + return characteristics, securityType, comment +} - inScope.initProc() - procName := strings.ToLower(c.ProcedureSpec.ProcName.Name.String()) - for _, p := range params { - // populate inScope with the procedure parameters. this will be - // subject maybe a bug where an inner procedure has access to - // outer procedure parameters. - inScope.proc.AddVar(expression.NewProcedureParam(strings.ToLower(p.Name), p.Type)) - } - bodyStr := strings.TrimSpace(fullQuery[c.SubStatementPositionStart:c.SubStatementPositionEnd]) - - bodyScope := b.buildSubquery(inScope, c.ProcedureSpec.Body, bodyStr, fullQuery) - b.validateStoredProcedure(bodyScope.node) - - // Check for recursive calls to same procedure - transform.Inspect(bodyScope.node, func(node sql.Node) bool { - switch n := node.(type) { - case *plan.Call: - if strings.EqualFold(procName, n.Name) { - b.handleErr(sql.ErrProcedureRecursiveCall.New(procName)) - } - return false - default: - return true - } - }) - +func (b *Builder) buildCreateProcedure(inScope *scope, subQuery string, fullQuery string, c *ast.DDL) (outScope *scope) { var db sql.Database = nil - dbName := c.ProcedureSpec.ProcName.Qualifier.String() - if dbName != "" { + if dbName := c.ProcedureSpec.ProcName.Qualifier.String(); dbName != "" { db = b.resolveDb(dbName) } else { db = b.currentDb() } + now := time.Now() + spd := sql.StoredProcedureDetails{ + Name: strings.ToLower(c.ProcedureSpec.ProcName.Name.String()), + CreateStatement: subQuery, + CreatedAt: now, + ModifiedAt: now, + SqlMode: sql.LoadSqlMode(b.ctx).String(), + } + + bodyStr := strings.TrimSpace(fullQuery[c.SubStatementPositionStart:c.SubStatementPositionEnd]) + outScope = inScope.push() - outScope.node = plan.NewCreateProcedure( - db, - procName, - c.ProcedureSpec.Definer, - params, - time.Now(), - time.Now(), - securityType, - characteristics, - bodyScope.node, - comment, - subQuery, - bodyStr, - ) + outScope.node = plan.NewCreateProcedure(db, spd, bodyStr) return outScope } diff --git a/sql/planbuilder/proc.go b/sql/planbuilder/proc.go index 56ccda45d1..e6056ecd22 100644 --- a/sql/planbuilder/proc.go +++ b/sql/planbuilder/proc.go @@ -226,6 +226,43 @@ func (b *Builder) buildIfConditional(inScope *scope, n ast.IfStatementCondition, return outScope } +func BuildProcedureHelper(ctx *sql.Context, cat sql.Catalog, db sql.Database, asOf sql.Expression, proc sql.StoredProcedureDetails) *plan.Procedure { + // TODO: new builder necessary? + b := New(ctx, cat, nil, nil) + b.DisableAuth() + b.SetParserOptions(sql.NewSqlModeFromString(proc.SqlMode).ParserOptions()) + if asOf != nil { + asOf, err := asOf.Eval(b.ctx, nil) + if err != nil { + b.handleErr(err) + } + b.ProcCtx().AsOf = asOf + } + b.ProcCtx().DbName = db.Name() + stmt, _, _, _ := b.parser.ParseWithOptions(b.ctx, proc.CreateStatement, ';', false, b.parserOpts) + procStmt := stmt.(*ast.DDL) + bodyStr := strings.TrimSpace(proc.CreateStatement[procStmt.SubStatementPositionStart:procStmt.SubStatementPositionEnd]) + bodyScope := b.buildSubquery(nil, procStmt.ProcedureSpec.Body, bodyStr, proc.CreateStatement) // TODO: scope? + + // TODO: validate + + procParams := b.buildProcedureParams(procStmt.ProcedureSpec.Params) + characteristics, securityType, comment := b.buildProcedureCharacteristics(procStmt.ProcedureSpec.Characteristics) + + return plan.NewProcedure( + proc.Name, + procStmt.ProcedureSpec.Definer, + procParams, + securityType, + comment, + characteristics, + proc.CreateStatement, + bodyScope.node, + proc.CreatedAt, + proc.ModifiedAt, + ) +} + func (b *Builder) buildCall(inScope *scope, c *ast.Call) (outScope *scope) { if err := b.cat.AuthorizationHandler().HandleAuth(b.ctx, b.authQueryState, c.Auth); err != nil && b.authEnabled { b.handleErr(err) @@ -255,14 +292,28 @@ func (b *Builder) buildCall(inScope *scope, c *ast.Call) (outScope *scope) { db = b.resolveDb(dbName) } else if b.ctx.GetCurrentDatabase() != "" { db = b.currentDb() + } else { + b.handleErr(sql.ErrDatabaseNotFound.New(c.ProcName.Qualifier.String())) + } + + // TODO: external stored procedures? + spdb, ok := db.(sql.StoredProcedureDatabase) + if !ok { + err := sql.ErrStoredProceduresNotSupported.New(db.Name()) + b.handleErr(err) + } + + procName := c.ProcName.Name.String() + proc, ok, err := spdb.GetStoredProcedure(b.ctx, procName) + if err != nil { + b.handleErr(err) + } + if !ok { + b.handleErr(sql.ErrStoredProcedureDoesNotExist.New(procName)) } - outScope.node = plan.NewCall( - db, - c.ProcName.Name.String(), - params, - asOf, - b.cat) + newProc := BuildProcedureHelper(b.ctx, b.cat, db, asOf, proc) + outScope.node = plan.NewCall( db, procName, params, newProc, asOf, b.cat) return outScope } diff --git a/sql/rowexec/ddl.go b/sql/rowexec/ddl.go index af9516aa65..76ff4260b0 100644 --- a/sql/rowexec/ddl.go +++ b/sql/rowexec/ddl.go @@ -1128,16 +1128,9 @@ func createIndexesForCreateTable(ctx *sql.Context, db sql.Database, tableNode sq } func (b *BaseBuilder) buildCreateProcedure(ctx *sql.Context, n *plan.CreateProcedure, row sql.Row) (sql.RowIter, error) { - sqlMode := sql.LoadSqlMode(ctx) return &createProcedureIter{ - spd: sql.StoredProcedureDetails{ - Name: n.Name, - CreateStatement: n.CreateProcedureString, - CreatedAt: n.CreatedAt, - ModifiedAt: n.ModifiedAt, - SqlMode: sqlMode.String(), - }, - db: n.Database(), + spd: n.StoredProcDetails, + db: n.Database(), }, nil }