Skip to content

Commit

Permalink
wip: Try to type more expressions for casting
Browse files Browse the repository at this point in the history
We want arguments to be better typed, but it also means we need more
typing information for functional arguments, aggregates etc.

Signed-off-by: Dirkjan Bussink <[email protected]>
  • Loading branch information
dbussink committed Jun 19, 2024
1 parent 1cc3e14 commit b8398fe
Show file tree
Hide file tree
Showing 13 changed files with 128 additions and 64 deletions.
28 changes: 4 additions & 24 deletions go/test/endtoend/vtgate/queries/subquery/subquery_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -200,30 +200,10 @@ func TestSubqueries(t *testing.T) {
queries := []string{
`INSERT INTO user (id, name) VALUES (1, 'Alice'), (2, 'Bob'), (3, 'Charlie'), (4, 'David'), (5, 'Eve'), (6, 'Frank'), (7, 'Grace'), (8, 'Hannah'), (9, 'Ivy'), (10, 'Jack')`,
`INSERT INTO user_extra (user_id, extra_info) VALUES (1, 'info1'), (1, 'info2'), (2, 'info1'), (3, 'info1'), (3, 'info2'), (4, 'info1'), (5, 'info1'), (6, 'info1'), (7, 'info1'), (8, 'info1')`,
`SELECT (SELECT COUNT(*) FROM user_extra) AS order_count, id FROM user WHERE id = (SELECT COUNT(*) FROM user_extra)`,
`SELECT id, (SELECT COUNT(*) FROM user_extra) AS order_count FROM user ORDER BY (SELECT COUNT(*) FROM user_extra)`,
`SELECT id FROM user WHERE id = (SELECT COUNT(*) FROM user_extra) ORDER BY (SELECT COUNT(*) FROM user_extra)`,
`SELECT (SELECT COUNT(*) FROM user_extra WHERE user.id = user_extra.user_id) AS extra_count, id, name FROM user WHERE (SELECT COUNT(*) FROM user_extra WHERE user.id = user_extra.user_id) > 0`,
`SELECT id, name, (SELECT COUNT(*) FROM user_extra WHERE user.id = user_extra.user_id) AS extra_count FROM user ORDER BY (SELECT COUNT(*) FROM user_extra WHERE user.id = user_extra.user_id)`,
`SELECT id, name FROM user WHERE (SELECT COUNT(*) FROM user_extra WHERE user.id = user_extra.user_id) > 0 ORDER BY (SELECT COUNT(*) FROM user_extra WHERE user.id = user_extra.user_id)`,
`SELECT id, name, (SELECT COUNT(*) FROM user_extra WHERE user.id = user_extra.user_id) AS extra_count FROM user GROUP BY id, name HAVING COUNT(*) > (SELECT COUNT(*) FROM user_extra WHERE user.id = user_extra.user_id)`,
`SELECT id, name, COUNT(*) FROM user WHERE (SELECT COUNT(*) FROM user_extra WHERE user.id = user_extra.user_id) > 0 GROUP BY id, name HAVING COUNT(*) > (SELECT COUNT(*) FROM user_extra WHERE user.id = user_extra.user_id)`,
`SELECT id, round(MAX(id + (SELECT COUNT(*) FROM user_extra where user_id = 42))) as r FROM user WHERE id = 42 GROUP BY id ORDER BY r`,
`SELECT id, name, (SELECT COUNT(*) FROM user_extra WHERE user.id = user_extra.user_id) * 2 AS double_extra_count FROM user`,
`SELECT id, name FROM user WHERE id IN (SELECT user_id FROM user_extra WHERE LENGTH(extra_info) > 4)`,
`SELECT id, COUNT(*) FROM user GROUP BY id HAVING COUNT(*) > (SELECT COUNT(*) FROM user_extra WHERE user_extra.user_id = user.id) + 1`,
`SELECT id, name FROM user ORDER BY (SELECT COUNT(*) FROM user_extra WHERE user.id = user_extra.user_id) * id`,
`SELECT id, name, (SELECT COUNT(*) FROM user_extra WHERE user.id = user_extra.user_id) + id AS extra_count_plus_id FROM user`,
`SELECT id, name FROM user WHERE id IN (SELECT user_id FROM user_extra WHERE extra_info = 'info1') OR id IN (SELECT user_id FROM user_extra WHERE extra_info = 'info2')`,
`SELECT id, name, (SELECT COUNT(*) FROM user_extra) AS total_extra_count, SUM(id) AS sum_ids FROM user GROUP BY id, name ORDER BY (SELECT COUNT(*) FROM user_extra)`,
// `SELECT id, name, (SELECT SUM(LENGTH(extra_info)) FROM user_extra) AS total_length_extra_info, AVG(id) AS avg_ids FROM user GROUP BY id, name HAVING (SELECT SUM(LENGTH(extra_info)) FROM user_extra) > 10`,
`SELECT id, name, (SELECT AVG(LENGTH(extra_info)) FROM user_extra) AS avg_length_extra_info, MAX(id) AS max_id FROM user WHERE id IN (SELECT user_id FROM user_extra) GROUP BY id, name`,
`SELECT id, name, (SELECT MAX(LENGTH(extra_info)) FROM user_extra) AS max_length_extra_info, MIN(id) AS min_id FROM user GROUP BY id, name ORDER BY (SELECT MAX(LENGTH(extra_info)) FROM user_extra)`,
`SELECT id, name, (SELECT MIN(LENGTH(extra_info)) FROM user_extra) AS min_length_extra_info, SUM(id) AS sum_ids FROM user GROUP BY id, name HAVING (SELECT MIN(LENGTH(extra_info)) FROM user_extra) < 5`,
`SELECT id, name, (SELECT COUNT(*) FROM user_extra) AS total_extra_count, AVG(id) AS avg_ids FROM user WHERE id > (SELECT COUNT(*) FROM user_extra) GROUP BY id, name`,
// `SELECT id, name, (SELECT SUM(LENGTH(extra_info)) FROM user_extra) AS total_length_extra_info, COUNT(id) AS count_ids FROM user GROUP BY id, name ORDER BY (SELECT SUM(LENGTH(extra_info)) FROM user_extra)`,
// `SELECT id, name, (SELECT COUNT(*) FROM user_extra) AS total_extra_count, (SELECT SUM(LENGTH(extra_info)) FROM user_extra) AS total_length_extra_info, (SELECT AVG(LENGTH(extra_info)) FROM user_extra) AS avg_length_extra_info, (SELECT MAX(LENGTH(extra_info)) FROM user_extra) AS max_length_extra_info, (SELECT MIN(LENGTH(extra_info)) FROM user_extra) AS min_length_extra_info, SUM(id) AS sum_ids FROM user GROUP BY id, name HAVING (SELECT AVG(LENGTH(extra_info)) FROM user_extra) > 2`,
`SELECT id, name, (SELECT COUNT(*) FROM user_extra) + id AS total_extra_count_plus_id, AVG(id) AS avg_ids FROM user WHERE id < (SELECT MAX(user_id) FROM user_extra) GROUP BY id, name`,
`SELECT SUM(LENGTH(extra_info)) FROM user_extra`,
`SELECT (SELECT SUM(LENGTH(extra_info)) FROM user_extra) AS total_length_extra_info FROM user`,
`SELECT count(*) FROM user_extra`,
`SELECT (SELECT count(*) FROM user_extra) AS total_length_extra_info FROM user`,
}

for idx, query := range queries {
Expand Down
30 changes: 29 additions & 1 deletion go/vt/sqlparser/ast_format.go
Original file line number Diff line number Diff line change
Expand Up @@ -1359,7 +1359,35 @@ func (node *Literal) Format(buf *TrackedBuffer) {

// Format formats the node.
func (node *Argument) Format(buf *TrackedBuffer) {
buf.WriteArg(":", node.Name)
// We need to make sure that any value used still returns
// the right type when interpolated. For example, if we have a
// decimal type with 0 scale, we don't want it to be interpreted
// as an integer after interpolation as that would the default
// literal interpretation in MySQL.
switch {
case node.Type == sqltypes.Unknown:
// Ensure we handle unknown first as we don't want to treat
// the type as a bitmask for the further tests.
buf.WriteArg(":", node.Name)
case sqltypes.IsDecimal(node.Type):
buf.astPrintf(node, "CAST(:%#s AS DECIMAL(%d, %d))", node.Name, node.Size, node.Scale)
case sqltypes.IsUnsigned(node.Type):
buf.astPrintf(node, "CAST(:%#s AS UNSIGNED)", node.Name)
case node.Type == sqltypes.Float64:
buf.astPrintf(node, "CAST(:%#s AS DOUBLE)", node.Name)
case node.Type == sqltypes.Float32:
buf.astPrintf(node, "CAST(:%#s AS FLOAT)", node.Name)
case sqltypes.IsDate(node.Type):
buf.astPrintf(node, "date :%#s", node.Name)
case node.Type == sqltypes.Time:
buf.astPrintf(node, "time :%#s", node.Name)
case node.Type == sqltypes.Timestamp, node.Type == sqltypes.Datetime:
buf.astPrintf(node, "timestamp :%#s", node.Name)
default:
// Nothing special to do, the default literal will be correct.
buf.WriteArg(":", node.Name)
}

if node.Type >= 0 {
// For bind variables that are statically typed, emit their type as an adjacent comment.
// This comment will be ignored by older versions of Vitess (and by MySQL) but will provide
Expand Down
41 changes: 40 additions & 1 deletion go/vt/sqlparser/ast_format_fast.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

4 changes: 2 additions & 2 deletions go/vt/sqlparser/ast_funcs.go
Original file line number Diff line number Diff line change
Expand Up @@ -560,8 +560,8 @@ func parseBindVariable(yylex yyLexer, bvar string) *Argument {
return NewArgument(bvar)
}

func NewTypedArgument(in string, t sqltypes.Type) *Argument {
return &Argument{Name: in, Type: t}
func NewTypedArgument(in string, t sqltypes.Type, size, scale int32) *Argument {
return &Argument{Name: in, Type: t, Size: size, Scale: scale}
}

func NewTypedArgumentFromLiteral(in string, lit *Literal) (*Argument, error) {
Expand Down
2 changes: 1 addition & 1 deletion go/vt/vtgate/planbuilder/operator_transformers.go
Original file line number Diff line number Diff line change
Expand Up @@ -312,7 +312,7 @@ func transformAggregator(ctx *plancontext.PlanningContext, op *operators.Aggrega
aggrParam.Original = aggr.Original
aggrParam.OrigOpcode = aggr.OriginalOpCode
aggrParam.WCol = aggr.WSOffset
aggrParam.Type = aggr.GetTypeCollation(ctx)
aggrParam.Type = aggr.GetParameterType(ctx)
aggregates = append(aggregates, aggrParam)
}

Expand Down
2 changes: 1 addition & 1 deletion go/vt/vtgate/planbuilder/operators/info_schema_planning.go
Original file line number Diff line number Diff line change
Expand Up @@ -147,7 +147,7 @@ func extractInfoSchemaRoutingPredicate(ctx *plancontext.PlanningContext, in sqlp
} else {
name = ctx.GetReservedArgumentFor(col)
}
cmp.Right = sqlparser.NewTypedArgument(name, sqltypes.VarChar)
cmp.Right = sqlparser.NewTypedArgument(name, sqltypes.VarChar, 0, 0)
return isSchemaName, name, rhs
}

Expand Down
11 changes: 3 additions & 8 deletions go/vt/vtgate/planbuilder/operators/queryprojection.go
Original file line number Diff line number Diff line change
Expand Up @@ -95,17 +95,12 @@ func (aggr Aggr) NeedsWeightString(ctx *plancontext.PlanningContext) bool {
return aggr.OpCode.NeedsComparableValues() && ctx.SemTable.NeedsWeightString(aggr.Func.GetArg())
}

func (aggr Aggr) GetTypeCollation(ctx *plancontext.PlanningContext) evalengine.Type {
func (aggr Aggr) GetParameterType(ctx *plancontext.PlanningContext) evalengine.Type {
if aggr.Func == nil {
return evalengine.Type{}
}
switch aggr.OpCode {
case opcode.AggregateMin, opcode.AggregateMax, opcode.AggregateSumDistinct, opcode.AggregateCountDistinct:
typ, _ := ctx.TypeForExpr(aggr.Func.GetArg())
return typ

}
return evalengine.Type{}
typ, _ := ctx.TypeForExpr(aggr.Func.GetArg())
return typ
}

// NewGroupBy creates a new group by from the given fields.
Expand Down
26 changes: 5 additions & 21 deletions go/vt/vtgate/planbuilder/operators/sharded_routing.go
Original file line number Diff line number Diff line change
Expand Up @@ -261,7 +261,7 @@ func (tr *ShardedRouting) planIsExpr(ctx *plancontext.PlanningContext, node *sql
return false
}
vdValue := &sqlparser.NullVal{}
val := makeEvalEngineExpr(ctx, vdValue)
val := ctx.MakeEvalEngineExpr(vdValue)
if val == nil {
return false
}
Expand All @@ -285,7 +285,7 @@ func (tr *ShardedRouting) planInOp(ctx *plancontext.PlanningContext, cmp *sqlpar
return tr.planEqualOp(ctx, &sqlparser.ComparisonExpr{Left: left, Right: valTuple[0], Operator: sqlparser.EqualOp})
}

value := makeEvalEngineExpr(ctx, vdValue)
value := ctx.MakeEvalEngineExpr(vdValue)
if value == nil {
return false
}
Expand All @@ -309,7 +309,7 @@ func (tr *ShardedRouting) planLikeOp(ctx *plancontext.PlanningContext, node *sql
}

vdValue := node.Right
val := makeEvalEngineExpr(ctx, vdValue)
val := ctx.MakeEvalEngineExpr(vdValue)
if val == nil {
return false
}
Expand Down Expand Up @@ -496,7 +496,7 @@ func (tr *ShardedRouting) planEqualOp(ctx *plancontext.PlanningContext, node *sq
}
vdValue = node.Left
}
val := makeEvalEngineExpr(ctx, vdValue)
val := ctx.MakeEvalEngineExpr(vdValue)
if val == nil {
return false
}
Expand Down Expand Up @@ -538,7 +538,7 @@ func (tr *ShardedRouting) planCompositeInOpRecursive(
return false
}
}
newPlanValues := makeEvalEngineExpr(ctx, rightVals)
newPlanValues := ctx.MakeEvalEngineExpr(rightVals)
if newPlanValues == nil {
return false
}
Expand Down Expand Up @@ -681,19 +681,3 @@ func tryMergeJoinShardedRouting(
}
return nil
}

// makeEvalEngineExpr transforms the given sqlparser.Expr into an evalengine expression
func makeEvalEngineExpr(ctx *plancontext.PlanningContext, n sqlparser.Expr) evalengine.Expr {
for _, expr := range ctx.SemTable.GetExprAndEqualities(n) {
ee, _ := evalengine.Translate(expr, &evalengine.Config{
Collation: ctx.SemTable.Collation,
ResolveType: ctx.TypeForExpr,
Environment: ctx.VSchema.Environment(),
})
if ee != nil {
return ee
}
}

return nil
}
8 changes: 7 additions & 1 deletion go/vt/vtgate/planbuilder/operators/subquery_planning.go
Original file line number Diff line number Diff line change
Expand Up @@ -377,6 +377,12 @@ func rewriteColNameToArgument(ctx *plancontext.PlanningContext, in sqlparser.Exp
}

for _, sq2 := range subqueries {
sel := sqlparser.GetFirstSelect(sq2.originalSubquery.Select)
alias, ok := sel.SelectExprs[0].(*sqlparser.AliasedExpr)
if !ok {
panic(vterrors.VT09015())
}
typ, _ := ctx.TypeForExpr(alias.Expr)
if s == sq2.ArgName {
switch {
case sq1.FilterType.NeedsListArg():
Expand All @@ -388,7 +394,7 @@ func rewriteColNameToArgument(ctx *plancontext.PlanningContext, in sqlparser.Exp
}
return sqlparser.NewArgument(sq1.HasValuesName)
default:
return sqlparser.NewArgument(s)
return sqlparser.NewTypedArgument(s, typ.Type(), typ.Size(), typ.Scale())
}
}
}
Expand Down
30 changes: 29 additions & 1 deletion go/vt/vtgate/planbuilder/plancontext/planning_context.go
Original file line number Diff line number Diff line change
Expand Up @@ -207,11 +207,39 @@ func (ctx *PlanningContext) RewriteDerivedTableExpression(expr sqlparser.Expr, t
return modifiedExpr
}

// MakeEvalEngineExpr transforms the given sqlparser.Expr into an evalengine expression
func (ctx *PlanningContext) MakeEvalEngineExpr(n sqlparser.Expr) evalengine.Expr {
for _, expr := range ctx.SemTable.GetExprAndEqualities(n) {
ee, _ := evalengine.Translate(expr, &evalengine.Config{
Collation: ctx.SemTable.Collation,
ResolveType: ctx.TypeForExpr,
Environment: ctx.VSchema.Environment(),
})
if ee != nil {
return ee
}
}

return nil
}

// TypeForExpr returns the type of the given expression, with nullable set if the expression is from an outer table.
func (ctx *PlanningContext) TypeForExpr(e sqlparser.Expr) (evalengine.Type, bool) {
t, found := ctx.SemTable.TypeForExpr(e)
if !found {
return t, found
// Try if we can compile the expression to retrieve the type.
// This doesn't work atm for aggregate expressions since the
// evalengine doesn't know about this.
expr := ctx.MakeEvalEngineExpr(e)
if expr == nil {
return t, found
}
env := evalengine.EmptyExpressionEnv(ctx.VSchema.Environment())
typ, err := env.TypeOf(expr)
if err != nil {
return t, found
}
return typ, true
}
deps := ctx.SemTable.RecursiveDeps(e)
// If the expression is from an outer table, it should be nullable
Expand Down
2 changes: 1 addition & 1 deletion go/vt/vtgate/planbuilder/testdata/onecase.json
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
[
{
"comment": "Add your test case here for debugging and run go test -run=One.",
"query": "",
"query": "SELECT (SELECT SUM(LENGTH(extra_info)) FROM user_extra) AS total_length_extra_info FROM user",
"plan": {

}
Expand Down
2 changes: 1 addition & 1 deletion go/vt/vtgate/semantics/semantic_state.go
Original file line number Diff line number Diff line change
Expand Up @@ -671,7 +671,7 @@ func (st *SemTable) TypeForExpr(e sqlparser.Expr) (evalengine.Type, bool) {
return evalengine.NewTypeEx(sqltypes.VarBinary, collations.CollationBinaryID, wt.Nullable(), 0, 0, nil), true
}

return evalengine.Type{}, false
return evalengine.NewType(sqltypes.Unknown, collations.Unknown), false
}

// NeedsWeightString returns true if the given expression needs weight_string to do safe comparisons
Expand Down
6 changes: 5 additions & 1 deletion go/vt/vtgate/semantics/typer.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ package semantics

import (
"vitess.io/vitess/go/mysql/collations"
"vitess.io/vitess/go/sqltypes"
"vitess.io/vitess/go/vt/sqlparser"
"vitess.io/vitess/go/vt/vtgate/engine/opcode"
"vitess.io/vitess/go/vt/vtgate/evalengine"
Expand Down Expand Up @@ -60,7 +61,10 @@ func (t *typer) up(cursor *sqlparser.Cursor) error {
inputType = tt
}
}
t.m[node] = code.ResolveType(inputType, t.collationEnv)
typ := code.ResolveType(inputType, t.collationEnv)
if typ.Type() != sqltypes.Unknown {
t.m[node] = typ
}
}
return nil
}
Expand Down

0 comments on commit b8398fe

Please sign in to comment.