From f60664bd765cb15a70363bdea05023250cc384ca Mon Sep 17 00:00:00 2001 From: Dmitry Kropachev Date: Tue, 19 Dec 2023 00:29:31 -0400 Subject: [PATCH 1/2] fix(typedef): fix CQLPretty to produce runable queries --- pkg/typedef/bag.go | 22 +++++++----------- pkg/typedef/interfaces.go | 2 +- pkg/typedef/simple_type.go | 47 ++++++++++++++++++++++---------------- pkg/typedef/tuple.go | 15 ++++++------ pkg/typedef/typedef.go | 42 ++++++++++++++++++++++++++-------- pkg/typedef/types.go | 22 +++++++++--------- pkg/typedef/types_test.go | 2 +- pkg/typedef/udt.go | 24 +++++++++---------- pkg/utils/utils.go | 8 +++++-- 9 files changed, 106 insertions(+), 78 deletions(-) diff --git a/pkg/typedef/bag.go b/pkg/typedef/bag.go index 9178c3cd..32f32d5d 100644 --- a/pkg/typedef/bag.go +++ b/pkg/typedef/bag.go @@ -59,26 +59,20 @@ func (ct *BagType) CQLHolder() string { return "?" } -func (ct *BagType) CQLPretty(query string, value []interface{}) (string, int) { - if len(value) == 0 { - return query, 0 - } - if reflect.TypeOf(value[0]).Kind() != reflect.Slice { +func (ct *BagType) CQLPretty(value interface{}) string { + if reflect.TypeOf(value).Kind() != reflect.Slice { panic(fmt.Sprintf("set cql pretty, unknown type %v", ct)) } - s := reflect.ValueOf(value[0]) - op, cl := "[", "]" + s := reflect.ValueOf(value) + format := "[%s]" if ct.ComplexType == TYPE_SET { - op, cl = "{", "}" + format = "{%s}" } - vv := op - vv += strings.Repeat("?,", s.Len()) - vv = strings.TrimRight(vv, ",") - vv += cl + out := make([]string, s.Len()) for i := 0; i < s.Len(); i++ { - vv, _ = ct.ValueType.CQLPretty(vv, []interface{}{s.Index(i).Interface()}) + out[i] = ct.ValueType.CQLPretty(s.Index(i).Interface()) } - return strings.Replace(query, "?", vv, 1), 1 + return fmt.Sprintf(format, strings.Join(out, ",")) } func (ct *BagType) GenValue(r *rand.Rand, p *PartitionRangeConfig) []interface{} { diff --git a/pkg/typedef/interfaces.go b/pkg/typedef/interfaces.go index d63e72de..5e79e456 100644 --- a/pkg/typedef/interfaces.go +++ b/pkg/typedef/interfaces.go @@ -23,7 +23,7 @@ type Type interface { Name() string CQLDef() string CQLHolder() string - CQLPretty(string, []interface{}) (string, int) + CQLPretty(interface{}) string GenValue(*rand.Rand, *PartitionRangeConfig) []interface{} GenJSONValue(*rand.Rand, *PartitionRangeConfig) interface{} LenValue() int diff --git a/pkg/typedef/simple_type.go b/pkg/typedef/simple_type.go index f8ca0105..92f96079 100644 --- a/pkg/typedef/simple_type.go +++ b/pkg/typedef/simple_type.go @@ -20,7 +20,6 @@ import ( "math" "math/big" "net" - "strings" "time" "github.com/gocql/gocql" @@ -67,43 +66,51 @@ func (st SimpleType) LenValue() int { return 1 } -func (st SimpleType) CQLPretty(query string, value []interface{}) (string, int) { - if len(value) == 0 { - return query, 0 - } - var replacement string +func (st SimpleType) CQLPretty(value interface{}) string { switch st { case TYPE_ASCII, TYPE_TEXT, TYPE_VARCHAR, TYPE_INET, TYPE_DATE: - replacement = fmt.Sprintf("'%s'", value[0]) + return fmt.Sprintf("'%s'", value) case TYPE_BLOB: - if v, ok := value[0].(string); ok { + if v, ok := value.(string); ok { if len(v) > 100 { v = v[:100] } - replacement = "textasblob('" + v + "')" + return "textasblob('" + v + "')" } + panic(fmt.Sprintf("unexpected blob value [%T]%+v", value, value)) case TYPE_BIGINT, TYPE_INT, TYPE_SMALLINT, TYPE_TINYINT: - replacement = fmt.Sprintf("%d", value[0]) + return fmt.Sprintf("%d", value) case TYPE_DECIMAL, TYPE_DOUBLE, TYPE_FLOAT: - replacement = fmt.Sprintf("%.2f", value[0]) + return fmt.Sprintf("%.2f", value) case TYPE_BOOLEAN: - if v, ok := value[0].(bool); ok { - replacement = fmt.Sprintf("%t", v) + if v, ok := value.(bool); ok { + return fmt.Sprintf("%t", v) + } + panic(fmt.Sprintf("unexpected boolean value [%T]%+v", value, value)) + case TYPE_TIME: + if v, ok := value.(int64); ok { + // CQL supports only 3 digits microseconds: + // '10:10:55.83275+0000': marshaling error: Milliseconds length exceeds expected (5)" + return fmt.Sprintf("'%s'", time.Time{}.Add(time.Duration(v)).Format("15:04:05.999")) } - case TYPE_TIME, TYPE_TIMESTAMP: - if v, ok := value[0].(time.Time); ok { - replacement = "'" + v.Format(time.RFC3339) + "'" + panic(fmt.Sprintf("unexpected time value [%T]%+v", value, value)) + case TYPE_TIMESTAMP: + if v, ok := value.(int64); ok { + // CQL supports only 3 digits microseconds: + // '1976-03-25T10:10:55.83275+0000': marshaling error: Milliseconds length exceeds expected (5)" + return time.UnixMicro(v).UTC().Format("'2006-01-02T15:04:05.999-0700'") } + panic(fmt.Sprintf("unexpected timestamp value [%T]%+v", value, value)) case TYPE_DURATION, TYPE_TIMEUUID, TYPE_UUID: - replacement = fmt.Sprintf("%s", value[0]) + return fmt.Sprintf("%s", value) case TYPE_VARINT: - if s, ok := value[0].(*big.Int); ok { - replacement = fmt.Sprintf("%d", s.Int64()) + if s, ok := value.(*big.Int); ok { + return fmt.Sprintf("%d", s.Int64()) } + panic(fmt.Sprintf("unexpected varint value [%T]%+v", value, value)) default: panic(fmt.Sprintf("cql pretty: not supported type %s", st)) } - return strings.Replace(query, "?", replacement, 1), 1 } func (st SimpleType) CQLType() gocql.TypeInfo { diff --git a/pkg/typedef/tuple.go b/pkg/typedef/tuple.go index d6e3420a..3a61f2e7 100644 --- a/pkg/typedef/tuple.go +++ b/pkg/typedef/tuple.go @@ -15,6 +15,7 @@ package typedef import ( + "fmt" "strings" "github.com/gocql/gocql" @@ -54,16 +55,16 @@ func (t *TupleType) CQLHolder() string { return "(" + strings.TrimRight(strings.Repeat("?,", len(t.ValueTypes)), ",") + ")" } -func (t *TupleType) CQLPretty(query string, value []interface{}) (string, int) { - if len(value) == 0 { - return query, 0 +func (t *TupleType) CQLPretty(value interface{}) string { + values, ok := value.([]interface{}) + if !ok { + return "()" } - var cnt, tmp int + out := make([]string, len(values)) for i, tp := range t.ValueTypes { - query, tmp = tp.CQLPretty(query, value[i:]) - cnt += tmp + out[i] = tp.CQLPretty(values[i]) } - return query, cnt + return fmt.Sprintf("(%s)", strings.Join(out, ",")) } func (t *TupleType) Indexable() bool { diff --git a/pkg/typedef/typedef.go b/pkg/typedef/typedef.go index 3d3d3c96..44afd033 100644 --- a/pkg/typedef/typedef.go +++ b/pkg/typedef/typedef.go @@ -16,6 +16,7 @@ package typedef import ( "fmt" + "strings" "github.com/scylladb/gocqlx/v2/qb" @@ -70,21 +71,12 @@ type Stmt struct { } func (s *Stmt) PrettyCQL() string { - var replaced int query, _ := s.Query.ToCql() values := s.Values.Copy() if len(values) == 0 { return query } - for _, typ := range s.Types { - query, replaced = typ.CQLPretty(query, values) - if len(values) >= replaced { - values = values[replaced:] - } else { - break - } - } - return query + return prettyCQL(query, values, s.Types) } type StatementType uint8 @@ -165,3 +157,33 @@ const ( CacheDelete CacheArrayLen ) + +func prettyCQL(query string, values Values, types Types) string { + if len(values) == 0 { + return query + } + + k := 0 + out := make([]string, 0, len(values)*2) + queryChunks := strings.Split(query, "?") + out = append(out, queryChunks[0]) + qID := 1 + for _, typ := range types { + tupleType, ok := typ.(*TupleType) + if !ok { + out = append(out, typ.CQLPretty(values[k])) + out = append(out, queryChunks[qID]) + qID++ + k++ + continue + } + for _, t := range tupleType.ValueTypes { + out = append(out, t.CQLPretty(values[k])) + out = append(out, queryChunks[qID]) + qID++ + k++ + } + } + out = append(out, queryChunks[qID:]...) + return strings.Join(out, "") +} diff --git a/pkg/typedef/types.go b/pkg/typedef/types.go index 8ac6f2e6..e311d42f 100644 --- a/pkg/typedef/types.go +++ b/pkg/typedef/types.go @@ -140,19 +140,19 @@ func (mt *MapType) CQLHolder() string { return "?" } -func (mt *MapType) CQLPretty(query string, value []interface{}) (string, int) { - if reflect.TypeOf(value[0]).Kind() != reflect.Map { +func (mt *MapType) CQLPretty(value interface{}) string { + if reflect.TypeOf(value).Kind() != reflect.Map { panic(fmt.Sprintf("map cql pretty, unknown type %v", mt)) } - s := reflect.ValueOf(value[0]).MapRange() - vv := "{" + vof := reflect.ValueOf(value) + s := vof.MapRange() + out := make([]string, len(vof.MapKeys())) + id := 0 for s.Next() { - vv += fmt.Sprintf("%v:?,", s.Key().Interface()) - vv, _ = mt.ValueType.CQLPretty(vv, []interface{}{s.Value().Interface()}) + out[id] = fmt.Sprintf("%s:%s", mt.KeyType.CQLPretty(s.Key().Interface()), mt.ValueType.CQLPretty(s.Value().Interface())) + id++ } - vv = strings.TrimSuffix(vv, ",") - vv += "}" - return strings.Replace(query, "?", vv, 1), 1 + return fmt.Sprintf("{%s}", strings.Join(out, ",")) } func (mt *MapType) GenJSONValue(r *rand.Rand, p *PartitionRangeConfig) interface{} { @@ -209,8 +209,8 @@ func (ct *CounterType) CQLHolder() string { return "?" } -func (ct *CounterType) CQLPretty(query string, value []interface{}) (string, int) { - return strings.Replace(query, "?", fmt.Sprintf("%d", value[0]), 1), 1 +func (ct *CounterType) CQLPretty(value interface{}) string { + return fmt.Sprintf("%d", value) } func (ct *CounterType) GenJSONValue(r *rand.Rand, _ *PartitionRangeConfig) interface{} { diff --git a/pkg/typedef/types_test.go b/pkg/typedef/types_test.go index 32cb9525..51c0a799 100644 --- a/pkg/typedef/types_test.go +++ b/pkg/typedef/types_test.go @@ -217,7 +217,7 @@ func TestCQLPretty(t *testing.T) { t.Parallel() for _, p := range prettytests { - result, _ := p.typ.CQLPretty(p.query, p.values) + result := p.typ.CQLPretty(p.values) if result != p.expected { t.Fatalf("expected '%s', got '%s' for values %v and type '%v'", p.expected, result, p.values, p.typ) } diff --git a/pkg/typedef/udt.go b/pkg/typedef/udt.go index d0f24da0..447f4475 100644 --- a/pkg/typedef/udt.go +++ b/pkg/typedef/udt.go @@ -48,21 +48,21 @@ func (t *UDTType) CQLHolder() string { return "?" } -func (t *UDTType) CQLPretty(query string, value []interface{}) (string, int) { - if len(value) == 0 { - return query, 0 +func (t *UDTType) CQLPretty(value interface{}) string { + s, ok := value.(map[string]interface{}) + if !ok { + panic(fmt.Sprintf("udt pretty, unknown type %v", t)) } - if s, ok := value[0].(map[string]interface{}); ok { - vv := "{" - for k, v := range t.ValueTypes { - vv += fmt.Sprintf("%s:?,", k) - vv, _ = v.CQLPretty(vv, []interface{}{s[k]}) + + out := make([]string, 0, len(t.ValueTypes)) + for k, v := range t.ValueTypes { + keyVal, kexExists := s[k] + if !kexExists { + continue } - vv = strings.TrimSuffix(vv, ",") - vv += "}" - return strings.Replace(query, "?", vv, 1), 1 + out = append(out, fmt.Sprintf("%s:%s", k, v.CQLPretty(keyVal))) } - panic(fmt.Sprintf("udt pretty, unknown type %v", t)) + return fmt.Sprintf("{%s}", strings.Join(out, ",")) } func (t *UDTType) Indexable() bool { diff --git a/pkg/utils/utils.go b/pkg/utils/utils.go index ed7610b0..0b775d48 100644 --- a/pkg/utils/utils.go +++ b/pkg/utils/utils.go @@ -25,7 +25,7 @@ import ( "golang.org/x/exp/rand" ) -var maxDateMs = time.Date(9999, 12, 31, 0, 0, 0, 0, time.UTC).UTC().UnixMilli() +var maxDateMs = time.Date(9999, 12, 31, 23, 59, 59, 999999999, time.UTC).UTC().UnixMilli() // RandDateStr generates time in string representation // it is done in such way because we wanted to make JSON statement to work @@ -34,8 +34,12 @@ func RandDateStr(rnd *rand.Rand) string { return time.UnixMilli(rnd.Int63n(maxDateMs)).UTC().Format("2006-01-02") } +// RandTimestamp generates timestamp in nanoseconds +// Date limit needed to make sure that textual representation of the date is parsable by cql and drivers +// Currently CQL fails: unable to parse date '95260-10-10T19:09:07.148+0000': marshaling error: Unable to parse timestamp from '95260-10-10t19:09:07.148+0000'" +// if year is bigger than 9999, same as golang time.Parse func RandTimestamp(rnd *rand.Rand) int64 { - return rnd.Int63() + return rnd.Int63n(maxDateMs) } func RandDate(rnd *rand.Rand) time.Time { From 69aaefd159825184be9e3afe70da088ec59d380e Mon Sep 17 00:00:00 2001 From: Dmitry Kropachev Date: Thu, 14 Dec 2023 01:53:00 -0400 Subject: [PATCH 2/2] feat(cli): add statement-log-file option --- cmd/gemini/root.go | 25 +++-- pkg/generators/statement_generator.go | 2 +- pkg/jobs/jobs.go | 8 +- pkg/stmtlogger/filelogger.go | 135 +++++++++++++++++++++++ pkg/store/cqlstore.go | 31 +++--- pkg/store/store.go | 153 ++++++++++++++++---------- pkg/typedef/const.go | 13 +++ pkg/typedef/interfaces.go | 5 + pkg/typedef/simple_type.go | 4 +- pkg/typedef/typedef.go | 59 ++++++++-- pkg/typedef/types_test.go | 105 +++++++++--------- 11 files changed, 385 insertions(+), 155 deletions(-) create mode 100644 pkg/stmtlogger/filelogger.go diff --git a/cmd/gemini/root.go b/cmd/gemini/root.go index 3893c680..5ed2b95b 100644 --- a/cmd/gemini/root.go +++ b/cmd/gemini/root.go @@ -105,6 +105,8 @@ var ( requestTimeout time.Duration connectTimeout time.Duration profilingPort int + testStatementLogFile string + oracleStatementLogFile string ) func interactive() bool { @@ -133,14 +135,6 @@ func readSchema(confFile string, schemaConfig typedef.SchemaConfig) (*typedef.Sc return schemaBuilder.Build(), nil } -type createBuilder struct { - stmt string -} - -func (cb createBuilder) ToCql() (stmt string, names []string) { - return cb.stmt, nil -} - func run(_ *cobra.Command, _ []string) error { logger := createLogger(level) globalStatus := status.NewGlobalStatus(1000) @@ -219,6 +213,8 @@ func run(_ *cobra.Command, _ []string) error { MaxRetriesMutate: maxRetriesMutate, MaxRetriesMutateSleep: maxRetriesMutateSleep, UseServerSideTimestamps: useServerSideTimestamps, + TestLogStatementsFile: testStatementLogFile, + OracleLogStatementsFile: oracleStatementLogFile, } var tracingFile *os.File if tracingOutFile != "" { @@ -243,22 +239,25 @@ func run(_ *cobra.Command, _ []string) error { defer utils.IgnoreError(st.Close) if dropSchema && mode != jobs.ReadMode { - for _, stmt := range generators.GetDropSchema(schema) { + for _, stmt := range generators.GetDropKeyspace(schema) { logger.Debug(stmt) - if err = st.Mutate(context.Background(), createBuilder{stmt: stmt}); err != nil { + if err = st.Mutate(context.Background(), typedef.SimpleStmt(stmt, typedef.DropKeyspaceStatementType)); err != nil { return errors.Wrap(err, "unable to drop schema") } } } testKeyspace, oracleKeyspace := generators.GetCreateKeyspaces(schema) - if err = st.Create(context.Background(), createBuilder{stmt: testKeyspace}, createBuilder{stmt: oracleKeyspace}); err != nil { + if err = st.Create( + context.Background(), + typedef.SimpleStmt(testKeyspace, typedef.CreateKeyspaceStatementType), + typedef.SimpleStmt(oracleKeyspace, typedef.CreateKeyspaceStatementType)); err != nil { return errors.Wrap(err, "unable to create keyspace") } for _, stmt := range generators.GetCreateSchema(schema) { logger.Debug(stmt) - if err = st.Mutate(context.Background(), createBuilder{stmt: stmt}); err != nil { + if err = st.Mutate(context.Background(), typedef.SimpleStmt(stmt, typedef.CreateSchemaStatementType)); err != nil { return errors.Wrap(err, "unable to create schema") } } @@ -531,6 +530,8 @@ func init() { rootCmd.Flags().DurationVarP(&connectTimeout, "connect-timeout", "", 30*time.Second, "Duration of waiting connection established") rootCmd.Flags().IntVarP(&profilingPort, "profiling-port", "", 0, "If non-zero starts pprof profiler on given port at 'http://0.0.0.0:/profile'") rootCmd.Flags().IntVarP(&maxErrorsToStore, "max-errors-to-store", "", 1000, "Maximum number of errors to store and output at the end") + rootCmd.Flags().StringVarP(&testStatementLogFile, "test-statement-log-file", "", "", "File to write statements flow to") + rootCmd.Flags().StringVarP(&oracleStatementLogFile, "oracle-statement-log-file", "", "", "File to write statements flow to") } func printSetup(seed, schemaSeed uint64) { diff --git a/pkg/generators/statement_generator.go b/pkg/generators/statement_generator.go index 6d2cb84b..9aca0e7f 100644 --- a/pkg/generators/statement_generator.go +++ b/pkg/generators/statement_generator.go @@ -140,7 +140,7 @@ func GetCreateSchema(s *typedef.Schema) []string { return stmts } -func GetDropSchema(s *typedef.Schema) []string { +func GetDropKeyspace(s *typedef.Schema) []string { return []string{ fmt.Sprintf("DROP KEYSPACE IF EXISTS %s", s.Keyspace.Name), } diff --git a/pkg/jobs/jobs.go b/pkg/jobs/jobs.go index 3de54e9f..ac21c09f 100644 --- a/pkg/jobs/jobs.go +++ b/pkg/jobs/jobs.go @@ -329,7 +329,7 @@ func ddl( if w := logger.Check(zap.DebugLevel, "ddl statement"); w != nil { w.Write(zap.String("pretty_cql", ddlStmt.PrettyCQL())) } - if err = s.Mutate(ctx, ddlStmt.Query); err != nil { + if err = s.Mutate(ctx, ddlStmt); err != nil { if errors.Is(err, context.Canceled) { return nil } @@ -376,13 +376,11 @@ func mutation( } return err } - mutateQuery := mutateStmt.Query - mutateValues := mutateStmt.Values if w := logger.Check(zap.DebugLevel, "mutation statement"); w != nil { w.Write(zap.String("pretty_cql", mutateStmt.PrettyCQL())) } - if err = s.Mutate(ctx, mutateQuery, mutateValues...); err != nil { + if err = s.Mutate(ctx, mutateStmt); err != nil { if errors.Is(err, context.Canceled) { return nil } @@ -425,7 +423,7 @@ func validation( attempt := 1 for { lastErr = err - err = s.Check(ctx, table, stmt.Query, attempt == maxAttempts, stmt.Values...) + err = s.Check(ctx, table, stmt, attempt == maxAttempts) if err == nil { if attempt > 1 { diff --git a/pkg/stmtlogger/filelogger.go b/pkg/stmtlogger/filelogger.go new file mode 100644 index 00000000..5087556c --- /dev/null +++ b/pkg/stmtlogger/filelogger.go @@ -0,0 +1,135 @@ +// Copyright 2019 ScyllaDB +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package stmtlogger + +import ( + "log" + "os" + "strconv" + "sync/atomic" + "time" + + "github.com/pkg/errors" + + "github.com/scylladb/gemini/pkg/typedef" +) + +const ( + defaultChanSize = 1000 + errorsOnFileLimit = 5 +) + +type FileLogger struct { + fd *os.File + activeChannel atomic.Pointer[loggerChan] + channel loggerChan + filename string + isFileNonOperational bool +} + +type loggerChan chan logRec + +type logRec struct { + stmt *typedef.Stmt + ts time.Time +} + +func (fl *FileLogger) LogStmt(stmt *typedef.Stmt) { + ch := fl.activeChannel.Load() + if ch != nil { + *ch <- logRec{ + stmt: stmt, + } + } +} + +func (fl *FileLogger) LogStmtWithTimeStamp(stmt *typedef.Stmt, ts time.Time) { + ch := fl.activeChannel.Load() + if ch != nil { + *ch <- logRec{ + stmt: stmt, + ts: ts, + } + } +} + +func (fl *FileLogger) Close() error { + return fl.fd.Close() +} + +func (fl *FileLogger) committer() { + var err2 error + + defer func() { + fl.activeChannel.Swap(nil) + close(fl.channel) + }() + + errsAtRow := 0 + + for rec := range fl.channel { + if fl.isFileNonOperational { + continue + } + + _, err1 := fl.fd.Write([]byte(rec.stmt.PrettyCQL())) + opType := rec.stmt.QueryType.OpType() + if rec.ts.IsZero() || !(opType == typedef.OpInsert || opType == typedef.OpUpdate || opType == typedef.OpDelete) { + _, err2 = fl.fd.Write([]byte(";\n")) + } else { + _, err2 = fl.fd.Write([]byte(" USING TIMESTAMP " + strconv.FormatInt(rec.ts.UnixNano()/1000, 10) + ";\n")) + } + if err2 == nil && err1 == nil { + errsAtRow = 0 + continue + } + + if errors.Is(err2, os.ErrClosed) || errors.Is(err1, os.ErrClosed) { + fl.isFileNonOperational = true + return + } + + errsAtRow++ + if errsAtRow > errorsOnFileLimit { + fl.isFileNonOperational = true + } + + if err2 != nil { + err1 = err2 + } + log.Printf("failed to write to file %q: %s", fl.filename, err1) + return + } +} + +func NewFileLogger(filename string) (*FileLogger, error) { + if filename == "" { + return nil, nil + } + fd, err := os.OpenFile(filename, os.O_CREATE|os.O_TRUNC|os.O_WRONLY, 0644) + if err != nil { + return nil, err + } + + out := &FileLogger{ + filename: filename, + fd: fd, + channel: make(loggerChan, defaultChanSize), + } + out.activeChannel.Store(&out.channel) + + go out.committer() + return out, nil +} diff --git a/pkg/store/cqlstore.go b/pkg/store/cqlstore.go index 82a99868..c756e07a 100644 --- a/pkg/store/cqlstore.go +++ b/pkg/store/cqlstore.go @@ -30,7 +30,7 @@ import ( "github.com/scylladb/gemini/pkg/typedef" ) -type cqlStore struct { +type cqlStore struct { //nolint:govet session *gocql.Session schema *typedef.Schema ops *prometheus.CounterVec @@ -39,20 +39,21 @@ type cqlStore struct { maxRetriesMutate int maxRetriesMutateSleep time.Duration useServerSideTimestamps bool + stmtLogger stmtLogger } func (cs *cqlStore) name() string { return cs.system } -func (cs *cqlStore) mutate(ctx context.Context, builder qb.Builder, values ...interface{}) (err error) { +func (cs *cqlStore) mutate(ctx context.Context, stmt *typedef.Stmt) (err error) { var i int for i = 0; i < cs.maxRetriesMutate; i++ { // retry with new timestamp as list modification with the same ts // will produce duplicated values, see https://github.com/scylladb/scylladb/issues/7937 - err = cs.doMutate(ctx, builder, time.Now(), values...) + err = cs.doMutate(ctx, stmt, time.Now()) if err == nil { - cs.ops.WithLabelValues(cs.system, opType(builder)).Inc() + cs.ops.WithLabelValues(cs.system, opType(stmt)).Inc() return nil } select { @@ -67,14 +68,15 @@ func (cs *cqlStore) mutate(ctx context.Context, builder qb.Builder, values ...in return err } -func (cs *cqlStore) doMutate(ctx context.Context, builder qb.Builder, ts time.Time, values ...interface{}) error { - queryBody, _ := builder.ToCql() - - query := cs.session.Query(queryBody, values...).WithContext(ctx) +func (cs *cqlStore) doMutate(ctx context.Context, stmt *typedef.Stmt, ts time.Time) error { + queryBody, _ := stmt.Query.ToCql() + query := cs.session.Query(queryBody, stmt.Values...).WithContext(ctx) if cs.useServerSideTimestamps { query = query.DefaultTimestamp(false) + cs.stmtLogger.LogStmt(stmt) } else { query = query.WithTimestamp(ts.UnixNano() / 1000) + cs.stmtLogger.LogStmtWithTimeStamp(stmt, ts) } if err := query.Exec(); err != nil { @@ -90,10 +92,11 @@ func (cs *cqlStore) doMutate(ctx context.Context, builder qb.Builder, ts time.Ti return nil } -func (cs *cqlStore) load(ctx context.Context, builder qb.Builder, values []interface{}) (result []map[string]interface{}, err error) { - query, _ := builder.ToCql() - iter := cs.session.Query(query, values...).WithContext(ctx).Iter() - cs.ops.WithLabelValues(cs.system, opType(builder)).Inc() +func (cs *cqlStore) load(ctx context.Context, stmt *typedef.Stmt) (result []map[string]interface{}, err error) { + query, _ := stmt.Query.ToCql() + cs.stmtLogger.LogStmt(stmt) + iter := cs.session.Query(query, stmt.Values...).WithContext(ctx).Iter() + cs.ops.WithLabelValues(cs.system, opType(stmt)).Inc() return loadSet(iter), iter.Close() } @@ -126,8 +129,8 @@ func ignore(err error) bool { } } -func opType(builder qb.Builder) string { - switch builder.(type) { +func opType(stmt *typedef.Stmt) string { + switch stmt.Query.(type) { case *qb.InsertBuilder: return "insert" case *qb.DeleteBuilder: diff --git a/pkg/store/store.go b/pkg/store/store.go index ed5fd98a..d9890add 100644 --- a/pkg/store/store.go +++ b/pkg/store/store.go @@ -32,20 +32,21 @@ import ( "github.com/pkg/errors" "github.com/prometheus/client_golang/prometheus" "github.com/prometheus/client_golang/prometheus/promauto" - "github.com/scylladb/go-set/strset" - "github.com/scylladb/gocqlx/v2/qb" "go.uber.org/multierr" "gopkg.in/inf.v0" + "github.com/scylladb/go-set/strset" + + "github.com/scylladb/gemini/pkg/stmtlogger" "github.com/scylladb/gemini/pkg/typedef" ) type loader interface { - load(context.Context, qb.Builder, []interface{}) ([]map[string]interface{}, error) + load(context.Context, *typedef.Stmt) ([]map[string]interface{}, error) } type storer interface { - mutate(context.Context, qb.Builder, ...interface{}) error + mutate(context.Context, *typedef.Stmt) error } type storeLoader interface { @@ -55,14 +56,22 @@ type storeLoader interface { name() string } +type stmtLogger interface { + LogStmt(*typedef.Stmt) + LogStmtWithTimeStamp(stmt *typedef.Stmt, ts time.Time) + Close() error +} + type Store interface { - Create(context.Context, qb.Builder, qb.Builder) error - Mutate(context.Context, qb.Builder, ...interface{}) error - Check(context.Context, *typedef.Table, qb.Builder, bool, ...interface{}) error + Create(context.Context, *typedef.Stmt, *typedef.Stmt) error + Mutate(context.Context, *typedef.Stmt) error + Check(context.Context, *typedef.Table, *typedef.Stmt, bool) error Close() error } type Config struct { + TestLogStatementsFile string + OracleLogStatementsFile string MaxRetriesMutate int MaxRetriesMutateSleep time.Duration UseServerSideTimestamps bool @@ -75,48 +84,23 @@ func New(schema *typedef.Schema, testCluster, oracleCluster *gocql.ClusterConfig }, []string{"system", "method"}, ) - var oracleStore storeLoader - var validations bool - if oracleCluster != nil { - oracleSession, err := newSession(oracleCluster, traceOut) - if err != nil { - return nil, errors.Wrapf(err, "failed to connect to oracle cluster") - } - oracleStore = &cqlStore{ - session: oracleSession, - schema: schema, - system: "oracle", - ops: ops, - maxRetriesMutate: cfg.MaxRetriesMutate + 10, - maxRetriesMutateSleep: cfg.MaxRetriesMutateSleep, - useServerSideTimestamps: cfg.UseServerSideTimestamps, - logger: logger, - } - validations = true - } else { - oracleStore = &noOpStore{ - system: "oracle", - } + oracleStore, err := getStore("oracle", schema, oracleCluster, cfg, cfg.OracleLogStatementsFile, traceOut, logger, ops) + if err != nil { + return nil, err } - testSession, err := newSession(testCluster, traceOut) + if testCluster == nil { + return nil, errors.New("test cluster is empty") + } + testStore, err := getStore("test", schema, testCluster, cfg, cfg.TestLogStatementsFile, traceOut, logger, ops) if err != nil { - return nil, errors.Wrapf(err, "failed to connect to oracle cluster") + return nil, err } return &delegatingStore{ - testStore: &cqlStore{ - session: testSession, - schema: schema, - system: "test", - ops: ops, - maxRetriesMutate: cfg.MaxRetriesMutate, - maxRetriesMutateSleep: cfg.MaxRetriesMutateSleep, - useServerSideTimestamps: cfg.UseServerSideTimestamps, - logger: logger, - }, + testStore: testStore, oracleStore: oracleStore, - validations: validations, + validations: oracleStore != nil, logger: logger.Named("delegating_store"), }, nil } @@ -125,11 +109,11 @@ type noOpStore struct { system string } -func (n *noOpStore) mutate(context.Context, qb.Builder, ...interface{}) error { +func (n *noOpStore) mutate(context.Context, *typedef.Stmt) error { return nil } -func (n *noOpStore) load(context.Context, qb.Builder, []interface{}) ([]map[string]interface{}, error) { +func (n *noOpStore) load(context.Context, *typedef.Stmt) ([]map[string]interface{}, error) { return nil, nil } @@ -146,31 +130,39 @@ func (n *noOpStore) close() error { } type delegatingStore struct { - oracleStore storeLoader - testStore storeLoader - logger *zap.Logger - validations bool + oracleStore storeLoader + testStore storeLoader + statementLogger stmtLogger + logger *zap.Logger + validations bool } -func (ds delegatingStore) Create(ctx context.Context, testBuilder, oracleBuilder qb.Builder) error { - if err := mutate(ctx, ds.oracleStore, oracleBuilder, []interface{}{}); err != nil { +func (ds delegatingStore) Create(ctx context.Context, testBuilder, oracleBuilder *typedef.Stmt) error { + if ds.statementLogger != nil { + ds.statementLogger.LogStmt(testBuilder) + } + if err := mutate(ctx, ds.oracleStore, oracleBuilder); err != nil { return errors.Wrap(err, "oracle failed store creation") } - if err := mutate(ctx, ds.testStore, testBuilder, []interface{}{}); err != nil { + if err := mutate(ctx, ds.testStore, testBuilder); err != nil { return errors.Wrap(err, "test failed store creation") } return nil } -func (ds delegatingStore) Mutate(ctx context.Context, builder qb.Builder, values ...interface{}) error { +func (ds delegatingStore) Mutate(ctx context.Context, stmt *typedef.Stmt) error { var testErr error var wg sync.WaitGroup wg.Add(1) + go func() { - testErr = mutate(ctx, ds.testStore, builder, values...) - wg.Done() + defer wg.Done() + testErr = errors.Wrapf( + ds.testStore.mutate(ctx, stmt), + "unable to apply mutations to the %s store", ds.testStore.name()) }() - if oracleErr := mutate(ctx, ds.oracleStore, builder, values...); oracleErr != nil { + + if oracleErr := ds.oracleStore.mutate(ctx, stmt); oracleErr != nil { // Oracle failed, transition cannot take place ds.logger.Info("oracle store failed mutation, transition to next state impossible so continuing with next mutation", zap.Error(oracleErr)) return oracleErr @@ -184,23 +176,24 @@ func (ds delegatingStore) Mutate(ctx context.Context, builder qb.Builder, values return nil } -func mutate(ctx context.Context, s storeLoader, builder qb.Builder, values ...interface{}) error { - if err := s.mutate(ctx, builder, values...); err != nil { +func mutate(ctx context.Context, s storeLoader, stmt *typedef.Stmt) error { + if err := s.mutate(ctx, stmt); err != nil { return errors.Wrapf(err, "unable to apply mutations to the %s store", s.name()) } return nil } -func (ds delegatingStore) Check(ctx context.Context, table *typedef.Table, builder qb.Builder, detailedDiff bool, values ...interface{}) error { +func (ds delegatingStore) Check(ctx context.Context, table *typedef.Table, stmt *typedef.Stmt, detailedDiff bool) error { var testRows, oracleRows []map[string]interface{} var testErr, oracleErr error var wg sync.WaitGroup wg.Add(1) + go func() { - testRows, testErr = ds.testStore.load(ctx, builder, values) + testRows, testErr = ds.testStore.load(ctx, stmt) wg.Done() }() - oracleRows, oracleErr = ds.oracleStore.load(ctx, builder, values) + oracleRows, oracleErr = ds.oracleStore.load(ctx, stmt) if oracleErr != nil { return errors.Wrapf(oracleErr, "unable to load check data from the oracle store") } @@ -257,7 +250,47 @@ func (ds delegatingStore) Check(ctx context.Context, table *typedef.Table, build } func (ds delegatingStore) Close() (err error) { + if ds.statementLogger != nil { + err = multierr.Append(err, ds.statementLogger.Close()) + } err = multierr.Append(err, ds.testStore.close()) err = multierr.Append(err, ds.oracleStore.close()) return } + +func getStore( + name string, + schema *typedef.Schema, + clusterConfig *gocql.ClusterConfig, + cfg Config, + stmtLogFile string, + traceOut *os.File, + logger *zap.Logger, + ops *prometheus.CounterVec, +) (out storeLoader, err error) { + if clusterConfig == nil { + return &noOpStore{ + system: name, + }, nil + } + oracleSession, err := newSession(clusterConfig, traceOut) + if err != nil { + return nil, errors.Wrapf(err, "failed to connect to %s cluster", name) + } + oracleFileLogger, err := stmtlogger.NewFileLogger(stmtLogFile) + if err != nil { + return nil, err + } + + return &cqlStore{ + session: oracleSession, + schema: schema, + system: name, + ops: ops, + maxRetriesMutate: cfg.MaxRetriesMutate + 10, + maxRetriesMutateSleep: cfg.MaxRetriesMutateSleep, + useServerSideTimestamps: cfg.UseServerSideTimestamps, + logger: logger.Named(name), + stmtLogger: oracleFileLogger, + }, nil +} diff --git a/pkg/typedef/const.go b/pkg/typedef/const.go index b61befe2..93e7e1eb 100644 --- a/pkg/typedef/const.go +++ b/pkg/typedef/const.go @@ -36,6 +36,19 @@ const ( AlterColumnStatementType DropColumnStatementType AddColumnStatementType + DropKeyspaceStatementType + CreateKeyspaceStatementType + CreateSchemaStatementType +) + +const ( + OpSelect OpType = iota + OpInsert + OpUpdate + OpDelete + OpSchemaAlter + OpSchemaDrop + OpSchemaCreate ) //nolint:revive diff --git a/pkg/typedef/interfaces.go b/pkg/typedef/interfaces.go index 5e79e456..f12d26f4 100644 --- a/pkg/typedef/interfaces.go +++ b/pkg/typedef/interfaces.go @@ -33,6 +33,11 @@ type Type interface { CQLType() gocql.TypeInfo } +type Statement interface { + ToCql() (stmt string, names []string) + PrettyCQL() string +} + type Types []Type func (l Types) LenValue() int { diff --git a/pkg/typedef/simple_type.go b/pkg/typedef/simple_type.go index 92f96079..bd173687 100644 --- a/pkg/typedef/simple_type.go +++ b/pkg/typedef/simple_type.go @@ -96,9 +96,9 @@ func (st SimpleType) CQLPretty(value interface{}) string { panic(fmt.Sprintf("unexpected time value [%T]%+v", value, value)) case TYPE_TIMESTAMP: if v, ok := value.(int64); ok { - // CQL supports only 3 digits microseconds: + // CQL supports only 3 digits milliseconds: // '1976-03-25T10:10:55.83275+0000': marshaling error: Milliseconds length exceeds expected (5)" - return time.UnixMicro(v).UTC().Format("'2006-01-02T15:04:05.999-0700'") + return time.UnixMilli(v).UTC().Format("'2006-01-02T15:04:05.999-0700'") } panic(fmt.Sprintf("unexpected timestamp value [%T]%+v", value, value)) case TYPE_DURATION, TYPE_TIMEUUID, TYPE_UUID: diff --git a/pkg/typedef/typedef.go b/pkg/typedef/typedef.go index 44afd033..9194e98d 100644 --- a/pkg/typedef/typedef.go +++ b/pkg/typedef/typedef.go @@ -24,6 +24,9 @@ import ( ) type ( + CQLFeature int + OpType uint8 + ValueWithToken struct { Value Values Token uint64 @@ -48,20 +51,26 @@ type ( UseLWT bool } - CQLFeature int + Stmts struct { + PostStmtHook func() + List []*Stmt + QueryType StatementType + } + + StmtCache struct { + Query qb.Builder + Types Types + QueryType StatementType + LenValue int + } ) -type Stmts struct { - PostStmtHook func() - List []*Stmt - QueryType StatementType +type SimpleQuery struct { + query string } -type StmtCache struct { - Query qb.Builder - Types Types - QueryType StatementType - LenValue int +func (q SimpleQuery) ToCql() (stmt string, names []string) { + return q.query, nil } type Stmt struct { @@ -70,6 +79,15 @@ type Stmt struct { Values Values } +func SimpleStmt(query string, queryType StatementType) *Stmt { + return &Stmt{ + StmtCache: &StmtCache{ + Query: SimpleQuery{query}, + QueryType: queryType, + }, + } +} + func (s *Stmt) PrettyCQL() string { query, _ := s.Query.ToCql() values := s.Values.Copy() @@ -110,6 +128,27 @@ func (st StatementType) ToString() string { } } +func (st StatementType) OpType() OpType { + switch st { + case SelectStatementType, SelectRangeStatementType, SelectByIndexStatementType, SelectFromMaterializedViewStatementType: + return OpSelect + case InsertStatementType, InsertJSONStatementType: + return OpInsert + case UpdateStatementType: + return OpUpdate + case DeleteStatementType: + return OpDelete + case AlterColumnStatementType, DropColumnStatementType, AddColumnStatementType: + return OpSchemaAlter + case DropKeyspaceStatementType: + return OpSchemaDrop + case CreateKeyspaceStatementType, CreateSchemaStatementType: + return OpSchemaCreate + default: + panic(fmt.Sprintf("unknown statement type %d", st)) + } +} + func (st StatementType) PossibleAsyncOperation() bool { switch st { case SelectByIndexStatementType, SelectFromMaterializedViewStatementType: diff --git a/pkg/typedef/types_test.go b/pkg/typedef/types_test.go index 51c0a799..9a8a0a5c 100644 --- a/pkg/typedef/types_test.go +++ b/pkg/typedef/types_test.go @@ -12,7 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -package typedef_test +package typedef import ( "math/big" @@ -21,142 +21,140 @@ import ( "time" "gopkg.in/inf.v0" - - "github.com/scylladb/gemini/pkg/typedef" ) var millennium = time.Date(1999, 12, 31, 23, 59, 59, 0, time.UTC) var prettytests = []struct { - typ typedef.Type + typ Type query string expected string values []interface{} }{ { - typ: typedef.TYPE_ASCII, + typ: TYPE_ASCII, query: "SELECT * FROM tbl WHERE pk0=?", values: []interface{}{"a"}, expected: "SELECT * FROM tbl WHERE pk0='a'", }, { - typ: typedef.TYPE_BIGINT, + typ: TYPE_BIGINT, query: "SELECT * FROM tbl WHERE pk0=?", values: []interface{}{big.NewInt(10)}, expected: "SELECT * FROM tbl WHERE pk0=10", }, { - typ: typedef.TYPE_BLOB, + typ: TYPE_BLOB, query: "SELECT * FROM tbl WHERE pk0=?", values: []interface{}{"a"}, expected: "SELECT * FROM tbl WHERE pk0=textasblob('a')", }, { - typ: typedef.TYPE_BOOLEAN, + typ: TYPE_BOOLEAN, query: "SELECT * FROM tbl WHERE pk0=?", values: []interface{}{true}, expected: "SELECT * FROM tbl WHERE pk0=true", }, { - typ: typedef.TYPE_DATE, + typ: TYPE_DATE, query: "SELECT * FROM tbl WHERE pk0=?", values: []interface{}{millennium.Format("2006-01-02")}, expected: "SELECT * FROM tbl WHERE pk0='1999-12-31'", }, { - typ: typedef.TYPE_DECIMAL, + typ: TYPE_DECIMAL, query: "SELECT * FROM tbl WHERE pk0=?", values: []interface{}{inf.NewDec(1000, 0)}, expected: "SELECT * FROM tbl WHERE pk0=1000", }, { - typ: typedef.TYPE_DOUBLE, + typ: TYPE_DOUBLE, query: "SELECT * FROM tbl WHERE pk0=?", values: []interface{}{10.0}, expected: "SELECT * FROM tbl WHERE pk0=10.00", }, { - typ: typedef.TYPE_DURATION, + typ: TYPE_DURATION, query: "SELECT * FROM tbl WHERE pk0=?", values: []interface{}{10 * time.Minute}, expected: "SELECT * FROM tbl WHERE pk0=10m0s", }, { - typ: typedef.TYPE_FLOAT, + typ: TYPE_FLOAT, query: "SELECT * FROM tbl WHERE pk0=?", values: []interface{}{10.0}, expected: "SELECT * FROM tbl WHERE pk0=10.00", }, { - typ: typedef.TYPE_INET, + typ: TYPE_INET, query: "SELECT * FROM tbl WHERE pk0=?", values: []interface{}{net.ParseIP("192.168.0.1")}, expected: "SELECT * FROM tbl WHERE pk0='192.168.0.1'", }, { - typ: typedef.TYPE_INT, + typ: TYPE_INT, query: "SELECT * FROM tbl WHERE pk0=?", values: []interface{}{10}, expected: "SELECT * FROM tbl WHERE pk0=10", }, { - typ: typedef.TYPE_SMALLINT, + typ: TYPE_SMALLINT, query: "SELECT * FROM tbl WHERE pk0=?", values: []interface{}{2}, expected: "SELECT * FROM tbl WHERE pk0=2", }, { - typ: typedef.TYPE_TEXT, + typ: TYPE_TEXT, query: "SELECT * FROM tbl WHERE pk0=?", values: []interface{}{"a"}, expected: "SELECT * FROM tbl WHERE pk0='a'", }, { - typ: typedef.TYPE_TIME, + typ: TYPE_TIME, query: "SELECT * FROM tbl WHERE pk0=?", - values: []interface{}{millennium}, - expected: "SELECT * FROM tbl WHERE pk0='" + millennium.Format(time.RFC3339) + "'", + values: []interface{}{millennium.UnixNano()}, + expected: "SELECT * FROM tbl WHERE pk0='" + millennium.Format("15:04:05.999") + "'", }, { - typ: typedef.TYPE_TIMESTAMP, + typ: TYPE_TIMESTAMP, query: "SELECT * FROM tbl WHERE pk0=?", - values: []interface{}{millennium}, - expected: "SELECT * FROM tbl WHERE pk0='" + millennium.Format(time.RFC3339) + "'", + values: []interface{}{millennium.UnixMilli()}, + expected: "SELECT * FROM tbl WHERE pk0='" + millennium.Format("2006-01-02T15:04:05.999-0700") + "'", }, { - typ: typedef.TYPE_TIMEUUID, + typ: TYPE_TIMEUUID, query: "SELECT * FROM tbl WHERE pk0=?", values: []interface{}{"63176980-bfde-11d3-bc37-1c4d704231dc"}, expected: "SELECT * FROM tbl WHERE pk0=63176980-bfde-11d3-bc37-1c4d704231dc", }, { - typ: typedef.TYPE_TINYINT, + typ: TYPE_TINYINT, query: "SELECT * FROM tbl WHERE pk0=?", values: []interface{}{1}, expected: "SELECT * FROM tbl WHERE pk0=1", }, { - typ: typedef.TYPE_UUID, + typ: TYPE_UUID, query: "SELECT * FROM tbl WHERE pk0=?", values: []interface{}{"63176980-bfde-11d3-bc37-1c4d704231dc"}, expected: "SELECT * FROM tbl WHERE pk0=63176980-bfde-11d3-bc37-1c4d704231dc", }, { - typ: typedef.TYPE_VARCHAR, + typ: TYPE_VARCHAR, query: "SELECT * FROM tbl WHERE pk0=?", values: []interface{}{"a"}, expected: "SELECT * FROM tbl WHERE pk0='a'", }, { - typ: typedef.TYPE_VARINT, + typ: TYPE_VARINT, query: "SELECT * FROM tbl WHERE pk0=?", values: []interface{}{big.NewInt(1001)}, expected: "SELECT * FROM tbl WHERE pk0=1001", }, { - typ: &typedef.BagType{ - ComplexType: typedef.TYPE_SET, - ValueType: typedef.TYPE_ASCII, + typ: &BagType{ + ComplexType: TYPE_SET, + ValueType: TYPE_ASCII, Frozen: false, }, query: "SELECT * FROM tbl WHERE pk0=?", @@ -164,9 +162,9 @@ var prettytests = []struct { expected: "SELECT * FROM tbl WHERE pk0={'a','b'}", }, { - typ: &typedef.BagType{ - ComplexType: typedef.TYPE_LIST, - ValueType: typedef.TYPE_ASCII, + typ: &BagType{ + ComplexType: TYPE_LIST, + ValueType: TYPE_ASCII, Frozen: false, }, query: "SELECT * FROM tbl WHERE pk0=?", @@ -174,28 +172,28 @@ var prettytests = []struct { expected: "SELECT * FROM tbl WHERE pk0=['a','b']", }, { - typ: &typedef.MapType{ - KeyType: typedef.TYPE_ASCII, - ValueType: typedef.TYPE_ASCII, + typ: &MapType{ + KeyType: TYPE_ASCII, + ValueType: TYPE_ASCII, Frozen: false, }, query: "SELECT * FROM tbl WHERE pk0=?", values: []interface{}{map[string]string{"a": "b"}}, - expected: "SELECT * FROM tbl WHERE pk0={a:'b'}", + expected: "SELECT * FROM tbl WHERE pk0={'a':'b'}", }, { - typ: &typedef.MapType{ - KeyType: typedef.TYPE_ASCII, - ValueType: typedef.TYPE_BLOB, + typ: &MapType{ + KeyType: TYPE_ASCII, + ValueType: TYPE_BLOB, Frozen: false, }, query: "SELECT * FROM tbl WHERE pk0=?", values: []interface{}{map[string]string{"a": "b"}}, - expected: "SELECT * FROM tbl WHERE pk0={a:textasblob('b')}", + expected: "SELECT * FROM tbl WHERE pk0={'a':textasblob('b')}", }, { - typ: &typedef.TupleType{ - ValueTypes: []typedef.SimpleType{typedef.TYPE_ASCII}, + typ: &TupleType{ + ValueTypes: []SimpleType{TYPE_ASCII}, Frozen: false, }, query: "SELECT * FROM tbl WHERE pk0=?", @@ -203,8 +201,8 @@ var prettytests = []struct { expected: "SELECT * FROM tbl WHERE pk0='a'", }, { - typ: &typedef.TupleType{ - ValueTypes: []typedef.SimpleType{typedef.TYPE_ASCII, typedef.TYPE_ASCII}, + typ: &TupleType{ + ValueTypes: []SimpleType{TYPE_ASCII, TYPE_ASCII}, Frozen: false, }, query: "SELECT * FROM tbl WHERE pk0={?,?}", @@ -216,10 +214,15 @@ var prettytests = []struct { func TestCQLPretty(t *testing.T) { t.Parallel() - for _, p := range prettytests { - result := p.typ.CQLPretty(p.values) - if result != p.expected { - t.Fatalf("expected '%s', got '%s' for values %v and type '%v'", p.expected, result, p.values, p.typ) - } + for id := range prettytests { + test := prettytests[id] + t.Run(test.typ.Name(), func(t *testing.T) { + t.Parallel() + + result := prettyCQL(test.query, test.values, []Type{test.typ}) + if result != test.expected { + t.Errorf("expected '%s', got '%s' for values %v and type '%v'", test.expected, result, test.values, test.typ) + } + }) } }