Skip to content

Commit

Permalink
refactor create procedure and call procedure
Browse files Browse the repository at this point in the history
  • Loading branch information
James Cor committed Jan 28, 2025
1 parent 8a1af52 commit 93bc19e
Show file tree
Hide file tree
Showing 6 changed files with 178 additions and 244 deletions.
108 changes: 23 additions & 85 deletions sql/analyzer/stored_procedures.go
Original file line number Diff line number Diff line change
Expand Up @@ -42,58 +42,25 @@ func loadStoredProcedures(ctx *sql.Context, a *Analyzer, n sql.Node, scope *plan

allDatabases := a.Catalog.AllDatabases(ctx)
for _, database := range allDatabases {
if pdb, ok := database.(sql.StoredProcedureDatabase); ok {
procedures, err := pdb.GetStoredProcedures(ctx)
pdb, ok := database.(sql.StoredProcedureDatabase);
if !ok {
continue
}
procedures, err := pdb.GetStoredProcedures(ctx)
if err != nil {
return nil, err
}
for _, procedure := range procedures {
proc := planbuilder.BuildProcedureHelper(ctx, a.Catalog, database, nil, procedure)
err = scope.Procedures.Register(database.Name(), proc)
if err != nil {
return nil, err
}

for _, procedure := range procedures {
var procToRegister *plan.Procedure
var parsedProcedure sql.Node
b := planbuilder.New(ctx, a.Catalog, nil, nil)
b.DisableAuth()
b.SetParserOptions(sql.NewSqlModeFromString(procedure.SqlMode).ParserOptions())
parsedProcedure, _, _, _, err = b.Parse(procedure.CreateStatement, nil, false)
if err != nil {
procToRegister = &plan.Procedure{
CreateProcedureString: procedure.CreateStatement,
}
procToRegister.ValidationError = err
} else if cp, ok := parsedProcedure.(*plan.CreateProcedure); !ok {
return nil, sql.ErrProcedureCreateStatementInvalid.New(procedure.CreateStatement)
} else {
procToRegister = cp.Procedure
}

procToRegister.CreatedAt = procedure.CreatedAt
procToRegister.ModifiedAt = procedure.ModifiedAt

err = scope.Procedures.Register(database.Name(), procToRegister)
if err != nil {
return nil, err
}
}
}
}
return scope, nil
}

// analyzeCreateProcedure checks the plan.CreateProcedure and returns a valid plan.Procedure or an error
func analyzeCreateProcedure(ctx *sql.Context, a *Analyzer, cp *plan.CreateProcedure, scope *plan.Scope, sel RuleSelector, qFlags *sql.QueryFlags) (*plan.Procedure, error) {
var analyzedNode sql.Node
var err error
analyzedNode, _, err = analyzeProcedureBodies(ctx, a, cp.Procedure, false, scope, sel, qFlags)
if err != nil {
return nil, err
}
analyzedProc, ok := analyzedNode.(*plan.Procedure)
if !ok {
return nil, fmt.Errorf("analyzed node %T and expected *plan.Procedure", analyzedNode)
}
return analyzedProc, nil
}

func hasProcedureCall(n sql.Node) bool {
referencesProcedures := false
transform.Inspect(n, func(n sql.Node) bool {
Expand Down Expand Up @@ -164,9 +131,7 @@ func applyProcedures(ctx *sql.Context, a *Analyzer, n sql.Node, scope *plan.Scop
return n, transform.SameTree, nil
}

hasProcedureCall := hasProcedureCall(n)
_, isShowCreateProcedure := n.(*plan.ShowCreateProcedure)
if !hasProcedureCall && !isShowCreateProcedure {
if _, isShowCreateProcedure := n.(*plan.ShowCreateProcedure); !hasProcedureCall(n) && !isShowCreateProcedure {
return n, transform.SameTree, nil
}

Expand Down Expand Up @@ -197,46 +162,19 @@ func applyProcedures(ctx *sql.Context, a *Analyzer, n sql.Node, scope *plan.Scop
return call.WithProcedure(externalProcedure), transform.NewTree, nil
}

if spdb, ok := call.Database().(sql.StoredProcedureDatabase); ok {
procedure, ok, err := spdb.GetStoredProcedure(ctx, call.Name)
if err != nil {
return nil, transform.SameTree, err
}
if !ok {
err := sql.ErrStoredProcedureDoesNotExist.New(call.Name)
if call.Database().Name() == "" {
return nil, transform.SameTree, fmt.Errorf("%w; this might be because no database is selected", err)
}
return nil, transform.SameTree, err
}
var parsedProcedure sql.Node
b := planbuilder.New(ctx, a.Catalog, nil, nil)
b.DisableAuth()
b.SetParserOptions(sql.NewSqlModeFromString(procedure.SqlMode).ParserOptions())
if call.AsOf() != nil {
asOf, err := call.AsOf().Eval(ctx, nil)
if err != nil {
return n, transform.SameTree, err
}
b.ProcCtx().AsOf = asOf
}
b.ProcCtx().DbName = call.Database().Name()
parsedProcedure, _, _, _, err = b.Parse(procedure.CreateStatement, nil, false)
if err != nil {
return nil, transform.SameTree, err
}
cp, ok := parsedProcedure.(*plan.CreateProcedure)
if !ok {
return nil, transform.SameTree, sql.ErrProcedureCreateStatementInvalid.New(procedure.CreateStatement)
}
analyzedProc, err := analyzeCreateProcedure(ctx, a, cp, scope, sel, nil)
if err != nil {
return nil, transform.SameTree, err
}
return call.WithProcedure(analyzedProc), transform.NewTree, nil
} else {
if _, isStoredProcDb := call.Database().(sql.StoredProcedureDatabase); !isStoredProcDb {
return nil, transform.SameTree, sql.ErrStoredProceduresNotSupported.New(call.Database().Name())
}

analyzedNode, _, err := analyzeProcedureBodies(ctx, a, call.Procedure, false, scope, sel, qFlags)
if err != nil {
return nil, transform.SameTree, err
}
analyzedProc, ok := analyzedNode.(*plan.Procedure)
if !ok {
return nil, transform.SameTree, fmt.Errorf("analyzed node %T and expected *plan.Procedure", analyzedNode)
}
return call.WithProcedure(analyzedProc), transform.NewTree, nil
})
if err != nil {
return nil, transform.SameTree, err
Expand Down
13 changes: 7 additions & 6 deletions sql/plan/call.go
Original file line number Diff line number Diff line change
Expand Up @@ -39,13 +39,14 @@ var _ sql.Expressioner = (*Call)(nil)
var _ Versionable = (*Call)(nil)

// NewCall returns a *Call node.
func NewCall(db sql.Database, name string, params []sql.Expression, asOf sql.Expression, catalog sql.Catalog) *Call {
func NewCall(db sql.Database, name string, params []sql.Expression, proc *Procedure, asOf sql.Expression, catalog sql.Catalog) *Call {
return &Call{
db: db,
Name: name,
Params: params,
asOf: asOf,
cat: catalog,
db: db,
Name: name,
Params: params,
Procedure: proc,
asOf: asOf,
cat: catalog,
}
}

Expand Down
150 changes: 64 additions & 86 deletions sql/plan/ddl_procedure.go
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
// Copyright 2021 Dolthub, Inc.
// Copyright 2021-2025 Dolthub, Inc.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
Expand All @@ -15,18 +15,16 @@
package plan

import (
"fmt"
"time"

"github.com/dolthub/go-mysql-server/sql/types"
"github.com/dolthub/go-mysql-server/sql/types"

"github.com/dolthub/go-mysql-server/sql"
)

type CreateProcedure struct {
*Procedure
ddlNode
BodyString string
ddlNode ddlNode

StoredProcDetails sql.StoredProcedureDetails
BodyString string
}

var _ sql.Node = (*CreateProcedure)(nil)
Expand All @@ -37,48 +35,31 @@ var _ sql.CollationCoercible = (*CreateProcedure)(nil)
// NewCreateProcedure returns a *CreateProcedure node.
func NewCreateProcedure(
db sql.Database,
name,
definer string,
params []ProcedureParam,
createdAt, modifiedAt time.Time,
securityContext ProcedureSecurityContext,
characteristics []Characteristic,
body sql.Node,
comment, createString, bodyString string,
storedProcDetails sql.StoredProcedureDetails,
bodyString string,
) *CreateProcedure {
procedure := NewProcedure(
name,
definer,
params,
securityContext,
comment,
characteristics,
createString,
body,
createdAt,
modifiedAt)
return &CreateProcedure{
Procedure: procedure,
BodyString: bodyString,
ddlNode: ddlNode{db},
ddlNode: ddlNode{db},
StoredProcDetails: storedProcDetails,
BodyString: bodyString,
}
}

// Database implements the sql.Databaser interface.
func (c *CreateProcedure) Database() sql.Database {
return c.Db
return c.ddlNode.Db
}

// WithDatabase implements the sql.Databaser interface.
func (c *CreateProcedure) WithDatabase(database sql.Database) (sql.Node, error) {
cp := *c
cp.Db = database
cp.ddlNode.Db = database
return &cp, nil
}

// Resolved implements the sql.Node interface.
func (c *CreateProcedure) Resolved() bool {
return c.ddlNode.Resolved() && c.Procedure.Resolved()
return c.ddlNode.Resolved()
}

func (c *CreateProcedure) IsReadOnly() bool {
Expand All @@ -92,22 +73,15 @@ func (c *CreateProcedure) Schema() sql.Schema {

// Children implements the sql.Node interface.
func (c *CreateProcedure) Children() []sql.Node {
return []sql.Node{c.Procedure}
return []sql.Node{}
}

// WithChildren implements the sql.Node interface.
func (c *CreateProcedure) WithChildren(children ...sql.Node) (sql.Node, error) {
if len(children) != 1 {
return nil, sql.ErrInvalidChildrenNumber.New(c, len(children), 1)
}
procedure, ok := children[0].(*Procedure)
if !ok {
return nil, fmt.Errorf("expected `*Procedure` but got `%T`", children[0])
if len(children) != 0 {
return nil, sql.ErrInvalidChildrenNumber.New(c, len(children), 0)
}

nc := *c
nc.Procedure = procedure
return &nc, nil
return c, nil
}

// CollationCoercibility implements the interface sql.CollationCoercible.
Expand All @@ -117,50 +91,54 @@ func (*CreateProcedure) CollationCoercibility(ctx *sql.Context) (collation sql.C

// String implements the sql.Node interface.
func (c *CreateProcedure) String() string {
definer := ""
if c.Definer != "" {
definer = fmt.Sprintf(" DEFINER = %s", c.Definer)
}
params := ""
for i, param := range c.Params {
if i > 0 {
params += ", "
}
params += param.String()
}
comment := ""
if c.Comment != "" {
comment = fmt.Sprintf(" COMMENT '%s'", c.Comment)
}
characteristics := ""
for _, characteristic := range c.Characteristics {
characteristics += fmt.Sprintf(" %s", characteristic.String())
}
return fmt.Sprintf("CREATE%s PROCEDURE %s (%s) %s%s%s %s",
definer, c.Name, params, c.SecurityContext.String(), comment, characteristics, c.Procedure.String())
// move this logic elsewhere
return "TODO"
//definer := ""
//if c.Procedure.Definer != "" {
// definer = fmt.Sprintf(" DEFINER = %s", c.Procedure.Definer)
//}
//params := ""
//for i, param := range c.Procedure.Params {
// if i > 0 {
// params += ", "
// }
// params += param.String()
//}
//comment := ""
//if c.Procedure.Comment != "" {
// comment = fmt.Sprintf(" COMMENT '%s'", c.Procedure.Comment)
//}
//characteristics := ""
//for _, characteristic := range c.Procedure.Characteristics {
// characteristics += fmt.Sprintf(" %s", characteristic.String())
//}
//return fmt.Sprintf("CREATE%s PROCEDURE %s (%s) %s%s%s %s",
// definer, c.Procedure.Name, params, c.Procedure.SecurityContext.String(), comment, characteristics, c.Procedure.String())
}

// DebugString implements the sql.DebugStringer interface.
func (c *CreateProcedure) DebugString() string {
definer := ""
if c.Definer != "" {
definer = fmt.Sprintf(" DEFINER = %s", c.Definer)
}
params := ""
for i, param := range c.Params {
if i > 0 {
params += ", "
}
params += param.String()
}
comment := ""
if c.Comment != "" {
comment = fmt.Sprintf(" COMMENT '%s'", c.Comment)
}
characteristics := ""
for _, characteristic := range c.Characteristics {
characteristics += fmt.Sprintf(" %s", characteristic.String())
}
return fmt.Sprintf("CREATE%s PROCEDURE %s (%s) %s%s%s %s",
definer, c.Name, params, c.SecurityContext.String(), comment, characteristics, sql.DebugString(c.Procedure))
// move this logic elsewhere
return "TODO"
//definer := ""
//if c.Procedure.Definer != "" {
// definer = fmt.Sprintf(" DEFINER = %s", c.Procedure.Definer)
//}
//params := ""
//for i, param := range c.Procedure.Params {
// if i > 0 {
// params += ", "
// }
// params += param.String()
//}
//comment := ""
//if c.Procedure.Comment != "" {
// comment = fmt.Sprintf(" COMMENT '%s'", c.Procedure.Comment)
//}
//characteristics := ""
//for _, characteristic := range c.Procedure.Characteristics {
// characteristics += fmt.Sprintf(" %s", characteristic.String())
//}
//return fmt.Sprintf("CREATE%s PROCEDURE %s (%s) %s%s%s %s",
// definer, c.Procedure.Name, params, c.Procedure.SecurityContext.String(), comment, characteristics, sql.DebugString(c.Procedure))
}
Loading

0 comments on commit 93bc19e

Please sign in to comment.