diff --git a/.run/Run Gemini Mixed.run.xml b/.run/Run Gemini Mixed.run.xml index c6d7c6ca..1021a5e0 100644 --- a/.run/Run Gemini Mixed.run.xml +++ b/.run/Run Gemini Mixed.run.xml @@ -3,7 +3,7 @@ - + diff --git a/Makefile b/Makefile index 489c78e1..6c78e719 100644 --- a/Makefile +++ b/Makefile @@ -20,30 +20,39 @@ GEMINI_TEST_CLUSTER ?= $(shell docker inspect --format='{{ .NetworkSettings.Netw GEMINI_ORACLE_CLUSTER ?= $(shell docker inspect --format='{{ .NetworkSettings.Networks.gemini.IPAddress }}' gemini-oracle) GEMINI_DOCKER_NETWORK ?= gemini GEMINI_FLAGS ?= --fail-fast \ - --level=info \ - --non-interactive \ - --consistency=LOCAL_QUORUM \ - --test-host-selection-policy=token-aware \ - --oracle-host-selection-policy=token-aware \ - --mode=$(MODE) \ - --non-interactive \ - --request-timeout=5s \ - --connect-timeout=15s \ - --use-server-timestamps=false \ --async-objects-stabilization-attempts=10 \ - --max-mutation-retries=10 \ - --replication-strategy="{'class': 'NetworkTopologyStrategy', 'replication_factor': '1'}" \ - --oracle-replication-strategy="{'class': 'NetworkTopologyStrategy', 'replication_factor': '1'}" \ + --async-objects-stabilization-backoff=1s \ --concurrency=$(CONCURRENCY) \ - --dataset-size=$(DATASET_SIZE) \ + --cql-features=$(CQL_FEATURES) \ + --dataset-size=$(DATASET_SIZE) \ + --mode=$(MODE) \ + --duration=$(DURATION) \ --seed=$(SEED) \ --schema-seed=$(SEED) \ - --cql-features=$(CQL_FEATURES) \ - --duration=$(DURATION) \ --warmup=$(WARMUP) \ + --drop-schema=true \ + --level=info \ + --materialized-views=false \ + --async-objects-stabilization-attempts=10 \ --profiling-port=6060 \ - --drop-schema=true - + --token-range-slices=10000 \ + --partition-key-buffer-reuse-size=100 \ + --oracle-connect-timeout=15s \ + --oracle-request-timeout=5s \ + --oracle-consistency=LOCAL_QUORUM \ + --oracle-max-mutation-retries=10 \ + --oracle-max-mutation-retries-backoff=1s \ + --oracle-use-server-timestamps=false \ + --oracle-replication-strategy="{'class': 'NetworkTopologyStrategy', 'replication_factor': '1'}" \ + --oracle-host-selection-policy=token-aware \ + --test-consistency=LOCAL_QUORUM \ + --test-request-timeout=5s \ + --test-connect-timeout=15s \ + --test-use-server-timestamps=false \ + --test-host-selection-policy=token-aware \ + --test-replication-strategy="{'class': 'NetworkTopologyStrategy', 'replication_factor': '1'}" \ + --test-max-mutation-retries=10 \ + --test-dc=datacenter1 ifndef GOBIN export GOBIN := $(MAKEFILE_PATH)/bin diff --git a/cmd/gemini/main.go b/cmd/gemini/main.go index 5fbe1c29..949826b7 100644 --- a/cmd/gemini/main.go +++ b/cmd/gemini/main.go @@ -18,6 +18,8 @@ import ( "fmt" "os" "runtime/debug" + + _ "github.com/scylladb/gemini/pkg/metrics" ) //go:generate sh -c "git describe --tags --abbrev=0 | tr -d '\n' > ./Version" diff --git a/cmd/gemini/root.go b/cmd/gemini/root.go index 6d2a22d0..ad6cdbc2 100644 --- a/cmd/gemini/root.go +++ b/cmd/gemini/root.go @@ -17,22 +17,21 @@ package main import ( "encoding/json" "fmt" + "io" "log" "math" "net/http" "net/http/pprof" "os" "os/signal" + "runtime" "strconv" "strings" "syscall" "text/tabwriter" "time" - "github.com/gocql/gocql" - "github.com/hailocab/go-hostpool" "github.com/pkg/errors" - "github.com/prometheus/client_golang/prometheus/promhttp" "github.com/spf13/cobra" "go.uber.org/zap" "go.uber.org/zap/zapcore" @@ -40,42 +39,38 @@ import ( "golang.org/x/net/context" "gonum.org/v1/gonum/stat/distuv" - "github.com/scylladb/gemini/pkg/auth" - "github.com/scylladb/gemini/pkg/builders" "github.com/scylladb/gemini/pkg/generators" "github.com/scylladb/gemini/pkg/jobs" "github.com/scylladb/gemini/pkg/realrandom" "github.com/scylladb/gemini/pkg/replication" "github.com/scylladb/gemini/pkg/status" "github.com/scylladb/gemini/pkg/store" + "github.com/scylladb/gemini/pkg/store/drivers" "github.com/scylladb/gemini/pkg/typedef" "github.com/scylladb/gemini/pkg/utils" ) var ( - testClusterHost []string - testClusterUsername string - testClusterPassword string - oracleClusterHost []string - oracleClusterUsername string - oracleClusterPassword string + level string + profilingPort int + mode string + warmup time.Duration + duration time.Duration + verbose bool + failFast bool + concurrency uint64 + + oracleConfig drivers.CQLConfig + testConfig drivers.CQLConfig + testReplicationStrategy string + oracleReplicationStrategy string + schemaFile string outFileArg string - concurrency uint64 seed string schemaSeed string dropSchema bool - verbose bool - mode string - failFast bool - nonInteractive bool - duration time.Duration - bind string - warmup time.Duration - replicationStrategy string tableOptions []string - oracleReplicationStrategy string - consistency string maxTables int maxPartitionKeys int minPartitionKeys int @@ -86,63 +81,24 @@ var ( datasetSize string cqlFeatures string useMaterializedViews bool - level string - maxRetriesMutate int - maxRetriesMutateSleep time.Duration maxErrorsToStore int pkBufferReuseSize uint64 partitionCount uint64 partitionKeyDistribution string normalDistMean float64 normalDistSigma float64 - tracingOutFile string useCounters bool asyncObjectStabilizationAttempts int asyncObjectStabilizationDelay time.Duration useLWT bool - testClusterHostSelectionPolicy string - oracleClusterHostSelectionPolicy string - useServerSideTimestamps bool - requestTimeout time.Duration - connectTimeout time.Duration - profilingPort int - testStatementLogFile string - oracleStatementLogFile string ) -func interactive() bool { - return !nonInteractive -} - -func readSchema(confFile string, schemaConfig typedef.SchemaConfig) (*typedef.Schema, error) { - byteValue, err := os.ReadFile(confFile) - if err != nil { - return nil, err - } - - var shm typedef.Schema - - err = json.Unmarshal(byteValue, &shm) - if err != nil { - return nil, err - } - - schemaBuilder := builders.NewSchemaBuilder() - schemaBuilder.Keyspace(shm.Keyspace).Config(schemaConfig) - for t, tbl := range shm.Tables { - shm.Tables[t].LinkIndexAndColumns() - schemaBuilder.Table(tbl) - } - return schemaBuilder.Build(), nil -} - func run(_ *cobra.Command, _ []string) error { ctx, cancel := signal.NotifyContext(context.Background(), os.Interrupt, syscall.SIGABRT, syscall.SIGTERM, syscall.SIGINT) defer cancel() logger := createLogger(level) globalStatus := status.NewGlobalStatus(int32(maxErrorsToStore)) - defer utils.IgnoreError(logger.Sync) if err := validateSeed(seed); err != nil { return errors.Wrapf(err, "failed to parse --seed argument") @@ -156,26 +112,6 @@ func run(_ *cobra.Command, _ []string) error { rand.Seed(intSeed) - cons, err := gocql.ParseConsistencyWrapper(consistency) - if err != nil { - logger.Error("Unable parse consistency, error=%s. Falling back on Quorum", zap.Error(err)) - cons = gocql.Quorum - } - - testHostSelectionPolicy, err := getHostSelectionPolicy(testClusterHostSelectionPolicy, testClusterHost) - if err != nil { - return err - } - oracleHostSelectionPolicy, err := getHostSelectionPolicy(oracleClusterHostSelectionPolicy, oracleClusterHost) - if err != nil { - return err - } - - go func() { - http.Handle("/metrics", promhttp.Handler()) - _ = http.ListenAndServe(bind, nil) - }() - if profilingPort != 0 { go func() { mux := http.NewServeMux() @@ -184,19 +120,18 @@ func run(_ *cobra.Command, _ []string) error { }() } - outFile, err := createFile(outFileArg, os.Stdout) - if err != nil { - return err - } - defer utils.IgnoreError(outFile.Sync) - schemaConfig := createSchemaConfig(logger) - if err = schemaConfig.Valid(); err != nil { + if err := schemaConfig.Valid(); err != nil { return errors.Wrap(err, "invalid schema configuration") } - var schema *typedef.Schema + + var ( + schema *typedef.Schema + err error + ) + if len(schemaFile) > 0 { - schema, err = readSchema(schemaFile, schemaConfig) + schema, err = typedef.NewSchemaFromFile(schemaFile, schemaConfig) if err != nil { return errors.Wrap(err, "cannot create schema") } @@ -207,40 +142,35 @@ func run(_ *cobra.Command, _ []string) error { } } - jsonSchema, _ := json.MarshalIndent(schema, "", " ") + printSetup(intSeed, intSchemaSeed, schema) - printSetup(intSeed, intSchemaSeed) - fmt.Printf("Schema: %v\n", string(jsonSchema)) + var oracle store.Driver + if len(oracleConfig.Hosts) > 0 { + oracle, err = drivers.NewCQL(ctx, "oracle", schema, oracleConfig, logger.Named("oracle_store")) + if err != nil { + return errors.Wrap(err, "failed to create oracle store") + } - testCluster, oracleCluster := createClusters(cons, testHostSelectionPolicy, oracleHostSelectionPolicy, logger) - storeConfig := store.Config{ - MaxRetriesMutate: maxRetriesMutate, - MaxRetriesMutateSleep: maxRetriesMutateSleep, - UseServerSideTimestamps: useServerSideTimestamps, - TestLogStatementsFile: testStatementLogFile, - OracleLogStatementsFile: oracleStatementLogFile, - } - var tracingFile *os.File - if tracingOutFile != "" { - switch tracingOutFile { - case "stderr": - tracingFile = os.Stderr - case "stdout": - tracingFile = os.Stdout - default: - tf, ioErr := createFile(tracingOutFile, os.Stdout) - if ioErr != nil { - return ioErr + defer func() { + if closer, ok := oracle.(io.Closer); ok { + utils.IgnoreError(closer.Close) } - tracingFile = tf - defer utils.IgnoreError(tracingFile.Sync) - } + }() + + } else { + oracle = drivers.NewNop() } - st, err := store.New(schema, testCluster, oracleCluster, storeConfig, tracingFile, logger) + + test, err := drivers.NewCQL(ctx, "test", schema, testConfig, logger.Named("test_store")) + if err != nil { + return errors.Wrap(err, "failed to create oracle store") + } + defer utils.IgnoreError(test.Close) + + st, err := store.New(logger, test, oracle) if err != nil { return err } - defer utils.IgnoreError(st.Close) if dropSchema && mode != jobs.ReadMode { for _, stmt := range generators.GetDropKeyspace(schema) { @@ -276,21 +206,6 @@ func run(_ *cobra.Command, _ []string) error { gens := generators.New(ctx, schema, distFunc, schemaConfig.GetPartitionRangeConfig(), intSeed, partitionCount, pkBufferReuseSize, logger) defer utils.IgnoreError(gens.Close) - if !nonInteractive { - sp := createSpinner(interactive()) - ticker := time.NewTicker(time.Second) - go func() { - for { - select { - case <-ctx.Done(): - return - case <-ticker.C: - sp.Set(" Running Gemini... %v", globalStatus) - } - } - }() - } - if warmup > 0 { warmupCtx, warmupCancel := context.WithTimeout(ctx, warmup) defer warmupCancel() @@ -310,6 +225,12 @@ func run(_ *cobra.Command, _ []string) error { } logger.Info("test finished") + + outFile, err := utils.CreateFile(outFileArg, os.Stdout) + if err != nil { + return err + } + globalStatus.PrintResult(outFile, schema, version) if globalStatus.HasErrors() { return errors.Errorf("gemini encountered errors, exiting with non zero status") @@ -317,17 +238,6 @@ func run(_ *cobra.Command, _ []string) error { return nil } -func createFile(fname string, def *os.File) (*os.File, error) { - if fname != "" { - f, err := os.Create(fname) - if err != nil { - return nil, errors.Wrapf(err, "Unable to open output file %s", fname) - } - return f, nil - } - return def, nil -} - const ( stdDistMean = math.MaxUint64 / 2 oneStdDev = 0.341 * math.MaxUint64 @@ -360,70 +270,47 @@ func createDistributionFunc(distribution string, size, seed uint64, mu, sigma fl } func createLogger(level string) *zap.Logger { - lvl := zap.NewAtomicLevel() - if err := lvl.UnmarshalText([]byte(level)); err != nil { - lvl.SetLevel(zap.InfoLevel) + lvl, err := zap.ParseAtomicLevel(level) + if err != nil { + lvl = zap.NewAtomicLevelAt(zap.InfoLevel) } - encoderCfg := zap.NewDevelopmentEncoderConfig() + + encoderCfg := zap.NewProductionEncoderConfig() + encoderCfg.EncodeName = zapcore.FullNameEncoder + encoderCfg.EncodeLevel = zapcore.LowercaseLevelEncoder + encoderCfg.EncodeTime = zapcore.RFC3339TimeEncoder + encoderCfg.EncodeDuration = zapcore.MillisDurationEncoder + encoderCfg.LevelKey = "L" + encoderCfg.MessageKey = "M" + encoderCfg.TimeKey = "T" + encoderCfg.CallerKey = "C" + logger := zap.New(zapcore.NewCore( zapcore.NewJSONEncoder(encoderCfg), zapcore.Lock(os.Stdout), lvl, )) - return logger -} -func createClusters( - consistency gocql.Consistency, - testHostSelectionPolicy, oracleHostSelectionPolicy gocql.HostSelectionPolicy, - logger *zap.Logger, -) (*gocql.ClusterConfig, *gocql.ClusterConfig) { - retryPolicy := &gocql.ExponentialBackoffRetryPolicy{ - Min: time.Second, - Max: 60 * time.Second, - NumRetries: 5, - } - testCluster := gocql.NewCluster(testClusterHost...) - testCluster.Timeout = requestTimeout - testCluster.ConnectTimeout = connectTimeout - testCluster.RetryPolicy = retryPolicy - testCluster.Consistency = consistency - testCluster.PoolConfig.HostSelectionPolicy = testHostSelectionPolicy - testAuthenticator, testAuthErr := auth.BuildAuthenticator(testClusterUsername, testClusterPassword) - if testAuthErr != nil { - logger.Warn("%s for test cluster", zap.Error(testAuthErr)) - } - testCluster.Authenticator = testAuthenticator - if len(oracleClusterHost) == 0 { - return testCluster, nil - } - oracleCluster := gocql.NewCluster(oracleClusterHost...) - testCluster.Timeout = requestTimeout - testCluster.ConnectTimeout = connectTimeout - oracleCluster.RetryPolicy = retryPolicy - oracleCluster.Consistency = consistency - oracleCluster.PoolConfig.HostSelectionPolicy = oracleHostSelectionPolicy - oracleAuthenticator, oracleAuthErr := auth.BuildAuthenticator(oracleClusterUsername, oracleClusterPassword) - if oracleAuthErr != nil { - logger.Warn("%s for oracle cluster", zap.Error(oracleAuthErr)) - } - oracleCluster.Authenticator = oracleAuthenticator - return testCluster, oracleCluster + runtime.SetFinalizer(logger, func(l *zap.Logger) { + utils.IgnoreError(logger.Sync) + }) + + return logger } func getReplicationStrategy(rs string, fallback *replication.Replication, logger *zap.Logger) *replication.Replication { - switch rs { + switch strings.ToLower(rs) { case "network": return replication.NewNetworkTopologyStrategy() case "simple": return replication.NewSimpleStrategy() default: - replicationStrategy := &replication.Replication{} - if err := json.Unmarshal([]byte(strings.ReplaceAll(rs, "'", "\"")), replicationStrategy); err != nil { + rf := &replication.Replication{} + if err := json.Unmarshal([]byte(strings.ReplaceAll(rs, "'", "\"")), rf); err != nil { logger.Error("unable to parse replication strategy", zap.String("strategy", rs), zap.Error(err)) return fallback } - return replicationStrategy + return rf } } @@ -438,19 +325,6 @@ func getCQLFeature(feature string) typedef.CQLFeature { } } -func getHostSelectionPolicy(policy string, hosts []string) (gocql.HostSelectionPolicy, error) { - switch policy { - case "round-robin": - return gocql.RoundRobinHostPolicy(), nil - case "host-pool": - return gocql.HostPoolHostPolicy(hostpool.New(hosts)), nil - case "token-aware": - return gocql.TokenAwareHostPolicy(gocql.RoundRobinHostPolicy()), nil - default: - return nil, fmt.Errorf("unknown host selection policy \"%s\"", policy) - } -} - var rootCmd = &cobra.Command{ Use: "gemini", Short: "Gemini is an automatic random testing tool for Scylla.", @@ -460,38 +334,52 @@ var rootCmd = &cobra.Command{ func init() { rootCmd.Version = version + ", commit " + commit + ", date " + date - rootCmd.Flags().StringSliceVarP(&testClusterHost, "test-cluster", "t", []string{}, "Host names or IPs of the test cluster that is system under test") - _ = rootCmd.MarkFlagRequired("test-cluster") - rootCmd.Flags().StringVarP(&testClusterUsername, "test-username", "", "", "Username for the test cluster") - rootCmd.Flags().StringVarP(&testClusterPassword, "test-password", "", "", "Password for the test cluster") - rootCmd.Flags().StringSliceVarP( - &oracleClusterHost, "oracle-cluster", "o", []string{}, - "Host names or IPs of the oracle cluster that provides correct answers. If omitted no oracle will be used") - rootCmd.Flags().StringVarP(&oracleClusterUsername, "oracle-username", "", "", "Username for the oracle cluster") - rootCmd.Flags().StringVarP(&oracleClusterPassword, "oracle-password", "", "", "Password for the oracle cluster") - rootCmd.Flags().StringVarP(&schemaFile, "schema", "", "", "Schema JSON config file") + + rootCmd.Flags().IntVarP(&profilingPort, "profiling-port", "", 0, "If non-zero starts pprof profiler on given port at 'http://0.0.0.0:/profile'") rootCmd.Flags().StringVarP(&mode, "mode", "m", jobs.MixedMode, "Query operation mode. Mode options: write, read, mixed (default)") rootCmd.Flags().Uint64VarP(&concurrency, "concurrency", "c", 10, "Number of threads per table to run concurrently") rootCmd.Flags().StringVarP(&seed, "seed", "s", "random", "Statement seed value") rootCmd.Flags().StringVarP(&schemaSeed, "schema-seed", "", "random", "Schema seed value") - rootCmd.Flags().BoolVarP(&dropSchema, "drop-schema", "d", false, "Drop schema before starting tests run") rootCmd.Flags().BoolVarP(&verbose, "verbose", "v", false, "Verbose output during test run") rootCmd.Flags().BoolVarP(&failFast, "fail-fast", "f", false, "Stop on the first failure") - rootCmd.Flags().BoolVarP(&nonInteractive, "non-interactive", "", false, "Run in non-interactive mode (disable progress indicator)") + rootCmd.Flags().StringVarP(&level, "level", "", "info", "Specify the logging level, debug|info|warn|error|dpanic|panic|fatal") + + rootCmd.Flags().StringSliceVarP(&testConfig.Hosts, "test-cluster", "t", []string{}, "Host names or IPs of the test cluster that is system under test") + rootCmd.Flags().StringVarP(&testConfig.Trace, "test-tracing-outfile", "", "", "Specify the file to which tracing information gets written. Two magic names are available, 'stdout' and 'stderr'. By default tracing is disabled.") + rootCmd.Flags().StringVarP(&testConfig.Consistency, "test-consistency", "", "LOCAL_QUORUM", "Specify the desired consistency as ANY|ONE|TWO|THREE|QUORUM|LOCAL_QUORUM|EACH_QUORUM|LOCAL_ONE") + rootCmd.Flags().StringVarP(&testConfig.DC, "test-dc", "", "", "Datacenter name for the test cluster") + rootCmd.Flags().StringVarP(&testConfig.HostSelectionPolicy, "test-host-selection-policy", "", "token-aware", "Host selection policy used by the driver for the test cluster: round-robin|host-pool|token-aware") + rootCmd.Flags().StringVarP(&testConfig.Username, "test-username", "", "", "Username for the test cluster") + rootCmd.Flags().StringVarP(&testConfig.Password, "test-password", "", "", "Password for the test cluster") + rootCmd.Flags().StringVarP(&testConfig.StatementLog, "test-statement-log-file", "", "", "File to write statements flow to") + rootCmd.Flags().DurationVarP(&testConfig.RequestTimeout, "test-request-timeout", "", 30*time.Second, "Duration of waiting request execution") + rootCmd.Flags().DurationVarP(&testConfig.ConnectTimeout, "test-connect-timeout", "", 30*time.Second, "Duration of waiting connection established") + rootCmd.Flags().DurationVarP(&testConfig.MaxRetriesMutateSleep, "test-max-mutation-retries-backoff", "", 10*time.Millisecond, "Duration between attempts to apply a mutation for example 10ms or 1s") + rootCmd.Flags().IntVarP(&testConfig.MaxRetriesMutate, "test-max-mutation-retries", "", 2, "Maximum number of attempts to apply a mutation") + rootCmd.Flags().BoolVarP(&testConfig.UseServerSideTimestamps, "test-use-server-timestamps", "", false, "Use server-side generated timestamps for writes") + rootCmd.Flags().StringVarP(&testReplicationStrategy, "test-replication-strategy", "", "simple", "Specify the desired replication strategy as either the coded short hand simple|network to get the default for each type or provide the entire specification in the form {'class':'....'}") + + rootCmd.Flags().StringSliceVarP(&oracleConfig.Hosts, "oracle-cluster", "o", []string{}, "Host names or IPs of the oracle cluster that provides correct answers. If omitted no oracle will be used") + rootCmd.Flags().StringVarP(&oracleConfig.Trace, "oracle-tracing-outfile", "", "", "Specify the file to which tracing information gets written. Two magic names are available, 'stdout' and 'stderr'. By default tracing is disabled.") + rootCmd.Flags().StringVarP(&oracleConfig.Consistency, "oracle-consistency", "", "LOCAL_QUORUM", "Specify the desired consistency as ANY|ONE|TWO|THREE|QUORUM|LOCAL_QUORUM|EACH_QUORUM|LOCAL_ONE") + rootCmd.Flags().StringVarP(&oracleConfig.DC, "oracle-dc", "", "", "Datacenter name for the oracle cluster") + rootCmd.Flags().StringVarP(&oracleConfig.HostSelectionPolicy, "oracle-host-selection-policy", "", "token-aware", "Host selection policy used by the driver for the oracle cluster: round-robin|host-pool|token-aware") + rootCmd.Flags().StringVarP(&oracleConfig.Username, "oracle-username", "", "", "Username for the oracle cluster") + rootCmd.Flags().StringVarP(&oracleConfig.Password, "oracle-password", "", "", "Password for the oracle cluster") + rootCmd.Flags().StringVarP(&oracleConfig.StatementLog, "oracle-statement-log-file", "", "", "File to write statements flow to") + rootCmd.Flags().DurationVarP(&oracleConfig.RequestTimeout, "oracle-request-timeout", "", 30*time.Second, "Duration of waiting request execution") + rootCmd.Flags().DurationVarP(&oracleConfig.ConnectTimeout, "oracle-connect-timeout", "", 30*time.Second, "Duration of waiting connection established") + rootCmd.Flags().DurationVarP(&oracleConfig.MaxRetriesMutateSleep, "oracle-max-mutation-retries-backoff", "", 10*time.Millisecond, "Duration between attempts to apply a mutation for example 10ms or 1s") + rootCmd.Flags().IntVarP(&oracleConfig.MaxRetriesMutate, "oracle-max-mutation-retries", "", 2, "Maximum number of attempts to apply a mutation") + rootCmd.Flags().BoolVarP(&oracleConfig.UseServerSideTimestamps, "oracle-use-server-timestamps", "", false, "Use server-side generated timestamps for writes") + rootCmd.Flags().StringVarP(&oracleReplicationStrategy, "oracle-replication-strategy", "", "simple", "Specify the desired replication strategy of the oracle cluster as either the coded short hand simple|network to get the default for each type or provide the entire specification in the form {'class':'....'}") + + rootCmd.Flags().StringVarP(&schemaFile, "schema", "", "", "Schema JSON config file") + rootCmd.Flags().BoolVarP(&dropSchema, "drop-schema", "d", false, "Drop schema before starting tests run") rootCmd.Flags().DurationVarP(&duration, "duration", "", 30*time.Second, "") rootCmd.Flags().StringVarP(&outFileArg, "outfile", "", "", "Specify the name of the file where the results should go") - rootCmd.Flags().StringVarP(&bind, "bind", "b", "0.0.0.0:2112", "Specify the interface and port which to bind prometheus metrics on. Default is ':2112'") rootCmd.Flags().DurationVarP(&warmup, "warmup", "", 30*time.Second, "Specify the warmup perid as a duration for example 30s or 10h") - rootCmd.Flags().StringVarP( - &replicationStrategy, "replication-strategy", "", "simple", - "Specify the desired replication strategy as either the coded short hand simple|network to get the default for each type or provide "+ - "the entire specification in the form {'class':'....'}") - rootCmd.Flags().StringVarP( - &oracleReplicationStrategy, "oracle-replication-strategy", "", "simple", - "Specify the desired replication strategy of the oracle cluster as either the coded short hand simple|network to get the default for each "+ - "type or provide the entire specification in the form {'class':'....'}") rootCmd.Flags().StringArrayVarP(&tableOptions, "table-options", "", []string{}, "Repeatable argument to set table options to be added to the created tables") - rootCmd.Flags().StringVarP(&consistency, "consistency", "", "LOCAL_QUORUM", "Specify the desired consistency as ANY|ONE|TWO|THREE|QUORUM|LOCAL_QUORUM|EACH_QUORUM|LOCAL_ONE") rootCmd.Flags().IntVarP(&maxTables, "max-tables", "", 1, "Maximum number of generated tables") rootCmd.Flags().IntVarP(&maxPartitionKeys, "max-partition-keys", "", 6, "Maximum number of generated partition keys") rootCmd.Flags().IntVarP(&minPartitionKeys, "min-partition-keys", "", 2, "Minimum number of generated partition keys") @@ -502,60 +390,42 @@ func init() { rootCmd.Flags().StringVarP(&datasetSize, "dataset-size", "", "large", "Specify the type of dataset size to use, small|large") rootCmd.Flags().StringVarP(&cqlFeatures, "cql-features", "", "basic", "Specify the type of cql features to use, basic|normal|all") rootCmd.Flags().BoolVarP(&useMaterializedViews, "materialized-views", "", false, "Run gemini with materialized views support") - rootCmd.Flags().StringVarP(&level, "level", "", "info", "Specify the logging level, debug|info|warn|error|dpanic|panic|fatal") - rootCmd.Flags().IntVarP(&maxRetriesMutate, "max-mutation-retries", "", 2, "Maximum number of attempts to apply a mutation") - rootCmd.Flags().DurationVarP( - &maxRetriesMutateSleep, "max-mutation-retries-backoff", "", 10*time.Millisecond, - "Duration between attempts to apply a mutation for example 10ms or 1s") rootCmd.Flags().Uint64VarP(&pkBufferReuseSize, "partition-key-buffer-reuse-size", "", 100, "Number of reused buffered partition keys") rootCmd.Flags().Uint64VarP(&partitionCount, "token-range-slices", "", 10000, "Number of slices to divide the token space into") - rootCmd.Flags().StringVarP( - &partitionKeyDistribution, "partition-key-distribution", "", "uniform", - "Specify the distribution from which to draw partition keys, supported values are currently uniform|normal|zipf") + rootCmd.Flags().StringVarP(&partitionKeyDistribution, "partition-key-distribution", "", "uniform", "Specify the distribution from which to draw partition keys, supported values are currently uniform|normal|zipf") rootCmd.Flags().Float64VarP(&normalDistMean, "normal-dist-mean", "", stdDistMean, "Mean of the normal distribution") rootCmd.Flags().Float64VarP(&normalDistSigma, "normal-dist-sigma", "", oneStdDev, "Sigma of the normal distribution, defaults to one standard deviation ~0.341") - rootCmd.Flags().StringVarP( - &tracingOutFile, "tracing-outfile", "", "", - "Specify the file to which tracing information gets written. Two magic names are available, 'stdout' and 'stderr'. By default tracing is disabled.") rootCmd.Flags().BoolVarP(&useCounters, "use-counters", "", false, "Ensure that at least one table is a counter table") - rootCmd.Flags().IntVarP( - &asyncObjectStabilizationAttempts, "async-objects-stabilization-attempts", "", 10, - "Maximum number of attempts to validate result sets from MV and SI") - rootCmd.Flags().DurationVarP( - &asyncObjectStabilizationDelay, "async-objects-stabilization-backoff", "", 10*time.Millisecond, - "Duration between attempts to validate result sets from MV and SI for example 10ms or 1s") + rootCmd.Flags().IntVarP(&asyncObjectStabilizationAttempts, "async-objects-stabilization-attempts", "", 10, "Maximum number of attempts to validate result sets from MV and SI") + rootCmd.Flags().DurationVarP(&asyncObjectStabilizationDelay, "async-objects-stabilization-backoff", "", 10*time.Millisecond, "Duration between attempts to validate result sets from MV and SI for example 10ms or 1s") rootCmd.Flags().BoolVarP(&useLWT, "use-lwt", "", false, "Emit LWT based updates") - rootCmd.Flags().StringVarP( - &oracleClusterHostSelectionPolicy, "oracle-host-selection-policy", "", "token-aware", - "Host selection policy used by the driver for the oracle cluster: round-robin|host-pool|token-aware") - rootCmd.Flags().StringVarP( - &testClusterHostSelectionPolicy, "test-host-selection-policy", "", "token-aware", - "Host selection policy used by the driver for the test cluster: round-robin|host-pool|token-aware") - rootCmd.Flags().BoolVarP(&useServerSideTimestamps, "use-server-timestamps", "", false, "Use server-side generated timestamps for writes") - rootCmd.Flags().DurationVarP(&requestTimeout, "request-timeout", "", 30*time.Second, "Duration of waiting request execution") - rootCmd.Flags().DurationVarP(&connectTimeout, "connect-timeout", "", 30*time.Second, "Duration of waiting connection established") - rootCmd.Flags().IntVarP(&profilingPort, "profiling-port", "", 0, "If non-zero starts pprof profiler on given port at 'http://0.0.0.0:/profile'") rootCmd.Flags().IntVarP(&maxErrorsToStore, "max-errors-to-store", "", 1000, "Maximum number of errors to store and output at the end") - rootCmd.Flags().StringVarP(&testStatementLogFile, "test-statement-log-file", "", "", "File to write statements flow to") - rootCmd.Flags().StringVarP(&oracleStatementLogFile, "oracle-statement-log-file", "", "", "File to write statements flow to") + + _ = rootCmd.MarkFlagRequired("test-cluster") } -func printSetup(seed, schemaSeed uint64) { +func printSetup(seed, schemaSeed uint64, schema *typedef.Schema) { + jsonSchema, _ := json.MarshalIndent(schema, "", " ") + tw := new(tabwriter.Writer) tw.Init(os.Stdout, 0, 8, 2, '\t', tabwriter.AlignRight) - fmt.Fprintf(tw, "Seed:\t%d\n", seed) - fmt.Fprintf(tw, "Schema seed:\t%d\n", schemaSeed) - fmt.Fprintf(tw, "Maximum duration:\t%s\n", duration) - fmt.Fprintf(tw, "Warmup duration:\t%s\n", warmup) - fmt.Fprintf(tw, "Concurrency:\t%d\n", concurrency) - fmt.Fprintf(tw, "Test cluster:\t%s\n", testClusterHost) - fmt.Fprintf(tw, "Oracle cluster:\t%s\n", oracleClusterHost) + _, _ = fmt.Fprintf(tw, "Seed:\t%d\n", seed) + _, _ = fmt.Fprintf(tw, "Schema seed:\t%d\n", schemaSeed) + _, _ = fmt.Fprintf(tw, "Maximum duration:\t%s\n", duration) + _, _ = fmt.Fprintf(tw, "Warmup duration:\t%s\n", warmup) + _, _ = fmt.Fprintf(tw, "Concurrency:\t%d\n", concurrency) + _, _ = fmt.Fprintf(tw, "Test cluster:\t%s\n", testConfig.Hosts) + _, _ = fmt.Fprintf(tw, "Oracle cluster:\t%s\n", oracleConfig.Hosts) + _, _ = fmt.Printf("Schema: \t%s\n", string(jsonSchema)) + if outFileArg == "" { - fmt.Fprintf(tw, "Output file:\t%s\n", "") + _, _ = fmt.Fprintf(tw, "Output file:\t%s\n", "") } else { - fmt.Fprintf(tw, "Output file:\t%s\n", outFileArg) + _, _ = fmt.Fprintf(tw, "Output file:\t%s\n", outFileArg) + } + if err := tw.Flush(); err != nil { + log.Printf("Failed to print setup: %v", err) } - tw.Flush() } func RealRandom() uint64 { @@ -582,10 +452,11 @@ func seedFromString(seed string) uint64 { func generateSchema(logger *zap.Logger, sc typedef.SchemaConfig, schemaSeed string) (schema *typedef.Schema, intSchemaSeed uint64, err error) { intSchemaSeed = seedFromString(schemaSeed) schema = generators.GenSchema(sc, intSchemaSeed) - err = schema.Validate(partitionCount) - if err == nil { + + if err = schema.Validate(partitionCount); err == nil { return schema, intSchemaSeed, nil } + if schemaSeed != "random" { // If user provided schema, allow to run it, but log warning logger.Warn(errors.Wrap(err, "validation failed, running this test could end up in error or stale gemini").Error()) diff --git a/cmd/gemini/schema.go b/cmd/gemini/schema.go index 68884a62..8ceed93c 100644 --- a/cmd/gemini/schema.go +++ b/cmd/gemini/schema.go @@ -64,7 +64,7 @@ func createDefaultSchemaConfig(logger *zap.Logger) typedef.SchemaConfig { MaxTupleParts = 20 MaxUDTParts = 20 ) - rs := getReplicationStrategy(replicationStrategy, replication.NewSimpleStrategy(), logger) + rs := getReplicationStrategy(testReplicationStrategy, replication.NewSimpleStrategy(), logger) ors := getReplicationStrategy(oracleReplicationStrategy, rs, logger) return typedef.SchemaConfig{ ReplicationStrategy: rs, diff --git a/pkg/auth/auth.go b/pkg/auth/auth.go index cf797d9d..70473323 100644 --- a/pkg/auth/auth.go +++ b/pkg/auth/auth.go @@ -20,6 +20,11 @@ import ( "github.com/gocql/gocql" ) +var ( + ErrUsernameNotProvided = errors.New("username not provided") + ErrorPasswordNotProvided = errors.New("password not provided") +) + // BuildAuthenticator : Returns a new gocql.PasswordAuthenticator // if both username and password are provided. func BuildAuthenticator(username, password string) (*gocql.PasswordAuthenticator, error) { @@ -34,7 +39,7 @@ func BuildAuthenticator(username, password string) (*gocql.PasswordAuthenticator return &authenticator, nil } if username != "" { - return nil, errors.New("Password not provided") + return nil, ErrorPasswordNotProvided } - return nil, errors.New("Username not provided") + return nil, ErrUsernameNotProvided } diff --git a/pkg/generators/statement_generator.go b/pkg/generators/statement_generator.go index 1755e513..1c280949 100644 --- a/pkg/generators/statement_generator.go +++ b/pkg/generators/statement_generator.go @@ -20,14 +20,13 @@ import ( "golang.org/x/exp/rand" - "github.com/scylladb/gemini/pkg/builders" "github.com/scylladb/gemini/pkg/typedef" "github.com/scylladb/gemini/pkg/utils" ) func GenSchema(sc typedef.SchemaConfig, seed uint64) *typedef.Schema { r := rand.New(rand.NewSource(seed)) - builder := builders.NewSchemaBuilder() + builder := typedef.NewSchemaBuilder() builder.Config(sc) keyspace := typedef.Keyspace{ Name: "ks1", @@ -36,10 +35,12 @@ func GenSchema(sc typedef.SchemaConfig, seed uint64) *typedef.Schema { } builder.Keyspace(keyspace) numTables := utils.RandInt2(r, 1, sc.GetMaxTables()) + for i := 0; i < numTables; i++ { - table := genTable(sc, fmt.Sprintf("table%d", i+1), r) + table := genTable(sc, fmt.Sprintf("table_%d", i+1), r) builder.Table(table) } + return builder.Build() } @@ -60,9 +61,11 @@ func genTable(sc typedef.SchemaConfig, tableName string, r *rand.Rand) *typedef. typedef.KnownIssuesJSONWithTuples: true, }, } + for _, option := range sc.TableOptions { table.TableOptions = append(table.TableOptions, option.ToCQL()) } + if sc.UseCounters { table.Columns = typedef.Columns{ { @@ -78,21 +81,17 @@ func genTable(sc typedef.SchemaConfig, tableName string, r *rand.Rand) *typedef. for i := 0; i < len(columns); i++ { columns[i] = &typedef.ColumnDef{Name: GenColumnName("col", i), Type: GenColumnType(len(columns), &sc, r)} } + table.Columns = columns - var indexes []typedef.IndexDef if sc.CQLFeature > typedef.CQL_FEATURE_BASIC && len(columns) > 0 { - indexes = CreateIndexesForColumn(&table, utils.RandInt2(r, 1, len(columns))) + table.Indexes = CreateIndexesForColumn(&table, utils.RandInt2(r, 1, len(columns))) } - table.Indexes = indexes - var mvs []typedef.MaterializedView if sc.CQLFeature > typedef.CQL_FEATURE_BASIC && sc.UseMaterializedViews && len(clusteringKeys) > 0 && columns.ValidColumnsForPrimaryKey().Len() != 0 { - mvs = CreateMaterializedViews(columns, table.Name, partitionKeys, clusteringKeys, r) + table.MaterializedViews = CreateMaterializedViews(columns, table.Name, partitionKeys, clusteringKeys, r) } - table.MaterializedViews = mvs - return &table } diff --git a/pkg/jobs/gen_ddl_stmt.go b/pkg/jobs/gen_ddl_stmt.go index a8866fd4..6550097a 100644 --- a/pkg/jobs/gen_ddl_stmt.go +++ b/pkg/jobs/gen_ddl_stmt.go @@ -20,7 +20,6 @@ import ( "golang.org/x/exp/rand" - "github.com/scylladb/gemini/pkg/builders" "github.com/scylladb/gemini/pkg/generators" "github.com/scylladb/gemini/pkg/typedef" ) @@ -60,16 +59,16 @@ func genAddColumnStmt(t *typedef.Table, keyspace string, column *typedef.ColumnD stmt := fmt.Sprintf(createType, keyspace, c.TypeName, strings.Join(typs, ",")) stmts = append(stmts, &typedef.Stmt{ StmtCache: &typedef.StmtCache{ - Query: &builders.AlterTableBuilder{ + Query: &typedef.AlterTableBuilder{ Stmt: stmt, }, }, }) } - stmt := "ALTER TABLE " + keyspace + "." + t.Name + " ADD " + column.Name + " " + column.Type.CQLDef() + stmt := "alter table " + keyspace + "." + t.Name + " ADD " + column.Name + " " + column.Type.CQLDef() stmts = append(stmts, &typedef.Stmt{ StmtCache: &typedef.StmtCache{ - Query: &builders.AlterTableBuilder{ + Query: &typedef.AlterTableBuilder{ Stmt: stmt, }, }, @@ -87,10 +86,10 @@ func genAddColumnStmt(t *typedef.Table, keyspace string, column *typedef.ColumnD func genDropColumnStmt(t *typedef.Table, keyspace string, column *typedef.ColumnDef) (*typedef.Stmts, error) { var stmts []*typedef.Stmt - stmt := "ALTER TABLE " + keyspace + "." + t.Name + " DROP " + column.Name + stmt := "alter table " + keyspace + "." + t.Name + " DROP " + column.Name stmts = append(stmts, &typedef.Stmt{ StmtCache: &typedef.StmtCache{ - Query: &builders.AlterTableBuilder{ + Query: &typedef.AlterTableBuilder{ Stmt: stmt, }, QueryType: typedef.DropColumnStatementType, diff --git a/pkg/metrics/metrics.go b/pkg/metrics/metrics.go new file mode 100644 index 00000000..54a166a6 --- /dev/null +++ b/pkg/metrics/metrics.go @@ -0,0 +1,29 @@ +package metrics + +import ( + "log" + "net/http" + + "github.com/prometheus/client_golang/prometheus" + "github.com/prometheus/client_golang/prometheus/promauto" + "github.com/prometheus/client_golang/prometheus/promhttp" +) + +var CQLRequests *prometheus.CounterVec + +func init() { + CQLRequests = promauto.NewCounterVec( + prometheus.CounterOpts{ + Name: "gemini_cql_requests", + Help: "How many CQL requests processed, partitioned by system and CQL query type aka 'method' (batch, delete, insert, update).", + }, + []string{"system", "method"}, + ) + + go func() { + http.Handle("/metrics", promhttp.Handler()) + if err := http.ListenAndServe("0.0.0.0:2121", nil); err != nil { + log.Fatalf("Failed to start metrics server: %v\n", err) + } + }() +} diff --git a/pkg/stmtlogger/filelogger.go b/pkg/stmtlogger/filelogger.go index d367bb88..f0561229 100644 --- a/pkg/stmtlogger/filelogger.go +++ b/pkg/stmtlogger/filelogger.go @@ -39,7 +39,7 @@ const ( ) type ( - StmtToFile interface { + Interface interface { LogStmt(stmt *typedef.Stmt, ts ...time.Time) error Close() error } @@ -55,7 +55,7 @@ type ( } ) -func NewFileLogger(filename string) (StmtToFile, error) { +func NewFileLogger(filename string) (Interface, error) { if filename == "" { return &nopFileLogger{}, nil } @@ -68,7 +68,7 @@ func NewFileLogger(filename string) (StmtToFile, error) { return NewLogger(fd) } -func NewLogger(w io.Writer) (StmtToFile, error) { +func NewLogger(w io.Writer) (Interface, error) { ctx, cancel := context.WithCancel(context.Background()) out := &logger{ diff --git a/pkg/store/cqlstore.go b/pkg/store/cqlstore.go deleted file mode 100644 index 6a2269ab..00000000 --- a/pkg/store/cqlstore.go +++ /dev/null @@ -1,137 +0,0 @@ -// Copyright 2019 ScyllaDB -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package store - -import ( - "context" - "io" - "time" - - "github.com/gocql/gocql" - "github.com/pkg/errors" - "github.com/prometheus/client_golang/prometheus" - "github.com/scylladb/gocqlx/v2/qb" - "go.uber.org/zap" - - "github.com/scylladb/gemini/pkg/stmtlogger" - "github.com/scylladb/gemini/pkg/typedef" -) - -type cqlStore struct { //nolint:govet - session *gocql.Session - schema *typedef.Schema - ops *prometheus.CounterVec - logger *zap.Logger - system string - maxRetriesMutate int - maxRetriesMutateSleep time.Duration - useServerSideTimestamps bool - stmtLogger stmtlogger.StmtToFile -} - -func (cs *cqlStore) name() string { - return cs.system -} - -func (cs *cqlStore) mutate(ctx context.Context, stmt *typedef.Stmt) error { - for range cs.maxRetriesMutate { - if err := cs.doMutate(ctx, stmt); err == nil { - cs.ops.WithLabelValues(cs.system, opType(stmt)).Inc() - return nil - } - - select { - case <-ctx.Done(): - return ctx.Err() - case <-time.After(cs.maxRetriesMutateSleep): - } - } - - return errors.Errorf("failed to mutate after %d retries", cs.maxRetriesMutate) -} - -func (cs *cqlStore) doMutate(ctx context.Context, stmt *typedef.Stmt) error { - queryBody, _ := stmt.Query.ToCql() - query := cs.session.Query(queryBody, stmt.Values...).WithContext(ctx).DefaultTimestamp(false) - defer query.Release() - - var ts time.Time - - if !cs.useServerSideTimestamps { - ts = time.Now() - query = query.WithTimestamp(ts.UnixMicro()) - } - - if err := query.Exec(); err != nil { - return errors.Wrapf(err, "[cluster = %s, query = '%s']", cs.system, queryBody) - } - - if err := cs.stmtLogger.LogStmt(stmt, ts); err != nil { - return err - } - - return nil -} - -func (cs *cqlStore) load(ctx context.Context, stmt *typedef.Stmt) ([]map[string]any, error) { - cql, _ := stmt.Query.ToCql() - - query := cs.session.Query(cql, stmt.Values...).WithContext(ctx) - defer query.Release() - - iter := query.Iter() - cs.ops.WithLabelValues(cs.system, opType(stmt)).Inc() - - result := loadSet(iter) - - if err := cs.stmtLogger.LogStmt(stmt); err != nil { - return nil, err - } - - return result, iter.Close() -} - -func (cs *cqlStore) close() error { - cs.session.Close() - return nil -} - -func newSession(cluster *gocql.ClusterConfig, out io.Writer) (*gocql.Session, error) { - session, err := cluster.CreateSession() - if err != nil { - return nil, err - } - if out != nil { - session.SetTrace(gocql.NewTraceWriter(session, out)) - } - return session, nil -} - -func opType(stmt *typedef.Stmt) string { - switch stmt.Query.(type) { - case *qb.InsertBuilder: - return "insert" - case *qb.DeleteBuilder: - return "delete" - case *qb.UpdateBuilder: - return "update" - case *qb.SelectBuilder: - return "select" - case *qb.BatchBuilder: - return "batch" - default: - return "unknown" - } -} diff --git a/pkg/store/drivers/noop.go b/pkg/store/drivers/noop.go new file mode 100644 index 00000000..d13ac9d8 --- /dev/null +++ b/pkg/store/drivers/noop.go @@ -0,0 +1,21 @@ +package drivers + +import ( + "context" + + "github.com/scylladb/gemini/pkg/typedef" +) + +type Nop struct{} + +func NewNop() Nop { + return Nop{} +} + +func (n Nop) Execute(context.Context, *typedef.Stmt) error { + return nil +} + +func (n Nop) Fetch(context.Context, *typedef.Stmt) ([]map[string]any, error) { + return nil, nil +} diff --git a/pkg/store/drivers/scylla.go b/pkg/store/drivers/scylla.go new file mode 100644 index 00000000..f1b1e770 --- /dev/null +++ b/pkg/store/drivers/scylla.go @@ -0,0 +1,294 @@ +// Copyright 2019 ScyllaDB +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package drivers + +import ( + "bufio" + "context" + "fmt" + "io" + "time" + + "github.com/gocql/gocql" + "github.com/hailocab/go-hostpool" + "github.com/pkg/errors" + "github.com/scylladb/gocqlx/v2/qb" + "go.uber.org/zap" + + "github.com/scylladb/gemini/pkg/auth" + "github.com/scylladb/gemini/pkg/metrics" + "github.com/scylladb/gemini/pkg/stmtlogger" + "github.com/scylladb/gemini/pkg/typedef" + "github.com/scylladb/gemini/pkg/utils" +) + +type ( + ScyllaDB struct { //nolint:govet + session *gocql.Session + schema *typedef.Schema + logger *zap.Logger + system string + + statementLog stmtlogger.Interface + maxRetriesMutateSleep time.Duration + maxRetriesMutate int + useServerSideTimestamps bool + } + + CQLConfig struct { + Hosts []string + Trace string + Consistency string + DC string + HostSelectionPolicy string + Username string + Password string + StatementLog string + RequestTimeout time.Duration + ConnectTimeout time.Duration + MaxRetriesMutateSleep time.Duration + MaxRetriesMutate int + UseServerSideTimestamps bool + } +) + +func NewCQL( + ctx context.Context, + name string, + schema *typedef.Schema, + cfg CQLConfig, + logger *zap.Logger, +) (*ScyllaDB, error) { + clusterConfig, err := createClusters(cfg, logger) + if err != nil { + return nil, err + } + + session, err := clusterConfig.CreateSession() + if err != nil { + return nil, errors.Wrapf(err, "failed to connect to %s cluster", name) + } + + if cfg.Trace != "" { + trace, err := cfg.getTraceLogFile() + if err != nil { + return nil, errors.Wrap(err, "failed to create trace file") + } + + session.SetTrace(gocql.NewTraceWriter(session, trace)) + } + + if err := session.AwaitSchemaAgreement(ctx); err != nil { + return nil, errors.Wrapf(err, "failed to await schema agreement for %s cluster", name) + } + + statementLog, err := cfg.getStatementLogFile() + if err != nil { + return nil, errors.Wrap(err, "failed to create statement logger") + } + + return &ScyllaDB{ + session: session, + schema: schema, + logger: logger.Named(name), + system: name, + statementLog: statementLog, + maxRetriesMutateSleep: cfg.MaxRetriesMutateSleep, + maxRetriesMutate: cfg.MaxRetriesMutate, + useServerSideTimestamps: cfg.UseServerSideTimestamps, + }, nil +} + +func (cs *ScyllaDB) Execute(ctx context.Context, stmt *typedef.Stmt) error { + for range cs.maxRetriesMutate { + if err := cs.doMutate(ctx, stmt); err == nil { + metrics.CQLRequests.WithLabelValues(cs.system, opType(stmt)).Inc() + return nil + } + + select { + case <-ctx.Done(): + return ctx.Err() + case <-time.After(cs.maxRetriesMutateSleep): + } + } + + return errors.Errorf("failed to Mutate after %d retries", cs.maxRetriesMutate) +} + +func (cs *ScyllaDB) doMutate(ctx context.Context, stmt *typedef.Stmt) error { + queryBody, _ := stmt.Query.ToCql() + query := cs.session.Query(queryBody, stmt.Values...).WithContext(ctx).DefaultTimestamp(false) + defer query.Release() + + var ts time.Time + + if !cs.useServerSideTimestamps { + ts = time.Now() + query = query.WithTimestamp(ts.UnixMicro()) + } + + if err := query.Exec(); err != nil { + return errors.Wrapf(err, "[cluster = %s, query = %s]", cs.system, queryBody) + } + + if err := cs.statementLog.LogStmt(stmt, ts); err != nil { + return err + } + + return nil +} + +func (cs *ScyllaDB) Fetch(ctx context.Context, stmt *typedef.Stmt) ([]map[string]any, error) { + cql, _ := stmt.Query.ToCql() + + query := cs.session. + Query(cql, stmt.Values...). + WithContext(ctx). + DefaultTimestamp(false) + + defer query.Release() + + metrics.CQLRequests.WithLabelValues(cs.system, opType(stmt)).Inc() + + result, err := loadSet(query.Iter()) + if err != nil { + return nil, errors.Wrapf(err, "[cluster = %s, query = %s]", cs.system, cql) + } + + if err := cs.statementLog.LogStmt(stmt); err != nil { + return nil, err + } + + return result, nil +} + +func (cs *ScyllaDB) Close() error { + cs.session.Close() + return nil +} + +func opType(stmt *typedef.Stmt) string { + switch stmt.Query.(type) { + case *qb.InsertBuilder: + return "insert" + case *qb.DeleteBuilder: + return "delete" + case *qb.UpdateBuilder: + return "update" + case *qb.SelectBuilder: + return "select" + case *qb.BatchBuilder: + return "batch" + default: + return "unknown" + } +} + +func createClusters(cfg CQLConfig, logger *zap.Logger) (*gocql.ClusterConfig, error) { + authenticator, err := auth.BuildAuthenticator(cfg.Username, cfg.Password) + if err != nil { + return nil, errors.Wrap(err, "failed to create authenticator") + } + + cons, err := gocql.ParseConsistencyWrapper(cfg.Consistency) + if err != nil { + return nil, errors.Wrapf(err, "failed to parse consistency: %s", cfg.Consistency) + } + + selectionPolicy, err := parseHostSelectionPolicy(cfg.HostSelectionPolicy, cfg.DC, cfg.Hosts) + if err != nil { + return nil, errors.Wrapf(err, "failed to parse host selection policy: %s", cfg.HostSelectionPolicy) + } + + cluster := gocql.NewCluster(cfg.Hosts...) + cluster.Timeout = cfg.RequestTimeout + cluster.ProtoVersion = 4 + cluster.ConnectTimeout = cfg.ConnectTimeout + cluster.Consistency = cons + cluster.PoolConfig.HostSelectionPolicy = selectionPolicy + cluster.DefaultTimestamp = false + cluster.Logger = zap.NewStdLog(logger) + cluster.MaxRoutingKeyInfo = 10_000 + cluster.MaxPreparedStmts = 10_000 + cluster.Authenticator = authenticator + cluster.RetryPolicy = &gocql.ExponentialBackoffRetryPolicy{ + Min: 500 * time.Millisecond, + Max: cfg.RequestTimeout, + NumRetries: cfg.MaxRetriesMutate, + } + cluster.ReconnectionPolicy = &gocql.ExponentialReconnectionPolicy{ + InitialInterval: 500 * time.Millisecond, + MaxRetries: 10, + MaxInterval: cfg.ConnectTimeout, + } + + return cluster, nil +} + +func parseHostSelectionPolicy(policy, dc string, hosts []string) (gocql.HostSelectionPolicy, error) { + switch policy { + case "round-robin": + if dc != "" { + return gocql.DCAwareRoundRobinPolicy(dc), nil + } + + return gocql.RoundRobinHostPolicy(), nil + case "host-pool": + return gocql.HostPoolHostPolicy(hostpool.New(hosts)), nil + case "token-aware": + if dc != "" { + return gocql.TokenAwareHostPolicy(gocql.DCAwareRoundRobinPolicy(dc)), nil + } + + return gocql.TokenAwareHostPolicy(gocql.RoundRobinHostPolicy()), nil + default: + return nil, fmt.Errorf("unknown host selection policy \"%s\"", policy) + } +} + +func loadSet(iter *gocql.Iter) ([]map[string]any, error) { + rows := make([]map[string]any, 0, iter.NumRows()) + + for { + row := make(map[string]any, len(iter.Columns())) + + if !iter.MapScan(row) { + break + } + + rows = append(rows, row) + } + + return rows, iter.Close() +} + +func (c CQLConfig) getStatementLogFile() (stmtlogger.Interface, error) { + writer, err := utils.CreateFile(c.StatementLog) + if err != nil { + return nil, err + } + + return stmtlogger.NewLogger(writer) +} + +func (c CQLConfig) getTraceLogFile() (io.Writer, error) { + writer, err := utils.CreateFile(c.StatementLog) + if err != nil { + return nil, err + } + + return bufio.NewWriterSize(writer, 8192), nil +} diff --git a/pkg/store/drivers/scylla_test.go b/pkg/store/drivers/scylla_test.go new file mode 100644 index 00000000..fc62d2f0 --- /dev/null +++ b/pkg/store/drivers/scylla_test.go @@ -0,0 +1 @@ +package drivers_test diff --git a/pkg/store/helpers.go b/pkg/store/helpers.go index 8dff43be..ffe40fb7 100644 --- a/pkg/store/helpers.go +++ b/pkg/store/helpers.go @@ -85,15 +85,3 @@ func lt(mi, mj map[string]any) bool { panic(msg) } } - -func loadSet(iter *gocql.Iter) []map[string]any { - var rows []map[string]any - for { - row := make(map[string]any) - if !iter.MapScan(row) { - break - } - rows = append(rows, row) - } - return rows -} diff --git a/pkg/store/store.go b/pkg/store/store.go index 8eeb3c38..4d3112ba 100644 --- a/pkg/store/store.go +++ b/pkg/store/store.go @@ -16,280 +16,29 @@ package store import ( "context" - "fmt" - "io" - "math/big" - "reflect" - "sort" - "sync" - "time" "go.uber.org/zap" - "github.com/gocql/gocql" - "github.com/google/go-cmp/cmp" - "github.com/google/go-cmp/cmp/cmpopts" - "github.com/pkg/errors" - "github.com/prometheus/client_golang/prometheus" - "github.com/prometheus/client_golang/prometheus/promauto" - "go.uber.org/multierr" - "gopkg.in/inf.v0" - - "github.com/scylladb/go-set/strset" - - "github.com/scylladb/gemini/pkg/stmtlogger" "github.com/scylladb/gemini/pkg/typedef" ) -type loader interface { - load(context.Context, *typedef.Stmt) ([]map[string]any, error) -} - -type storer interface { - mutate(context.Context, *typedef.Stmt) error -} - -type storeLoader interface { - storer - loader - close() error - name() string -} - -type Store interface { - Create(context.Context, *typedef.Stmt, *typedef.Stmt) error - Mutate(context.Context, *typedef.Stmt) error - Check(context.Context, *typedef.Table, *typedef.Stmt, bool) error - Close() error -} - -type Config struct { - TestLogStatementsFile string - OracleLogStatementsFile string - MaxRetriesMutate int - MaxRetriesMutateSleep time.Duration - UseServerSideTimestamps bool -} - -func New(schema *typedef.Schema, testCluster, oracleCluster *gocql.ClusterConfig, cfg Config, traceOut io.Writer, logger *zap.Logger) (Store, error) { - ops := promauto.NewCounterVec(prometheus.CounterOpts{ - Name: "gemini_cql_requests", - Help: "How many CQL requests processed, partitioned by system and CQL query type aka 'method' (batch, delete, insert, update).", - }, []string{"system", "method"}, - ) - - oracleStore, err := getStore("oracle", schema, oracleCluster, cfg, cfg.OracleLogStatementsFile, traceOut, logger, ops) - if err != nil { - return nil, err +type ( + Driver interface { + Execute(ctx context.Context, stmt *typedef.Stmt) error + Fetch(ctx context.Context, stmt *typedef.Stmt) ([]map[string]any, error) } - if testCluster == nil { - return nil, errors.New("test cluster is empty") - } - testStore, err := getStore("test", schema, testCluster, cfg, cfg.TestLogStatementsFile, traceOut, logger, ops) - if err != nil { - return nil, err + Store interface { + Create(context.Context, *typedef.Stmt, *typedef.Stmt) error + Mutate(context.Context, *typedef.Stmt) error + Check(context.Context, *typedef.Table, *typedef.Stmt, bool) error } +) - return &delegatingStore{ +func New(logger *zap.Logger, testStore, oracleStore Driver) (*ValidatingStore, error) { + return &ValidatingStore{ testStore: testStore, oracleStore: oracleStore, - validations: oracleStore != nil, logger: logger.Named("delegating_store"), }, nil } - -type noOpStore struct { - system string -} - -func (n *noOpStore) mutate(context.Context, *typedef.Stmt) error { - return nil -} - -func (n *noOpStore) load(context.Context, *typedef.Stmt) ([]map[string]any, error) { - return nil, nil -} - -func (n *noOpStore) Close() error { - return nil -} - -func (n *noOpStore) name() string { - return n.system -} - -func (n *noOpStore) close() error { - return nil -} - -type delegatingStore struct { - oracleStore storeLoader - testStore storeLoader - statementLogger stmtlogger.StmtToFile - logger *zap.Logger - validations bool -} - -func (ds delegatingStore) Create(ctx context.Context, testBuilder, oracleBuilder *typedef.Stmt) error { - if err := mutate(ctx, ds.oracleStore, oracleBuilder); err != nil { - return errors.Wrap(err, "oracle failed store creation") - } - - if err := mutate(ctx, ds.testStore, testBuilder); err != nil { - return errors.Wrap(err, "test failed store creation") - } - - if ds.statementLogger != nil { - if err := ds.statementLogger.LogStmt(testBuilder); err != nil { - return errors.Wrap(err, "failed to log test create statement") - } - } - - return nil -} - -func (ds delegatingStore) Mutate(ctx context.Context, stmt *typedef.Stmt) error { - var testErr error - var wg sync.WaitGroup - wg.Add(1) - - go func() { - defer wg.Done() - testErr = errors.Wrapf( - ds.testStore.mutate(ctx, stmt), - "unable to apply mutations to the %s store", ds.testStore.name()) - }() - - if oracleErr := ds.oracleStore.mutate(ctx, stmt); oracleErr != nil { - // Oracle failed, transition cannot take place - ds.logger.Info("oracle store failed mutation, transition to next state impossible so continuing with next mutation", zap.Error(oracleErr)) - return oracleErr - } - wg.Wait() - if testErr != nil { - // Test store failed, transition cannot take place - ds.logger.Info("test store failed mutation, transition to next state impossible so continuing with next mutation", zap.Error(testErr)) - return testErr - } - return nil -} - -func mutate(ctx context.Context, s storeLoader, stmt *typedef.Stmt) error { - if err := s.mutate(ctx, stmt); err != nil { - return errors.Wrapf(err, "unable to apply mutations to the %s store", s.name()) - } - return nil -} - -func (ds delegatingStore) Check(ctx context.Context, table *typedef.Table, stmt *typedef.Stmt, detailedDiff bool) error { - var testRows, oracleRows []map[string]any - var testErr, oracleErr error - var wg sync.WaitGroup - wg.Add(1) - - go func() { - defer wg.Done() - testRows, testErr = ds.testStore.load(ctx, stmt) - }() - oracleRows, oracleErr = ds.oracleStore.load(ctx, stmt) - if oracleErr != nil { - return errors.Wrapf(oracleErr, "unable to load check data from the oracle store") - } - wg.Wait() - if testErr != nil { - return errors.Wrapf(testErr, "unable to load check data from the test store") - } - if !ds.validations { - return nil - } - if len(testRows) == 0 && len(oracleRows) == 0 { - return nil - } - if len(testRows) != len(oracleRows) { - if !detailedDiff { - return fmt.Errorf("rows count differ (test store rows %d, oracle store rows %d, detailed information will be at last attempt)", len(testRows), len(oracleRows)) - } - testSet := strset.New(pks(table, testRows)...) - oracleSet := strset.New(pks(table, oracleRows)...) - missingInTest := strset.Difference(oracleSet, testSet).List() - missingInOracle := strset.Difference(testSet, oracleSet).List() - return fmt.Errorf("row count differ (test has %d rows, oracle has %d rows, test is missing rows: %s, oracle is missing rows: %s)", - len(testRows), len(oracleRows), missingInTest, missingInOracle) - } - if reflect.DeepEqual(testRows, oracleRows) { - return nil - } - if !detailedDiff { - return fmt.Errorf("test and oracle store have difference, detailed information will be at last attempt") - } - sort.SliceStable(testRows, func(i, j int) bool { - return lt(testRows[i], testRows[j]) - }) - sort.SliceStable(oracleRows, func(i, j int) bool { - return lt(oracleRows[i], oracleRows[j]) - }) - for i, oracleRow := range oracleRows { - testRow := testRows[i] - cmp.AllowUnexported() - diff := cmp.Diff(oracleRow, testRow, - cmpopts.SortMaps(func(x, y *inf.Dec) bool { - return x.Cmp(y) < 0 - }), - cmp.Comparer(func(x, y *inf.Dec) bool { - return x.Cmp(y) == 0 - }), cmp.Comparer(func(x, y *big.Int) bool { - return x.Cmp(y) == 0 - })) - if diff != "" { - return fmt.Errorf("rows differ (-%v +%v): %v", oracleRow, testRow, diff) - } - } - return nil -} - -func (ds delegatingStore) Close() (err error) { - if ds.statementLogger != nil { - err = multierr.Append(err, ds.statementLogger.Close()) - } - err = multierr.Append(err, ds.testStore.close()) - err = multierr.Append(err, ds.oracleStore.close()) - return -} - -func getStore( - name string, - schema *typedef.Schema, - clusterConfig *gocql.ClusterConfig, - cfg Config, - stmtLogFile string, - traceOut io.Writer, - logger *zap.Logger, - ops *prometheus.CounterVec, -) (out storeLoader, err error) { - if clusterConfig == nil { - return &noOpStore{ - system: name, - }, nil - } - oracleSession, err := newSession(clusterConfig, traceOut) - if err != nil { - return nil, errors.Wrapf(err, "failed to connect to %s cluster", name) - } - oracleFileLogger, err := stmtlogger.NewFileLogger(stmtLogFile) - if err != nil { - return nil, err - } - - return &cqlStore{ - session: oracleSession, - schema: schema, - system: name, - ops: ops, - maxRetriesMutate: cfg.MaxRetriesMutate + 10, - maxRetriesMutateSleep: cfg.MaxRetriesMutateSleep, - useServerSideTimestamps: cfg.UseServerSideTimestamps, - logger: logger.Named(name), - stmtLogger: oracleFileLogger, - }, nil -} diff --git a/pkg/store/validating_store.go b/pkg/store/validating_store.go new file mode 100644 index 00000000..9a2305eb --- /dev/null +++ b/pkg/store/validating_store.go @@ -0,0 +1,151 @@ +// Copyright 2019 ScyllaDB +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package store + +import ( + "context" + "fmt" + "math/big" + "reflect" + "sort" + "sync" + + "github.com/google/go-cmp/cmp" + "github.com/google/go-cmp/cmp/cmpopts" + "github.com/pkg/errors" + "github.com/scylladb/go-set/strset" + "go.uber.org/zap" + "gopkg.in/inf.v0" + + "github.com/scylladb/gemini/pkg/typedef" +) + +type ValidatingStore struct { + oracleStore Driver + testStore Driver + logger *zap.Logger +} + +func (ds ValidatingStore) Create(ctx context.Context, testBuilder, oracleBuilder *typedef.Stmt) error { + if err := ds.testStore.Execute(ctx, testBuilder); err != nil { + return errors.Wrap(err, "oracle failed store creation") + } + + if err := ds.testStore.Execute(ctx, oracleBuilder); err != nil { + return errors.Wrap(err, "oracle failed store creation") + } + + return nil +} + +func (ds ValidatingStore) Mutate(ctx context.Context, stmt *typedef.Stmt) error { + var ( + wg sync.WaitGroup + oracleErr error + ) + + if ds.oracleStore != nil { + wg.Add(1) + go func() { + defer wg.Done() + oracleErr = ds.oracleStore.Execute(ctx, stmt) + }() + } + + if err := ds.testStore.Execute(ctx, stmt); err != nil { + return errors.Wrap(err, "unable to apply mutations to the test store") + } + + wg.Wait() + + if oracleErr != nil { + return errors.Wrap(oracleErr, "unable to apply mutations to the oracle store") + } + + return nil +} + +func (ds ValidatingStore) Check(ctx context.Context, table *typedef.Table, stmt *typedef.Stmt, detailedDiff bool) error { + var ( + oracleErr error + oracleRows []map[string]any + wg sync.WaitGroup + ) + + if ds.oracleStore != nil { + wg.Add(1) + + go func() { + defer wg.Done() + oracleRows, oracleErr = ds.oracleStore.Fetch(ctx, stmt) + }() + } + + testRows, testErr := ds.testStore.Fetch(ctx, stmt) + + if testErr != nil { + return errors.Wrap(testErr, "unable to Load check data from the test store") + } + + wg.Wait() + + if oracleErr != nil { + return errors.Wrap(oracleErr, "unable to Load check data from the oracle store") + } + + if len(testRows) == 0 && len(oracleRows) == 0 { + return nil + } + if len(testRows) != len(oracleRows) { + if !detailedDiff { + return fmt.Errorf("rows count differ (test store rows %d, oracle store rows %d, detailed information will be at last attempt)", len(testRows), len(oracleRows)) + } + testSet := strset.New(pks(table, testRows)...) + oracleSet := strset.New(pks(table, oracleRows)...) + missingInTest := strset.Difference(oracleSet, testSet).List() + missingInOracle := strset.Difference(testSet, oracleSet).List() + return fmt.Errorf("row count differ (test has %d rows, oracle has %d rows, test is missing rows: %s, oracle is missing rows: %s)", + len(testRows), len(oracleRows), missingInTest, missingInOracle) + } + if reflect.DeepEqual(testRows, oracleRows) { + return nil + } + if !detailedDiff { + return fmt.Errorf("test and oracle store have difference, detailed information will be at last attempt") + } + sort.SliceStable(testRows, func(i, j int) bool { + return lt(testRows[i], testRows[j]) + }) + sort.SliceStable(oracleRows, func(i, j int) bool { + return lt(oracleRows[i], oracleRows[j]) + }) + for i, oracleRow := range oracleRows { + testRow := testRows[i] + cmp.AllowUnexported() + diff := cmp.Diff(oracleRow, testRow, + cmpopts.SortMaps(func(x, y *inf.Dec) bool { + return x.Cmp(y) < 0 + }), + cmp.Comparer(func(x, y *inf.Dec) bool { + return x.Cmp(y) == 0 + }), cmp.Comparer(func(x, y *big.Int) bool { + return x.Cmp(y) == 0 + })) + if diff != "" { + return fmt.Errorf("rows differ (-%v +%v): %v", oracleRow, testRow, diff) + } + } + return nil +} diff --git a/pkg/testutils/name_mappings.go b/pkg/testutils/name_mappings.go index d36bea16..dece82dc 100644 --- a/pkg/testutils/name_mappings.go +++ b/pkg/testutils/name_mappings.go @@ -22,7 +22,6 @@ import ( "golang.org/x/exp/rand" - "github.com/scylladb/gemini/pkg/builders" "github.com/scylladb/gemini/pkg/replication" "github.com/scylladb/gemini/pkg/routingkey" "github.com/scylladb/gemini/pkg/tableopts" @@ -176,7 +175,7 @@ func createTableOptions(cql string) []tableopts.Option { } func genTestSchema(sc typedef.SchemaConfig, table *typedef.Table) *typedef.Schema { - builder := builders.NewSchemaBuilder() + builder := typedef.NewSchemaBuilder() builder.Config(sc) keyspace := typedef.Keyspace{ Name: "ks1", diff --git a/pkg/builders/builders.go b/pkg/typedef/builders.go similarity index 57% rename from pkg/builders/builders.go rename to pkg/typedef/builders.go index 6ad39f1b..f66ee9ea 100644 --- a/pkg/builders/builders.go +++ b/pkg/typedef/builders.go @@ -12,18 +12,13 @@ // See the License for the specific language governing permissions and // limitations under the License. -package builders - -import ( - "github.com/scylladb/gemini/pkg/querycache" - "github.com/scylladb/gemini/pkg/typedef" -) +package typedef type SchemaBuilder interface { - Config(config typedef.SchemaConfig) SchemaBuilder - Keyspace(typedef.Keyspace) SchemaBuilder - Table(*typedef.Table) SchemaBuilder - Build() *typedef.Schema + Config(config SchemaConfig) SchemaBuilder + Keyspace(Keyspace) SchemaBuilder + Table(*Table) SchemaBuilder + Build() *Schema } type AlterTableBuilder struct { @@ -35,30 +30,30 @@ func (atb *AlterTableBuilder) ToCql() (string, []string) { } type schemaBuilder struct { - keyspace typedef.Keyspace - tables []*typedef.Table - config typedef.SchemaConfig + keyspace Keyspace + tables []*Table + config SchemaConfig } -func (s *schemaBuilder) Keyspace(keyspace typedef.Keyspace) SchemaBuilder { +func (s *schemaBuilder) Keyspace(keyspace Keyspace) SchemaBuilder { s.keyspace = keyspace return s } -func (s *schemaBuilder) Config(config typedef.SchemaConfig) SchemaBuilder { +func (s *schemaBuilder) Config(config SchemaConfig) SchemaBuilder { s.config = config return s } -func (s *schemaBuilder) Table(table *typedef.Table) SchemaBuilder { +func (s *schemaBuilder) Table(table *Table) SchemaBuilder { s.tables = append(s.tables, table) return s } -func (s *schemaBuilder) Build() *typedef.Schema { - out := &typedef.Schema{Keyspace: s.keyspace, Tables: s.tables, Config: s.config} +func (s *schemaBuilder) Build() *Schema { + out := &Schema{Keyspace: s.keyspace, Tables: s.tables, Config: s.config} for id := range s.tables { - s.tables[id].Init(out, querycache.New(out)) + s.tables[id].Init(out, New(out)) } return out } diff --git a/pkg/querycache/querycache.go b/pkg/typedef/querycache.go similarity index 62% rename from pkg/querycache/querycache.go rename to pkg/typedef/querycache.go index 0aa755ae..948d5290 100644 --- a/pkg/querycache/querycache.go +++ b/pkg/typedef/querycache.go @@ -12,18 +12,16 @@ // See the License for the specific language governing permissions and // limitations under the License. -package querycache +package typedef import ( "fmt" "sync" "github.com/scylladb/gocqlx/v2/qb" - - "github.com/scylladb/gemini/pkg/typedef" ) -type QueryCache [typedef.CacheArrayLen]*typedef.StmtCache +type QueryCache [CacheArrayLen]*StmtCache func (c QueryCache) Reset() { for id := range c { @@ -32,19 +30,19 @@ func (c QueryCache) Reset() { } type Cache struct { - schema *typedef.Schema - table *typedef.Table + schema *Schema + table *Table cache QueryCache mu sync.RWMutex } -func New(s *typedef.Schema) *Cache { +func New(s *Schema) *Cache { return &Cache{ schema: s, } } -func (c *Cache) BindToTable(t *typedef.Table) { +func (c *Cache) BindToTable(t *Table) { c.table = t } @@ -54,14 +52,14 @@ func (c *Cache) Reset() { c.cache.Reset() } -func (c *Cache) getQuery(qct typedef.StatementCacheType) *typedef.StmtCache { +func (c *Cache) getQuery(qct StatementCacheType) *StmtCache { c.mu.RLock() defer c.mu.RUnlock() rec := c.cache[qct] return rec } -func (c *Cache) GetQuery(qct typedef.StatementCacheType) *typedef.StmtCache { +func (c *Cache) GetQuery(qct StatementCacheType) *StmtCache { rec := c.getQuery(qct) if rec != nil { return rec @@ -73,35 +71,35 @@ func (c *Cache) GetQuery(qct typedef.StatementCacheType) *typedef.StmtCache { return rec } -type CacheBuilderFn func(s *typedef.Schema, t *typedef.Table) *typedef.StmtCache +type CacheBuilderFn func(s *Schema, t *Table) *StmtCache -type CacheBuilderFnMap map[typedef.StatementCacheType]CacheBuilderFn +type CacheBuilderFnMap map[StatementCacheType]CacheBuilderFn -func (m CacheBuilderFnMap) ToList() [typedef.CacheArrayLen]CacheBuilderFn { - out := [typedef.CacheArrayLen]CacheBuilderFn{} +func (m CacheBuilderFnMap) ToList() [CacheArrayLen]CacheBuilderFn { + out := [CacheArrayLen]CacheBuilderFn{} for idx, builderFn := range m { out[idx] = builderFn } for idx := range out { if out[idx] == nil { - panic(fmt.Sprintf("no builder for %s", typedef.StatementCacheType(idx).ToString())) + panic(fmt.Sprintf("no builder for %s", StatementCacheType(idx).ToString())) } } return out } var CacheBuilders = CacheBuilderFnMap{ - typedef.CacheInsert: genInsertStmtCache, - typedef.CacheInsertIfNotExists: genInsertIfNotExistsStmtCache, - typedef.CacheDelete: genDeleteStmtCache, - typedef.CacheUpdate: genUpdateStmtCache, + CacheInsert: genInsertStmtCache, + CacheInsertIfNotExists: genInsertIfNotExistsStmtCache, + CacheDelete: genDeleteStmtCache, + CacheUpdate: genUpdateStmtCache, }.ToList() func genInsertStmtCache( - s *typedef.Schema, - t *typedef.Table, -) *typedef.StmtCache { - allTypes := make([]typedef.Type, 0, t.PartitionKeys.Len()+t.ClusteringKeys.Len()+t.Columns.Len()) + s *Schema, + t *Table, +) *StmtCache { + allTypes := make([]Type, 0, t.PartitionKeys.Len()+t.ClusteringKeys.Len()+t.Columns.Len()) builder := qb.Insert(s.Keyspace.Name + "." + t.Name) for _, pk := range t.PartitionKeys { builder = builder.Columns(pk.Name) @@ -113,38 +111,38 @@ func genInsertStmtCache( } for _, col := range t.Columns { switch colType := col.Type.(type) { - case *typedef.TupleType: + case *TupleType: builder = builder.TupleColumn(col.Name, len(colType.ValueTypes)) default: builder = builder.Columns(col.Name) } allTypes = append(allTypes, col.Type) } - return &typedef.StmtCache{ + return &StmtCache{ Query: builder, Types: allTypes, - QueryType: typedef.InsertStatementType, + QueryType: InsertStatementType, } } func genInsertIfNotExistsStmtCache( - s *typedef.Schema, - t *typedef.Table, -) *typedef.StmtCache { + s *Schema, + t *Table, +) *StmtCache { out := genInsertStmtCache(s, t) out.Query = out.Query.(*qb.InsertBuilder).Unique() return out } -func genUpdateStmtCache(s *typedef.Schema, t *typedef.Table) *typedef.StmtCache { - allTypes := make([]typedef.Type, 0, t.PartitionKeys.Len()+t.ClusteringKeys.Len()+t.Columns.Len()) +func genUpdateStmtCache(s *Schema, t *Table) *StmtCache { + var allTypes []Type builder := qb.Update(s.Keyspace.Name + "." + t.Name) for _, cdef := range t.Columns { switch t := cdef.Type.(type) { - case *typedef.TupleType: + case *TupleType: builder = builder.SetTuple(cdef.Name, len(t.ValueTypes)) - case *typedef.CounterType: + case *CounterType: builder = builder.SetLit(cdef.Name, cdef.Name+"+1") continue default: @@ -161,16 +159,15 @@ func genUpdateStmtCache(s *typedef.Schema, t *typedef.Table) *typedef.StmtCache builder = builder.Where(qb.Eq(ck.Name)) allTypes = append(allTypes, ck.Type) } - return &typedef.StmtCache{ + return &StmtCache{ Query: builder, Types: allTypes, - QueryType: typedef.UpdateStatementType, + QueryType: UpdateStatementType, } } -func genDeleteStmtCache(s *typedef.Schema, t *typedef.Table) *typedef.StmtCache { - allTypes := make([]typedef.Type, 0, t.PartitionKeys.Len()+t.ClusteringKeys.Len()) - +func genDeleteStmtCache(s *Schema, t *Table) *StmtCache { + var allTypes []Type builder := qb.Delete(s.Keyspace.Name + "." + t.Name) for _, pk := range t.PartitionKeys { builder = builder.Where(qb.Eq(pk.Name)) @@ -183,9 +180,9 @@ func genDeleteStmtCache(s *typedef.Schema, t *typedef.Table) *typedef.StmtCache allTypes = append(allTypes, ck.Type, ck.Type) } - return &typedef.StmtCache{ + return &StmtCache{ Query: builder, Types: allTypes, - QueryType: typedef.DeleteStatementType, + QueryType: DeleteStatementType, } } diff --git a/pkg/typedef/schema.go b/pkg/typedef/schema.go index 32f36350..639d872d 100644 --- a/pkg/typedef/schema.go +++ b/pkg/typedef/schema.go @@ -16,6 +16,7 @@ package typedef import ( "encoding/json" + "os" "strconv" "github.com/pkg/errors" @@ -37,6 +38,28 @@ type Schema struct { Config SchemaConfig `json:"-"` } +func NewSchemaFromFile(confFile string, schemaConfig SchemaConfig) (*Schema, error) { + byteValue, err := os.ReadFile(confFile) + if err != nil { + return nil, err + } + + shm := &Schema{} + + if err = json.Unmarshal(byteValue, shm); err != nil { + return nil, err + } + + sb := NewSchemaBuilder() + sb.Keyspace(shm.Keyspace).Config(schemaConfig) + for t, tbl := range shm.Tables { + shm.Tables[t].LinkIndexAndColumns() + sb.Table(tbl) + } + + return sb.Build(), nil +} + func (s *Schema) GetHash() string { out, err := json.Marshal(s) if err != nil { diff --git a/pkg/typedef/schemaconfig.go b/pkg/typedef/schemaconfig.go index 707b90e5..e5537ff2 100644 --- a/pkg/typedef/schemaconfig.go +++ b/pkg/typedef/schemaconfig.go @@ -50,12 +50,15 @@ func (sc *SchemaConfig) Valid() error { if sc.MaxPartitionKeys <= sc.MinPartitionKeys { return ErrSchemaConfigInvalidRangePK } + if sc.MaxClusteringKeys <= sc.MinClusteringKeys { return ErrSchemaConfigInvalidRangeCK } + if sc.MaxColumns <= sc.MinColumns { return ErrSchemaConfigInvalidRangeCols } + return nil } diff --git a/pkg/typedef/table.go b/pkg/typedef/table.go index 1af82947..56d4cdd0 100644 --- a/pkg/typedef/table.go +++ b/pkg/typedef/table.go @@ -18,7 +18,7 @@ import ( "sync" ) -type QueryCache interface { +type QueryCacheInterface interface { GetQuery(qct StatementCacheType) *StmtCache Reset() BindToTable(t *Table) @@ -27,7 +27,7 @@ type QueryCache interface { type KnownIssues map[string]bool type Table struct { - queryCache QueryCache + queryCache QueryCacheInterface schema *Schema Name string `json:"name"` PartitionKeys Columns `json:"partition_keys"` @@ -83,7 +83,7 @@ func (t *Table) ResetQueryCache() { t.partitionKeysLenValues = 0 } -func (t *Table) Init(s *Schema, c QueryCache) { +func (t *Table) Init(s *Schema, c QueryCacheInterface) { t.schema = s t.queryCache = c t.queryCache.BindToTable(t) diff --git a/pkg/utils/utils.go b/pkg/utils/utils.go index 956d5297..ae74492c 100644 --- a/pkg/utils/utils.go +++ b/pkg/utils/utils.go @@ -17,10 +17,15 @@ package utils import ( "encoding/hex" "fmt" + "io" + "os" + "runtime" "strconv" "strings" "time" + "github.com/pkg/errors" + "github.com/gocql/gocql" "golang.org/x/exp/rand" ) @@ -107,3 +112,30 @@ func UUIDFromTime(rnd *rand.Rand) string { } return gocql.UUIDFromTime(RandDate(rnd)).String() } + +func CreateFile(input string, def ...io.Writer) (io.Writer, error) { + switch input { + case "": + if len(def) > 0 && def[0] != nil { + return def[0], nil + } + + return io.Discard, nil + case "stderr": + return os.Stderr, nil + case "stdout": + return os.Stdout, nil + default: + tracingWriter, err := os.OpenFile(input, os.O_CREATE|os.O_WRONLY|os.O_TRUNC, 0o644) + if err != nil { + return nil, errors.Wrapf(err, "failed to open file %s", input) + } + + runtime.SetFinalizer(tracingWriter, func(f *os.File) { + IgnoreError(f.Sync) + IgnoreError(f.Close) + }) + + return tracingWriter, nil + } +}