diff --git a/go/test/endtoend/vtgate/queries/subquery/subquery_test.go b/go/test/endtoend/vtgate/queries/subquery/subquery_test.go index abbf5ff15e8..8497db0c482 100644 --- a/go/test/endtoend/vtgate/queries/subquery/subquery_test.go +++ b/go/test/endtoend/vtgate/queries/subquery/subquery_test.go @@ -200,30 +200,10 @@ func TestSubqueries(t *testing.T) { queries := []string{ `INSERT INTO user (id, name) VALUES (1, 'Alice'), (2, 'Bob'), (3, 'Charlie'), (4, 'David'), (5, 'Eve'), (6, 'Frank'), (7, 'Grace'), (8, 'Hannah'), (9, 'Ivy'), (10, 'Jack')`, `INSERT INTO user_extra (user_id, extra_info) VALUES (1, 'info1'), (1, 'info2'), (2, 'info1'), (3, 'info1'), (3, 'info2'), (4, 'info1'), (5, 'info1'), (6, 'info1'), (7, 'info1'), (8, 'info1')`, - `SELECT (SELECT COUNT(*) FROM user_extra) AS order_count, id FROM user WHERE id = (SELECT COUNT(*) FROM user_extra)`, - `SELECT id, (SELECT COUNT(*) FROM user_extra) AS order_count FROM user ORDER BY (SELECT COUNT(*) FROM user_extra)`, - `SELECT id FROM user WHERE id = (SELECT COUNT(*) FROM user_extra) ORDER BY (SELECT COUNT(*) FROM user_extra)`, - `SELECT (SELECT COUNT(*) FROM user_extra WHERE user.id = user_extra.user_id) AS extra_count, id, name FROM user WHERE (SELECT COUNT(*) FROM user_extra WHERE user.id = user_extra.user_id) > 0`, - `SELECT id, name, (SELECT COUNT(*) FROM user_extra WHERE user.id = user_extra.user_id) AS extra_count FROM user ORDER BY (SELECT COUNT(*) FROM user_extra WHERE user.id = user_extra.user_id)`, - `SELECT id, name FROM user WHERE (SELECT COUNT(*) FROM user_extra WHERE user.id = user_extra.user_id) > 0 ORDER BY (SELECT COUNT(*) FROM user_extra WHERE user.id = user_extra.user_id)`, - `SELECT id, name, (SELECT COUNT(*) FROM user_extra WHERE user.id = user_extra.user_id) AS extra_count FROM user GROUP BY id, name HAVING COUNT(*) > (SELECT COUNT(*) FROM user_extra WHERE user.id = user_extra.user_id)`, - `SELECT id, name, COUNT(*) FROM user WHERE (SELECT COUNT(*) FROM user_extra WHERE user.id = user_extra.user_id) > 0 GROUP BY id, name HAVING COUNT(*) > (SELECT COUNT(*) FROM user_extra WHERE user.id = user_extra.user_id)`, - `SELECT id, round(MAX(id + (SELECT COUNT(*) FROM user_extra where user_id = 42))) as r FROM user WHERE id = 42 GROUP BY id ORDER BY r`, - `SELECT id, name, (SELECT COUNT(*) FROM user_extra WHERE user.id = user_extra.user_id) * 2 AS double_extra_count FROM user`, - `SELECT id, name FROM user WHERE id IN (SELECT user_id FROM user_extra WHERE LENGTH(extra_info) > 4)`, - `SELECT id, COUNT(*) FROM user GROUP BY id HAVING COUNT(*) > (SELECT COUNT(*) FROM user_extra WHERE user_extra.user_id = user.id) + 1`, - `SELECT id, name FROM user ORDER BY (SELECT COUNT(*) FROM user_extra WHERE user.id = user_extra.user_id) * id`, - `SELECT id, name, (SELECT COUNT(*) FROM user_extra WHERE user.id = user_extra.user_id) + id AS extra_count_plus_id FROM user`, - `SELECT id, name FROM user WHERE id IN (SELECT user_id FROM user_extra WHERE extra_info = 'info1') OR id IN (SELECT user_id FROM user_extra WHERE extra_info = 'info2')`, - `SELECT id, name, (SELECT COUNT(*) FROM user_extra) AS total_extra_count, SUM(id) AS sum_ids FROM user GROUP BY id, name ORDER BY (SELECT COUNT(*) FROM user_extra)`, - // `SELECT id, name, (SELECT SUM(LENGTH(extra_info)) FROM user_extra) AS total_length_extra_info, AVG(id) AS avg_ids FROM user GROUP BY id, name HAVING (SELECT SUM(LENGTH(extra_info)) FROM user_extra) > 10`, - `SELECT id, name, (SELECT AVG(LENGTH(extra_info)) FROM user_extra) AS avg_length_extra_info, MAX(id) AS max_id FROM user WHERE id IN (SELECT user_id FROM user_extra) GROUP BY id, name`, - `SELECT id, name, (SELECT MAX(LENGTH(extra_info)) FROM user_extra) AS max_length_extra_info, MIN(id) AS min_id FROM user GROUP BY id, name ORDER BY (SELECT MAX(LENGTH(extra_info)) FROM user_extra)`, - `SELECT id, name, (SELECT MIN(LENGTH(extra_info)) FROM user_extra) AS min_length_extra_info, SUM(id) AS sum_ids FROM user GROUP BY id, name HAVING (SELECT MIN(LENGTH(extra_info)) FROM user_extra) < 5`, - `SELECT id, name, (SELECT COUNT(*) FROM user_extra) AS total_extra_count, AVG(id) AS avg_ids FROM user WHERE id > (SELECT COUNT(*) FROM user_extra) GROUP BY id, name`, - // `SELECT id, name, (SELECT SUM(LENGTH(extra_info)) FROM user_extra) AS total_length_extra_info, COUNT(id) AS count_ids FROM user GROUP BY id, name ORDER BY (SELECT SUM(LENGTH(extra_info)) FROM user_extra)`, - // `SELECT id, name, (SELECT COUNT(*) FROM user_extra) AS total_extra_count, (SELECT SUM(LENGTH(extra_info)) FROM user_extra) AS total_length_extra_info, (SELECT AVG(LENGTH(extra_info)) FROM user_extra) AS avg_length_extra_info, (SELECT MAX(LENGTH(extra_info)) FROM user_extra) AS max_length_extra_info, (SELECT MIN(LENGTH(extra_info)) FROM user_extra) AS min_length_extra_info, SUM(id) AS sum_ids FROM user GROUP BY id, name HAVING (SELECT AVG(LENGTH(extra_info)) FROM user_extra) > 2`, - `SELECT id, name, (SELECT COUNT(*) FROM user_extra) + id AS total_extra_count_plus_id, AVG(id) AS avg_ids FROM user WHERE id < (SELECT MAX(user_id) FROM user_extra) GROUP BY id, name`, + `SELECT SUM(LENGTH(extra_info)) FROM user_extra`, + `SELECT (SELECT SUM(LENGTH(extra_info)) FROM user_extra) AS total_length_extra_info FROM user`, + `SELECT count(*) FROM user_extra`, + `SELECT (SELECT count(*) FROM user_extra) AS total_length_extra_info FROM user`, } for idx, query := range queries { diff --git a/go/vt/sqlparser/ast_format.go b/go/vt/sqlparser/ast_format.go index 8d8a01a6eb2..f46f3aa4c0d 100644 --- a/go/vt/sqlparser/ast_format.go +++ b/go/vt/sqlparser/ast_format.go @@ -1359,7 +1359,35 @@ func (node *Literal) Format(buf *TrackedBuffer) { // Format formats the node. func (node *Argument) Format(buf *TrackedBuffer) { - buf.WriteArg(":", node.Name) + // We need to make sure that any value used still returns + // the right type when interpolated. For example, if we have a + // decimal type with 0 scale, we don't want it to be interpreted + // as an integer after interpolation as that would the default + // literal interpretation in MySQL. + switch { + case node.Type == sqltypes.Unknown: + // Ensure we handle unknown first as we don't want to treat + // the type as a bitmask for the further tests. + buf.WriteArg(":", node.Name) + case sqltypes.IsDecimal(node.Type): + buf.astPrintf(node, "CAST(:%#s AS DECIMAL(%d, %d))", node.Name, node.Size, node.Scale) + case sqltypes.IsUnsigned(node.Type): + buf.astPrintf(node, "CAST(:%#s AS UNSIGNED)", node.Name) + case node.Type == sqltypes.Float64: + buf.astPrintf(node, "CAST(:%#s AS DOUBLE)", node.Name) + case node.Type == sqltypes.Float32: + buf.astPrintf(node, "CAST(:%#s AS FLOAT)", node.Name) + case sqltypes.IsDate(node.Type): + buf.astPrintf(node, "date :%#s", node.Name) + case node.Type == sqltypes.Time: + buf.astPrintf(node, "time :%#s", node.Name) + case node.Type == sqltypes.Timestamp, node.Type == sqltypes.Datetime: + buf.astPrintf(node, "timestamp :%#s", node.Name) + default: + // Nothing special to do, the default literal will be correct. + buf.WriteArg(":", node.Name) + } + if node.Type >= 0 { // For bind variables that are statically typed, emit their type as an adjacent comment. // This comment will be ignored by older versions of Vitess (and by MySQL) but will provide diff --git a/go/vt/sqlparser/ast_format_fast.go b/go/vt/sqlparser/ast_format_fast.go index 4be0bfd75f7..5e5902d3236 100644 --- a/go/vt/sqlparser/ast_format_fast.go +++ b/go/vt/sqlparser/ast_format_fast.go @@ -1778,7 +1778,46 @@ func (node *Literal) FormatFast(buf *TrackedBuffer) { // FormatFast formats the node. func (node *Argument) FormatFast(buf *TrackedBuffer) { - buf.WriteArg(":", node.Name) + // We need to make sure that any value used still returns + // the right type when interpolated. For example, if we have a + // decimal type with 0 scale, we don't want it to be interpreted + // as an integer after interpolation as that would the default + // literal interpretation in MySQL. + switch { + case sqltypes.IsDecimal(node.Type): + buf.WriteString("CAST(:") + buf.WriteString(node.Name) + buf.WriteString(" AS DECIMAL(") + buf.WriteString(fmt.Sprintf("%d", node.Size)) + buf.WriteString(", ") + buf.WriteString(fmt.Sprintf("%d", node.Scale)) + buf.WriteString("))") + case sqltypes.IsUnsigned(node.Type): + buf.WriteString("CAST(:") + buf.WriteString(node.Name) + buf.WriteString(" AS UNSIGNED)") + case node.Type == sqltypes.Float64: + buf.WriteString("CAST(:") + buf.WriteString(node.Name) + buf.WriteString(" AS DOUBLE)") + case node.Type == sqltypes.Float32: + buf.WriteString("CAST(:") + buf.WriteString(node.Name) + buf.WriteString(" AS FLOAT)") + case sqltypes.IsDate(node.Type): + buf.WriteString("date :") + buf.WriteString(node.Name) + case node.Type == sqltypes.Time: + buf.WriteString("time :") + buf.WriteString(node.Name) + case node.Type == sqltypes.Timestamp, node.Type == sqltypes.Datetime: + buf.WriteString("timestamp :") + buf.WriteString(node.Name) + default: + // Nothing special to do, the default literal will be correct. + buf.WriteArg(":", node.Name) + } + if node.Type >= 0 { // For bind variables that are statically typed, emit their type as an adjacent comment. // This comment will be ignored by older versions of Vitess (and by MySQL) but will provide diff --git a/go/vt/sqlparser/ast_funcs.go b/go/vt/sqlparser/ast_funcs.go index f4f1e3a5455..e42344ff63c 100644 --- a/go/vt/sqlparser/ast_funcs.go +++ b/go/vt/sqlparser/ast_funcs.go @@ -560,8 +560,8 @@ func parseBindVariable(yylex yyLexer, bvar string) *Argument { return NewArgument(bvar) } -func NewTypedArgument(in string, t sqltypes.Type) *Argument { - return &Argument{Name: in, Type: t} +func NewTypedArgument(in string, t sqltypes.Type, size, scale int32) *Argument { + return &Argument{Name: in, Type: t, Size: size, Scale: scale} } func NewTypedArgumentFromLiteral(in string, lit *Literal) (*Argument, error) { diff --git a/go/vt/vtgate/planbuilder/operator_transformers.go b/go/vt/vtgate/planbuilder/operator_transformers.go index bec5cd28bb5..f7dd7ee83d8 100644 --- a/go/vt/vtgate/planbuilder/operator_transformers.go +++ b/go/vt/vtgate/planbuilder/operator_transformers.go @@ -312,7 +312,7 @@ func transformAggregator(ctx *plancontext.PlanningContext, op *operators.Aggrega aggrParam.Original = aggr.Original aggrParam.OrigOpcode = aggr.OriginalOpCode aggrParam.WCol = aggr.WSOffset - aggrParam.Type = aggr.GetTypeCollation(ctx) + aggrParam.Type = aggr.GetParameterType(ctx) aggregates = append(aggregates, aggrParam) } diff --git a/go/vt/vtgate/planbuilder/operators/info_schema_planning.go b/go/vt/vtgate/planbuilder/operators/info_schema_planning.go index f8dc9b9d281..760be7473d0 100644 --- a/go/vt/vtgate/planbuilder/operators/info_schema_planning.go +++ b/go/vt/vtgate/planbuilder/operators/info_schema_planning.go @@ -147,7 +147,7 @@ func extractInfoSchemaRoutingPredicate(ctx *plancontext.PlanningContext, in sqlp } else { name = ctx.GetReservedArgumentFor(col) } - cmp.Right = sqlparser.NewTypedArgument(name, sqltypes.VarChar) + cmp.Right = sqlparser.NewTypedArgument(name, sqltypes.VarChar, 0, 0) return isSchemaName, name, rhs } diff --git a/go/vt/vtgate/planbuilder/operators/queryprojection.go b/go/vt/vtgate/planbuilder/operators/queryprojection.go index 5729dbd0c2e..0f136076ef8 100644 --- a/go/vt/vtgate/planbuilder/operators/queryprojection.go +++ b/go/vt/vtgate/planbuilder/operators/queryprojection.go @@ -95,17 +95,12 @@ func (aggr Aggr) NeedsWeightString(ctx *plancontext.PlanningContext) bool { return aggr.OpCode.NeedsComparableValues() && ctx.SemTable.NeedsWeightString(aggr.Func.GetArg()) } -func (aggr Aggr) GetTypeCollation(ctx *plancontext.PlanningContext) evalengine.Type { +func (aggr Aggr) GetParameterType(ctx *plancontext.PlanningContext) evalengine.Type { if aggr.Func == nil { return evalengine.Type{} } - switch aggr.OpCode { - case opcode.AggregateMin, opcode.AggregateMax, opcode.AggregateSumDistinct, opcode.AggregateCountDistinct: - typ, _ := ctx.TypeForExpr(aggr.Func.GetArg()) - return typ - - } - return evalengine.Type{} + typ, _ := ctx.TypeForExpr(aggr.Func.GetArg()) + return typ } // NewGroupBy creates a new group by from the given fields. diff --git a/go/vt/vtgate/planbuilder/operators/sharded_routing.go b/go/vt/vtgate/planbuilder/operators/sharded_routing.go index 1319b76f040..e7b9040b6f0 100644 --- a/go/vt/vtgate/planbuilder/operators/sharded_routing.go +++ b/go/vt/vtgate/planbuilder/operators/sharded_routing.go @@ -261,7 +261,7 @@ func (tr *ShardedRouting) planIsExpr(ctx *plancontext.PlanningContext, node *sql return false } vdValue := &sqlparser.NullVal{} - val := makeEvalEngineExpr(ctx, vdValue) + val := ctx.MakeEvalEngineExpr(vdValue) if val == nil { return false } @@ -285,7 +285,7 @@ func (tr *ShardedRouting) planInOp(ctx *plancontext.PlanningContext, cmp *sqlpar return tr.planEqualOp(ctx, &sqlparser.ComparisonExpr{Left: left, Right: valTuple[0], Operator: sqlparser.EqualOp}) } - value := makeEvalEngineExpr(ctx, vdValue) + value := ctx.MakeEvalEngineExpr(vdValue) if value == nil { return false } @@ -309,7 +309,7 @@ func (tr *ShardedRouting) planLikeOp(ctx *plancontext.PlanningContext, node *sql } vdValue := node.Right - val := makeEvalEngineExpr(ctx, vdValue) + val := ctx.MakeEvalEngineExpr(vdValue) if val == nil { return false } @@ -496,7 +496,7 @@ func (tr *ShardedRouting) planEqualOp(ctx *plancontext.PlanningContext, node *sq } vdValue = node.Left } - val := makeEvalEngineExpr(ctx, vdValue) + val := ctx.MakeEvalEngineExpr(vdValue) if val == nil { return false } @@ -538,7 +538,7 @@ func (tr *ShardedRouting) planCompositeInOpRecursive( return false } } - newPlanValues := makeEvalEngineExpr(ctx, rightVals) + newPlanValues := ctx.MakeEvalEngineExpr(rightVals) if newPlanValues == nil { return false } @@ -681,19 +681,3 @@ func tryMergeJoinShardedRouting( } return nil } - -// makeEvalEngineExpr transforms the given sqlparser.Expr into an evalengine expression -func makeEvalEngineExpr(ctx *plancontext.PlanningContext, n sqlparser.Expr) evalengine.Expr { - for _, expr := range ctx.SemTable.GetExprAndEqualities(n) { - ee, _ := evalengine.Translate(expr, &evalengine.Config{ - Collation: ctx.SemTable.Collation, - ResolveType: ctx.TypeForExpr, - Environment: ctx.VSchema.Environment(), - }) - if ee != nil { - return ee - } - } - - return nil -} diff --git a/go/vt/vtgate/planbuilder/operators/subquery_planning.go b/go/vt/vtgate/planbuilder/operators/subquery_planning.go index cdc0b8b191a..ec2ce8d5325 100644 --- a/go/vt/vtgate/planbuilder/operators/subquery_planning.go +++ b/go/vt/vtgate/planbuilder/operators/subquery_planning.go @@ -377,6 +377,12 @@ func rewriteColNameToArgument(ctx *plancontext.PlanningContext, in sqlparser.Exp } for _, sq2 := range subqueries { + sel := sqlparser.GetFirstSelect(sq2.originalSubquery.Select) + alias, ok := sel.SelectExprs[0].(*sqlparser.AliasedExpr) + if !ok { + panic(vterrors.VT09015()) + } + typ, _ := ctx.TypeForExpr(alias.Expr) if s == sq2.ArgName { switch { case sq1.FilterType.NeedsListArg(): @@ -388,7 +394,7 @@ func rewriteColNameToArgument(ctx *plancontext.PlanningContext, in sqlparser.Exp } return sqlparser.NewArgument(sq1.HasValuesName) default: - return sqlparser.NewArgument(s) + return sqlparser.NewTypedArgument(s, typ.Type(), typ.Size(), typ.Scale()) } } } diff --git a/go/vt/vtgate/planbuilder/plancontext/planning_context.go b/go/vt/vtgate/planbuilder/plancontext/planning_context.go index 90a6bdac6f8..bf786519e51 100644 --- a/go/vt/vtgate/planbuilder/plancontext/planning_context.go +++ b/go/vt/vtgate/planbuilder/plancontext/planning_context.go @@ -207,11 +207,39 @@ func (ctx *PlanningContext) RewriteDerivedTableExpression(expr sqlparser.Expr, t return modifiedExpr } +// MakeEvalEngineExpr transforms the given sqlparser.Expr into an evalengine expression +func (ctx *PlanningContext) MakeEvalEngineExpr(n sqlparser.Expr) evalengine.Expr { + for _, expr := range ctx.SemTable.GetExprAndEqualities(n) { + ee, _ := evalengine.Translate(expr, &evalengine.Config{ + Collation: ctx.SemTable.Collation, + ResolveType: ctx.TypeForExpr, + Environment: ctx.VSchema.Environment(), + }) + if ee != nil { + return ee + } + } + + return nil +} + // TypeForExpr returns the type of the given expression, with nullable set if the expression is from an outer table. func (ctx *PlanningContext) TypeForExpr(e sqlparser.Expr) (evalengine.Type, bool) { t, found := ctx.SemTable.TypeForExpr(e) if !found { - return t, found + // Try if we can compile the expression to retrieve the type. + // This doesn't work atm for aggregate expressions since the + // evalengine doesn't know about this. + expr := ctx.MakeEvalEngineExpr(e) + if expr == nil { + return t, found + } + env := evalengine.EmptyExpressionEnv(ctx.VSchema.Environment()) + typ, err := env.TypeOf(expr) + if err != nil { + return t, found + } + return typ, true } deps := ctx.SemTable.RecursiveDeps(e) // If the expression is from an outer table, it should be nullable diff --git a/go/vt/vtgate/planbuilder/testdata/onecase.json b/go/vt/vtgate/planbuilder/testdata/onecase.json index da7543f706a..1519ee89cd5 100644 --- a/go/vt/vtgate/planbuilder/testdata/onecase.json +++ b/go/vt/vtgate/planbuilder/testdata/onecase.json @@ -1,7 +1,7 @@ [ { "comment": "Add your test case here for debugging and run go test -run=One.", - "query": "", + "query": "SELECT (SELECT SUM(LENGTH(extra_info)) FROM user_extra) AS total_length_extra_info FROM user", "plan": { } diff --git a/go/vt/vtgate/semantics/semantic_state.go b/go/vt/vtgate/semantics/semantic_state.go index 1dcaaf87061..72473755480 100644 --- a/go/vt/vtgate/semantics/semantic_state.go +++ b/go/vt/vtgate/semantics/semantic_state.go @@ -671,7 +671,7 @@ func (st *SemTable) TypeForExpr(e sqlparser.Expr) (evalengine.Type, bool) { return evalengine.NewTypeEx(sqltypes.VarBinary, collations.CollationBinaryID, wt.Nullable(), 0, 0, nil), true } - return evalengine.Type{}, false + return evalengine.NewType(sqltypes.Unknown, collations.Unknown), false } // NeedsWeightString returns true if the given expression needs weight_string to do safe comparisons diff --git a/go/vt/vtgate/semantics/typer.go b/go/vt/vtgate/semantics/typer.go index b56c836a740..bb9d342695e 100644 --- a/go/vt/vtgate/semantics/typer.go +++ b/go/vt/vtgate/semantics/typer.go @@ -18,6 +18,7 @@ package semantics import ( "vitess.io/vitess/go/mysql/collations" + "vitess.io/vitess/go/sqltypes" "vitess.io/vitess/go/vt/sqlparser" "vitess.io/vitess/go/vt/vtgate/engine/opcode" "vitess.io/vitess/go/vt/vtgate/evalengine" @@ -60,7 +61,10 @@ func (t *typer) up(cursor *sqlparser.Cursor) error { inputType = tt } } - t.m[node] = code.ResolveType(inputType, t.collationEnv) + typ := code.ResolveType(inputType, t.collationEnv) + if typ.Type() != sqltypes.Unknown { + t.m[node] = typ + } } return nil }