diff --git a/CHANGELOG.md b/CHANGELOG.md index e86c1ddd..bc99cd48 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -2,6 +2,7 @@ ## Unreleased +- Ensure proper termination when errors happen. - Fix mutation timestamps to match on system under test and test oracle. - Gemini now timestamps errors for easier correlation. diff --git a/cmd/gemini/root.go b/cmd/gemini/root.go index a416b5c2..4b3d0b89 100644 --- a/cmd/gemini/root.go +++ b/cmd/gemini/root.go @@ -170,7 +170,7 @@ func run(cmd *cobra.Command, args []string) { if verbose { fmt.Println(stmt) } - if err := session.Mutate(stmt); err != nil { + if err := session.Mutate(context.Background(), stmt); err != nil { fmt.Printf("%v", err) return } @@ -180,7 +180,7 @@ func run(cmd *cobra.Command, args []string) { if verbose { fmt.Println(stmt) } - if err := session.Mutate(stmt); err != nil { + if err := session.Mutate(context.Background(), stmt); err != nil { fmt.Printf("%v", err) return } @@ -229,12 +229,12 @@ func runJob(f testJob, schema *gemini.Schema, s *gemini.Session, mode string, ou for { select { case <-timer.C: + cancelWorkers() + testRes = drain(c, testRes) testRes.PrintResult(out) fmt.Println("Test run completed. Exiting.") - cancelWorkers() return case <-reporterCtx.Done(): - testRes.PrintResult(out) return case res := <-c: testRes = res.Merge(&testRes) @@ -242,24 +242,26 @@ func runJob(f testJob, schema *gemini.Schema, s *gemini.Session, mode string, ou sp.Suffix = fmt.Sprintf(" Running Gemini... %v", testRes) } if testRes.ReadErrors > 0 { - testRes.PrintResult(out) - fmt.Println(testRes.Errors) if failFast { fmt.Println("Error in data validation. Exiting.") cancelWorkers() + testRes = drain(c, testRes) + testRes.PrintResult(out) return } + testRes.PrintResult(out) } } } }(duration) workers.Wait() + close(c) cancelReporter() reporter.Wait() } -func mutationJob(schema *gemini.Schema, table gemini.Table, s *gemini.Session, p gemini.PartitionRange, testStatus *Status, out *os.File) { +func mutationJob(ctx context.Context, schema *gemini.Schema, table gemini.Table, s *gemini.Session, p gemini.PartitionRange, testStatus *Status, out *os.File) { mutateStmt, err := schema.GenMutateStmt(table, &p) if err != nil { fmt.Printf("Failed! Mutation statement generation failed: '%v'\n", err) @@ -271,7 +273,7 @@ func mutationJob(schema *gemini.Schema, table gemini.Table, s *gemini.Session, p if verbose { fmt.Println(mutateStmt.PrettyCQL()) } - if err := s.Mutate(mutateQuery, mutateValues...); err != nil { + if err := s.Mutate(ctx, mutateQuery, mutateValues...); err != nil { e := gemini.JobError{ Timestamp: time.Now(), Message: "Mutation failed: " + err.Error(), @@ -284,14 +286,14 @@ func mutationJob(schema *gemini.Schema, table gemini.Table, s *gemini.Session, p } } -func validationJob(schema *gemini.Schema, table gemini.Table, s *gemini.Session, p gemini.PartitionRange, testStatus *Status, out *os.File) { +func validationJob(ctx context.Context, schema *gemini.Schema, table gemini.Table, s *gemini.Session, p gemini.PartitionRange, testStatus *Status, out *os.File) { checkStmt := schema.GenCheckStmt(table, &p) checkQuery := checkStmt.Query checkValues := checkStmt.Values() if verbose { fmt.Println(checkStmt.PrettyCQL()) } - if err := s.Check(table, checkQuery, checkValues...); err != nil { + if err := s.Check(ctx, table, checkQuery, checkValues...); err != nil { // De-duplication needed? e := gemini.JobError{ Timestamp: time.Now(), @@ -318,15 +320,15 @@ func Job(ctx context.Context, wg *sync.WaitGroup, schema *gemini.Schema, table g } switch mode { case writeMode: - mutationJob(schema, table, s, p, &testStatus, out) + mutationJob(ctx, schema, table, s, p, &testStatus, out) case readMode: - validationJob(schema, table, s, p, &testStatus, out) + validationJob(ctx, schema, table, s, p, &testStatus, out) default: ind := p.Rand.Intn(100000) % 2 if ind == 0 { - mutationJob(schema, table, s, p, &testStatus, out) + mutationJob(ctx, schema, table, s, p, &testStatus, out) } else { - validationJob(schema, table, s, p, &testStatus, out) + validationJob(ctx, schema, table, s, p, &testStatus, out) } } @@ -334,7 +336,7 @@ func Job(ctx context.Context, wg *sync.WaitGroup, schema *gemini.Schema, table g c <- testStatus testStatus = Status{} } - if failFast && testStatus.ReadErrors > 0 { + if failFast && (testStatus.ReadErrors > 0 || testStatus.WriteErrors > 0) { break } i++ @@ -390,3 +392,10 @@ func printSetup() error { tw.Flush() return nil } + +func drain(ch chan Status, testRes Status) Status { + for res := range ch { + testRes = res.Merge(&testRes) + } + return testRes +} diff --git a/session.go b/session.go index c258b691..e4b617f2 100644 --- a/session.go +++ b/session.go @@ -1,6 +1,7 @@ package gemini import ( + "context" "fmt" "math/big" "sort" @@ -55,27 +56,27 @@ func (s *Session) Close() { s.oracleSession.Close() } -func (s *Session) Mutate(query string, values ...interface{}) error { +func (s *Session) Mutate(ctx context.Context, query string, values ...interface{}) error { ts := time.Now() var tsUsec int64 = ts.UnixNano() / 1000 - if err := s.testSession.Query(query, values...).WithTimestamp(tsUsec).Exec(); err != nil { + if err := s.testSession.Query(query, values...).WithContext(ctx).WithTimestamp(tsUsec).Exec(); !ignore(err) { return fmt.Errorf("%v [cluster = test, query = '%s']", err, query) } - if err := s.oracleSession.Query(query, values...).WithTimestamp(tsUsec).Exec(); err != nil { + if err := s.oracleSession.Query(query, values...).WithContext(ctx).WithTimestamp(tsUsec).Exec(); !ignore(err) { return fmt.Errorf("%v [cluster = oracle, query = '%s']", err, query) } return nil } -func (s *Session) Check(table Table, query string, values ...interface{}) (err error) { - testIter := s.testSession.Query(query, values...).Iter() - oracleIter := s.oracleSession.Query(query, values...).Iter() +func (s *Session) Check(ctx context.Context, table Table, query string, values ...interface{}) (err error) { + testIter := s.testSession.Query(query, values...).WithContext(ctx).Iter() + oracleIter := s.oracleSession.Query(query, values...).WithContext(ctx).Iter() defer func() { - if e := testIter.Close(); e != nil { - err = multierr.Append(err, errors.Errorf("test system failed: %s", err.Error())) + if e := testIter.Close(); !ignore(e) { + err = multierr.Append(err, errors.Errorf("test system failed: %s", e.Error())) } - if e := oracleIter.Close(); e != nil { - err = multierr.Append(err, errors.Errorf("oracle failed: %s", err.Error())) + if e := oracleIter.Close(); !ignore(e) { + err = multierr.Append(err, errors.Errorf("oracle failed: %s", e.Error())) } }() @@ -169,3 +170,15 @@ func loadSet(iter *gocql.Iter) []map[string]interface{} { } return rows } + +func ignore(err error) bool { + if err == nil { + return true + } + switch err { + case context.Canceled, context.DeadlineExceeded: + return true + default: + return false + } +}