diff --git a/.run/Run Gemini Mixed.run.xml b/.run/Run Gemini Mixed.run.xml
index c6d7c6ca..1021a5e0 100644
--- a/.run/Run Gemini Mixed.run.xml
+++ b/.run/Run Gemini Mixed.run.xml
@@ -3,7 +3,7 @@
-
+
diff --git a/Makefile b/Makefile
index 489c78e1..6c78e719 100644
--- a/Makefile
+++ b/Makefile
@@ -20,30 +20,39 @@ GEMINI_TEST_CLUSTER ?= $(shell docker inspect --format='{{ .NetworkSettings.Netw
GEMINI_ORACLE_CLUSTER ?= $(shell docker inspect --format='{{ .NetworkSettings.Networks.gemini.IPAddress }}' gemini-oracle)
GEMINI_DOCKER_NETWORK ?= gemini
GEMINI_FLAGS ?= --fail-fast \
- --level=info \
- --non-interactive \
- --consistency=LOCAL_QUORUM \
- --test-host-selection-policy=token-aware \
- --oracle-host-selection-policy=token-aware \
- --mode=$(MODE) \
- --non-interactive \
- --request-timeout=5s \
- --connect-timeout=15s \
- --use-server-timestamps=false \
--async-objects-stabilization-attempts=10 \
- --max-mutation-retries=10 \
- --replication-strategy="{'class': 'NetworkTopologyStrategy', 'replication_factor': '1'}" \
- --oracle-replication-strategy="{'class': 'NetworkTopologyStrategy', 'replication_factor': '1'}" \
+ --async-objects-stabilization-backoff=1s \
--concurrency=$(CONCURRENCY) \
- --dataset-size=$(DATASET_SIZE) \
+ --cql-features=$(CQL_FEATURES) \
+ --dataset-size=$(DATASET_SIZE) \
+ --mode=$(MODE) \
+ --duration=$(DURATION) \
--seed=$(SEED) \
--schema-seed=$(SEED) \
- --cql-features=$(CQL_FEATURES) \
- --duration=$(DURATION) \
--warmup=$(WARMUP) \
+ --drop-schema=true \
+ --level=info \
+ --materialized-views=false \
+ --async-objects-stabilization-attempts=10 \
--profiling-port=6060 \
- --drop-schema=true
-
+ --token-range-slices=10000 \
+ --partition-key-buffer-reuse-size=100 \
+ --oracle-connect-timeout=15s \
+ --oracle-request-timeout=5s \
+ --oracle-consistency=LOCAL_QUORUM \
+ --oracle-max-mutation-retries=10 \
+ --oracle-max-mutation-retries-backoff=1s \
+ --oracle-use-server-timestamps=false \
+ --oracle-replication-strategy="{'class': 'NetworkTopologyStrategy', 'replication_factor': '1'}" \
+ --oracle-host-selection-policy=token-aware \
+ --test-consistency=LOCAL_QUORUM \
+ --test-request-timeout=5s \
+ --test-connect-timeout=15s \
+ --test-use-server-timestamps=false \
+ --test-host-selection-policy=token-aware \
+ --test-replication-strategy="{'class': 'NetworkTopologyStrategy', 'replication_factor': '1'}" \
+ --test-max-mutation-retries=10 \
+ --test-dc=datacenter1
ifndef GOBIN
export GOBIN := $(MAKEFILE_PATH)/bin
diff --git a/cmd/gemini/main.go b/cmd/gemini/main.go
index 5fbe1c29..949826b7 100644
--- a/cmd/gemini/main.go
+++ b/cmd/gemini/main.go
@@ -18,6 +18,8 @@ import (
"fmt"
"os"
"runtime/debug"
+
+ _ "github.com/scylladb/gemini/pkg/metrics"
)
//go:generate sh -c "git describe --tags --abbrev=0 | tr -d '\n' > ./Version"
diff --git a/cmd/gemini/root.go b/cmd/gemini/root.go
index 6d2a22d0..ad6cdbc2 100644
--- a/cmd/gemini/root.go
+++ b/cmd/gemini/root.go
@@ -17,22 +17,21 @@ package main
import (
"encoding/json"
"fmt"
+ "io"
"log"
"math"
"net/http"
"net/http/pprof"
"os"
"os/signal"
+ "runtime"
"strconv"
"strings"
"syscall"
"text/tabwriter"
"time"
- "github.com/gocql/gocql"
- "github.com/hailocab/go-hostpool"
"github.com/pkg/errors"
- "github.com/prometheus/client_golang/prometheus/promhttp"
"github.com/spf13/cobra"
"go.uber.org/zap"
"go.uber.org/zap/zapcore"
@@ -40,42 +39,38 @@ import (
"golang.org/x/net/context"
"gonum.org/v1/gonum/stat/distuv"
- "github.com/scylladb/gemini/pkg/auth"
- "github.com/scylladb/gemini/pkg/builders"
"github.com/scylladb/gemini/pkg/generators"
"github.com/scylladb/gemini/pkg/jobs"
"github.com/scylladb/gemini/pkg/realrandom"
"github.com/scylladb/gemini/pkg/replication"
"github.com/scylladb/gemini/pkg/status"
"github.com/scylladb/gemini/pkg/store"
+ "github.com/scylladb/gemini/pkg/store/drivers"
"github.com/scylladb/gemini/pkg/typedef"
"github.com/scylladb/gemini/pkg/utils"
)
var (
- testClusterHost []string
- testClusterUsername string
- testClusterPassword string
- oracleClusterHost []string
- oracleClusterUsername string
- oracleClusterPassword string
+ level string
+ profilingPort int
+ mode string
+ warmup time.Duration
+ duration time.Duration
+ verbose bool
+ failFast bool
+ concurrency uint64
+
+ oracleConfig drivers.CQLConfig
+ testConfig drivers.CQLConfig
+ testReplicationStrategy string
+ oracleReplicationStrategy string
+
schemaFile string
outFileArg string
- concurrency uint64
seed string
schemaSeed string
dropSchema bool
- verbose bool
- mode string
- failFast bool
- nonInteractive bool
- duration time.Duration
- bind string
- warmup time.Duration
- replicationStrategy string
tableOptions []string
- oracleReplicationStrategy string
- consistency string
maxTables int
maxPartitionKeys int
minPartitionKeys int
@@ -86,63 +81,24 @@ var (
datasetSize string
cqlFeatures string
useMaterializedViews bool
- level string
- maxRetriesMutate int
- maxRetriesMutateSleep time.Duration
maxErrorsToStore int
pkBufferReuseSize uint64
partitionCount uint64
partitionKeyDistribution string
normalDistMean float64
normalDistSigma float64
- tracingOutFile string
useCounters bool
asyncObjectStabilizationAttempts int
asyncObjectStabilizationDelay time.Duration
useLWT bool
- testClusterHostSelectionPolicy string
- oracleClusterHostSelectionPolicy string
- useServerSideTimestamps bool
- requestTimeout time.Duration
- connectTimeout time.Duration
- profilingPort int
- testStatementLogFile string
- oracleStatementLogFile string
)
-func interactive() bool {
- return !nonInteractive
-}
-
-func readSchema(confFile string, schemaConfig typedef.SchemaConfig) (*typedef.Schema, error) {
- byteValue, err := os.ReadFile(confFile)
- if err != nil {
- return nil, err
- }
-
- var shm typedef.Schema
-
- err = json.Unmarshal(byteValue, &shm)
- if err != nil {
- return nil, err
- }
-
- schemaBuilder := builders.NewSchemaBuilder()
- schemaBuilder.Keyspace(shm.Keyspace).Config(schemaConfig)
- for t, tbl := range shm.Tables {
- shm.Tables[t].LinkIndexAndColumns()
- schemaBuilder.Table(tbl)
- }
- return schemaBuilder.Build(), nil
-}
-
func run(_ *cobra.Command, _ []string) error {
ctx, cancel := signal.NotifyContext(context.Background(), os.Interrupt, syscall.SIGABRT, syscall.SIGTERM, syscall.SIGINT)
defer cancel()
logger := createLogger(level)
globalStatus := status.NewGlobalStatus(int32(maxErrorsToStore))
- defer utils.IgnoreError(logger.Sync)
if err := validateSeed(seed); err != nil {
return errors.Wrapf(err, "failed to parse --seed argument")
@@ -156,26 +112,6 @@ func run(_ *cobra.Command, _ []string) error {
rand.Seed(intSeed)
- cons, err := gocql.ParseConsistencyWrapper(consistency)
- if err != nil {
- logger.Error("Unable parse consistency, error=%s. Falling back on Quorum", zap.Error(err))
- cons = gocql.Quorum
- }
-
- testHostSelectionPolicy, err := getHostSelectionPolicy(testClusterHostSelectionPolicy, testClusterHost)
- if err != nil {
- return err
- }
- oracleHostSelectionPolicy, err := getHostSelectionPolicy(oracleClusterHostSelectionPolicy, oracleClusterHost)
- if err != nil {
- return err
- }
-
- go func() {
- http.Handle("/metrics", promhttp.Handler())
- _ = http.ListenAndServe(bind, nil)
- }()
-
if profilingPort != 0 {
go func() {
mux := http.NewServeMux()
@@ -184,19 +120,18 @@ func run(_ *cobra.Command, _ []string) error {
}()
}
- outFile, err := createFile(outFileArg, os.Stdout)
- if err != nil {
- return err
- }
- defer utils.IgnoreError(outFile.Sync)
-
schemaConfig := createSchemaConfig(logger)
- if err = schemaConfig.Valid(); err != nil {
+ if err := schemaConfig.Valid(); err != nil {
return errors.Wrap(err, "invalid schema configuration")
}
- var schema *typedef.Schema
+
+ var (
+ schema *typedef.Schema
+ err error
+ )
+
if len(schemaFile) > 0 {
- schema, err = readSchema(schemaFile, schemaConfig)
+ schema, err = typedef.NewSchemaFromFile(schemaFile, schemaConfig)
if err != nil {
return errors.Wrap(err, "cannot create schema")
}
@@ -207,40 +142,35 @@ func run(_ *cobra.Command, _ []string) error {
}
}
- jsonSchema, _ := json.MarshalIndent(schema, "", " ")
+ printSetup(intSeed, intSchemaSeed, schema)
- printSetup(intSeed, intSchemaSeed)
- fmt.Printf("Schema: %v\n", string(jsonSchema))
+ var oracle store.Driver
+ if len(oracleConfig.Hosts) > 0 {
+ oracle, err = drivers.NewCQL(ctx, "oracle", schema, oracleConfig, logger.Named("oracle_store"))
+ if err != nil {
+ return errors.Wrap(err, "failed to create oracle store")
+ }
- testCluster, oracleCluster := createClusters(cons, testHostSelectionPolicy, oracleHostSelectionPolicy, logger)
- storeConfig := store.Config{
- MaxRetriesMutate: maxRetriesMutate,
- MaxRetriesMutateSleep: maxRetriesMutateSleep,
- UseServerSideTimestamps: useServerSideTimestamps,
- TestLogStatementsFile: testStatementLogFile,
- OracleLogStatementsFile: oracleStatementLogFile,
- }
- var tracingFile *os.File
- if tracingOutFile != "" {
- switch tracingOutFile {
- case "stderr":
- tracingFile = os.Stderr
- case "stdout":
- tracingFile = os.Stdout
- default:
- tf, ioErr := createFile(tracingOutFile, os.Stdout)
- if ioErr != nil {
- return ioErr
+ defer func() {
+ if closer, ok := oracle.(io.Closer); ok {
+ utils.IgnoreError(closer.Close)
}
- tracingFile = tf
- defer utils.IgnoreError(tracingFile.Sync)
- }
+ }()
+
+ } else {
+ oracle = drivers.NewNop()
}
- st, err := store.New(schema, testCluster, oracleCluster, storeConfig, tracingFile, logger)
+
+ test, err := drivers.NewCQL(ctx, "test", schema, testConfig, logger.Named("test_store"))
+ if err != nil {
+ return errors.Wrap(err, "failed to create oracle store")
+ }
+ defer utils.IgnoreError(test.Close)
+
+ st, err := store.New(logger, test, oracle)
if err != nil {
return err
}
- defer utils.IgnoreError(st.Close)
if dropSchema && mode != jobs.ReadMode {
for _, stmt := range generators.GetDropKeyspace(schema) {
@@ -276,21 +206,6 @@ func run(_ *cobra.Command, _ []string) error {
gens := generators.New(ctx, schema, distFunc, schemaConfig.GetPartitionRangeConfig(), intSeed, partitionCount, pkBufferReuseSize, logger)
defer utils.IgnoreError(gens.Close)
- if !nonInteractive {
- sp := createSpinner(interactive())
- ticker := time.NewTicker(time.Second)
- go func() {
- for {
- select {
- case <-ctx.Done():
- return
- case <-ticker.C:
- sp.Set(" Running Gemini... %v", globalStatus)
- }
- }
- }()
- }
-
if warmup > 0 {
warmupCtx, warmupCancel := context.WithTimeout(ctx, warmup)
defer warmupCancel()
@@ -310,6 +225,12 @@ func run(_ *cobra.Command, _ []string) error {
}
logger.Info("test finished")
+
+ outFile, err := utils.CreateFile(outFileArg, os.Stdout)
+ if err != nil {
+ return err
+ }
+
globalStatus.PrintResult(outFile, schema, version)
if globalStatus.HasErrors() {
return errors.Errorf("gemini encountered errors, exiting with non zero status")
@@ -317,17 +238,6 @@ func run(_ *cobra.Command, _ []string) error {
return nil
}
-func createFile(fname string, def *os.File) (*os.File, error) {
- if fname != "" {
- f, err := os.Create(fname)
- if err != nil {
- return nil, errors.Wrapf(err, "Unable to open output file %s", fname)
- }
- return f, nil
- }
- return def, nil
-}
-
const (
stdDistMean = math.MaxUint64 / 2
oneStdDev = 0.341 * math.MaxUint64
@@ -360,70 +270,47 @@ func createDistributionFunc(distribution string, size, seed uint64, mu, sigma fl
}
func createLogger(level string) *zap.Logger {
- lvl := zap.NewAtomicLevel()
- if err := lvl.UnmarshalText([]byte(level)); err != nil {
- lvl.SetLevel(zap.InfoLevel)
+ lvl, err := zap.ParseAtomicLevel(level)
+ if err != nil {
+ lvl = zap.NewAtomicLevelAt(zap.InfoLevel)
}
- encoderCfg := zap.NewDevelopmentEncoderConfig()
+
+ encoderCfg := zap.NewProductionEncoderConfig()
+ encoderCfg.EncodeName = zapcore.FullNameEncoder
+ encoderCfg.EncodeLevel = zapcore.LowercaseLevelEncoder
+ encoderCfg.EncodeTime = zapcore.RFC3339TimeEncoder
+ encoderCfg.EncodeDuration = zapcore.MillisDurationEncoder
+ encoderCfg.LevelKey = "L"
+ encoderCfg.MessageKey = "M"
+ encoderCfg.TimeKey = "T"
+ encoderCfg.CallerKey = "C"
+
logger := zap.New(zapcore.NewCore(
zapcore.NewJSONEncoder(encoderCfg),
zapcore.Lock(os.Stdout),
lvl,
))
- return logger
-}
-func createClusters(
- consistency gocql.Consistency,
- testHostSelectionPolicy, oracleHostSelectionPolicy gocql.HostSelectionPolicy,
- logger *zap.Logger,
-) (*gocql.ClusterConfig, *gocql.ClusterConfig) {
- retryPolicy := &gocql.ExponentialBackoffRetryPolicy{
- Min: time.Second,
- Max: 60 * time.Second,
- NumRetries: 5,
- }
- testCluster := gocql.NewCluster(testClusterHost...)
- testCluster.Timeout = requestTimeout
- testCluster.ConnectTimeout = connectTimeout
- testCluster.RetryPolicy = retryPolicy
- testCluster.Consistency = consistency
- testCluster.PoolConfig.HostSelectionPolicy = testHostSelectionPolicy
- testAuthenticator, testAuthErr := auth.BuildAuthenticator(testClusterUsername, testClusterPassword)
- if testAuthErr != nil {
- logger.Warn("%s for test cluster", zap.Error(testAuthErr))
- }
- testCluster.Authenticator = testAuthenticator
- if len(oracleClusterHost) == 0 {
- return testCluster, nil
- }
- oracleCluster := gocql.NewCluster(oracleClusterHost...)
- testCluster.Timeout = requestTimeout
- testCluster.ConnectTimeout = connectTimeout
- oracleCluster.RetryPolicy = retryPolicy
- oracleCluster.Consistency = consistency
- oracleCluster.PoolConfig.HostSelectionPolicy = oracleHostSelectionPolicy
- oracleAuthenticator, oracleAuthErr := auth.BuildAuthenticator(oracleClusterUsername, oracleClusterPassword)
- if oracleAuthErr != nil {
- logger.Warn("%s for oracle cluster", zap.Error(oracleAuthErr))
- }
- oracleCluster.Authenticator = oracleAuthenticator
- return testCluster, oracleCluster
+ runtime.SetFinalizer(logger, func(l *zap.Logger) {
+ utils.IgnoreError(logger.Sync)
+ })
+
+ return logger
}
func getReplicationStrategy(rs string, fallback *replication.Replication, logger *zap.Logger) *replication.Replication {
- switch rs {
+ switch strings.ToLower(rs) {
case "network":
return replication.NewNetworkTopologyStrategy()
case "simple":
return replication.NewSimpleStrategy()
default:
- replicationStrategy := &replication.Replication{}
- if err := json.Unmarshal([]byte(strings.ReplaceAll(rs, "'", "\"")), replicationStrategy); err != nil {
+ rf := &replication.Replication{}
+ if err := json.Unmarshal([]byte(strings.ReplaceAll(rs, "'", "\"")), rf); err != nil {
logger.Error("unable to parse replication strategy", zap.String("strategy", rs), zap.Error(err))
return fallback
}
- return replicationStrategy
+ return rf
}
}
@@ -438,19 +325,6 @@ func getCQLFeature(feature string) typedef.CQLFeature {
}
}
-func getHostSelectionPolicy(policy string, hosts []string) (gocql.HostSelectionPolicy, error) {
- switch policy {
- case "round-robin":
- return gocql.RoundRobinHostPolicy(), nil
- case "host-pool":
- return gocql.HostPoolHostPolicy(hostpool.New(hosts)), nil
- case "token-aware":
- return gocql.TokenAwareHostPolicy(gocql.RoundRobinHostPolicy()), nil
- default:
- return nil, fmt.Errorf("unknown host selection policy \"%s\"", policy)
- }
-}
-
var rootCmd = &cobra.Command{
Use: "gemini",
Short: "Gemini is an automatic random testing tool for Scylla.",
@@ -460,38 +334,52 @@ var rootCmd = &cobra.Command{
func init() {
rootCmd.Version = version + ", commit " + commit + ", date " + date
- rootCmd.Flags().StringSliceVarP(&testClusterHost, "test-cluster", "t", []string{}, "Host names or IPs of the test cluster that is system under test")
- _ = rootCmd.MarkFlagRequired("test-cluster")
- rootCmd.Flags().StringVarP(&testClusterUsername, "test-username", "", "", "Username for the test cluster")
- rootCmd.Flags().StringVarP(&testClusterPassword, "test-password", "", "", "Password for the test cluster")
- rootCmd.Flags().StringSliceVarP(
- &oracleClusterHost, "oracle-cluster", "o", []string{},
- "Host names or IPs of the oracle cluster that provides correct answers. If omitted no oracle will be used")
- rootCmd.Flags().StringVarP(&oracleClusterUsername, "oracle-username", "", "", "Username for the oracle cluster")
- rootCmd.Flags().StringVarP(&oracleClusterPassword, "oracle-password", "", "", "Password for the oracle cluster")
- rootCmd.Flags().StringVarP(&schemaFile, "schema", "", "", "Schema JSON config file")
+
+ rootCmd.Flags().IntVarP(&profilingPort, "profiling-port", "", 0, "If non-zero starts pprof profiler on given port at 'http://0.0.0.0:/profile'")
rootCmd.Flags().StringVarP(&mode, "mode", "m", jobs.MixedMode, "Query operation mode. Mode options: write, read, mixed (default)")
rootCmd.Flags().Uint64VarP(&concurrency, "concurrency", "c", 10, "Number of threads per table to run concurrently")
rootCmd.Flags().StringVarP(&seed, "seed", "s", "random", "Statement seed value")
rootCmd.Flags().StringVarP(&schemaSeed, "schema-seed", "", "random", "Schema seed value")
- rootCmd.Flags().BoolVarP(&dropSchema, "drop-schema", "d", false, "Drop schema before starting tests run")
rootCmd.Flags().BoolVarP(&verbose, "verbose", "v", false, "Verbose output during test run")
rootCmd.Flags().BoolVarP(&failFast, "fail-fast", "f", false, "Stop on the first failure")
- rootCmd.Flags().BoolVarP(&nonInteractive, "non-interactive", "", false, "Run in non-interactive mode (disable progress indicator)")
+ rootCmd.Flags().StringVarP(&level, "level", "", "info", "Specify the logging level, debug|info|warn|error|dpanic|panic|fatal")
+
+ rootCmd.Flags().StringSliceVarP(&testConfig.Hosts, "test-cluster", "t", []string{}, "Host names or IPs of the test cluster that is system under test")
+ rootCmd.Flags().StringVarP(&testConfig.Trace, "test-tracing-outfile", "", "", "Specify the file to which tracing information gets written. Two magic names are available, 'stdout' and 'stderr'. By default tracing is disabled.")
+ rootCmd.Flags().StringVarP(&testConfig.Consistency, "test-consistency", "", "LOCAL_QUORUM", "Specify the desired consistency as ANY|ONE|TWO|THREE|QUORUM|LOCAL_QUORUM|EACH_QUORUM|LOCAL_ONE")
+ rootCmd.Flags().StringVarP(&testConfig.DC, "test-dc", "", "", "Datacenter name for the test cluster")
+ rootCmd.Flags().StringVarP(&testConfig.HostSelectionPolicy, "test-host-selection-policy", "", "token-aware", "Host selection policy used by the driver for the test cluster: round-robin|host-pool|token-aware")
+ rootCmd.Flags().StringVarP(&testConfig.Username, "test-username", "", "", "Username for the test cluster")
+ rootCmd.Flags().StringVarP(&testConfig.Password, "test-password", "", "", "Password for the test cluster")
+ rootCmd.Flags().StringVarP(&testConfig.StatementLog, "test-statement-log-file", "", "", "File to write statements flow to")
+ rootCmd.Flags().DurationVarP(&testConfig.RequestTimeout, "test-request-timeout", "", 30*time.Second, "Duration of waiting request execution")
+ rootCmd.Flags().DurationVarP(&testConfig.ConnectTimeout, "test-connect-timeout", "", 30*time.Second, "Duration of waiting connection established")
+ rootCmd.Flags().DurationVarP(&testConfig.MaxRetriesMutateSleep, "test-max-mutation-retries-backoff", "", 10*time.Millisecond, "Duration between attempts to apply a mutation for example 10ms or 1s")
+ rootCmd.Flags().IntVarP(&testConfig.MaxRetriesMutate, "test-max-mutation-retries", "", 2, "Maximum number of attempts to apply a mutation")
+ rootCmd.Flags().BoolVarP(&testConfig.UseServerSideTimestamps, "test-use-server-timestamps", "", false, "Use server-side generated timestamps for writes")
+ rootCmd.Flags().StringVarP(&testReplicationStrategy, "test-replication-strategy", "", "simple", "Specify the desired replication strategy as either the coded short hand simple|network to get the default for each type or provide the entire specification in the form {'class':'....'}")
+
+ rootCmd.Flags().StringSliceVarP(&oracleConfig.Hosts, "oracle-cluster", "o", []string{}, "Host names or IPs of the oracle cluster that provides correct answers. If omitted no oracle will be used")
+ rootCmd.Flags().StringVarP(&oracleConfig.Trace, "oracle-tracing-outfile", "", "", "Specify the file to which tracing information gets written. Two magic names are available, 'stdout' and 'stderr'. By default tracing is disabled.")
+ rootCmd.Flags().StringVarP(&oracleConfig.Consistency, "oracle-consistency", "", "LOCAL_QUORUM", "Specify the desired consistency as ANY|ONE|TWO|THREE|QUORUM|LOCAL_QUORUM|EACH_QUORUM|LOCAL_ONE")
+ rootCmd.Flags().StringVarP(&oracleConfig.DC, "oracle-dc", "", "", "Datacenter name for the oracle cluster")
+ rootCmd.Flags().StringVarP(&oracleConfig.HostSelectionPolicy, "oracle-host-selection-policy", "", "token-aware", "Host selection policy used by the driver for the oracle cluster: round-robin|host-pool|token-aware")
+ rootCmd.Flags().StringVarP(&oracleConfig.Username, "oracle-username", "", "", "Username for the oracle cluster")
+ rootCmd.Flags().StringVarP(&oracleConfig.Password, "oracle-password", "", "", "Password for the oracle cluster")
+ rootCmd.Flags().StringVarP(&oracleConfig.StatementLog, "oracle-statement-log-file", "", "", "File to write statements flow to")
+ rootCmd.Flags().DurationVarP(&oracleConfig.RequestTimeout, "oracle-request-timeout", "", 30*time.Second, "Duration of waiting request execution")
+ rootCmd.Flags().DurationVarP(&oracleConfig.ConnectTimeout, "oracle-connect-timeout", "", 30*time.Second, "Duration of waiting connection established")
+ rootCmd.Flags().DurationVarP(&oracleConfig.MaxRetriesMutateSleep, "oracle-max-mutation-retries-backoff", "", 10*time.Millisecond, "Duration between attempts to apply a mutation for example 10ms or 1s")
+ rootCmd.Flags().IntVarP(&oracleConfig.MaxRetriesMutate, "oracle-max-mutation-retries", "", 2, "Maximum number of attempts to apply a mutation")
+ rootCmd.Flags().BoolVarP(&oracleConfig.UseServerSideTimestamps, "oracle-use-server-timestamps", "", false, "Use server-side generated timestamps for writes")
+ rootCmd.Flags().StringVarP(&oracleReplicationStrategy, "oracle-replication-strategy", "", "simple", "Specify the desired replication strategy of the oracle cluster as either the coded short hand simple|network to get the default for each type or provide the entire specification in the form {'class':'....'}")
+
+ rootCmd.Flags().StringVarP(&schemaFile, "schema", "", "", "Schema JSON config file")
+ rootCmd.Flags().BoolVarP(&dropSchema, "drop-schema", "d", false, "Drop schema before starting tests run")
rootCmd.Flags().DurationVarP(&duration, "duration", "", 30*time.Second, "")
rootCmd.Flags().StringVarP(&outFileArg, "outfile", "", "", "Specify the name of the file where the results should go")
- rootCmd.Flags().StringVarP(&bind, "bind", "b", "0.0.0.0:2112", "Specify the interface and port which to bind prometheus metrics on. Default is ':2112'")
rootCmd.Flags().DurationVarP(&warmup, "warmup", "", 30*time.Second, "Specify the warmup perid as a duration for example 30s or 10h")
- rootCmd.Flags().StringVarP(
- &replicationStrategy, "replication-strategy", "", "simple",
- "Specify the desired replication strategy as either the coded short hand simple|network to get the default for each type or provide "+
- "the entire specification in the form {'class':'....'}")
- rootCmd.Flags().StringVarP(
- &oracleReplicationStrategy, "oracle-replication-strategy", "", "simple",
- "Specify the desired replication strategy of the oracle cluster as either the coded short hand simple|network to get the default for each "+
- "type or provide the entire specification in the form {'class':'....'}")
rootCmd.Flags().StringArrayVarP(&tableOptions, "table-options", "", []string{}, "Repeatable argument to set table options to be added to the created tables")
- rootCmd.Flags().StringVarP(&consistency, "consistency", "", "LOCAL_QUORUM", "Specify the desired consistency as ANY|ONE|TWO|THREE|QUORUM|LOCAL_QUORUM|EACH_QUORUM|LOCAL_ONE")
rootCmd.Flags().IntVarP(&maxTables, "max-tables", "", 1, "Maximum number of generated tables")
rootCmd.Flags().IntVarP(&maxPartitionKeys, "max-partition-keys", "", 6, "Maximum number of generated partition keys")
rootCmd.Flags().IntVarP(&minPartitionKeys, "min-partition-keys", "", 2, "Minimum number of generated partition keys")
@@ -502,60 +390,42 @@ func init() {
rootCmd.Flags().StringVarP(&datasetSize, "dataset-size", "", "large", "Specify the type of dataset size to use, small|large")
rootCmd.Flags().StringVarP(&cqlFeatures, "cql-features", "", "basic", "Specify the type of cql features to use, basic|normal|all")
rootCmd.Flags().BoolVarP(&useMaterializedViews, "materialized-views", "", false, "Run gemini with materialized views support")
- rootCmd.Flags().StringVarP(&level, "level", "", "info", "Specify the logging level, debug|info|warn|error|dpanic|panic|fatal")
- rootCmd.Flags().IntVarP(&maxRetriesMutate, "max-mutation-retries", "", 2, "Maximum number of attempts to apply a mutation")
- rootCmd.Flags().DurationVarP(
- &maxRetriesMutateSleep, "max-mutation-retries-backoff", "", 10*time.Millisecond,
- "Duration between attempts to apply a mutation for example 10ms or 1s")
rootCmd.Flags().Uint64VarP(&pkBufferReuseSize, "partition-key-buffer-reuse-size", "", 100, "Number of reused buffered partition keys")
rootCmd.Flags().Uint64VarP(&partitionCount, "token-range-slices", "", 10000, "Number of slices to divide the token space into")
- rootCmd.Flags().StringVarP(
- &partitionKeyDistribution, "partition-key-distribution", "", "uniform",
- "Specify the distribution from which to draw partition keys, supported values are currently uniform|normal|zipf")
+ rootCmd.Flags().StringVarP(&partitionKeyDistribution, "partition-key-distribution", "", "uniform", "Specify the distribution from which to draw partition keys, supported values are currently uniform|normal|zipf")
rootCmd.Flags().Float64VarP(&normalDistMean, "normal-dist-mean", "", stdDistMean, "Mean of the normal distribution")
rootCmd.Flags().Float64VarP(&normalDistSigma, "normal-dist-sigma", "", oneStdDev, "Sigma of the normal distribution, defaults to one standard deviation ~0.341")
- rootCmd.Flags().StringVarP(
- &tracingOutFile, "tracing-outfile", "", "",
- "Specify the file to which tracing information gets written. Two magic names are available, 'stdout' and 'stderr'. By default tracing is disabled.")
rootCmd.Flags().BoolVarP(&useCounters, "use-counters", "", false, "Ensure that at least one table is a counter table")
- rootCmd.Flags().IntVarP(
- &asyncObjectStabilizationAttempts, "async-objects-stabilization-attempts", "", 10,
- "Maximum number of attempts to validate result sets from MV and SI")
- rootCmd.Flags().DurationVarP(
- &asyncObjectStabilizationDelay, "async-objects-stabilization-backoff", "", 10*time.Millisecond,
- "Duration between attempts to validate result sets from MV and SI for example 10ms or 1s")
+ rootCmd.Flags().IntVarP(&asyncObjectStabilizationAttempts, "async-objects-stabilization-attempts", "", 10, "Maximum number of attempts to validate result sets from MV and SI")
+ rootCmd.Flags().DurationVarP(&asyncObjectStabilizationDelay, "async-objects-stabilization-backoff", "", 10*time.Millisecond, "Duration between attempts to validate result sets from MV and SI for example 10ms or 1s")
rootCmd.Flags().BoolVarP(&useLWT, "use-lwt", "", false, "Emit LWT based updates")
- rootCmd.Flags().StringVarP(
- &oracleClusterHostSelectionPolicy, "oracle-host-selection-policy", "", "token-aware",
- "Host selection policy used by the driver for the oracle cluster: round-robin|host-pool|token-aware")
- rootCmd.Flags().StringVarP(
- &testClusterHostSelectionPolicy, "test-host-selection-policy", "", "token-aware",
- "Host selection policy used by the driver for the test cluster: round-robin|host-pool|token-aware")
- rootCmd.Flags().BoolVarP(&useServerSideTimestamps, "use-server-timestamps", "", false, "Use server-side generated timestamps for writes")
- rootCmd.Flags().DurationVarP(&requestTimeout, "request-timeout", "", 30*time.Second, "Duration of waiting request execution")
- rootCmd.Flags().DurationVarP(&connectTimeout, "connect-timeout", "", 30*time.Second, "Duration of waiting connection established")
- rootCmd.Flags().IntVarP(&profilingPort, "profiling-port", "", 0, "If non-zero starts pprof profiler on given port at 'http://0.0.0.0:/profile'")
rootCmd.Flags().IntVarP(&maxErrorsToStore, "max-errors-to-store", "", 1000, "Maximum number of errors to store and output at the end")
- rootCmd.Flags().StringVarP(&testStatementLogFile, "test-statement-log-file", "", "", "File to write statements flow to")
- rootCmd.Flags().StringVarP(&oracleStatementLogFile, "oracle-statement-log-file", "", "", "File to write statements flow to")
+
+ _ = rootCmd.MarkFlagRequired("test-cluster")
}
-func printSetup(seed, schemaSeed uint64) {
+func printSetup(seed, schemaSeed uint64, schema *typedef.Schema) {
+ jsonSchema, _ := json.MarshalIndent(schema, "", " ")
+
tw := new(tabwriter.Writer)
tw.Init(os.Stdout, 0, 8, 2, '\t', tabwriter.AlignRight)
- fmt.Fprintf(tw, "Seed:\t%d\n", seed)
- fmt.Fprintf(tw, "Schema seed:\t%d\n", schemaSeed)
- fmt.Fprintf(tw, "Maximum duration:\t%s\n", duration)
- fmt.Fprintf(tw, "Warmup duration:\t%s\n", warmup)
- fmt.Fprintf(tw, "Concurrency:\t%d\n", concurrency)
- fmt.Fprintf(tw, "Test cluster:\t%s\n", testClusterHost)
- fmt.Fprintf(tw, "Oracle cluster:\t%s\n", oracleClusterHost)
+ _, _ = fmt.Fprintf(tw, "Seed:\t%d\n", seed)
+ _, _ = fmt.Fprintf(tw, "Schema seed:\t%d\n", schemaSeed)
+ _, _ = fmt.Fprintf(tw, "Maximum duration:\t%s\n", duration)
+ _, _ = fmt.Fprintf(tw, "Warmup duration:\t%s\n", warmup)
+ _, _ = fmt.Fprintf(tw, "Concurrency:\t%d\n", concurrency)
+ _, _ = fmt.Fprintf(tw, "Test cluster:\t%s\n", testConfig.Hosts)
+ _, _ = fmt.Fprintf(tw, "Oracle cluster:\t%s\n", oracleConfig.Hosts)
+ _, _ = fmt.Printf("Schema: \t%s\n", string(jsonSchema))
+
if outFileArg == "" {
- fmt.Fprintf(tw, "Output file:\t%s\n", "")
+ _, _ = fmt.Fprintf(tw, "Output file:\t%s\n", "")
} else {
- fmt.Fprintf(tw, "Output file:\t%s\n", outFileArg)
+ _, _ = fmt.Fprintf(tw, "Output file:\t%s\n", outFileArg)
+ }
+ if err := tw.Flush(); err != nil {
+ log.Printf("Failed to print setup: %v", err)
}
- tw.Flush()
}
func RealRandom() uint64 {
@@ -582,10 +452,11 @@ func seedFromString(seed string) uint64 {
func generateSchema(logger *zap.Logger, sc typedef.SchemaConfig, schemaSeed string) (schema *typedef.Schema, intSchemaSeed uint64, err error) {
intSchemaSeed = seedFromString(schemaSeed)
schema = generators.GenSchema(sc, intSchemaSeed)
- err = schema.Validate(partitionCount)
- if err == nil {
+
+ if err = schema.Validate(partitionCount); err == nil {
return schema, intSchemaSeed, nil
}
+
if schemaSeed != "random" {
// If user provided schema, allow to run it, but log warning
logger.Warn(errors.Wrap(err, "validation failed, running this test could end up in error or stale gemini").Error())
diff --git a/cmd/gemini/schema.go b/cmd/gemini/schema.go
index 68884a62..8ceed93c 100644
--- a/cmd/gemini/schema.go
+++ b/cmd/gemini/schema.go
@@ -64,7 +64,7 @@ func createDefaultSchemaConfig(logger *zap.Logger) typedef.SchemaConfig {
MaxTupleParts = 20
MaxUDTParts = 20
)
- rs := getReplicationStrategy(replicationStrategy, replication.NewSimpleStrategy(), logger)
+ rs := getReplicationStrategy(testReplicationStrategy, replication.NewSimpleStrategy(), logger)
ors := getReplicationStrategy(oracleReplicationStrategy, rs, logger)
return typedef.SchemaConfig{
ReplicationStrategy: rs,
diff --git a/pkg/auth/auth.go b/pkg/auth/auth.go
index cf797d9d..70473323 100644
--- a/pkg/auth/auth.go
+++ b/pkg/auth/auth.go
@@ -20,6 +20,11 @@ import (
"github.com/gocql/gocql"
)
+var (
+ ErrUsernameNotProvided = errors.New("username not provided")
+ ErrorPasswordNotProvided = errors.New("password not provided")
+)
+
// BuildAuthenticator : Returns a new gocql.PasswordAuthenticator
// if both username and password are provided.
func BuildAuthenticator(username, password string) (*gocql.PasswordAuthenticator, error) {
@@ -34,7 +39,7 @@ func BuildAuthenticator(username, password string) (*gocql.PasswordAuthenticator
return &authenticator, nil
}
if username != "" {
- return nil, errors.New("Password not provided")
+ return nil, ErrorPasswordNotProvided
}
- return nil, errors.New("Username not provided")
+ return nil, ErrUsernameNotProvided
}
diff --git a/pkg/generators/statement_generator.go b/pkg/generators/statement_generator.go
index 1755e513..1c280949 100644
--- a/pkg/generators/statement_generator.go
+++ b/pkg/generators/statement_generator.go
@@ -20,14 +20,13 @@ import (
"golang.org/x/exp/rand"
- "github.com/scylladb/gemini/pkg/builders"
"github.com/scylladb/gemini/pkg/typedef"
"github.com/scylladb/gemini/pkg/utils"
)
func GenSchema(sc typedef.SchemaConfig, seed uint64) *typedef.Schema {
r := rand.New(rand.NewSource(seed))
- builder := builders.NewSchemaBuilder()
+ builder := typedef.NewSchemaBuilder()
builder.Config(sc)
keyspace := typedef.Keyspace{
Name: "ks1",
@@ -36,10 +35,12 @@ func GenSchema(sc typedef.SchemaConfig, seed uint64) *typedef.Schema {
}
builder.Keyspace(keyspace)
numTables := utils.RandInt2(r, 1, sc.GetMaxTables())
+
for i := 0; i < numTables; i++ {
- table := genTable(sc, fmt.Sprintf("table%d", i+1), r)
+ table := genTable(sc, fmt.Sprintf("table_%d", i+1), r)
builder.Table(table)
}
+
return builder.Build()
}
@@ -60,9 +61,11 @@ func genTable(sc typedef.SchemaConfig, tableName string, r *rand.Rand) *typedef.
typedef.KnownIssuesJSONWithTuples: true,
},
}
+
for _, option := range sc.TableOptions {
table.TableOptions = append(table.TableOptions, option.ToCQL())
}
+
if sc.UseCounters {
table.Columns = typedef.Columns{
{
@@ -78,21 +81,17 @@ func genTable(sc typedef.SchemaConfig, tableName string, r *rand.Rand) *typedef.
for i := 0; i < len(columns); i++ {
columns[i] = &typedef.ColumnDef{Name: GenColumnName("col", i), Type: GenColumnType(len(columns), &sc, r)}
}
+
table.Columns = columns
- var indexes []typedef.IndexDef
if sc.CQLFeature > typedef.CQL_FEATURE_BASIC && len(columns) > 0 {
- indexes = CreateIndexesForColumn(&table, utils.RandInt2(r, 1, len(columns)))
+ table.Indexes = CreateIndexesForColumn(&table, utils.RandInt2(r, 1, len(columns)))
}
- table.Indexes = indexes
- var mvs []typedef.MaterializedView
if sc.CQLFeature > typedef.CQL_FEATURE_BASIC && sc.UseMaterializedViews && len(clusteringKeys) > 0 && columns.ValidColumnsForPrimaryKey().Len() != 0 {
- mvs = CreateMaterializedViews(columns, table.Name, partitionKeys, clusteringKeys, r)
+ table.MaterializedViews = CreateMaterializedViews(columns, table.Name, partitionKeys, clusteringKeys, r)
}
- table.MaterializedViews = mvs
-
return &table
}
diff --git a/pkg/jobs/gen_ddl_stmt.go b/pkg/jobs/gen_ddl_stmt.go
index a8866fd4..6550097a 100644
--- a/pkg/jobs/gen_ddl_stmt.go
+++ b/pkg/jobs/gen_ddl_stmt.go
@@ -20,7 +20,6 @@ import (
"golang.org/x/exp/rand"
- "github.com/scylladb/gemini/pkg/builders"
"github.com/scylladb/gemini/pkg/generators"
"github.com/scylladb/gemini/pkg/typedef"
)
@@ -60,16 +59,16 @@ func genAddColumnStmt(t *typedef.Table, keyspace string, column *typedef.ColumnD
stmt := fmt.Sprintf(createType, keyspace, c.TypeName, strings.Join(typs, ","))
stmts = append(stmts, &typedef.Stmt{
StmtCache: &typedef.StmtCache{
- Query: &builders.AlterTableBuilder{
+ Query: &typedef.AlterTableBuilder{
Stmt: stmt,
},
},
})
}
- stmt := "ALTER TABLE " + keyspace + "." + t.Name + " ADD " + column.Name + " " + column.Type.CQLDef()
+ stmt := "alter table " + keyspace + "." + t.Name + " ADD " + column.Name + " " + column.Type.CQLDef()
stmts = append(stmts, &typedef.Stmt{
StmtCache: &typedef.StmtCache{
- Query: &builders.AlterTableBuilder{
+ Query: &typedef.AlterTableBuilder{
Stmt: stmt,
},
},
@@ -87,10 +86,10 @@ func genAddColumnStmt(t *typedef.Table, keyspace string, column *typedef.ColumnD
func genDropColumnStmt(t *typedef.Table, keyspace string, column *typedef.ColumnDef) (*typedef.Stmts, error) {
var stmts []*typedef.Stmt
- stmt := "ALTER TABLE " + keyspace + "." + t.Name + " DROP " + column.Name
+ stmt := "alter table " + keyspace + "." + t.Name + " DROP " + column.Name
stmts = append(stmts, &typedef.Stmt{
StmtCache: &typedef.StmtCache{
- Query: &builders.AlterTableBuilder{
+ Query: &typedef.AlterTableBuilder{
Stmt: stmt,
},
QueryType: typedef.DropColumnStatementType,
diff --git a/pkg/metrics/metrics.go b/pkg/metrics/metrics.go
new file mode 100644
index 00000000..54a166a6
--- /dev/null
+++ b/pkg/metrics/metrics.go
@@ -0,0 +1,29 @@
+package metrics
+
+import (
+ "log"
+ "net/http"
+
+ "github.com/prometheus/client_golang/prometheus"
+ "github.com/prometheus/client_golang/prometheus/promauto"
+ "github.com/prometheus/client_golang/prometheus/promhttp"
+)
+
+var CQLRequests *prometheus.CounterVec
+
+func init() {
+ CQLRequests = promauto.NewCounterVec(
+ prometheus.CounterOpts{
+ Name: "gemini_cql_requests",
+ Help: "How many CQL requests processed, partitioned by system and CQL query type aka 'method' (batch, delete, insert, update).",
+ },
+ []string{"system", "method"},
+ )
+
+ go func() {
+ http.Handle("/metrics", promhttp.Handler())
+ if err := http.ListenAndServe("0.0.0.0:2121", nil); err != nil {
+ log.Fatalf("Failed to start metrics server: %v\n", err)
+ }
+ }()
+}
diff --git a/pkg/stmtlogger/filelogger.go b/pkg/stmtlogger/filelogger.go
index d367bb88..f0561229 100644
--- a/pkg/stmtlogger/filelogger.go
+++ b/pkg/stmtlogger/filelogger.go
@@ -39,7 +39,7 @@ const (
)
type (
- StmtToFile interface {
+ Interface interface {
LogStmt(stmt *typedef.Stmt, ts ...time.Time) error
Close() error
}
@@ -55,7 +55,7 @@ type (
}
)
-func NewFileLogger(filename string) (StmtToFile, error) {
+func NewFileLogger(filename string) (Interface, error) {
if filename == "" {
return &nopFileLogger{}, nil
}
@@ -68,7 +68,7 @@ func NewFileLogger(filename string) (StmtToFile, error) {
return NewLogger(fd)
}
-func NewLogger(w io.Writer) (StmtToFile, error) {
+func NewLogger(w io.Writer) (Interface, error) {
ctx, cancel := context.WithCancel(context.Background())
out := &logger{
diff --git a/pkg/store/cqlstore.go b/pkg/store/cqlstore.go
deleted file mode 100644
index 6a2269ab..00000000
--- a/pkg/store/cqlstore.go
+++ /dev/null
@@ -1,137 +0,0 @@
-// Copyright 2019 ScyllaDB
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-package store
-
-import (
- "context"
- "io"
- "time"
-
- "github.com/gocql/gocql"
- "github.com/pkg/errors"
- "github.com/prometheus/client_golang/prometheus"
- "github.com/scylladb/gocqlx/v2/qb"
- "go.uber.org/zap"
-
- "github.com/scylladb/gemini/pkg/stmtlogger"
- "github.com/scylladb/gemini/pkg/typedef"
-)
-
-type cqlStore struct { //nolint:govet
- session *gocql.Session
- schema *typedef.Schema
- ops *prometheus.CounterVec
- logger *zap.Logger
- system string
- maxRetriesMutate int
- maxRetriesMutateSleep time.Duration
- useServerSideTimestamps bool
- stmtLogger stmtlogger.StmtToFile
-}
-
-func (cs *cqlStore) name() string {
- return cs.system
-}
-
-func (cs *cqlStore) mutate(ctx context.Context, stmt *typedef.Stmt) error {
- for range cs.maxRetriesMutate {
- if err := cs.doMutate(ctx, stmt); err == nil {
- cs.ops.WithLabelValues(cs.system, opType(stmt)).Inc()
- return nil
- }
-
- select {
- case <-ctx.Done():
- return ctx.Err()
- case <-time.After(cs.maxRetriesMutateSleep):
- }
- }
-
- return errors.Errorf("failed to mutate after %d retries", cs.maxRetriesMutate)
-}
-
-func (cs *cqlStore) doMutate(ctx context.Context, stmt *typedef.Stmt) error {
- queryBody, _ := stmt.Query.ToCql()
- query := cs.session.Query(queryBody, stmt.Values...).WithContext(ctx).DefaultTimestamp(false)
- defer query.Release()
-
- var ts time.Time
-
- if !cs.useServerSideTimestamps {
- ts = time.Now()
- query = query.WithTimestamp(ts.UnixMicro())
- }
-
- if err := query.Exec(); err != nil {
- return errors.Wrapf(err, "[cluster = %s, query = '%s']", cs.system, queryBody)
- }
-
- if err := cs.stmtLogger.LogStmt(stmt, ts); err != nil {
- return err
- }
-
- return nil
-}
-
-func (cs *cqlStore) load(ctx context.Context, stmt *typedef.Stmt) ([]map[string]any, error) {
- cql, _ := stmt.Query.ToCql()
-
- query := cs.session.Query(cql, stmt.Values...).WithContext(ctx)
- defer query.Release()
-
- iter := query.Iter()
- cs.ops.WithLabelValues(cs.system, opType(stmt)).Inc()
-
- result := loadSet(iter)
-
- if err := cs.stmtLogger.LogStmt(stmt); err != nil {
- return nil, err
- }
-
- return result, iter.Close()
-}
-
-func (cs *cqlStore) close() error {
- cs.session.Close()
- return nil
-}
-
-func newSession(cluster *gocql.ClusterConfig, out io.Writer) (*gocql.Session, error) {
- session, err := cluster.CreateSession()
- if err != nil {
- return nil, err
- }
- if out != nil {
- session.SetTrace(gocql.NewTraceWriter(session, out))
- }
- return session, nil
-}
-
-func opType(stmt *typedef.Stmt) string {
- switch stmt.Query.(type) {
- case *qb.InsertBuilder:
- return "insert"
- case *qb.DeleteBuilder:
- return "delete"
- case *qb.UpdateBuilder:
- return "update"
- case *qb.SelectBuilder:
- return "select"
- case *qb.BatchBuilder:
- return "batch"
- default:
- return "unknown"
- }
-}
diff --git a/pkg/store/drivers/noop.go b/pkg/store/drivers/noop.go
new file mode 100644
index 00000000..d13ac9d8
--- /dev/null
+++ b/pkg/store/drivers/noop.go
@@ -0,0 +1,21 @@
+package drivers
+
+import (
+ "context"
+
+ "github.com/scylladb/gemini/pkg/typedef"
+)
+
+type Nop struct{}
+
+func NewNop() Nop {
+ return Nop{}
+}
+
+func (n Nop) Execute(context.Context, *typedef.Stmt) error {
+ return nil
+}
+
+func (n Nop) Fetch(context.Context, *typedef.Stmt) ([]map[string]any, error) {
+ return nil, nil
+}
diff --git a/pkg/store/drivers/scylla.go b/pkg/store/drivers/scylla.go
new file mode 100644
index 00000000..f1b1e770
--- /dev/null
+++ b/pkg/store/drivers/scylla.go
@@ -0,0 +1,294 @@
+// Copyright 2019 ScyllaDB
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package drivers
+
+import (
+ "bufio"
+ "context"
+ "fmt"
+ "io"
+ "time"
+
+ "github.com/gocql/gocql"
+ "github.com/hailocab/go-hostpool"
+ "github.com/pkg/errors"
+ "github.com/scylladb/gocqlx/v2/qb"
+ "go.uber.org/zap"
+
+ "github.com/scylladb/gemini/pkg/auth"
+ "github.com/scylladb/gemini/pkg/metrics"
+ "github.com/scylladb/gemini/pkg/stmtlogger"
+ "github.com/scylladb/gemini/pkg/typedef"
+ "github.com/scylladb/gemini/pkg/utils"
+)
+
+type (
+ ScyllaDB struct { //nolint:govet
+ session *gocql.Session
+ schema *typedef.Schema
+ logger *zap.Logger
+ system string
+
+ statementLog stmtlogger.Interface
+ maxRetriesMutateSleep time.Duration
+ maxRetriesMutate int
+ useServerSideTimestamps bool
+ }
+
+ CQLConfig struct {
+ Hosts []string
+ Trace string
+ Consistency string
+ DC string
+ HostSelectionPolicy string
+ Username string
+ Password string
+ StatementLog string
+ RequestTimeout time.Duration
+ ConnectTimeout time.Duration
+ MaxRetriesMutateSleep time.Duration
+ MaxRetriesMutate int
+ UseServerSideTimestamps bool
+ }
+)
+
+func NewCQL(
+ ctx context.Context,
+ name string,
+ schema *typedef.Schema,
+ cfg CQLConfig,
+ logger *zap.Logger,
+) (*ScyllaDB, error) {
+ clusterConfig, err := createClusters(cfg, logger)
+ if err != nil {
+ return nil, err
+ }
+
+ session, err := clusterConfig.CreateSession()
+ if err != nil {
+ return nil, errors.Wrapf(err, "failed to connect to %s cluster", name)
+ }
+
+ if cfg.Trace != "" {
+ trace, err := cfg.getTraceLogFile()
+ if err != nil {
+ return nil, errors.Wrap(err, "failed to create trace file")
+ }
+
+ session.SetTrace(gocql.NewTraceWriter(session, trace))
+ }
+
+ if err := session.AwaitSchemaAgreement(ctx); err != nil {
+ return nil, errors.Wrapf(err, "failed to await schema agreement for %s cluster", name)
+ }
+
+ statementLog, err := cfg.getStatementLogFile()
+ if err != nil {
+ return nil, errors.Wrap(err, "failed to create statement logger")
+ }
+
+ return &ScyllaDB{
+ session: session,
+ schema: schema,
+ logger: logger.Named(name),
+ system: name,
+ statementLog: statementLog,
+ maxRetriesMutateSleep: cfg.MaxRetriesMutateSleep,
+ maxRetriesMutate: cfg.MaxRetriesMutate,
+ useServerSideTimestamps: cfg.UseServerSideTimestamps,
+ }, nil
+}
+
+func (cs *ScyllaDB) Execute(ctx context.Context, stmt *typedef.Stmt) error {
+ for range cs.maxRetriesMutate {
+ if err := cs.doMutate(ctx, stmt); err == nil {
+ metrics.CQLRequests.WithLabelValues(cs.system, opType(stmt)).Inc()
+ return nil
+ }
+
+ select {
+ case <-ctx.Done():
+ return ctx.Err()
+ case <-time.After(cs.maxRetriesMutateSleep):
+ }
+ }
+
+ return errors.Errorf("failed to Mutate after %d retries", cs.maxRetriesMutate)
+}
+
+func (cs *ScyllaDB) doMutate(ctx context.Context, stmt *typedef.Stmt) error {
+ queryBody, _ := stmt.Query.ToCql()
+ query := cs.session.Query(queryBody, stmt.Values...).WithContext(ctx).DefaultTimestamp(false)
+ defer query.Release()
+
+ var ts time.Time
+
+ if !cs.useServerSideTimestamps {
+ ts = time.Now()
+ query = query.WithTimestamp(ts.UnixMicro())
+ }
+
+ if err := query.Exec(); err != nil {
+ return errors.Wrapf(err, "[cluster = %s, query = %s]", cs.system, queryBody)
+ }
+
+ if err := cs.statementLog.LogStmt(stmt, ts); err != nil {
+ return err
+ }
+
+ return nil
+}
+
+func (cs *ScyllaDB) Fetch(ctx context.Context, stmt *typedef.Stmt) ([]map[string]any, error) {
+ cql, _ := stmt.Query.ToCql()
+
+ query := cs.session.
+ Query(cql, stmt.Values...).
+ WithContext(ctx).
+ DefaultTimestamp(false)
+
+ defer query.Release()
+
+ metrics.CQLRequests.WithLabelValues(cs.system, opType(stmt)).Inc()
+
+ result, err := loadSet(query.Iter())
+ if err != nil {
+ return nil, errors.Wrapf(err, "[cluster = %s, query = %s]", cs.system, cql)
+ }
+
+ if err := cs.statementLog.LogStmt(stmt); err != nil {
+ return nil, err
+ }
+
+ return result, nil
+}
+
+func (cs *ScyllaDB) Close() error {
+ cs.session.Close()
+ return nil
+}
+
+func opType(stmt *typedef.Stmt) string {
+ switch stmt.Query.(type) {
+ case *qb.InsertBuilder:
+ return "insert"
+ case *qb.DeleteBuilder:
+ return "delete"
+ case *qb.UpdateBuilder:
+ return "update"
+ case *qb.SelectBuilder:
+ return "select"
+ case *qb.BatchBuilder:
+ return "batch"
+ default:
+ return "unknown"
+ }
+}
+
+func createClusters(cfg CQLConfig, logger *zap.Logger) (*gocql.ClusterConfig, error) {
+ authenticator, err := auth.BuildAuthenticator(cfg.Username, cfg.Password)
+ if err != nil {
+ return nil, errors.Wrap(err, "failed to create authenticator")
+ }
+
+ cons, err := gocql.ParseConsistencyWrapper(cfg.Consistency)
+ if err != nil {
+ return nil, errors.Wrapf(err, "failed to parse consistency: %s", cfg.Consistency)
+ }
+
+ selectionPolicy, err := parseHostSelectionPolicy(cfg.HostSelectionPolicy, cfg.DC, cfg.Hosts)
+ if err != nil {
+ return nil, errors.Wrapf(err, "failed to parse host selection policy: %s", cfg.HostSelectionPolicy)
+ }
+
+ cluster := gocql.NewCluster(cfg.Hosts...)
+ cluster.Timeout = cfg.RequestTimeout
+ cluster.ProtoVersion = 4
+ cluster.ConnectTimeout = cfg.ConnectTimeout
+ cluster.Consistency = cons
+ cluster.PoolConfig.HostSelectionPolicy = selectionPolicy
+ cluster.DefaultTimestamp = false
+ cluster.Logger = zap.NewStdLog(logger)
+ cluster.MaxRoutingKeyInfo = 10_000
+ cluster.MaxPreparedStmts = 10_000
+ cluster.Authenticator = authenticator
+ cluster.RetryPolicy = &gocql.ExponentialBackoffRetryPolicy{
+ Min: 500 * time.Millisecond,
+ Max: cfg.RequestTimeout,
+ NumRetries: cfg.MaxRetriesMutate,
+ }
+ cluster.ReconnectionPolicy = &gocql.ExponentialReconnectionPolicy{
+ InitialInterval: 500 * time.Millisecond,
+ MaxRetries: 10,
+ MaxInterval: cfg.ConnectTimeout,
+ }
+
+ return cluster, nil
+}
+
+func parseHostSelectionPolicy(policy, dc string, hosts []string) (gocql.HostSelectionPolicy, error) {
+ switch policy {
+ case "round-robin":
+ if dc != "" {
+ return gocql.DCAwareRoundRobinPolicy(dc), nil
+ }
+
+ return gocql.RoundRobinHostPolicy(), nil
+ case "host-pool":
+ return gocql.HostPoolHostPolicy(hostpool.New(hosts)), nil
+ case "token-aware":
+ if dc != "" {
+ return gocql.TokenAwareHostPolicy(gocql.DCAwareRoundRobinPolicy(dc)), nil
+ }
+
+ return gocql.TokenAwareHostPolicy(gocql.RoundRobinHostPolicy()), nil
+ default:
+ return nil, fmt.Errorf("unknown host selection policy \"%s\"", policy)
+ }
+}
+
+func loadSet(iter *gocql.Iter) ([]map[string]any, error) {
+ rows := make([]map[string]any, 0, iter.NumRows())
+
+ for {
+ row := make(map[string]any, len(iter.Columns()))
+
+ if !iter.MapScan(row) {
+ break
+ }
+
+ rows = append(rows, row)
+ }
+
+ return rows, iter.Close()
+}
+
+func (c CQLConfig) getStatementLogFile() (stmtlogger.Interface, error) {
+ writer, err := utils.CreateFile(c.StatementLog)
+ if err != nil {
+ return nil, err
+ }
+
+ return stmtlogger.NewLogger(writer)
+}
+
+func (c CQLConfig) getTraceLogFile() (io.Writer, error) {
+ writer, err := utils.CreateFile(c.StatementLog)
+ if err != nil {
+ return nil, err
+ }
+
+ return bufio.NewWriterSize(writer, 8192), nil
+}
diff --git a/pkg/store/drivers/scylla_test.go b/pkg/store/drivers/scylla_test.go
new file mode 100644
index 00000000..fc62d2f0
--- /dev/null
+++ b/pkg/store/drivers/scylla_test.go
@@ -0,0 +1 @@
+package drivers_test
diff --git a/pkg/store/helpers.go b/pkg/store/helpers.go
index 8dff43be..ffe40fb7 100644
--- a/pkg/store/helpers.go
+++ b/pkg/store/helpers.go
@@ -85,15 +85,3 @@ func lt(mi, mj map[string]any) bool {
panic(msg)
}
}
-
-func loadSet(iter *gocql.Iter) []map[string]any {
- var rows []map[string]any
- for {
- row := make(map[string]any)
- if !iter.MapScan(row) {
- break
- }
- rows = append(rows, row)
- }
- return rows
-}
diff --git a/pkg/store/store.go b/pkg/store/store.go
index 8eeb3c38..4d3112ba 100644
--- a/pkg/store/store.go
+++ b/pkg/store/store.go
@@ -16,280 +16,29 @@ package store
import (
"context"
- "fmt"
- "io"
- "math/big"
- "reflect"
- "sort"
- "sync"
- "time"
"go.uber.org/zap"
- "github.com/gocql/gocql"
- "github.com/google/go-cmp/cmp"
- "github.com/google/go-cmp/cmp/cmpopts"
- "github.com/pkg/errors"
- "github.com/prometheus/client_golang/prometheus"
- "github.com/prometheus/client_golang/prometheus/promauto"
- "go.uber.org/multierr"
- "gopkg.in/inf.v0"
-
- "github.com/scylladb/go-set/strset"
-
- "github.com/scylladb/gemini/pkg/stmtlogger"
"github.com/scylladb/gemini/pkg/typedef"
)
-type loader interface {
- load(context.Context, *typedef.Stmt) ([]map[string]any, error)
-}
-
-type storer interface {
- mutate(context.Context, *typedef.Stmt) error
-}
-
-type storeLoader interface {
- storer
- loader
- close() error
- name() string
-}
-
-type Store interface {
- Create(context.Context, *typedef.Stmt, *typedef.Stmt) error
- Mutate(context.Context, *typedef.Stmt) error
- Check(context.Context, *typedef.Table, *typedef.Stmt, bool) error
- Close() error
-}
-
-type Config struct {
- TestLogStatementsFile string
- OracleLogStatementsFile string
- MaxRetriesMutate int
- MaxRetriesMutateSleep time.Duration
- UseServerSideTimestamps bool
-}
-
-func New(schema *typedef.Schema, testCluster, oracleCluster *gocql.ClusterConfig, cfg Config, traceOut io.Writer, logger *zap.Logger) (Store, error) {
- ops := promauto.NewCounterVec(prometheus.CounterOpts{
- Name: "gemini_cql_requests",
- Help: "How many CQL requests processed, partitioned by system and CQL query type aka 'method' (batch, delete, insert, update).",
- }, []string{"system", "method"},
- )
-
- oracleStore, err := getStore("oracle", schema, oracleCluster, cfg, cfg.OracleLogStatementsFile, traceOut, logger, ops)
- if err != nil {
- return nil, err
+type (
+ Driver interface {
+ Execute(ctx context.Context, stmt *typedef.Stmt) error
+ Fetch(ctx context.Context, stmt *typedef.Stmt) ([]map[string]any, error)
}
- if testCluster == nil {
- return nil, errors.New("test cluster is empty")
- }
- testStore, err := getStore("test", schema, testCluster, cfg, cfg.TestLogStatementsFile, traceOut, logger, ops)
- if err != nil {
- return nil, err
+ Store interface {
+ Create(context.Context, *typedef.Stmt, *typedef.Stmt) error
+ Mutate(context.Context, *typedef.Stmt) error
+ Check(context.Context, *typedef.Table, *typedef.Stmt, bool) error
}
+)
- return &delegatingStore{
+func New(logger *zap.Logger, testStore, oracleStore Driver) (*ValidatingStore, error) {
+ return &ValidatingStore{
testStore: testStore,
oracleStore: oracleStore,
- validations: oracleStore != nil,
logger: logger.Named("delegating_store"),
}, nil
}
-
-type noOpStore struct {
- system string
-}
-
-func (n *noOpStore) mutate(context.Context, *typedef.Stmt) error {
- return nil
-}
-
-func (n *noOpStore) load(context.Context, *typedef.Stmt) ([]map[string]any, error) {
- return nil, nil
-}
-
-func (n *noOpStore) Close() error {
- return nil
-}
-
-func (n *noOpStore) name() string {
- return n.system
-}
-
-func (n *noOpStore) close() error {
- return nil
-}
-
-type delegatingStore struct {
- oracleStore storeLoader
- testStore storeLoader
- statementLogger stmtlogger.StmtToFile
- logger *zap.Logger
- validations bool
-}
-
-func (ds delegatingStore) Create(ctx context.Context, testBuilder, oracleBuilder *typedef.Stmt) error {
- if err := mutate(ctx, ds.oracleStore, oracleBuilder); err != nil {
- return errors.Wrap(err, "oracle failed store creation")
- }
-
- if err := mutate(ctx, ds.testStore, testBuilder); err != nil {
- return errors.Wrap(err, "test failed store creation")
- }
-
- if ds.statementLogger != nil {
- if err := ds.statementLogger.LogStmt(testBuilder); err != nil {
- return errors.Wrap(err, "failed to log test create statement")
- }
- }
-
- return nil
-}
-
-func (ds delegatingStore) Mutate(ctx context.Context, stmt *typedef.Stmt) error {
- var testErr error
- var wg sync.WaitGroup
- wg.Add(1)
-
- go func() {
- defer wg.Done()
- testErr = errors.Wrapf(
- ds.testStore.mutate(ctx, stmt),
- "unable to apply mutations to the %s store", ds.testStore.name())
- }()
-
- if oracleErr := ds.oracleStore.mutate(ctx, stmt); oracleErr != nil {
- // Oracle failed, transition cannot take place
- ds.logger.Info("oracle store failed mutation, transition to next state impossible so continuing with next mutation", zap.Error(oracleErr))
- return oracleErr
- }
- wg.Wait()
- if testErr != nil {
- // Test store failed, transition cannot take place
- ds.logger.Info("test store failed mutation, transition to next state impossible so continuing with next mutation", zap.Error(testErr))
- return testErr
- }
- return nil
-}
-
-func mutate(ctx context.Context, s storeLoader, stmt *typedef.Stmt) error {
- if err := s.mutate(ctx, stmt); err != nil {
- return errors.Wrapf(err, "unable to apply mutations to the %s store", s.name())
- }
- return nil
-}
-
-func (ds delegatingStore) Check(ctx context.Context, table *typedef.Table, stmt *typedef.Stmt, detailedDiff bool) error {
- var testRows, oracleRows []map[string]any
- var testErr, oracleErr error
- var wg sync.WaitGroup
- wg.Add(1)
-
- go func() {
- defer wg.Done()
- testRows, testErr = ds.testStore.load(ctx, stmt)
- }()
- oracleRows, oracleErr = ds.oracleStore.load(ctx, stmt)
- if oracleErr != nil {
- return errors.Wrapf(oracleErr, "unable to load check data from the oracle store")
- }
- wg.Wait()
- if testErr != nil {
- return errors.Wrapf(testErr, "unable to load check data from the test store")
- }
- if !ds.validations {
- return nil
- }
- if len(testRows) == 0 && len(oracleRows) == 0 {
- return nil
- }
- if len(testRows) != len(oracleRows) {
- if !detailedDiff {
- return fmt.Errorf("rows count differ (test store rows %d, oracle store rows %d, detailed information will be at last attempt)", len(testRows), len(oracleRows))
- }
- testSet := strset.New(pks(table, testRows)...)
- oracleSet := strset.New(pks(table, oracleRows)...)
- missingInTest := strset.Difference(oracleSet, testSet).List()
- missingInOracle := strset.Difference(testSet, oracleSet).List()
- return fmt.Errorf("row count differ (test has %d rows, oracle has %d rows, test is missing rows: %s, oracle is missing rows: %s)",
- len(testRows), len(oracleRows), missingInTest, missingInOracle)
- }
- if reflect.DeepEqual(testRows, oracleRows) {
- return nil
- }
- if !detailedDiff {
- return fmt.Errorf("test and oracle store have difference, detailed information will be at last attempt")
- }
- sort.SliceStable(testRows, func(i, j int) bool {
- return lt(testRows[i], testRows[j])
- })
- sort.SliceStable(oracleRows, func(i, j int) bool {
- return lt(oracleRows[i], oracleRows[j])
- })
- for i, oracleRow := range oracleRows {
- testRow := testRows[i]
- cmp.AllowUnexported()
- diff := cmp.Diff(oracleRow, testRow,
- cmpopts.SortMaps(func(x, y *inf.Dec) bool {
- return x.Cmp(y) < 0
- }),
- cmp.Comparer(func(x, y *inf.Dec) bool {
- return x.Cmp(y) == 0
- }), cmp.Comparer(func(x, y *big.Int) bool {
- return x.Cmp(y) == 0
- }))
- if diff != "" {
- return fmt.Errorf("rows differ (-%v +%v): %v", oracleRow, testRow, diff)
- }
- }
- return nil
-}
-
-func (ds delegatingStore) Close() (err error) {
- if ds.statementLogger != nil {
- err = multierr.Append(err, ds.statementLogger.Close())
- }
- err = multierr.Append(err, ds.testStore.close())
- err = multierr.Append(err, ds.oracleStore.close())
- return
-}
-
-func getStore(
- name string,
- schema *typedef.Schema,
- clusterConfig *gocql.ClusterConfig,
- cfg Config,
- stmtLogFile string,
- traceOut io.Writer,
- logger *zap.Logger,
- ops *prometheus.CounterVec,
-) (out storeLoader, err error) {
- if clusterConfig == nil {
- return &noOpStore{
- system: name,
- }, nil
- }
- oracleSession, err := newSession(clusterConfig, traceOut)
- if err != nil {
- return nil, errors.Wrapf(err, "failed to connect to %s cluster", name)
- }
- oracleFileLogger, err := stmtlogger.NewFileLogger(stmtLogFile)
- if err != nil {
- return nil, err
- }
-
- return &cqlStore{
- session: oracleSession,
- schema: schema,
- system: name,
- ops: ops,
- maxRetriesMutate: cfg.MaxRetriesMutate + 10,
- maxRetriesMutateSleep: cfg.MaxRetriesMutateSleep,
- useServerSideTimestamps: cfg.UseServerSideTimestamps,
- logger: logger.Named(name),
- stmtLogger: oracleFileLogger,
- }, nil
-}
diff --git a/pkg/store/validating_store.go b/pkg/store/validating_store.go
new file mode 100644
index 00000000..9a2305eb
--- /dev/null
+++ b/pkg/store/validating_store.go
@@ -0,0 +1,151 @@
+// Copyright 2019 ScyllaDB
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package store
+
+import (
+ "context"
+ "fmt"
+ "math/big"
+ "reflect"
+ "sort"
+ "sync"
+
+ "github.com/google/go-cmp/cmp"
+ "github.com/google/go-cmp/cmp/cmpopts"
+ "github.com/pkg/errors"
+ "github.com/scylladb/go-set/strset"
+ "go.uber.org/zap"
+ "gopkg.in/inf.v0"
+
+ "github.com/scylladb/gemini/pkg/typedef"
+)
+
+type ValidatingStore struct {
+ oracleStore Driver
+ testStore Driver
+ logger *zap.Logger
+}
+
+func (ds ValidatingStore) Create(ctx context.Context, testBuilder, oracleBuilder *typedef.Stmt) error {
+ if err := ds.testStore.Execute(ctx, testBuilder); err != nil {
+ return errors.Wrap(err, "oracle failed store creation")
+ }
+
+ if err := ds.testStore.Execute(ctx, oracleBuilder); err != nil {
+ return errors.Wrap(err, "oracle failed store creation")
+ }
+
+ return nil
+}
+
+func (ds ValidatingStore) Mutate(ctx context.Context, stmt *typedef.Stmt) error {
+ var (
+ wg sync.WaitGroup
+ oracleErr error
+ )
+
+ if ds.oracleStore != nil {
+ wg.Add(1)
+ go func() {
+ defer wg.Done()
+ oracleErr = ds.oracleStore.Execute(ctx, stmt)
+ }()
+ }
+
+ if err := ds.testStore.Execute(ctx, stmt); err != nil {
+ return errors.Wrap(err, "unable to apply mutations to the test store")
+ }
+
+ wg.Wait()
+
+ if oracleErr != nil {
+ return errors.Wrap(oracleErr, "unable to apply mutations to the oracle store")
+ }
+
+ return nil
+}
+
+func (ds ValidatingStore) Check(ctx context.Context, table *typedef.Table, stmt *typedef.Stmt, detailedDiff bool) error {
+ var (
+ oracleErr error
+ oracleRows []map[string]any
+ wg sync.WaitGroup
+ )
+
+ if ds.oracleStore != nil {
+ wg.Add(1)
+
+ go func() {
+ defer wg.Done()
+ oracleRows, oracleErr = ds.oracleStore.Fetch(ctx, stmt)
+ }()
+ }
+
+ testRows, testErr := ds.testStore.Fetch(ctx, stmt)
+
+ if testErr != nil {
+ return errors.Wrap(testErr, "unable to Load check data from the test store")
+ }
+
+ wg.Wait()
+
+ if oracleErr != nil {
+ return errors.Wrap(oracleErr, "unable to Load check data from the oracle store")
+ }
+
+ if len(testRows) == 0 && len(oracleRows) == 0 {
+ return nil
+ }
+ if len(testRows) != len(oracleRows) {
+ if !detailedDiff {
+ return fmt.Errorf("rows count differ (test store rows %d, oracle store rows %d, detailed information will be at last attempt)", len(testRows), len(oracleRows))
+ }
+ testSet := strset.New(pks(table, testRows)...)
+ oracleSet := strset.New(pks(table, oracleRows)...)
+ missingInTest := strset.Difference(oracleSet, testSet).List()
+ missingInOracle := strset.Difference(testSet, oracleSet).List()
+ return fmt.Errorf("row count differ (test has %d rows, oracle has %d rows, test is missing rows: %s, oracle is missing rows: %s)",
+ len(testRows), len(oracleRows), missingInTest, missingInOracle)
+ }
+ if reflect.DeepEqual(testRows, oracleRows) {
+ return nil
+ }
+ if !detailedDiff {
+ return fmt.Errorf("test and oracle store have difference, detailed information will be at last attempt")
+ }
+ sort.SliceStable(testRows, func(i, j int) bool {
+ return lt(testRows[i], testRows[j])
+ })
+ sort.SliceStable(oracleRows, func(i, j int) bool {
+ return lt(oracleRows[i], oracleRows[j])
+ })
+ for i, oracleRow := range oracleRows {
+ testRow := testRows[i]
+ cmp.AllowUnexported()
+ diff := cmp.Diff(oracleRow, testRow,
+ cmpopts.SortMaps(func(x, y *inf.Dec) bool {
+ return x.Cmp(y) < 0
+ }),
+ cmp.Comparer(func(x, y *inf.Dec) bool {
+ return x.Cmp(y) == 0
+ }), cmp.Comparer(func(x, y *big.Int) bool {
+ return x.Cmp(y) == 0
+ }))
+ if diff != "" {
+ return fmt.Errorf("rows differ (-%v +%v): %v", oracleRow, testRow, diff)
+ }
+ }
+ return nil
+}
diff --git a/pkg/testutils/name_mappings.go b/pkg/testutils/name_mappings.go
index d36bea16..dece82dc 100644
--- a/pkg/testutils/name_mappings.go
+++ b/pkg/testutils/name_mappings.go
@@ -22,7 +22,6 @@ import (
"golang.org/x/exp/rand"
- "github.com/scylladb/gemini/pkg/builders"
"github.com/scylladb/gemini/pkg/replication"
"github.com/scylladb/gemini/pkg/routingkey"
"github.com/scylladb/gemini/pkg/tableopts"
@@ -176,7 +175,7 @@ func createTableOptions(cql string) []tableopts.Option {
}
func genTestSchema(sc typedef.SchemaConfig, table *typedef.Table) *typedef.Schema {
- builder := builders.NewSchemaBuilder()
+ builder := typedef.NewSchemaBuilder()
builder.Config(sc)
keyspace := typedef.Keyspace{
Name: "ks1",
diff --git a/pkg/builders/builders.go b/pkg/typedef/builders.go
similarity index 57%
rename from pkg/builders/builders.go
rename to pkg/typedef/builders.go
index 6ad39f1b..f66ee9ea 100644
--- a/pkg/builders/builders.go
+++ b/pkg/typedef/builders.go
@@ -12,18 +12,13 @@
// See the License for the specific language governing permissions and
// limitations under the License.
-package builders
-
-import (
- "github.com/scylladb/gemini/pkg/querycache"
- "github.com/scylladb/gemini/pkg/typedef"
-)
+package typedef
type SchemaBuilder interface {
- Config(config typedef.SchemaConfig) SchemaBuilder
- Keyspace(typedef.Keyspace) SchemaBuilder
- Table(*typedef.Table) SchemaBuilder
- Build() *typedef.Schema
+ Config(config SchemaConfig) SchemaBuilder
+ Keyspace(Keyspace) SchemaBuilder
+ Table(*Table) SchemaBuilder
+ Build() *Schema
}
type AlterTableBuilder struct {
@@ -35,30 +30,30 @@ func (atb *AlterTableBuilder) ToCql() (string, []string) {
}
type schemaBuilder struct {
- keyspace typedef.Keyspace
- tables []*typedef.Table
- config typedef.SchemaConfig
+ keyspace Keyspace
+ tables []*Table
+ config SchemaConfig
}
-func (s *schemaBuilder) Keyspace(keyspace typedef.Keyspace) SchemaBuilder {
+func (s *schemaBuilder) Keyspace(keyspace Keyspace) SchemaBuilder {
s.keyspace = keyspace
return s
}
-func (s *schemaBuilder) Config(config typedef.SchemaConfig) SchemaBuilder {
+func (s *schemaBuilder) Config(config SchemaConfig) SchemaBuilder {
s.config = config
return s
}
-func (s *schemaBuilder) Table(table *typedef.Table) SchemaBuilder {
+func (s *schemaBuilder) Table(table *Table) SchemaBuilder {
s.tables = append(s.tables, table)
return s
}
-func (s *schemaBuilder) Build() *typedef.Schema {
- out := &typedef.Schema{Keyspace: s.keyspace, Tables: s.tables, Config: s.config}
+func (s *schemaBuilder) Build() *Schema {
+ out := &Schema{Keyspace: s.keyspace, Tables: s.tables, Config: s.config}
for id := range s.tables {
- s.tables[id].Init(out, querycache.New(out))
+ s.tables[id].Init(out, New(out))
}
return out
}
diff --git a/pkg/querycache/querycache.go b/pkg/typedef/querycache.go
similarity index 62%
rename from pkg/querycache/querycache.go
rename to pkg/typedef/querycache.go
index 0aa755ae..948d5290 100644
--- a/pkg/querycache/querycache.go
+++ b/pkg/typedef/querycache.go
@@ -12,18 +12,16 @@
// See the License for the specific language governing permissions and
// limitations under the License.
-package querycache
+package typedef
import (
"fmt"
"sync"
"github.com/scylladb/gocqlx/v2/qb"
-
- "github.com/scylladb/gemini/pkg/typedef"
)
-type QueryCache [typedef.CacheArrayLen]*typedef.StmtCache
+type QueryCache [CacheArrayLen]*StmtCache
func (c QueryCache) Reset() {
for id := range c {
@@ -32,19 +30,19 @@ func (c QueryCache) Reset() {
}
type Cache struct {
- schema *typedef.Schema
- table *typedef.Table
+ schema *Schema
+ table *Table
cache QueryCache
mu sync.RWMutex
}
-func New(s *typedef.Schema) *Cache {
+func New(s *Schema) *Cache {
return &Cache{
schema: s,
}
}
-func (c *Cache) BindToTable(t *typedef.Table) {
+func (c *Cache) BindToTable(t *Table) {
c.table = t
}
@@ -54,14 +52,14 @@ func (c *Cache) Reset() {
c.cache.Reset()
}
-func (c *Cache) getQuery(qct typedef.StatementCacheType) *typedef.StmtCache {
+func (c *Cache) getQuery(qct StatementCacheType) *StmtCache {
c.mu.RLock()
defer c.mu.RUnlock()
rec := c.cache[qct]
return rec
}
-func (c *Cache) GetQuery(qct typedef.StatementCacheType) *typedef.StmtCache {
+func (c *Cache) GetQuery(qct StatementCacheType) *StmtCache {
rec := c.getQuery(qct)
if rec != nil {
return rec
@@ -73,35 +71,35 @@ func (c *Cache) GetQuery(qct typedef.StatementCacheType) *typedef.StmtCache {
return rec
}
-type CacheBuilderFn func(s *typedef.Schema, t *typedef.Table) *typedef.StmtCache
+type CacheBuilderFn func(s *Schema, t *Table) *StmtCache
-type CacheBuilderFnMap map[typedef.StatementCacheType]CacheBuilderFn
+type CacheBuilderFnMap map[StatementCacheType]CacheBuilderFn
-func (m CacheBuilderFnMap) ToList() [typedef.CacheArrayLen]CacheBuilderFn {
- out := [typedef.CacheArrayLen]CacheBuilderFn{}
+func (m CacheBuilderFnMap) ToList() [CacheArrayLen]CacheBuilderFn {
+ out := [CacheArrayLen]CacheBuilderFn{}
for idx, builderFn := range m {
out[idx] = builderFn
}
for idx := range out {
if out[idx] == nil {
- panic(fmt.Sprintf("no builder for %s", typedef.StatementCacheType(idx).ToString()))
+ panic(fmt.Sprintf("no builder for %s", StatementCacheType(idx).ToString()))
}
}
return out
}
var CacheBuilders = CacheBuilderFnMap{
- typedef.CacheInsert: genInsertStmtCache,
- typedef.CacheInsertIfNotExists: genInsertIfNotExistsStmtCache,
- typedef.CacheDelete: genDeleteStmtCache,
- typedef.CacheUpdate: genUpdateStmtCache,
+ CacheInsert: genInsertStmtCache,
+ CacheInsertIfNotExists: genInsertIfNotExistsStmtCache,
+ CacheDelete: genDeleteStmtCache,
+ CacheUpdate: genUpdateStmtCache,
}.ToList()
func genInsertStmtCache(
- s *typedef.Schema,
- t *typedef.Table,
-) *typedef.StmtCache {
- allTypes := make([]typedef.Type, 0, t.PartitionKeys.Len()+t.ClusteringKeys.Len()+t.Columns.Len())
+ s *Schema,
+ t *Table,
+) *StmtCache {
+ allTypes := make([]Type, 0, t.PartitionKeys.Len()+t.ClusteringKeys.Len()+t.Columns.Len())
builder := qb.Insert(s.Keyspace.Name + "." + t.Name)
for _, pk := range t.PartitionKeys {
builder = builder.Columns(pk.Name)
@@ -113,38 +111,38 @@ func genInsertStmtCache(
}
for _, col := range t.Columns {
switch colType := col.Type.(type) {
- case *typedef.TupleType:
+ case *TupleType:
builder = builder.TupleColumn(col.Name, len(colType.ValueTypes))
default:
builder = builder.Columns(col.Name)
}
allTypes = append(allTypes, col.Type)
}
- return &typedef.StmtCache{
+ return &StmtCache{
Query: builder,
Types: allTypes,
- QueryType: typedef.InsertStatementType,
+ QueryType: InsertStatementType,
}
}
func genInsertIfNotExistsStmtCache(
- s *typedef.Schema,
- t *typedef.Table,
-) *typedef.StmtCache {
+ s *Schema,
+ t *Table,
+) *StmtCache {
out := genInsertStmtCache(s, t)
out.Query = out.Query.(*qb.InsertBuilder).Unique()
return out
}
-func genUpdateStmtCache(s *typedef.Schema, t *typedef.Table) *typedef.StmtCache {
- allTypes := make([]typedef.Type, 0, t.PartitionKeys.Len()+t.ClusteringKeys.Len()+t.Columns.Len())
+func genUpdateStmtCache(s *Schema, t *Table) *StmtCache {
+ var allTypes []Type
builder := qb.Update(s.Keyspace.Name + "." + t.Name)
for _, cdef := range t.Columns {
switch t := cdef.Type.(type) {
- case *typedef.TupleType:
+ case *TupleType:
builder = builder.SetTuple(cdef.Name, len(t.ValueTypes))
- case *typedef.CounterType:
+ case *CounterType:
builder = builder.SetLit(cdef.Name, cdef.Name+"+1")
continue
default:
@@ -161,16 +159,15 @@ func genUpdateStmtCache(s *typedef.Schema, t *typedef.Table) *typedef.StmtCache
builder = builder.Where(qb.Eq(ck.Name))
allTypes = append(allTypes, ck.Type)
}
- return &typedef.StmtCache{
+ return &StmtCache{
Query: builder,
Types: allTypes,
- QueryType: typedef.UpdateStatementType,
+ QueryType: UpdateStatementType,
}
}
-func genDeleteStmtCache(s *typedef.Schema, t *typedef.Table) *typedef.StmtCache {
- allTypes := make([]typedef.Type, 0, t.PartitionKeys.Len()+t.ClusteringKeys.Len())
-
+func genDeleteStmtCache(s *Schema, t *Table) *StmtCache {
+ var allTypes []Type
builder := qb.Delete(s.Keyspace.Name + "." + t.Name)
for _, pk := range t.PartitionKeys {
builder = builder.Where(qb.Eq(pk.Name))
@@ -183,9 +180,9 @@ func genDeleteStmtCache(s *typedef.Schema, t *typedef.Table) *typedef.StmtCache
allTypes = append(allTypes, ck.Type, ck.Type)
}
- return &typedef.StmtCache{
+ return &StmtCache{
Query: builder,
Types: allTypes,
- QueryType: typedef.DeleteStatementType,
+ QueryType: DeleteStatementType,
}
}
diff --git a/pkg/typedef/schema.go b/pkg/typedef/schema.go
index 32f36350..639d872d 100644
--- a/pkg/typedef/schema.go
+++ b/pkg/typedef/schema.go
@@ -16,6 +16,7 @@ package typedef
import (
"encoding/json"
+ "os"
"strconv"
"github.com/pkg/errors"
@@ -37,6 +38,28 @@ type Schema struct {
Config SchemaConfig `json:"-"`
}
+func NewSchemaFromFile(confFile string, schemaConfig SchemaConfig) (*Schema, error) {
+ byteValue, err := os.ReadFile(confFile)
+ if err != nil {
+ return nil, err
+ }
+
+ shm := &Schema{}
+
+ if err = json.Unmarshal(byteValue, shm); err != nil {
+ return nil, err
+ }
+
+ sb := NewSchemaBuilder()
+ sb.Keyspace(shm.Keyspace).Config(schemaConfig)
+ for t, tbl := range shm.Tables {
+ shm.Tables[t].LinkIndexAndColumns()
+ sb.Table(tbl)
+ }
+
+ return sb.Build(), nil
+}
+
func (s *Schema) GetHash() string {
out, err := json.Marshal(s)
if err != nil {
diff --git a/pkg/typedef/schemaconfig.go b/pkg/typedef/schemaconfig.go
index 707b90e5..e5537ff2 100644
--- a/pkg/typedef/schemaconfig.go
+++ b/pkg/typedef/schemaconfig.go
@@ -50,12 +50,15 @@ func (sc *SchemaConfig) Valid() error {
if sc.MaxPartitionKeys <= sc.MinPartitionKeys {
return ErrSchemaConfigInvalidRangePK
}
+
if sc.MaxClusteringKeys <= sc.MinClusteringKeys {
return ErrSchemaConfigInvalidRangeCK
}
+
if sc.MaxColumns <= sc.MinColumns {
return ErrSchemaConfigInvalidRangeCols
}
+
return nil
}
diff --git a/pkg/typedef/table.go b/pkg/typedef/table.go
index 1af82947..56d4cdd0 100644
--- a/pkg/typedef/table.go
+++ b/pkg/typedef/table.go
@@ -18,7 +18,7 @@ import (
"sync"
)
-type QueryCache interface {
+type QueryCacheInterface interface {
GetQuery(qct StatementCacheType) *StmtCache
Reset()
BindToTable(t *Table)
@@ -27,7 +27,7 @@ type QueryCache interface {
type KnownIssues map[string]bool
type Table struct {
- queryCache QueryCache
+ queryCache QueryCacheInterface
schema *Schema
Name string `json:"name"`
PartitionKeys Columns `json:"partition_keys"`
@@ -83,7 +83,7 @@ func (t *Table) ResetQueryCache() {
t.partitionKeysLenValues = 0
}
-func (t *Table) Init(s *Schema, c QueryCache) {
+func (t *Table) Init(s *Schema, c QueryCacheInterface) {
t.schema = s
t.queryCache = c
t.queryCache.BindToTable(t)
diff --git a/pkg/utils/utils.go b/pkg/utils/utils.go
index 956d5297..ae74492c 100644
--- a/pkg/utils/utils.go
+++ b/pkg/utils/utils.go
@@ -17,10 +17,15 @@ package utils
import (
"encoding/hex"
"fmt"
+ "io"
+ "os"
+ "runtime"
"strconv"
"strings"
"time"
+ "github.com/pkg/errors"
+
"github.com/gocql/gocql"
"golang.org/x/exp/rand"
)
@@ -107,3 +112,30 @@ func UUIDFromTime(rnd *rand.Rand) string {
}
return gocql.UUIDFromTime(RandDate(rnd)).String()
}
+
+func CreateFile(input string, def ...io.Writer) (io.Writer, error) {
+ switch input {
+ case "":
+ if len(def) > 0 && def[0] != nil {
+ return def[0], nil
+ }
+
+ return io.Discard, nil
+ case "stderr":
+ return os.Stderr, nil
+ case "stdout":
+ return os.Stdout, nil
+ default:
+ tracingWriter, err := os.OpenFile(input, os.O_CREATE|os.O_WRONLY|os.O_TRUNC, 0o644)
+ if err != nil {
+ return nil, errors.Wrapf(err, "failed to open file %s", input)
+ }
+
+ runtime.SetFinalizer(tracingWriter, func(f *os.File) {
+ IgnoreError(f.Sync)
+ IgnoreError(f.Close)
+ })
+
+ return tracingWriter, nil
+ }
+}