Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: make the arguments print themselves with type info #16232

Merged
merged 19 commits into from
Jun 27, 2024
Merged
Show file tree
Hide file tree
Changes from 18 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,6 @@ import (
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"

"vitess.io/vitess/go/test/endtoend/cluster"
"vitess.io/vitess/go/test/endtoend/utils"

"vitess.io/vitess/go/mysql"
Expand All @@ -41,11 +40,7 @@ func TestNormalizeAllFields(t *testing.T) {

insertQuery := `insert into t1 values (1, "chars", "variable chars", x'73757265', 0x676F, 0.33, 9.99, 1, "1976-06-08", "small", "b", "{\"key\":\"value\"}", point(1,5), b'011', 0b0101)`
normalizedInsertQuery := `insert into t1 values (:vtg1 /* INT64 */, :vtg2 /* VARCHAR */, :vtg3 /* VARCHAR */, :vtg4 /* HEXVAL */, :vtg5 /* HEXNUM */, :vtg6 /* DECIMAL(3,2) */, :vtg7 /* DECIMAL(3,2) */, :vtg8 /* INT64 */, :vtg9 /* VARCHAR */, :vtg10 /* VARCHAR */, :vtg11 /* VARCHAR */, :vtg12 /* VARCHAR */, point(:vtg13 /* INT64 */, :vtg14 /* INT64 */), :vtg15 /* BITNUM */, :vtg16 /* BITNUM */)`
vtgateVersion, err := cluster.GetMajorVersion("vtgate")
require.NoError(t, err)
if vtgateVersion < 20 {
normalizedInsertQuery = `insert into t1 values (:vtg1 /* INT64 */, :vtg2 /* VARCHAR */, :vtg3 /* VARCHAR */, :vtg4 /* HEXVAL */, :vtg5 /* HEXNUM */, :vtg6 /* DECIMAL */, :vtg7 /* DECIMAL */, :vtg8 /* INT64 */, :vtg9 /* VARCHAR */, :vtg10 /* VARCHAR */, :vtg11 /* VARCHAR */, :vtg12 /* VARCHAR */, point(:vtg13 /* INT64 */, :vtg14 /* INT64 */), :vtg15 /* BITNUM */, :vtg16 /* BITNUM */)`
}

selectQuery := "select * from t1"
utils.Exec(t, conn, insertQuery)
qr := utils.Exec(t, conn, selectQuery)
Expand Down
19 changes: 18 additions & 1 deletion go/test/endtoend/vtgate/queries/subquery/subquery_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,8 @@ import (
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"

"vitess.io/vitess/go/sqltypes"

"vitess.io/vitess/go/test/endtoend/cluster"
"vitess.io/vitess/go/test/endtoend/utils"
)
Expand All @@ -34,7 +36,7 @@ func start(t *testing.T) (utils.MySQLCompare, func()) {
deleteAll := func() {
_, _ = utils.ExecAllowError(t, mcmp.VtConn, "set workload = oltp")

tables := []string{"t1", "t1_id2_idx", "t2", "t2_id4_idx"}
tables := []string{"t1", "t1_id2_idx", "t2", "t2_id4_idx", "user", "user_extra"}
for _, table := range tables {
_, _ = mcmp.ExecAndIgnore("delete from " + table)
}
Expand Down Expand Up @@ -232,3 +234,18 @@ func TestSubqueries(t *testing.T) {
})
}
}

func TestProperTypesOfPullOutValue(t *testing.T) {
utils.SkipIfBinaryIsBelowVersion(t, 21, "vtgate")

query := "select (select sum(id) from user) from user_extra"

mcmp, closer := start(t)
defer closer()

mcmp.Exec("INSERT INTO user (id, name) VALUES (1, 'Alice'), (2, 'Bob'), (3, 'Charlie'), (4, 'David'), (5, 'Eve'), (6, 'Frank'), (7, 'Grace'), (8, 'Hannah'), (9, 'Ivy'), (10, 'Jack')")
mcmp.Exec("INSERT INTO user_extra (user_id, extra_info) VALUES (1, 'info1'), (2, 'info1'), (3, 'info1'), (3, 'info2'), (4, 'info1'), (5, 'info1'), (6, 'info1'), (7, 'info1'), (8, 'info1')")

r := mcmp.Exec(query)
require.True(t, r.Fields[0].Type == sqltypes.Decimal)
}
58 changes: 58 additions & 0 deletions go/vt/sqlparser/ast_format.go
Original file line number Diff line number Diff line change
Expand Up @@ -1361,6 +1361,64 @@ func (node *Literal) Format(buf *TrackedBuffer) {

// Format formats the node.
func (node *Argument) Format(buf *TrackedBuffer) {
// We need to make sure that any value used still returns
// the right type when interpolated. For example, if we have a
// decimal type with 0 scale, we don't want it to be interpreted
// as an integer after interpolation as that would the default
// literal interpretation in MySQL.
switch {
case node.Type == sqltypes.Unknown:
// Ensure we handle unknown first as we don't want to treat
// the type as a bitmask for the further tests.
// do nothing, the default literal will be correct.
case sqltypes.IsDecimal(node.Type) && node.Scale == 0:
buf.WriteString("CAST(")
buf.WriteArg(":", node.Name)
buf.astPrintf(node, " AS DECIMAL(%d, %d))", node.Size, node.Scale)
return
case sqltypes.IsUnsigned(node.Type):
buf.WriteString("CAST(")
buf.WriteArg(":", node.Name)
buf.WriteString(" AS UNSIGNED)")
return
case node.Type == sqltypes.Float64:
buf.WriteString("CAST(")
buf.WriteArg(":", node.Name)
buf.WriteString(" AS DOUBLE)")
return
case node.Type == sqltypes.Float32:
buf.WriteString("CAST(")
buf.WriteArg(":", node.Name)
buf.WriteString(" AS FLOAT)")
return
case node.Type == sqltypes.Timestamp, node.Type == sqltypes.Datetime:
buf.WriteString("CAST(")
buf.WriteArg(":", node.Name)
buf.WriteString(" AS DATETIME")
if node.Size == 0 {
buf.WriteString(")")
return
}
buf.astPrintf(node, "(%d))", node.Size)
return
case sqltypes.IsDate(node.Type):
buf.WriteString("CAST(")
buf.WriteArg(":", node.Name)
buf.WriteString(" AS DATE")
buf.WriteString(")")
return
case node.Type == sqltypes.Time:
buf.WriteString("CAST(")
buf.WriteArg(":", node.Name)
buf.WriteString(" AS TIME")
if node.Size == 0 {
buf.WriteString(")")
return
}
buf.astPrintf(node, "(%d))", node.Size)
return
}
// Nothing special to do, the default literal will be correct.
buf.WriteArg(":", node.Name)
if node.Type >= 0 {
// For bind variables that are statically typed, emit their type as an adjacent comment.
Expand Down
66 changes: 66 additions & 0 deletions go/vt/sqlparser/ast_format_fast.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

17 changes: 12 additions & 5 deletions go/vt/sqlparser/normalizer_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -82,17 +82,24 @@ func TestNormalize(t *testing.T) {
}, {
// datetime val
in: "select * from t where foobar = timestamp'2012-02-29 12:34:56.123456'",
outstmt: "select * from t where foobar = :foobar /* DATETIME(6) */",
outstmt: "select * from t where foobar = CAST(:foobar AS DATETIME(6))",
outbv: map[string]*querypb.BindVariable{
"foobar": sqltypes.ValueBindVariable(sqltypes.NewDatetime("2012-02-29 12:34:56.123456")),
},
}, {
// time val
in: "select * from t where foobar = time'12:34:56.123456'",
outstmt: "select * from t where foobar = :foobar /* TIME(6) */",
outstmt: "select * from t where foobar = CAST(:foobar AS TIME(6))",
outbv: map[string]*querypb.BindVariable{
"foobar": sqltypes.ValueBindVariable(sqltypes.NewTime("12:34:56.123456")),
},
}, {
// time val
in: "select * from t where foobar = time'12:34:56'",
outstmt: "select * from t where foobar = CAST(:foobar AS TIME)",
outbv: map[string]*querypb.BindVariable{
"foobar": sqltypes.ValueBindVariable(sqltypes.NewTime("12:34:56")),
},
}, {
// multiple vals
in: "select * from t where foo = 1.2 and bar = 2",
Expand Down Expand Up @@ -334,21 +341,21 @@ func TestNormalize(t *testing.T) {
}, {
// DateVal should also be normalized
in: `select date'2022-08-06'`,
outstmt: `select :bv1 /* DATE */ from dual`,
outstmt: `select CAST(:bv1 AS DATE) from dual`,
outbv: map[string]*querypb.BindVariable{
"bv1": sqltypes.ValueBindVariable(sqltypes.MakeTrusted(sqltypes.Date, []byte("2022-08-06"))),
},
}, {
// TimeVal should also be normalized
in: `select time'17:05:12'`,
outstmt: `select :bv1 /* TIME */ from dual`,
outstmt: `select CAST(:bv1 AS TIME) from dual`,
outbv: map[string]*querypb.BindVariable{
"bv1": sqltypes.ValueBindVariable(sqltypes.MakeTrusted(sqltypes.Time, []byte("17:05:12"))),
},
}, {
// TimestampVal should also be normalized
in: `select timestamp'2022-08-06 17:05:12'`,
outstmt: `select :bv1 /* DATETIME */ from dual`,
outstmt: `select CAST(:bv1 AS DATETIME) from dual`,
outbv: map[string]*querypb.BindVariable{
"bv1": sqltypes.ValueBindVariable(sqltypes.MakeTrusted(sqltypes.Datetime, []byte("2022-08-06 17:05:12"))),
},
Expand Down
94 changes: 92 additions & 2 deletions go/vt/sqlparser/parsed_query_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,10 +20,11 @@ import (
"reflect"
"testing"

"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"

"vitess.io/vitess/go/sqltypes"
querypb "vitess.io/vitess/go/vt/proto/query"

"github.com/stretchr/testify/assert"
)

func TestNewParsedQuery(t *testing.T) {
Expand Down Expand Up @@ -205,3 +206,92 @@ func TestParseAndBind(t *testing.T) {
})
}
}

func TestCastBindVars(t *testing.T) {
testcases := []struct {
typ sqltypes.Type
size int
binds map[string]*querypb.BindVariable
out string
}{
{
typ: sqltypes.Decimal,
binds: map[string]*querypb.BindVariable{"arg": sqltypes.DecimalBindVariable("50")},
out: "select CAST(50 AS DECIMAL(0, 0)) from ",
},
{
typ: sqltypes.Uint32,
binds: map[string]*querypb.BindVariable{"arg": {Type: sqltypes.Uint32, Value: sqltypes.NewUint32(42).Raw()}},
out: "select CAST(42 AS UNSIGNED) from ",
},
{
typ: sqltypes.Float64,
binds: map[string]*querypb.BindVariable{"arg": {Type: sqltypes.Float64, Value: sqltypes.NewFloat64(42.42).Raw()}},
out: "select CAST(42.42 AS DOUBLE) from ",
},
{
typ: sqltypes.Float32,
binds: map[string]*querypb.BindVariable{"arg": {Type: sqltypes.Float32, Value: sqltypes.NewFloat32(42).Raw()}},
out: "select CAST(42 AS FLOAT) from ",
},
{
typ: sqltypes.Date,
binds: map[string]*querypb.BindVariable{"arg": {Type: sqltypes.Date, Value: sqltypes.NewDate("2021-10-30").Raw()}},
out: "select CAST('2021-10-30' AS DATE) from ",
},
{
typ: sqltypes.Time,
binds: map[string]*querypb.BindVariable{"arg": {Type: sqltypes.Time, Value: sqltypes.NewTime("12:00:00").Raw()}},
out: "select CAST('12:00:00' AS TIME) from ",
},
{
typ: sqltypes.Time,
size: 6,
binds: map[string]*querypb.BindVariable{"arg": {Type: sqltypes.Time, Value: sqltypes.NewTime("12:00:00").Raw()}},
out: "select CAST('12:00:00' AS TIME(6)) from ",
},
{
typ: sqltypes.Timestamp,
binds: map[string]*querypb.BindVariable{"arg": {Type: sqltypes.Timestamp, Value: sqltypes.NewTimestamp("2021-10-22 12:00:00").Raw()}},
out: "select CAST('2021-10-22 12:00:00' AS DATETIME) from ",
},
{
typ: sqltypes.Timestamp,
size: 6,
binds: map[string]*querypb.BindVariable{"arg": {Type: sqltypes.Timestamp, Value: sqltypes.NewTimestamp("2021-10-22 12:00:00").Raw()}},
out: "select CAST('2021-10-22 12:00:00' AS DATETIME(6)) from ",
},
{
typ: sqltypes.Datetime,
binds: map[string]*querypb.BindVariable{"arg": {Type: sqltypes.Datetime, Value: sqltypes.NewDatetime("2021-10-22 12:00:00").Raw()}},
out: "select CAST('2021-10-22 12:00:00' AS DATETIME) from ",
},
{
typ: sqltypes.Datetime,
size: 6,
binds: map[string]*querypb.BindVariable{"arg": {Type: sqltypes.Datetime, Value: sqltypes.NewDatetime("2021-10-22 12:00:00").Raw()}},
out: "select CAST('2021-10-22 12:00:00' AS DATETIME(6)) from ",
},
}

for _, testcase := range testcases {
t.Run(testcase.out, func(t *testing.T) {
argument := NewTypedArgument("arg", testcase.typ)
if testcase.size > 0 {
argument.Size = int32(testcase.size)
}

s := &Select{
SelectExprs: SelectExprs{
NewAliasedExpr(argument, ""),
},
}

pq := NewParsedQuery(s)
out, err := pq.GenerateQuery(testcase.binds, nil)

require.NoError(t, err)
require.Equal(t, testcase.out, out)
})
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -45,4 +45,4 @@ insert into t1 (id,intval,floatval) values (1,2,3.14) on duplicate key update in
1 ks_unsharded/-: insert into t1(id, intval, floatval) values (1, 2, 3.14) on duplicate key update intval = 3, floatval = 3.14
2 ks_unsharded/-: commit

----------------------------------------------------------------------
----------------------------------------------------------------------
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

clearly some diff I cannot see.

4 changes: 4 additions & 0 deletions go/vt/vtgate/evalengine/compiler.go
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,10 @@ func (v *EnumSetValues) Equal(other *EnumSetValues) bool {
return slices.Equal(*v, *other)
}

func NewUnknownType() Type {
return NewType(sqltypes.Unknown, collations.Unknown)
}

func NewType(t sqltypes.Type, collation collations.ID) Type {
// New types default to being nullable
return NewTypeEx(t, collation, true, 0, 0, nil)
Expand Down
3 changes: 0 additions & 3 deletions go/vt/vtgate/evalengine/expr_bvar.go
Original file line number Diff line number Diff line change
Expand Up @@ -83,9 +83,6 @@ func (bv *BindVariable) eval(env *ExpressionEnv) (eval, error) {
return nil, vterrors.Errorf(vtrpcpb.Code_INVALID_ARGUMENT, "query argument '%s' cannot be a tuple", bv.Key)
}
typ := bvar.Type
if bv.typed() {
typ = bv.Type
}
systay marked this conversation as resolved.
Show resolved Hide resolved
return valueToEval(sqltypes.MakeTrusted(typ, bvar.Value), typedCoercionCollation(typ, collations.CollationForType(typ, bv.Collation)), nil)
}
}
Expand Down
6 changes: 5 additions & 1 deletion go/vt/vtgate/planbuilder/operators/expressions.go
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,11 @@ func breakExpressionInLHSandRHS(
Name: bvName,
Expr: nodeExpr,
})
arg := sqlparser.NewArgument(bvName)
typeForExpr, _ := ctx.TypeForExpr(nodeExpr)
arg := sqlparser.NewTypedArgument(bvName, typeForExpr.Type())
arg.Scale = typeForExpr.Scale()
arg.Size = typeForExpr.Size()

// we are replacing one of the sides of the comparison with an argument,
// but we don't want to lose the type information we have, so we copy it over
ctx.SemTable.CopyExprInfo(nodeExpr, arg)
Expand Down
Loading
Loading