diff --git a/cmd/gemini/root.go b/cmd/gemini/root.go index 278a102d..6d2a22d0 100644 --- a/cmd/gemini/root.go +++ b/cmd/gemini/root.go @@ -22,24 +22,13 @@ import ( "net/http" "net/http/pprof" "os" + "os/signal" "strconv" "strings" + "syscall" "text/tabwriter" "time" - "github.com/scylladb/gemini/pkg/auth" - "github.com/scylladb/gemini/pkg/builders" - "github.com/scylladb/gemini/pkg/generators" - "github.com/scylladb/gemini/pkg/jobs" - "github.com/scylladb/gemini/pkg/realrandom" - "github.com/scylladb/gemini/pkg/replication" - "github.com/scylladb/gemini/pkg/store" - "github.com/scylladb/gemini/pkg/typedef" - "github.com/scylladb/gemini/pkg/utils" - - "github.com/scylladb/gemini/pkg/status" - "github.com/scylladb/gemini/pkg/stop" - "github.com/gocql/gocql" "github.com/hailocab/go-hostpool" "github.com/pkg/errors" @@ -50,6 +39,17 @@ import ( "golang.org/x/exp/rand" "golang.org/x/net/context" "gonum.org/v1/gonum/stat/distuv" + + "github.com/scylladb/gemini/pkg/auth" + "github.com/scylladb/gemini/pkg/builders" + "github.com/scylladb/gemini/pkg/generators" + "github.com/scylladb/gemini/pkg/jobs" + "github.com/scylladb/gemini/pkg/realrandom" + "github.com/scylladb/gemini/pkg/replication" + "github.com/scylladb/gemini/pkg/status" + "github.com/scylladb/gemini/pkg/store" + "github.com/scylladb/gemini/pkg/typedef" + "github.com/scylladb/gemini/pkg/utils" ) var ( @@ -137,8 +137,11 @@ func readSchema(confFile string, schemaConfig typedef.SchemaConfig) (*typedef.Sc } func run(_ *cobra.Command, _ []string) error { + ctx, cancel := signal.NotifyContext(context.Background(), os.Interrupt, syscall.SIGABRT, syscall.SIGTERM, syscall.SIGINT) + defer cancel() + logger := createLogger(level) - globalStatus := status.NewGlobalStatus(1000) + globalStatus := status.NewGlobalStatus(int32(maxErrorsToStore)) defer utils.IgnoreError(logger.Sync) if err := validateSeed(seed); err != nil { @@ -242,7 +245,7 @@ func run(_ *cobra.Command, _ []string) error { if dropSchema && mode != jobs.ReadMode { for _, stmt := range generators.GetDropKeyspace(schema) { logger.Debug(stmt) - if err = st.Mutate(context.Background(), typedef.SimpleStmt(stmt, typedef.DropKeyspaceStatementType)); err != nil { + if err = st.Mutate(ctx, typedef.SimpleStmt(stmt, typedef.DropKeyspaceStatementType)); err != nil { return errors.Wrap(err, "unable to drop schema") } } @@ -250,7 +253,7 @@ func run(_ *cobra.Command, _ []string) error { testKeyspace, oracleKeyspace := generators.GetCreateKeyspaces(schema) if err = st.Create( - context.Background(), + ctx, typedef.SimpleStmt(testKeyspace, typedef.CreateKeyspaceStatementType), typedef.SimpleStmt(oracleKeyspace, typedef.CreateKeyspaceStatementType)); err != nil { return errors.Wrap(err, "unable to create keyspace") @@ -263,11 +266,7 @@ func run(_ *cobra.Command, _ []string) error { } } - ctx, done := context.WithTimeout(context.Background(), duration+warmup+time.Second*2) - stopFlag := stop.NewFlag("main") - warmupStopFlag := stop.NewFlag("warmup") - stop.StartOsSignalsTransmitter(logger, stopFlag, warmupStopFlag) - pump := jobs.NewPump(stopFlag, logger) + pump := jobs.NewPump(ctx, logger) distFunc, err := createDistributionFunc(partitionKeyDistribution, partitionCount, intSeed, normalDistMean, normalDistSigma) if err != nil { @@ -281,10 +280,9 @@ func run(_ *cobra.Command, _ []string) error { sp := createSpinner(interactive()) ticker := time.NewTicker(time.Second) go func() { - defer done() for { select { - case <-stopFlag.SignalChannel(): + case <-ctx.Done(): return case <-ticker.C: sp.Set(" Running Gemini... %v", globalStatus) @@ -293,20 +291,24 @@ func run(_ *cobra.Command, _ []string) error { }() } - if warmup > 0 && !stopFlag.IsHardOrSoft() { - jobsList := jobs.ListFromMode(jobs.WarmupMode, warmup, concurrency) - if err = jobsList.Run(ctx, schema, schemaConfig, st, pump, gens, globalStatus, logger, intSeed, warmupStopFlag, failFast, verbose); err != nil { + if warmup > 0 { + warmupCtx, warmupCancel := context.WithTimeout(ctx, warmup) + defer warmupCancel() + + jobsList := jobs.ListFromMode(jobs.WarmupMode, concurrency) + if err = jobsList.Run(warmupCtx, schema, schemaConfig, st, pump, gens, globalStatus, logger, intSeed, failFast, verbose); err != nil { logger.Error("warmup encountered an error", zap.Error(err)) - stopFlag.SetHard(true) } } - if !stopFlag.IsHardOrSoft() { - jobsList := jobs.ListFromMode(mode, duration, concurrency) - if err = jobsList.Run(ctx, schema, schemaConfig, st, pump, gens, globalStatus, logger, intSeed, stopFlag.CreateChild("workload"), failFast, verbose); err != nil { - logger.Debug("error detected", zap.Error(err)) - } + jobsCtx, jobsCancel := context.WithTimeout(ctx, duration) + defer jobsCancel() + + jobsList := jobs.ListFromMode(mode, concurrency) + if err = jobsList.Run(jobsCtx, schema, schemaConfig, st, pump, gens, globalStatus, logger, intSeed, failFast, verbose); err != nil { + logger.Debug("error detected", zap.Error(err)) } + logger.Info("test finished") globalStatus.PrintResult(outFile, schema, version) if globalStatus.HasErrors() { diff --git a/pkg/generators/generator.go b/pkg/generators/generator.go index 46c4b734..c0ee9214 100644 --- a/pkg/generators/generator.go +++ b/pkg/generators/generator.go @@ -27,7 +27,7 @@ import ( // TokenIndex represents the position of a token in the token ring. // A token index is translated to a token by a generators. If the generators -// preserves the exact position, then the token index becomes the token; +// preserve the exact position, then the token index becomes the token; // otherwise token index represents an approximation of the token. // // We use a token index approach, because our generators actually generate diff --git a/pkg/generators/generators.go b/pkg/generators/generators.go index 23c88de0..95a06d84 100644 --- a/pkg/generators/generators.go +++ b/pkg/generators/generators.go @@ -25,9 +25,9 @@ import ( ) type Generators struct { - Generators []Generator wg *sync.WaitGroup cancel context.CancelFunc + Generators []Generator } func New( diff --git a/pkg/generators/partition.go b/pkg/generators/partition.go index 138e2286..eb418a72 100644 --- a/pkg/generators/partition.go +++ b/pkg/generators/partition.go @@ -114,12 +114,11 @@ func (s *Partition) safelyGetOldValuesChannel() chan *typedef.ValueWithToken { } func (s *Partition) Close() error { - for !s.closed.CompareAndSwap(false, true) { + if s.closed.CompareAndSwap(false, true) { + close(s.values) + close(s.oldValues) } - close(s.values) - close(s.oldValues) - return nil } diff --git a/pkg/jobs/jobs.go b/pkg/jobs/jobs.go index 5274c28c..55ddf2cb 100644 --- a/pkg/jobs/jobs.go +++ b/pkg/jobs/jobs.go @@ -28,7 +28,6 @@ import ( "github.com/scylladb/gemini/pkg/generators" "github.com/scylladb/gemini/pkg/joberror" "github.com/scylladb/gemini/pkg/status" - "github.com/scylladb/gemini/pkg/stop" "github.com/scylladb/gemini/pkg/store" "github.com/scylladb/gemini/pkg/typedef" ) @@ -53,10 +52,9 @@ var ( ) type List struct { - name string - jobs []job - duration time.Duration - workers uint64 + name string + jobs []job + workers uint64 } type job struct { @@ -72,16 +70,16 @@ type job struct { *generators.Generator, *status.GlobalStatus, *zap.Logger, - *stop.Flag, bool, bool, ) error name string } -func ListFromMode(mode string, duration time.Duration, workers uint64) List { +func ListFromMode(mode string, workers uint64) List { jobs := make([]job, 0, 2) name := "work cycle" + switch mode { case WriteMode: jobs = append(jobs, mutate) @@ -93,11 +91,11 @@ func ListFromMode(mode string, duration time.Duration, workers uint64) List { default: jobs = append(jobs, mutate, validate) } + return List{ - name: name, - jobs: jobs, - duration: duration, - workers: workers, + name: name, + jobs: jobs, + workers: workers, } } @@ -111,16 +109,10 @@ func (l List) Run( globalStatus *status.GlobalStatus, logger *zap.Logger, seed uint64, - stopFlag *stop.Flag, failFast, verbose bool, ) error { logger = logger.Named(l.name) - ctx = stopFlag.CancelContextOnSignal(ctx, stop.SignalHardStop) g, gCtx := errgroup.WithContext(ctx) - time.AfterFunc(l.duration, func() { - logger.Info("jobs time is up, begins jobs completion") - stopFlag.SetSoft(true) - }) partitionRangeConfig := schemaConfig.GetPartitionRangeConfig() logger.Info("start jobs") @@ -131,7 +123,7 @@ func (l List) Run( jobF := l.jobs[idx].function r := rand.New(rand.NewSource(seed)) g.Go(func() error { - return jobF(gCtx, pump, schema, schemaConfig, table, s, r, &partitionRangeConfig, generator, globalStatus, logger, stopFlag, failFast, verbose) + return jobF(gCtx, pump, schema, schemaConfig, table, s, r, &partitionRangeConfig, generator, globalStatus, logger, failFast, verbose) }) } } @@ -154,7 +146,6 @@ func mutationJob( g *generators.Generator, globalStatus *status.GlobalStatus, logger *zap.Logger, - stopFlag *stop.Flag, failFast, verbose bool, ) error { schemaConfig := &schemaCfg @@ -164,11 +155,8 @@ func mutationJob( logger.Info("ending mutation loop") }() for { - if stopFlag.IsHardOrSoft() { - return nil - } select { - case <-stopFlag.SignalChannel(): + case <-ctx.Done(): logger.Debug("mutation job terminated") return nil case hb := <-pump: @@ -187,7 +175,6 @@ func mutationJob( } } if failFast && globalStatus.HasErrors() { - stopFlag.SetSoft(true) return nil } } @@ -207,7 +194,6 @@ func validationJob( g *generators.Generator, globalStatus *status.GlobalStatus, logger *zap.Logger, - stopFlag *stop.Flag, failFast, _ bool, ) error { schemaConfig := &schemaCfg @@ -218,11 +204,8 @@ func validationJob( }() for { - if stopFlag.IsHardOrSoft() { - return nil - } select { - case <-stopFlag.SignalChannel(): + case <-ctx.Done(): return nil case hb := <-pump: time.Sleep(hb) @@ -262,7 +245,6 @@ func validationJob( } if failFast && globalStatus.HasErrors() { - stopFlag.SetSoft(true) return nil } } @@ -282,7 +264,6 @@ func warmupJob( g *generators.Generator, globalStatus *status.GlobalStatus, logger *zap.Logger, - stopFlag *stop.Flag, failFast, _ bool, ) error { schemaConfig := &schemaCfg @@ -292,10 +273,13 @@ func warmupJob( logger.Info("ending warmup loop") }() for { - if stopFlag.IsHardOrSoft() { + select { + case <-ctx.Done(): logger.Debug("warmup job terminated") return nil + default: } + // Do we care about errors during warmup? err := mutation(ctx, schema, schemaConfig, table, s, r, p, g, globalStatus, false, logger) if err != nil { @@ -303,7 +287,6 @@ func warmupJob( } if failFast && globalStatus.HasErrors() { - stopFlag.SetSoft(true) return nil } } diff --git a/pkg/jobs/pump.go b/pkg/jobs/pump.go index c929f8ce..4baf6a98 100644 --- a/pkg/jobs/pump.go +++ b/pkg/jobs/pump.go @@ -15,15 +15,14 @@ package jobs import ( + "context" "time" - "github.com/scylladb/gemini/pkg/stop" - "go.uber.org/zap" "golang.org/x/exp/rand" ) -func NewPump(stopFlag *stop.Flag, logger *zap.Logger) chan time.Duration { +func NewPump(ctx context.Context, logger *zap.Logger) <-chan time.Duration { pump := make(chan time.Duration, 10000) logger = logger.Named("Pump") go func() { @@ -32,8 +31,14 @@ func NewPump(stopFlag *stop.Flag, logger *zap.Logger) chan time.Duration { close(pump) logger.Debug("pump channel closed") }() - for !stopFlag.IsHardOrSoft() { - pump <- newHeartBeat() + + for { + select { + case <-ctx.Done(): + return + default: + pump <- newHeartBeat() + } } }() diff --git a/pkg/stop/flag.go b/pkg/stop/flag.go deleted file mode 100644 index 54c6e1f2..00000000 --- a/pkg/stop/flag.go +++ /dev/null @@ -1,221 +0,0 @@ -// Copyright 2019 ScyllaDB -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package stop - -import ( - "context" - "fmt" - "os" - "os/signal" - "sync" - "sync/atomic" - "syscall" - - "go.uber.org/zap" -) - -const ( - SignalNoop uint32 = iota - SignalSoftStop - SignalHardStop -) - -type SignalChannel chan uint32 - -var closedChan = createClosedChan() - -func createClosedChan() SignalChannel { - ch := make(SignalChannel) - close(ch) - return ch -} - -type SyncList[T any] struct { - children []T - childrenLock sync.RWMutex -} - -func (f *SyncList[T]) Append(el T) { - f.childrenLock.Lock() - defer f.childrenLock.Unlock() - f.children = append(f.children, el) -} - -func (f *SyncList[T]) Get() []T { - f.childrenLock.RLock() - defer f.childrenLock.RUnlock() - return f.children -} - -type logger interface { - Debug(msg string, fields ...zap.Field) -} - -type Flag struct { - name string - log logger - ch atomic.Pointer[SignalChannel] - parent *Flag - children SyncList[*Flag] - stopHandlers SyncList[func(signal uint32)] - val atomic.Uint32 -} - -func (s *Flag) Name() string { - return s.name -} - -func (s *Flag) closeChannel() { - ch := s.ch.Swap(&closedChan) - if ch != &closedChan { - close(*ch) - } -} - -func (s *Flag) sendSignal(signal uint32, sendToParent bool) bool { - s.log.Debug(fmt.Sprintf("flag %s received signal %s", s.name, GetStateName(signal))) - s.closeChannel() - out := s.val.CompareAndSwap(SignalNoop, signal) - if !out { - return false - } - - for _, handler := range s.stopHandlers.Get() { - handler(signal) - } - - for _, child := range s.children.Get() { - child.sendSignal(signal, sendToParent) - } - if sendToParent && s.parent != nil { - s.parent.sendSignal(signal, sendToParent) - } - return out -} - -func (s *Flag) SetHard(sendToParent bool) bool { - return s.sendSignal(SignalHardStop, sendToParent) -} - -func (s *Flag) SetSoft(sendToParent bool) bool { - return s.sendSignal(SignalSoftStop, sendToParent) -} - -func (s *Flag) CreateChild(name string) *Flag { - child := newFlag(name, s) - s.children.Append(child) - val := s.val.Load() - switch val { - case SignalSoftStop, SignalHardStop: - child.sendSignal(val, false) - } - return child -} - -func (s *Flag) SignalChannel() SignalChannel { - return *s.ch.Load() -} - -func (s *Flag) IsSoft() bool { - return s.val.Load() == SignalSoftStop -} - -func (s *Flag) IsHard() bool { - return s.val.Load() == SignalHardStop -} - -func (s *Flag) IsHardOrSoft() bool { - return s.val.Load() != SignalNoop -} - -func (s *Flag) AddHandler(handler func(signal uint32)) { - s.stopHandlers.Append(handler) - val := s.val.Load() - switch val { - case SignalSoftStop, SignalHardStop: - handler(val) - } -} - -func (s *Flag) AddHandler2(handler func(), expectedSignal uint32) { - s.AddHandler(func(signal uint32) { - switch expectedSignal { - case SignalNoop: - handler() - default: - if signal == expectedSignal { - handler() - } - } - }) -} - -func (s *Flag) CancelContextOnSignal(ctx context.Context, expectedSignal uint32) context.Context { - ctx, cancel := context.WithCancel(ctx) - s.AddHandler2(cancel, expectedSignal) - return ctx -} - -func (s *Flag) SetLogger(log logger) { - s.log = log -} - -func NewFlag(name string) *Flag { - return newFlag(name, nil) -} - -func newFlag(name string, parent *Flag) *Flag { - out := Flag{ - name: name, - parent: parent, - log: zap.NewNop(), - } - ch := make(SignalChannel) - out.ch.Store(&ch) - return &out -} - -func StartOsSignalsTransmitter(logger *zap.Logger, flags ...*Flag) { - graceful := make(chan os.Signal, 1) - signal.Notify(graceful, syscall.SIGTERM, syscall.SIGINT) - go func() { - sig := <-graceful - switch sig { - case syscall.SIGINT: - for i := range flags { - flags[i].SetSoft(true) - } - logger.Info("Get SIGINT signal, begin soft stop.") - default: - for i := range flags { - flags[i].SetHard(true) - } - logger.Info("Get SIGTERM signal, begin hard stop.") - } - }() -} - -func GetStateName(state uint32) string { - switch state { - case SignalSoftStop: - return "soft" - case SignalHardStop: - return "hard" - case SignalNoop: - return "no-signal" - default: - panic(fmt.Sprintf("unexpected signal %d", state)) - } -} diff --git a/pkg/stop/flag_test.go b/pkg/stop/flag_test.go deleted file mode 100644 index 81a4b344..00000000 --- a/pkg/stop/flag_test.go +++ /dev/null @@ -1,429 +0,0 @@ -// Copyright 2019 ScyllaDB -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package stop_test - -import ( - "context" - "errors" - "fmt" - "reflect" - "runtime" - "strings" - "sync/atomic" - "testing" - "time" - - "github.com/scylladb/gemini/pkg/stop" -) - -func TestHardStop(t *testing.T) { - t.Parallel() - testFlag, ctx, workersDone := initVars() - workers := 30 - - testSignals(t, workersDone, workers, testFlag.IsHard, testFlag.SetHard) - if ctx.Err() == nil { - t.Error("Error:SetHard function does not apply hardStopHandler") - } -} - -func TestSoftStop(t *testing.T) { - t.Parallel() - testFlag, ctx, workersDone := initVars() - workers := 30 - - testSignals(t, workersDone, workers, testFlag.IsSoft, testFlag.SetSoft) - if ctx.Err() != nil { - t.Error("Error:SetSoft function apply hardStopHandler") - } -} - -func TestSoftOrHardStop(t *testing.T) { - t.Parallel() - testFlag, ctx, workersDone := initVars() - workers := 30 - - testSignals(t, workersDone, workers, testFlag.IsHardOrSoft, testFlag.SetSoft) - if ctx.Err() != nil { - t.Error("Error:SetSoft function apply hardStopHandler") - } - - workersDone.Store(uint32(0)) - testSignals(t, workersDone, workers, testFlag.IsHardOrSoft, testFlag.SetHard) - if ctx.Err() != nil { - t.Error("Error:SetHard function apply hardStopHandler after SetSoft") - } - - testFlag, ctx, workersDone = initVars() - workersDone.Store(uint32(0)) - - testSignals(t, workersDone, workers, testFlag.IsHardOrSoft, testFlag.SetHard) - if ctx.Err() == nil { - t.Error("Error:SetHard function does not apply hardStopHandler") - } -} - -func initVars() (testFlag *stop.Flag, ctx context.Context, workersDone *atomic.Uint32) { - testFlagOut := stop.NewFlag("main_test") - ctx = testFlagOut.CancelContextOnSignal(context.Background(), stop.SignalHardStop) - workersDone = &atomic.Uint32{} - return testFlagOut, ctx, workersDone -} - -func testSignals( - t *testing.T, - workersDone *atomic.Uint32, - workers int, - checkFunc func() bool, - setFunc func(propagation bool) bool, -) { - t.Helper() - for i := 0; i != workers; i++ { - go func() { - for { - if checkFunc() { - workersDone.Add(1) - return - } - time.Sleep(10 * time.Millisecond) - } - }() - } - time.Sleep(200 * time.Millisecond) - setFunc(false) - - for i := 0; i != 10; i++ { - time.Sleep(100 * time.Millisecond) - if workersDone.Load() == uint32(workers) { - break - } - } - - setFuncName := runtime.FuncForPC(reflect.ValueOf(setFunc).Pointer()).Name() - setFuncName, _ = strings.CutSuffix(setFuncName, "-fm") - _, setFuncName, _ = strings.Cut(setFuncName, ".(") - setFuncName = strings.ReplaceAll(setFuncName, ").", ".") - checkFuncName := runtime.FuncForPC(reflect.ValueOf(checkFunc).Pointer()).Name() - checkFuncName, _ = strings.CutSuffix(checkFuncName, "-fm") - _, checkFuncName, _ = strings.Cut(checkFuncName, ".(") - checkFuncName = strings.ReplaceAll(checkFuncName, ").", ".") - - if workersDone.Load() != uint32(workers) { - t.Errorf("Error:%s or %s functions works not correctly %[2]s=%v", setFuncName, checkFuncName, checkFunc()) - } -} - -func TestSendToParent(t *testing.T) { - t.Parallel() - tcases := []tCase{ - { - testName: "parent-hard-true", - parentSignal: stop.SignalHardStop, - child1Signal: stop.SignalHardStop, - child11Signal: stop.SignalHardStop, - child12Signal: stop.SignalHardStop, - child2Signal: stop.SignalHardStop, - }, - { - testName: "parent-hard-false", - parentSignal: stop.SignalHardStop, - child1Signal: stop.SignalHardStop, - child11Signal: stop.SignalHardStop, - child12Signal: stop.SignalHardStop, - child2Signal: stop.SignalHardStop, - }, - { - testName: "parent-soft-true", - parentSignal: stop.SignalSoftStop, - child1Signal: stop.SignalSoftStop, - child11Signal: stop.SignalSoftStop, - child12Signal: stop.SignalSoftStop, - child2Signal: stop.SignalSoftStop, - }, - { - testName: "parent-soft-false", - parentSignal: stop.SignalSoftStop, - child1Signal: stop.SignalSoftStop, - child11Signal: stop.SignalSoftStop, - child12Signal: stop.SignalSoftStop, - child2Signal: stop.SignalSoftStop, - }, - { - testName: "child1-soft-true", - parentSignal: stop.SignalSoftStop, - child1Signal: stop.SignalSoftStop, - child11Signal: stop.SignalSoftStop, - child12Signal: stop.SignalSoftStop, - child2Signal: stop.SignalSoftStop, - }, - { - testName: "child1-soft-false", - parentSignal: stop.SignalNoop, - child1Signal: stop.SignalSoftStop, - child11Signal: stop.SignalSoftStop, - child12Signal: stop.SignalSoftStop, - child2Signal: stop.SignalNoop, - }, - { - testName: "child11-soft-true", - parentSignal: stop.SignalSoftStop, - child1Signal: stop.SignalSoftStop, - child11Signal: stop.SignalSoftStop, - child12Signal: stop.SignalSoftStop, - child2Signal: stop.SignalSoftStop, - }, - { - testName: "child11-soft-false", - parentSignal: stop.SignalNoop, - child1Signal: stop.SignalNoop, - child11Signal: stop.SignalSoftStop, - child12Signal: stop.SignalNoop, - child2Signal: stop.SignalNoop, - }, - } - for id := range tcases { - tcase := tcases[id] - t.Run(tcase.testName, func(t *testing.T) { - t.Parallel() - if err := tcase.runTest(); err != nil { - t.Error(err) - } - }) - } -} - -// nolint: govet -type parentChildInfo struct { - parent *stop.Flag - parentSignal uint32 - child1 *stop.Flag - child1Signal uint32 - child11 *stop.Flag - child11Signal uint32 - child12 *stop.Flag - child12Signal uint32 - child2 *stop.Flag - child2Signal uint32 -} - -func (t *parentChildInfo) getFlag(flagName string) *stop.Flag { - switch flagName { - case "parent": - return t.parent - case "child1": - return t.child1 - case "child2": - return t.child2 - case "child11": - return t.child11 - case "child12": - return t.child12 - default: - panic(fmt.Sprintf("no such flag %s", flagName)) - } -} - -func (t *parentChildInfo) getFlagHandlerState(flagName string) uint32 { - switch flagName { - case "parent": - return t.parentSignal - case "child1": - return t.child1Signal - case "child2": - return t.child2Signal - case "child11": - return t.child11Signal - case "child12": - return t.child12Signal - default: - panic(fmt.Sprintf("no such flag %s", flagName)) - } -} - -func (t *parentChildInfo) checkFlagState(flag *stop.Flag, expectedState uint32) error { - var err error - flagName := flag.Name() - state := t.getFlagHandlerState(flagName) - if state != expectedState { - err = errors.Join(err, fmt.Errorf("flag %s handler has state %s while it is expected to be %s", flagName, stop.GetStateName(state), stop.GetStateName(expectedState))) - } - flagState := getFlagState(flag) - if stop.GetStateName(expectedState) != flagState { - err = errors.Join(err, fmt.Errorf("flag %s has state %s while it is expected to be %s", flagName, flagState, stop.GetStateName(expectedState))) - } - return err -} - -type tCase struct { - testName string - parentSignal uint32 - child1Signal uint32 - child11Signal uint32 - child12Signal uint32 - child2Signal uint32 -} - -func (t *tCase) runTest() error { - chunk := strings.Split(t.testName, "-") - if len(chunk) != 3 { - panic(fmt.Sprintf("wrong test name %s", t.testName)) - } - flagName := chunk[0] - signalTypeName := chunk[1] - sendToParentName := chunk[2] - - var sendToParent bool - switch sendToParentName { - case "true": - sendToParent = true - case "false": - sendToParent = false - default: - panic(fmt.Sprintf("wrong test name %s", t.testName)) - } - runt := newParentChildInfo() - flag := runt.getFlag(flagName) - switch signalTypeName { - case "soft": - flag.SetSoft(sendToParent) - case "hard": - flag.SetHard(sendToParent) - default: - panic(fmt.Sprintf("wrong test name %s", t.testName)) - } - var err error - err = errors.Join(err, runt.checkFlagState(runt.parent, t.parentSignal)) - err = errors.Join(err, runt.checkFlagState(runt.child1, t.child1Signal)) - err = errors.Join(err, runt.checkFlagState(runt.child2, t.child2Signal)) - err = errors.Join(err, runt.checkFlagState(runt.child11, t.child11Signal)) - err = errors.Join(err, runt.checkFlagState(runt.child12, t.child12Signal)) - return err -} - -func newParentChildInfo() *parentChildInfo { - parent := stop.NewFlag("parent") - child1 := parent.CreateChild("child1") - out := parentChildInfo{ - parent: parent, - child1: child1, - child11: child1.CreateChild("child11"), - child12: child1.CreateChild("child12"), - child2: parent.CreateChild("child2"), - } - - out.parent.AddHandler(func(signal uint32) { - out.parentSignal = signal - }) - out.child1.AddHandler(func(signal uint32) { - out.child1Signal = signal - }) - out.child11.AddHandler(func(signal uint32) { - out.child11Signal = signal - }) - out.child12.AddHandler(func(signal uint32) { - out.child12Signal = signal - }) - out.child2.AddHandler(func(signal uint32) { - out.child2Signal = signal - }) - return &out -} - -func getFlagState(flag *stop.Flag) string { - switch { - case flag.IsSoft(): - return "soft" - case flag.IsHard(): - return "hard" - default: - return "no-signal" - } -} - -func TestSignalChannel(t *testing.T) { - t.Parallel() - t.Run("single-no-signal", func(t *testing.T) { - t.Parallel() - flag := stop.NewFlag("parent") - select { - case <-flag.SignalChannel(): - t.Error("should not get the signal") - case <-time.Tick(200 * time.Millisecond): - } - }) - - t.Run("single-beforehand", func(t *testing.T) { - t.Parallel() - flag := stop.NewFlag("parent") - flag.SetSoft(true) - <-flag.SignalChannel() - }) - - t.Run("single-normal", func(t *testing.T) { - t.Parallel() - flag := stop.NewFlag("parent") - go func() { - time.Sleep(200 * time.Millisecond) - flag.SetSoft(true) - }() - <-flag.SignalChannel() - }) - - t.Run("parent-beforehand", func(t *testing.T) { - t.Parallel() - parent := stop.NewFlag("parent") - child := parent.CreateChild("child") - parent.SetSoft(true) - <-child.SignalChannel() - }) - - t.Run("parent-beforehand", func(t *testing.T) { - t.Parallel() - parent := stop.NewFlag("parent") - parent.SetSoft(true) - child := parent.CreateChild("child") - <-child.SignalChannel() - }) - - t.Run("parent-normal", func(t *testing.T) { - t.Parallel() - parent := stop.NewFlag("parent") - child := parent.CreateChild("child") - go func() { - time.Sleep(200 * time.Millisecond) - parent.SetSoft(true) - }() - <-child.SignalChannel() - }) - - t.Run("child-beforehand", func(t *testing.T) { - t.Parallel() - parent := stop.NewFlag("parent") - child := parent.CreateChild("child") - child.SetSoft(true) - <-parent.SignalChannel() - }) - - t.Run("child-normal", func(t *testing.T) { - t.Parallel() - parent := stop.NewFlag("parent") - child := parent.CreateChild("child") - go func() { - time.Sleep(200 * time.Millisecond) - child.SetSoft(true) - }() - <-parent.SignalChannel() - }) -}