diff --git a/internal/partitions/partitions.go b/internal/partitions/partitions.go index 01419808..022a54e4 100644 --- a/internal/partitions/partitions.go +++ b/internal/partitions/partitions.go @@ -119,7 +119,6 @@ const ( func PartitionCollectionWithSize( ctx context.Context, uuidEntry *uuidutil.NamespaceAndUUID, - retryer *retry.Retryer, srcClient *mongo.Client, replicatorList []Replicator, subLogger *logger.Logger, @@ -137,7 +136,6 @@ func PartitionCollectionWithSize( partitions, docCount, byteCount, err := PartitionCollectionWithParameters( ctx, uuidEntry, - retryer, srcClient, replicatorList, defaultSampleRate, @@ -153,7 +151,6 @@ func PartitionCollectionWithSize( return PartitionCollectionWithParameters( ctx, uuidEntry, - retryer, srcClient, replicatorList, defaultSampleRate, @@ -174,7 +171,6 @@ func PartitionCollectionWithSize( func PartitionCollectionWithParameters( ctx context.Context, uuidEntry *uuidutil.NamespaceAndUUID, - retryer *retry.Retryer, srcClient *mongo.Client, replicatorList []Replicator, sampleRate float64, @@ -191,13 +187,13 @@ func PartitionCollectionWithParameters( // Get the collection's size in bytes and its document count. It is okay if these return zero since there might still be // items in the collection. Rely on getOuterIDBound to do a majority read to determine if we continue processing the collection. - collSizeInBytes, collDocCount, isCapped, err := GetSizeAndDocumentCount(ctx, subLogger, retryer, srcColl) + collSizeInBytes, collDocCount, isCapped, err := GetSizeAndDocumentCount(ctx, subLogger, srcColl) if err != nil { return nil, 0, 0, err } // The lower bound for the collection. There is no partitioning to do if the bound is nil. - minIDBound, err := getOuterIDBound(ctx, subLogger, retryer, minBound, srcDB, uuidEntry.CollName, globalFilter) + minIDBound, err := getOuterIDBound(ctx, subLogger, minBound, srcDB, uuidEntry.CollName, globalFilter) if err != nil { return nil, 0, 0, err } @@ -210,7 +206,7 @@ func PartitionCollectionWithParameters( } // The upper bound for the collection. There is no partitioning to do if the bound is nil. - maxIDBound, err := getOuterIDBound(ctx, subLogger, retryer, maxBound, srcDB, uuidEntry.CollName, globalFilter) + maxIDBound, err := getOuterIDBound(ctx, subLogger, maxBound, srcDB, uuidEntry.CollName, globalFilter) if err != nil { return nil, 0, 0, err } @@ -232,7 +228,7 @@ func PartitionCollectionWithParameters( // If a filter is used for partitioning, number of partitions is calculated with the ratio of filtered documents. if len(globalFilter) > 0 { - numFilteredDocs, filteredCntErr := GetDocumentCountAfterFiltering(ctx, subLogger, retryer, srcColl, globalFilter) + numFilteredDocs, filteredCntErr := GetDocumentCountAfterFiltering(ctx, subLogger, srcColl, globalFilter) if filteredCntErr == nil { numPartitions = getNumPartitions(collSizeInBytes, partitionSizeInBytes, float64(numFilteredDocs)/float64(collDocCount)) } else { @@ -251,7 +247,6 @@ func PartitionCollectionWithParameters( midIDBounds, collDropped, err := getMidIDBounds( ctx, subLogger, - retryer, srcDB, uuidEntry.CollName, collDocCount, @@ -314,7 +309,7 @@ func PartitionCollectionWithParameters( // capped status, in that order. // // Exported for usage in integration tests. -func GetSizeAndDocumentCount(ctx context.Context, logger *logger.Logger, retryer *retry.Retryer, srcColl *mongo.Collection) (int64, int64, bool, error) { +func GetSizeAndDocumentCount(ctx context.Context, logger *logger.Logger, srcColl *mongo.Collection) (int64, int64, bool, error) { srcDB := srcColl.Database() collName := srcColl.Name() @@ -324,39 +319,43 @@ func GetSizeAndDocumentCount(ctx context.Context, logger *logger.Logger, retryer Capped bool `bson:"capped"` }{} - err := retryer.Run(ctx, logger, func(ctx context.Context, ri *retry.FuncInfo) error { - ri.Log(logger.Logger, "collStats", "source", srcDB.Name(), collName, "Retrieving collection size and document count.") - request := bson.D{ - {"aggregate", collName}, - {"pipeline", mongo.Pipeline{ - bson.D{{"$collStats", bson.D{ - {"storageStats", bson.E{"scale", 1}}, - }}}, - // The "$group" here behaves as a project and rename when there's only one - // document (non-sharded case). When there are multiple documents (one for - // each shard) it correctly sums the counts and sizes from each shard. - bson.D{{"$group", bson.D{ - {"_id", "ns"}, - {"count", bson.D{{"$sum", "$storageStats.count"}}}, - {"size", bson.D{{"$sum", "$storageStats.size"}}}, - {"capped", bson.D{{"$first", "$capped"}}}}}}, - }}, - {"cursor", bson.D{}}, - } + err := retry.New().WithCallback( + func(ctx context.Context, ri *retry.FuncInfo) error { + ri.Log(logger.Logger, "collStats", "source", srcDB.Name(), collName, "Retrieving collection size and document count.") + request := bson.D{ + {"aggregate", collName}, + {"pipeline", mongo.Pipeline{ + bson.D{{"$collStats", bson.D{ + {"storageStats", bson.E{"scale", 1}}, + }}}, + // The "$group" here behaves as a project and rename when there's only one + // document (non-sharded case). When there are multiple documents (one for + // each shard) it correctly sums the counts and sizes from each shard. + bson.D{{"$group", bson.D{ + {"_id", "ns"}, + {"count", bson.D{{"$sum", "$storageStats.count"}}}, + {"size", bson.D{{"$sum", "$storageStats.size"}}}, + {"capped", bson.D{{"$first", "$capped"}}}}}}, + }}, + {"cursor", bson.D{}}, + } - cursor, driverErr := srcDB.RunCommandCursor(ctx, request) - if driverErr != nil { - return driverErr - } + cursor, driverErr := srcDB.RunCommandCursor(ctx, request) + if driverErr != nil { + return driverErr + } - defer cursor.Close(ctx) - if cursor.Next(ctx) { - if err := cursor.Decode(&value); err != nil { - return errors.Wrapf(err, "failed to decode $collStats response for source namespace %s.%s", srcDB.Name(), collName) + defer cursor.Close(ctx) + if cursor.Next(ctx) { + if err := cursor.Decode(&value); err != nil { + return errors.Wrapf(err, "failed to decode $collStats response for source namespace %s.%s", srcDB.Name(), collName) + } } - } - return nil - }) + return nil + }, + "retrieving %#q's statistics", + srcDB.Name()+"."+collName, + ).Run(ctx, logger) // TODO (REP-960): remove this check. // If we get NamespaceNotFoundError then return 0,0 since we won't do any partitioning with those returns @@ -380,7 +379,7 @@ func GetSizeAndDocumentCount(ctx context.Context, logger *logger.Logger, retryer // // This function could take a long time, especially if the collection does not have an index // on the filtered fields. -func GetDocumentCountAfterFiltering(ctx context.Context, logger *logger.Logger, retryer *retry.Retryer, srcColl *mongo.Collection, filter map[string]any) (int64, error) { +func GetDocumentCountAfterFiltering(ctx context.Context, logger *logger.Logger, srcColl *mongo.Collection, filter map[string]any) (int64, error) { srcDB := srcColl.Database() collName := srcColl.Name() @@ -395,27 +394,31 @@ func GetDocumentCountAfterFiltering(ctx context.Context, logger *logger.Logger, } pipeline = append(pipeline, bson.D{{"$count", "numFilteredDocs"}}) - err := retryer.Run(ctx, logger, func(ctx context.Context, ri *retry.FuncInfo) error { - ri.Log(logger.Logger, "count", "source", srcDB.Name(), collName, "Counting filtered documents.") - request := bson.D{ - {"aggregate", collName}, - {"pipeline", pipeline}, - {"cursor", bson.D{}}, - } + err := retry.New().WithCallback( + func(ctx context.Context, ri *retry.FuncInfo) error { + ri.Log(logger.Logger, "count", "source", srcDB.Name(), collName, "Counting filtered documents.") + request := bson.D{ + {"aggregate", collName}, + {"pipeline", pipeline}, + {"cursor", bson.D{}}, + } - cursor, driverErr := srcDB.RunCommandCursor(ctx, request) - if driverErr != nil { - return driverErr - } + cursor, driverErr := srcDB.RunCommandCursor(ctx, request) + if driverErr != nil { + return driverErr + } - defer cursor.Close(ctx) - if cursor.Next(ctx) { - if err := cursor.Decode(&value); err != nil { - return errors.Wrapf(err, "failed to decode $count response (%+v) for source namespace %s.%s after filter (%+v)", cursor.Current, srcDB.Name(), collName, filter) + defer cursor.Close(ctx) + if cursor.Next(ctx) { + if err := cursor.Decode(&value); err != nil { + return errors.Wrapf(err, "failed to decode $count response (%+v) for source namespace %s.%s after filter (%+v)", cursor.Current, srcDB.Name(), collName, filter) + } } - } - return nil - }) + return nil + }, + "counting %#q's filtered documents", + srcDB.Name()+"."+collName, + ).Run(ctx, logger) // TODO (REP-960): remove this check. // If we get NamespaceNotFoundError then return 0 since we won't do any partitioning with those returns @@ -458,7 +461,6 @@ func getNumPartitions(collSizeInBytes, partitionSizeInBytes int64, filteredRatio func getOuterIDBound( ctx context.Context, subLogger *logger.Logger, - retryer *retry.Retryer, minOrMaxBound minOrMaxBound, srcDB *mongo.Database, collName string, @@ -488,30 +490,35 @@ func getOuterIDBound( }...) // Get one document containing only the smallest or largest _id value in the collection. - err := retryer.Run(ctx, subLogger, func(ctx context.Context, ri *retry.FuncInfo) error { - ri.Log(subLogger.Logger, "aggregate", "source", srcDB.Name(), collName, fmt.Sprintf("getting %s _id partition bound", minOrMaxBound)) - cursor, cmdErr := - srcDB.RunCommandCursor(ctx, bson.D{ - {"aggregate", collName}, - {"pipeline", pipeline}, - {"hint", bson.D{{"_id", 1}}}, - {"cursor", bson.D{}}, - }) - - if cmdErr != nil { - return cmdErr - } + err := retry.New().WithCallback( + func(ctx context.Context, ri *retry.FuncInfo) error { + ri.Log(subLogger.Logger, "aggregate", "source", srcDB.Name(), collName, fmt.Sprintf("getting %s _id partition bound", minOrMaxBound)) + cursor, cmdErr := + srcDB.RunCommandCursor(ctx, bson.D{ + {"aggregate", collName}, + {"pipeline", pipeline}, + {"hint", bson.D{{"_id", 1}}}, + {"cursor", bson.D{}}, + }) + + if cmdErr != nil { + return cmdErr + } - // If we don't have at least one document, the collection is either empty or was dropped. - defer cursor.Close(ctx) - if !cursor.Next(ctx) { - return nil - } + // If we don't have at least one document, the collection is either empty or was dropped. + defer cursor.Close(ctx) + if !cursor.Next(ctx) { + return nil + } - // Return the _id value from that document. - docID, cmdErr = cursor.Current.LookupErr("_id") - return cmdErr - }) + // Return the _id value from that document. + docID, cmdErr = cursor.Current.LookupErr("_id") + return cmdErr + }, + "finding %#q's %s _id", + srcDB.Name()+"."+collName, + minOrMaxBound, + ).Run(ctx, subLogger) if err != nil { return nil, errors.Wrapf(err, "could not get %s _id bound for source collection '%s.%s'", minOrMaxBound, srcDB.Name(), collName) @@ -528,7 +535,6 @@ func getOuterIDBound( func getMidIDBounds( ctx context.Context, logger *logger.Logger, - retryer *retry.Retryer, srcDB *mongo.Database, collName string, collDocCount int64, @@ -576,48 +582,53 @@ func getMidIDBounds( // Get a cursor for the $sample and $bucketAuto aggregation. var midIDBounds []interface{} - agRetryer := retryer.WithErrorCodes(util.SampleTooManyDuplicates) - err := agRetryer.Run(ctx, logger, func(ctx context.Context, ri *retry.FuncInfo) error { - ri.Log(logger.Logger, "aggregate", "source", srcDB.Name(), collName, "Retrieving mid _id partition bounds using $sample.") - cursor, cmdErr := - srcDB.RunCommandCursor(ctx, bson.D{ - {"aggregate", collName}, - {"pipeline", pipeline}, - {"allowDiskUse", true}, - {"cursor", bson.D{}}, - }) - - if cmdErr != nil { - return errors.Wrapf(cmdErr, "failed to $sample and $bucketAuto documents for source namespace '%s.%s'", srcDB.Name(), collName) - } - - defer cursor.Close(ctx) - - // Iterate through all $bucketAuto documents of the form: - // { - // "_id" : { - // "min" : ... , - // "max" : ... - // }, - // "count" : ... - // } - midIDBounds = make([]interface{}, 0, numPartitions) - for cursor.Next(ctx) { - // Get a mid _id bound using the $bucketAuto document's max value. - bucketAutoDoc := make(bson.Raw, len(cursor.Current)) - copy(bucketAutoDoc, cursor.Current) - bound, err := bucketAutoDoc.LookupErr("_id", "max") - if err != nil { - return errors.Wrapf(err, "failed to lookup '_id.max' key in $bucketAuto document for source namespace '%s.%s'", srcDB.Name(), collName) - } - - // Append the copied bound to the other mid _id bounds. - midIDBounds = append(midIDBounds, bound) - ri.NoteSuccess() - } - - return cursor.Err() - }) + agRetryer := retry.New().WithErrorCodes(util.SampleTooManyDuplicates) + err := agRetryer. + WithCallback( + func(ctx context.Context, ri *retry.FuncInfo) error { + ri.Log(logger.Logger, "aggregate", "source", srcDB.Name(), collName, "Retrieving mid _id partition bounds using $sample.") + cursor, cmdErr := + srcDB.RunCommandCursor(ctx, bson.D{ + {"aggregate", collName}, + {"pipeline", pipeline}, + {"allowDiskUse", true}, + {"cursor", bson.D{}}, + }) + + if cmdErr != nil { + return errors.Wrapf(cmdErr, "failed to $sample and $bucketAuto documents for source namespace '%s.%s'", srcDB.Name(), collName) + } + + defer cursor.Close(ctx) + + // Iterate through all $bucketAuto documents of the form: + // { + // "_id" : { + // "min" : ... , + // "max" : ... + // }, + // "count" : ... + // } + midIDBounds = make([]interface{}, 0, numPartitions) + for cursor.Next(ctx) { + // Get a mid _id bound using the $bucketAuto document's max value. + bucketAutoDoc := make(bson.Raw, len(cursor.Current)) + copy(bucketAutoDoc, cursor.Current) + bound, err := bucketAutoDoc.LookupErr("_id", "max") + if err != nil { + return errors.Wrapf(err, "failed to lookup '_id.max' key in $bucketAuto document for source namespace '%s.%s'", srcDB.Name(), collName) + } + + // Append the copied bound to the other mid _id bounds. + midIDBounds = append(midIDBounds, bound) + ri.NoteSuccess() + } + + return cursor.Err() + }, + "finding %#q's _id partition boundaries", + srcDB.Name()+"."+collName, + ).Run(ctx, logger) if err != nil { return nil, false, errors.Wrapf(err, "encountered a problem in the cursor when trying to $sample and $bucketAuto aggregation for source namespace '%s.%s'", srcDB.Name(), collName) diff --git a/internal/retry/retry.go b/internal/retry/retry.go index a4b9fb18..6630bbe0 100644 --- a/internal/retry/retry.go +++ b/internal/retry/retry.go @@ -2,30 +2,22 @@ package retry import ( "context" - "errors" "fmt" - "math/rand" "time" "github.com/10gen/migration-verifier/internal/logger" + "github.com/10gen/migration-verifier/internal/reportutils" "github.com/10gen/migration-verifier/internal/util" + "github.com/10gen/migration-verifier/mmongo" + "github.com/10gen/migration-verifier/msync" + "github.com/pkg/errors" + "github.com/rs/zerolog" "github.com/samber/lo" "golang.org/x/sync/errgroup" ) type RetryCallback = func(context.Context, *FuncInfo) error -// Retry is a convenience that creates a retryer and executes it. -// See RunForTransientErrorsOnly for argument details. -func Retry( - ctx context.Context, - logger *logger.Logger, - callbacks ...RetryCallback, -) error { - retryer := New(DefaultDurationLimit) - return retryer.Run(ctx, logger, callbacks...) -} - // Run() runs each given callback in parallel. If none of them fail, // then no error is returned. // @@ -54,43 +46,57 @@ func Retry( // // This returns an error if the duration limit is reached, or if f() returns a // non-transient error. -func (r *Retryer) Run( - ctx context.Context, logger *logger.Logger, funcs ...RetryCallback, -) error { - return r.runRetryLoop(ctx, logger, funcs) +func (r *Retryer) Run(ctx context.Context, logger *logger.Logger) error { + return r.runRetryLoop(ctx, logger) } // runRetryLoop contains the core logic for the retry loops. func (r *Retryer) runRetryLoop( ctx context.Context, logger *logger.Logger, - funcs []RetryCallback, ) error { var err error + if len(r.callbacks) == 0 { + return errors.Errorf( + "retryer (%s) run with no callbacks", + r.description.OrElse("no description"), + ) + } + startTime := time.Now() li := &LoopInfo{ durationLimit: r.retryLimit, } funcinfos := lo.RepeatBy( - len(funcs), + len(r.callbacks), func(_ int) *FuncInfo { return &FuncInfo{ - lastResetTime: startTime, - loopInfo: li, + lastResetTime: msync.NewTypedAtomic(startTime), + loopDescription: r.description, + loopInfo: li, } }, ) sleepTime := minSleepTime for { + if li.attemptsSoFar > 0 { + r.addDescriptionToEvent(logger.Info()). + Int("attemptsSoFar", li.attemptsSoFar). + Msg("Retrying after failure.") + } + if beforeFunc, hasBefore := r.before.Get(); hasBefore { beforeFunc() } eg, egCtx := errgroup.WithContext(ctx) - for i, curFunc := range funcs { + + for i, curCbInfo := range r.callbacks { + curFunc := curCbInfo.callback + if curFunc == nil { panic("curFunc should be non-nil") } @@ -99,6 +105,31 @@ func (r *Retryer) runRetryLoop( } eg.Go(func() error { + cbDoneChan := make(chan struct{}) + defer close(cbDoneChan) + + go func() { + ticker := time.NewTicker(time.Minute) + defer ticker.Stop() + + for { + lastSuccessTime := funcinfos[i].lastResetTime.Load() + + select { + case <-cbDoneChan: + return + case <-ticker.C: + if funcinfos[i].lastResetTime.Load() == lastSuccessTime { + logger.Warn(). + Str("callbackDescription", curCbInfo.description). + Time("lastSuccessAt", lastSuccessTime). + Str("elapsedTime", reportutils.DurationToHMS(time.Since(lastSuccessTime))). + Msg("Operation has not reported success for a while.") + } + } + } + }() + err := curFunc(egCtx, funcinfos[i]) if err != nil { @@ -113,8 +144,16 @@ func (r *Retryer) runRetryLoop( } err = eg.Wait() + li.attemptsSoFar++ + // No error? Success! if err == nil { + if li.attemptsSoFar > 1 { + r.addDescriptionToEvent(logger.Info()). + Int("attempts", li.attemptsSoFar). + Msg("Retried operation succeeded.") + } + return nil } @@ -124,19 +163,19 @@ func (r *Retryer) runRetryLoop( panic(fmt.Sprintf("Error should be a %T, not %T: %v", groupErr, err, err)) } + failedFuncInfo := funcinfos[groupErr.funcNum] + // Not a transient error? Fail immediately. - if !r.shouldRetryWithSleep(logger, sleepTime, groupErr.errFromCallback) { + if !r.shouldRetryWithSleep(logger, sleepTime, *failedFuncInfo, groupErr.errFromCallback) { return groupErr.errFromCallback } - li.attemptNumber++ - // Our error is transient. If we've exhausted the allowed time // then fail. - failedFuncInfo := funcinfos[groupErr.funcNum] + if failedFuncInfo.GetDurationSoFar() > li.durationLimit { return RetryDurationLimitExceededErr{ - attempts: li.attemptNumber, + attempts: li.attemptsSoFar, duration: failedFuncInfo.GetDurationSoFar(), lastErr: groupErr.errFromCallback, } @@ -146,7 +185,9 @@ func (r *Retryer) runRetryLoop( // up to maxSleepTime. select { case <-ctx.Done(): - logger.Error().Err(ctx.Err()).Msg("Context was canceled. Aborting retry loop.") + r.addDescriptionToEvent(logger.Error()). + Err(ctx.Err()). + Msg("Context was canceled. Aborting retry loop.") return ctx.Err() case <-time.After(sleepTime): sleepTime *= sleepTimeMultiplier @@ -160,12 +201,27 @@ func (r *Retryer) runRetryLoop( // Set all of the funcs that did *not* fail as having just succeeded. for i, curInfo := range funcinfos { if i != groupErr.funcNum { - curInfo.lastResetTime = now + curInfo.lastResetTime.Store(now) } } } } +func (r *Retryer) addDescriptionToEvent(event *zerolog.Event) *zerolog.Event { + if description, hasDesc := r.description.Get(); hasDesc { + event.Str("description", description) + } else { + event.Strs("description", lo.Map( + r.callbacks, + func(cbInfo retryCallbackInfo, _ int) string { + return cbInfo.description + }, + )) + } + + return event +} + // // For the above function, there have historically been concerns regarding majority write concern // upon retrying a write operation to the server. Mongomirror explicitly handled this: @@ -179,36 +235,41 @@ func (r *Retryer) runRetryLoop( func (r *Retryer) shouldRetryWithSleep( logger *logger.Logger, sleepTime time.Duration, + funcinfo FuncInfo, err error, ) bool { - // Randomly retry approximately 1 in 100 calls to the wrapped - // function. This is only enabled in tests. - if r.retryRandomly && rand.Int()%100 == 0 { - logger.Debug().Msgf("Waiting %s seconds to retry operation because of test code forcing a retry.", sleepTime) - return true - } - if err == nil { - return false + panic("nil error should not get here") } - errCode := util.GetErrorCode(err) - if util.IsTransientError(err) { - logger.Warn().Int("error code", errCode).Err(err).Msgf( - "Waiting %s seconds to retry operation after transient error.", sleepTime) - return true + isTransient := util.IsTransientError(err) || lo.SomeBy( + r.additionalErrorCodes, + func(code int) bool { + return mmongo.ErrorHasCode(err, code) + }, + ) + + event := logger.WithLevel( + lo.Ternary(isTransient, zerolog.InfoLevel, zerolog.WarnLevel), + ) + + if loopDesc, hasLoopDesc := r.description.Get(); hasLoopDesc { + event.Str("operationDescription", loopDesc) } - for _, code := range r.additionalErrorCodes { - if code == errCode { - logger.Warn().Int("error code", errCode).Err(err).Msgf( - "Waiting %s seconds to retry operation after an error because it is in our additional codes list.", sleepTime) - return true - } + event.Str("callbackDescription", funcinfo.description). + Int("error code", util.GetErrorCode(err)). + Err(err) + + if isTransient { + event. + Stringer("delay", sleepTime). + Msg("Pausing before retrying after transient error.") + + return true } - logger.Debug().Err(err).Int("error code", errCode). - Msg("Not retrying on error because it is not transient nor is it in our additional codes list.") + event.Msg("Non-transient error occurred.") return false } diff --git a/internal/retry/retry_info.go b/internal/retry/retry_info.go index 9f96e22a..ea325ac7 100644 --- a/internal/retry/retry_info.go +++ b/internal/retry/retry_info.go @@ -4,6 +4,8 @@ import ( "time" "github.com/10gen/migration-verifier/internal/reportutils" + "github.com/10gen/migration-verifier/msync" + "github.com/10gen/migration-verifier/option" "github.com/rs/zerolog" ) @@ -13,14 +15,15 @@ import ( // The attempt number is 0-indexed (0 means this is the first attempt). // The duration tracks the duration of retrying for transient errors only. type LoopInfo struct { - attemptNumber int + attemptsSoFar int durationLimit time.Duration } type FuncInfo struct { - loopInfo *LoopInfo - - lastResetTime time.Time + loopInfo *LoopInfo + description string + loopDescription option.Option[string] + lastResetTime *msync.TypedAtomic[time.Time] } // Log will log a debug-level message for the current Info values and the provided strings. @@ -60,13 +63,13 @@ func (fi *FuncInfo) Log(logger *zerolog.Logger, cmdName string, clientType strin // GetAttemptNumber returns the Info's current attempt number (0-indexed). func (fi *FuncInfo) GetAttemptNumber() int { - return fi.loopInfo.attemptNumber + return fi.loopInfo.attemptsSoFar } // GetDurationSoFar returns the Info's current duration so far. This duration // applies to the duration of retrying for transient errors only. func (fi *FuncInfo) GetDurationSoFar() time.Duration { - return time.Since(fi.lastResetTime) + return time.Since(fi.lastResetTime.Load()) } // NoteSuccess is used to tell the retry util to reset its measurement @@ -76,5 +79,5 @@ func (fi *FuncInfo) GetDurationSoFar() time.Duration { // Call this after every successful command in a multi-command callback. // (It’s useless--but harmless--in a single-command callback.) func (i *FuncInfo) NoteSuccess() { - i.lastResetTime = time.Now() + i.lastResetTime.Store(time.Now()) } diff --git a/internal/retry/retryer.go b/internal/retry/retryer.go index 6c269113..76b20df3 100644 --- a/internal/retry/retryer.go +++ b/internal/retry/retryer.go @@ -1,52 +1,99 @@ package retry import ( + "fmt" + "slices" "time" "github.com/10gen/migration-verifier/option" ) +type retryCallbackInfo struct { + callback RetryCallback + description string +} + // Retryer handles retrying operations that fail because of network failures. type Retryer struct { retryLimit time.Duration - retryRandomly bool before option.Option[func()] + callbacks []retryCallbackInfo + description option.Option[string] additionalErrorCodes []int } -// New returns a new retryer. -func New(retryLimit time.Duration) *Retryer { - return NewWithRandomlyRetries(retryLimit, false) -} - -// NewWithRandomlyRetries returns a new retryer, but allows the option of setting the -// retryRandomly field. -func NewWithRandomlyRetries(retryLimit time.Duration, retryRandomly bool) *Retryer { +// New returns a new Retryer with DefaultDurationLimit as its time limit. +func New() *Retryer { return &Retryer{ - retryLimit: retryLimit, - retryRandomly: retryRandomly, + retryLimit: DefaultDurationLimit, } } // WithErrorCodes returns a new Retryer that will retry on the codes passed to -// this method. This allows for a single function to customize the codes it +// this method. This allows for a single retryer to customize the codes it // wants to retry on. Note that if the Retryer already has additional custom // error codes set, these are _replaced_ when this method is called. func (r *Retryer) WithErrorCodes(codes ...int) *Retryer { - r2 := *r + r2 := r.clone() r2.additionalErrorCodes = codes - return &r2 + return r2 } -// WithBefore sets a callback that always runs before any retryer callback. +// WithRetryLimit returns a new retryer with the specified time limit. +func (r *Retryer) WithRetryLimit(limit time.Duration) *Retryer { + r2 := r.clone() + r2.retryLimit = limit + + return r2 +} + +// WithBefore returns a new retryer with a callback that always runs before +// any retryer callback. // // This is useful if there are multiple callbacks and you need to reset some // condition before each retryer iteration. (In the single-callback case it’s // largely redundant.) func (r *Retryer) WithBefore(todo func()) *Retryer { - r2 := *r + r2 := r.clone() r2.before = option.Some(todo) + return r2 +} + +// WithDescription returns a new retryer with the given description. +func (r *Retryer) WithDescription(msg string, args ...any) *Retryer { + r2 := r.clone() + r2.description = option.Some(fmt.Sprintf(msg, args...)) + + return r2 +} + +// WithCallback returns a new retryer with the additional callback. +func (r *Retryer) WithCallback( + callback RetryCallback, + msg string, args ...any, +) *Retryer { + r2 := r.clone() + + r2.callbacks = append( + r2.callbacks, + retryCallbackInfo{ + callback: callback, + description: fmt.Sprintf(msg, args...), + }, + ) + + return r2 +} + +func (r *Retryer) clone() *Retryer { + r2 := *r + + r2.before = option.FromPointer(r.before.ToPointer()) + r2.description = option.FromPointer(r.description.ToPointer()) + r2.callbacks = slices.Clone(r.callbacks) + r2.additionalErrorCodes = slices.Clone(r.additionalErrorCodes) + return &r2 } diff --git a/internal/retry/retryer_test.go b/internal/retry/retryer_test.go index 44b5e3fb..8b430407 100644 --- a/internal/retry/retryer_test.go +++ b/internal/retry/retryer_test.go @@ -19,7 +19,7 @@ var someNetworkError = &mongo.CommandError{ var badError = errors.New("I am fatal!") func (suite *UnitTestSuite) TestRetryer() { - retryer := New(DefaultDurationLimit) + retryer := New() logger := suite.Logger() suite.Run("with a function that immediately succeeds", func() { @@ -29,7 +29,7 @@ func (suite *UnitTestSuite) TestRetryer() { return nil } - err := retryer.Run(suite.Context(), logger, f) + err := retryer.WithCallback(f, "f").Run(suite.Context(), logger) suite.NoError(err) suite.Equal(0, attemptNumber) @@ -38,7 +38,7 @@ func (suite *UnitTestSuite) TestRetryer() { return nil } - err = retryer.Run(suite.Context(), logger, f2) + err = retryer.WithCallback(f2, "f2").Run(suite.Context(), logger) suite.NoError(err) suite.Equal(0, attemptNumber) }) @@ -53,7 +53,7 @@ func (suite *UnitTestSuite) TestRetryer() { return nil } - err := retryer.Run(suite.Context(), logger, f) + err := retryer.WithCallback(f, "f").Run(suite.Context(), logger) suite.NoError(err) suite.Equal(2, attemptNumber) @@ -66,14 +66,14 @@ func (suite *UnitTestSuite) TestRetryer() { return nil } - err = retryer.Run(suite.Context(), logger, f2) + err = retryer.WithCallback(f2, "f2").Run(suite.Context(), logger) suite.NoError(err) suite.Equal(2, attemptNumber) }) } func (suite *UnitTestSuite) TestRetryerDurationLimitIsZero() { - retryer := New(0) + retryer := New().WithRetryLimit(0) attemptNumber := -1 f := func(_ context.Context, ri *FuncInfo) error { @@ -81,13 +81,13 @@ func (suite *UnitTestSuite) TestRetryerDurationLimitIsZero() { return someNetworkError } - err := retryer.Run(suite.Context(), suite.Logger(), f) + err := retryer.WithCallback(f, "f").Run(suite.Context(), suite.Logger()) suite.Assert().ErrorIs(err, someNetworkError) suite.Assert().Equal(0, attemptNumber) } func (suite *UnitTestSuite) TestRetryerDurationReset() { - retryer := New(DefaultDurationLimit) + retryer := New() logger := suite.Logger() // In this test, the given function f takes longer than the durationLimit @@ -99,7 +99,9 @@ func (suite *UnitTestSuite) TestRetryerDurationReset() { noSuccessIterations := 0 f1 := func(_ context.Context, ri *FuncInfo) error { // Artificially advance how much time was taken. - ri.lastResetTime = ri.lastResetTime.Add(-2 * ri.loopInfo.durationLimit) + ri.lastResetTime.Store( + ri.lastResetTime.Load().Add(-2 * ri.loopInfo.durationLimit), + ) noSuccessIterations++ if noSuccessIterations == 1 { @@ -109,7 +111,7 @@ func (suite *UnitTestSuite) TestRetryerDurationReset() { return nil } - err := retryer.Run(suite.Context(), logger, f1) + err := retryer.WithCallback(f1, "f1").Run(suite.Context(), logger) // The error should be the limit-exceeded error, with the // last-noted error being the transient error. @@ -122,7 +124,9 @@ func (suite *UnitTestSuite) TestRetryerDurationReset() { successIterations := 0 f2 := func(_ context.Context, ri *FuncInfo) error { // Artificially advance how much time was taken. - ri.lastResetTime = ri.lastResetTime.Add(-2 * ri.loopInfo.durationLimit) + ri.lastResetTime.Store( + ri.lastResetTime.Load().Add(-2 * ri.loopInfo.durationLimit), + ) ri.NoteSuccess() @@ -134,13 +138,13 @@ func (suite *UnitTestSuite) TestRetryerDurationReset() { return nil } - err = retryer.Run(suite.Context(), logger, f2) + err = retryer.WithCallback(f2, "f2").Run(suite.Context(), logger) suite.Assert().NoError(err) suite.Assert().Equal(2, successIterations) } func (suite *UnitTestSuite) TestCancelViaContext() { - retryer := New(DefaultDurationLimit) + retryer := New() logger := suite.Logger() counter := 0 @@ -160,7 +164,7 @@ func (suite *UnitTestSuite) TestCancelViaContext() { // retry code will see the cancel before the timer it sets expires. cancel() go func() { - err := retryer.Run(ctx, logger, f) + err := retryer.WithCallback(f, "f").Run(ctx, logger) suite.ErrorIs(err, context.Canceled) suite.Equal(1, counter) wg.Done() @@ -187,32 +191,32 @@ func (suite *UnitTestSuite) TestRetryerAdditionalErrorCodes() { } suite.Run("with no additional error codes", func() { - retryer := New(DefaultDurationLimit) - err := retryer.Run(suite.Context(), logger, f) + retryer := New() + err := retryer.WithCallback(f, "f").Run(suite.Context(), logger) suite.Equal(42, util.GetErrorCode(err)) suite.Equal(0, attemptNumber) }) suite.Run("with one additional error code", func() { - retryer := New(DefaultDurationLimit) + retryer := New() retryer = retryer.WithErrorCodes(42) - err := retryer.Run(suite.Context(), logger, f) + err := retryer.WithCallback(f, "f").Run(suite.Context(), logger) suite.NoError(err) suite.Equal(1, attemptNumber) }) suite.Run("with multiple additional error codes", func() { - retryer := New(DefaultDurationLimit) + retryer := New() retryer = retryer.WithErrorCodes(42, 43, 44) - err := retryer.Run(suite.Context(), logger, f) + err := retryer.WithCallback(f, "f").Run(suite.Context(), logger) suite.NoError(err) suite.Equal(1, attemptNumber) }) suite.Run("with multiple additional error codes that don't match error", func() { - retryer := New(DefaultDurationLimit) + retryer := New() retryer = retryer.WithErrorCodes(41, 43, 44) - err := retryer.Run(suite.Context(), logger, f) + err := retryer.WithCallback(f, "f").Run(suite.Context(), logger) suite.Equal(42, util.GetErrorCode(err)) suite.Equal(0, attemptNumber) }) @@ -222,11 +226,7 @@ func (suite *UnitTestSuite) TestMulti_NonTransient() { ctx := suite.Context() logger := suite.Logger() - retryer := New(DefaultDurationLimit) - - err := retryer.Run( - ctx, - logger, + err := New().WithCallback( func(ctx context.Context, _ *FuncInfo) error { timer := time.NewTimer(10 * time.Second) select { @@ -236,10 +236,13 @@ func (suite *UnitTestSuite) TestMulti_NonTransient() { return nil } }, + "slow", + ).WithCallback( func(_ context.Context, _ *FuncInfo) error { return badError }, - ) + "fails quickly", + ).Run(ctx, logger) suite.Assert().ErrorIs(err, badError) } @@ -252,20 +255,17 @@ func (suite *UnitTestSuite) TestMulti_Transient() { suite.Run( fmt.Sprintf("final error: %v", finalErr), func() { - retryer := New(DefaultDurationLimit) cb1Attempts := 0 cb2Attempts := 0 - err := retryer.Run( - ctx, - logger, - - // This one succeeds every time. + err := New().WithCallback( func(ctx context.Context, _ *FuncInfo) error { cb1Attempts++ return nil }, + "succeeds every time", + ).WithCallback( func(_ context.Context, _ *FuncInfo) error { cb2Attempts++ @@ -276,7 +276,8 @@ func (suite *UnitTestSuite) TestMulti_Transient() { return finalErr } }, - ) + "fails variously", + ).Run(ctx, logger) if finalErr == nil { suite.Assert().NoError(err) @@ -300,13 +301,11 @@ func (suite *UnitTestSuite) TestMulti_LongRunningSuccess() { startTime := time.Now() retryerLimit := 2 * time.Second - retryer := New(retryerLimit) + retryer := New().WithRetryLimit(retryerLimit) succeedPastTime := startTime.Add(retryerLimit + 1*time.Second) - err := retryer.Run( - ctx, - logger, + err := retryer.WithCallback( func(ctx context.Context, fi *FuncInfo) error { fi.NoteSuccess() @@ -317,6 +316,8 @@ func (suite *UnitTestSuite) TestMulti_LongRunningSuccess() { return nil }, + "quick success, then fail; all success after a bit", + ).WithCallback( func(ctx context.Context, fi *FuncInfo) error { if time.Now().Before(succeedPastTime) { <-ctx.Done() @@ -325,7 +326,8 @@ func (suite *UnitTestSuite) TestMulti_LongRunningSuccess() { return nil }, - ) + "long-running: hangs then succeeds", + ).Run(ctx, logger) suite.Assert().NoError(err) } diff --git a/internal/uuidutil/get_uuid.go b/internal/uuidutil/get_uuid.go index 86996933..1a694d77 100644 --- a/internal/uuidutil/get_uuid.go +++ b/internal/uuidutil/get_uuid.go @@ -27,8 +27,8 @@ type NamespaceAndUUID struct { CollName string } -func GetCollectionNamespaceAndUUID(ctx context.Context, logger *logger.Logger, retryer *retry.Retryer, db *mongo.Database, collName string) (*NamespaceAndUUID, error) { - binaryUUID, uuidErr := GetCollectionUUID(ctx, logger, retryer, db, collName) +func GetCollectionNamespaceAndUUID(ctx context.Context, logger *logger.Logger, db *mongo.Database, collName string) (*NamespaceAndUUID, error) { + binaryUUID, uuidErr := GetCollectionUUID(ctx, logger, db, collName) if uuidErr != nil { return nil, uuidErr } @@ -39,20 +39,21 @@ func GetCollectionNamespaceAndUUID(ctx context.Context, logger *logger.Logger, r }, nil } -func GetCollectionUUID(ctx context.Context, logger *logger.Logger, retryer *retry.Retryer, db *mongo.Database, collName string) (*primitive.Binary, error) { +func GetCollectionUUID(ctx context.Context, logger *logger.Logger, db *mongo.Database, collName string) (*primitive.Binary, error) { filter := bson.D{{"name", collName}} opts := options.ListCollections().SetNameOnly(false) var collSpecs []*mongo.CollectionSpecification - err := retryer.Run( - ctx, - logger, + err := retry.New().WithCallback( func(_ context.Context, ri *retry.FuncInfo) error { ri.Log(logger.Logger, "ListCollectionSpecifications", db.Name(), collName, "Getting collection UUID.", "") var driverErr error collSpecs, driverErr = db.ListCollectionSpecifications(ctx, filter, opts) return driverErr - }) + }, + "getting namespace %#q's specification", + db.Name()+"."+collName, + ).Run(ctx, logger) if err != nil { return nil, errors.Wrapf(err, "failed to list collections specification") } diff --git a/internal/verifier/change_stream.go b/internal/verifier/change_stream.go index e48a22d5..4451dad1 100644 --- a/internal/verifier/change_stream.go +++ b/internal/verifier/change_stream.go @@ -493,14 +493,11 @@ func (csr *ChangeStreamReader) StartChangeStream(ctx context.Context) error { // notifies the verifier's change event handler to exit. defer close(csr.changeEventBatchChan) - retryer := retry.New(retry.DefaultDurationLimit) - retryer = retryer.WithErrorCodes(util.CursorKilled) + retryer := retry.New().WithErrorCodes(util.CursorKilled) parentThreadWaiting := true - err := retryer.Run( - ctx, - csr.logger, + err := retryer.WithCallback( func(ctx context.Context, ri *retry.FuncInfo) error { changeStream, startTs, err := csr.createChangeStream(ctx) if err != nil { @@ -522,7 +519,8 @@ func (csr *ChangeStreamReader) StartChangeStream(ctx context.Context) error { return csr.iterateChangeStream(ctx, ri, changeStream) }, - ) + "running %s", csr, + ).Run(ctx, csr.logger) if err != nil { // NB: This failure always happens after the initial change stream diff --git a/internal/verifier/check.go b/internal/verifier/check.go index 3207791f..365cd987 100644 --- a/internal/verifier/check.go +++ b/internal/verifier/check.go @@ -123,6 +123,10 @@ func (verifier *Verifier) CheckWorker(ctxIn context.Context) error { ) } + verifier.logger.Debug(). + Interface("taskCountsByStatus", verificationStatus). + Send() + if waitForTaskCreation%2 == 0 { if generation > 0 || verifier.gen0PendingCollectionTasks.Load() == 0 { verifier.PrintVerificationSummary(ctx, GenerationInProgress) @@ -190,9 +194,7 @@ func (verifier *Verifier) CheckDriver(ctx context.Context, filter map[string]any return err } } - err = retry.Retry( - ctx, - verifier.logger, + err = retry.New().WithCallback( func(ctx context.Context, _ *retry.FuncInfo) error { err = verifier.AddMetaIndexes(ctx) if err != nil { @@ -211,7 +213,8 @@ func (verifier *Verifier) CheckDriver(ctx context.Context, filter map[string]any return nil }, - ) + "setting up verifier metadata", + ).Run(ctx, verifier.logger) if err != nil { return err @@ -325,13 +328,12 @@ func (verifier *Verifier) CheckDriver(ctx context.Context, filter map[string]any // Generation of recheck tasks can partial-fail. The following will // cause a full redo in that case, which is inefficient but simple. // Such failures seem unlikely anyhow. - err = retry.Retry( - ctx, - verifier.logger, + err = retry.New().WithCallback( func(ctx context.Context, _ *retry.FuncInfo) error { return verifier.GenerateRecheckTasksWhileLocked(ctx) }, - ) + "generating recheck tasks", + ).Run(ctx, verifier.logger) if err != nil { verifier.mux.Unlock() return err @@ -437,9 +439,7 @@ func FetchFailedAndIncompleteTasks( ) ([]VerificationTask, []VerificationTask, error) { var FailedTasks, allTasks, IncompleteTasks []VerificationTask - err := retry.Retry( - ctx, - logger, + err := retry.New().WithCallback( func(ctx context.Context, _ *retry.FuncInfo) error { cur, err := coll.Find(ctx, bson.D{ bson.E{Key: "type", Value: taskType}, @@ -463,7 +463,9 @@ func FetchFailedAndIncompleteTasks( return nil }, - ) + "fetching generation %d's failed & incomplete tasks", + generation, + ).Run(ctx, logger) return FailedTasks, IncompleteTasks, err } diff --git a/internal/verifier/clustertime.go b/internal/verifier/clustertime.go index 6fdbf279..116b0669 100644 --- a/internal/verifier/clustertime.go +++ b/internal/verifier/clustertime.go @@ -25,15 +25,11 @@ func GetNewClusterTime( logger *logger.Logger, client *mongo.Client, ) (primitive.Timestamp, error) { - retryer := retry.New(retry.DefaultDurationLimit) - var clusterTime primitive.Timestamp // First we just fetch the latest cluster time among all shards without // updating any shards’ oplogs. - err := retryer.Run( - ctx, - logger, + err := retry.New().WithCallback( func(ctx context.Context, _ *retry.FuncInfo) error { var err error clusterTime, err = runAppendOplogNote( @@ -44,7 +40,8 @@ func GetNewClusterTime( ) return err }, - ) + "appending oplog note to get cluster time", + ).Run(ctx, logger) if err != nil { return primitive.Timestamp{}, err @@ -53,9 +50,7 @@ func GetNewClusterTime( // fetchClusterTime() will have taught the mongos about the most current // shard’s cluster time. Now we tell that mongos to update all lagging // shards to that time. - err = retryer.Run( - ctx, - logger, + err = retry.New().WithCallback( func(ctx context.Context, _ *retry.FuncInfo) error { var err error _, err = runAppendOplogNote( @@ -66,7 +61,8 @@ func GetNewClusterTime( ) return err }, - ) + "appending oplog note to synchronize cluster", + ).Run(ctx, logger) if err != nil { // This isn't serious enough even for info-level. logger.Debug().Err(err). diff --git a/internal/verifier/compare.go b/internal/verifier/compare.go index 93234fc7..3fdcb12e 100644 --- a/internal/verifier/compare.go +++ b/internal/verifier/compare.go @@ -33,22 +33,30 @@ func (verifier *Verifier) FetchAndCompareDocuments( var docCount types.DocumentCount var byteCount types.ByteCount - retryer := retry.New(retry.DefaultDurationLimit) + retryer := retry.New().WithDescription( + "comparing task %v's documents (namespace: %s)", + task.PrimaryKey, + task.QueryFilter.Namespace, + ) err := retryer. WithBefore(func() { srcChannel, dstChannel, readSrcCallback, readDstCallback = verifier.getFetcherChannelsAndCallbacks(task) }). WithErrorCodes(util.CursorKilled). - Run( - givenCtx, - verifier.logger, + WithCallback( func(ctx context.Context, fi *retry.FuncInfo) error { return readSrcCallback(ctx, fi) }, + "reading from source", + ). + WithCallback( func(ctx context.Context, fi *retry.FuncInfo) error { return readDstCallback(ctx, fi) }, + "reading from destination", + ). + WithCallback( func(ctx context.Context, _ *retry.FuncInfo) error { var err error results, docCount, byteCount, err = verifier.compareDocsFromChannels( @@ -60,7 +68,8 @@ func (verifier *Verifier) FetchAndCompareDocuments( return err }, - ) + "comparing documents", + ).Run(givenCtx, verifier.logger) return results, docCount, byteCount, err } diff --git a/internal/verifier/migration_verifier.go b/internal/verifier/migration_verifier.go index cd158af0..78b9e7d3 100644 --- a/internal/verifier/migration_verifier.go +++ b/internal/verifier/migration_verifier.go @@ -612,7 +612,7 @@ func (verifier *Verifier) ProcessVerifyTask(ctx context.Context, workerNum int, Interface("task", task.PrimaryKey). Str("namespace", task.QueryFilter.Namespace). Int("mismatchesCount", len(problems)). - Msg("Document comparison task failed, but it may pass in the next generation.") + Msg("Discrepancies found. Will recheck in the next generation.") var mismatches []VerificationResult var missingIds []interface{} @@ -750,9 +750,8 @@ func (verifier *Verifier) getShardKeyFields( // 2. Fetch shard keys. // 3. Fetch the size: # of docs, and # of bytes. func (verifier *Verifier) partitionAndInspectNamespace(ctx context.Context, namespace string) ([]*partitions.Partition, []string, types.DocumentCount, types.ByteCount, error) { - retryer := retry.New(retry.DefaultDurationLimit) dbName, collName := SplitNamespace(namespace) - namespaceAndUUID, err := uuidutil.GetCollectionNamespaceAndUUID(ctx, verifier.logger, retryer, + namespaceAndUUID, err := uuidutil.GetCollectionNamespaceAndUUID(ctx, verifier.logger, verifier.srcClientDatabase(dbName), collName) if err != nil { return nil, nil, 0, 0, err @@ -767,7 +766,7 @@ func (verifier *Verifier) partitionAndInspectNamespace(ctx context.Context, name replicator1 := partitions.Replicator{ID: "verifier"} replicators := []partitions.Replicator{replicator1} partitionList, srcDocs, srcBytes, err := partitions.PartitionCollectionWithSize( - ctx, namespaceAndUUID, retryer, verifier.srcClient, replicators, verifier.logger, verifier.partitionSizeInBytes, verifier.globalFilter) + ctx, namespaceAndUUID, verifier.srcClient, replicators, verifier.logger, verifier.partitionSizeInBytes, verifier.globalFilter) if err != nil { return nil, nil, 0, 0, err } @@ -1239,9 +1238,7 @@ func (verifier *Verifier) GetVerificationStatus(ctx context.Context) (*Verificat var results []bson.Raw - err := retry.Retry( - ctx, - verifier.logger, + err := retry.New().WithCallback( func(ctx context.Context, _ *retry.FuncInfo) error { cursor, err := taskCollection.Aggregate( ctx, @@ -1266,9 +1263,16 @@ func (verifier *Verifier) GetVerificationStatus(ctx context.Context) (*Verificat return cursor.All(ctx, &results) }, - ) + "counting generation %d's (non-primary) tasks by status", + generation, + ).Run(ctx, verifier.logger) + if err != nil { - return nil, err + return nil, errors.Wrapf( + err, + "failed to count generation %d's tasks by status", + generation, + ) } verificationStatus := VerificationStatus{} @@ -1497,6 +1501,7 @@ func (verifier *Verifier) PrintVerificationSummary(ctx context.Context, genstatu if err != nil { verifier.logger.Err(err).Msgf("Failed to report per-namespace statistics") + return } verifier.printChangeEventStatistics(strBuilder) diff --git a/internal/verifier/mongos_refresh.go b/internal/verifier/mongos_refresh.go index ec8fa70f..80fdb511 100644 --- a/internal/verifier/mongos_refresh.go +++ b/internal/verifier/mongos_refresh.go @@ -33,8 +33,6 @@ func RefreshAllMongosInstances( Strs("hosts", hosts). Msgf("Refreshing all %d mongos instance(s) on the source.", len(hosts)) - r := retry.New(retry.DefaultDurationLimit) - for _, host := range hosts { singleHostClientOpts := *clientOpts @@ -52,16 +50,13 @@ func RefreshAllMongosInstances( shardConnStr, err := getAnyExistingShardConnectionStr( ctx, l, - r, singleHostClient, ) if err != nil { return err } - err = r.Run( - ctx, - l, + err = retry.New().WithCallback( func(ctx context.Context, _ *retry.FuncInfo) error { // Query a collection on the config server with linearizable read concern to advance the config // server primary's majority-committed optime. This populates the $configOpTime. @@ -112,7 +107,9 @@ func RefreshAllMongosInstances( } return nil - }) + }, + "refreshing mongos shard cache", + ).Run(ctx, l) if err != nil { return err @@ -137,10 +134,9 @@ func RefreshAllMongosInstances( func getAnyExistingShardConnectionStr( ctx context.Context, l *logger.Logger, - r *retry.Retryer, client *mongo.Client, ) (string, error) { - res, err := runListShards(ctx, l, r, client) + res, err := runListShards(ctx, l, client) if err != nil { return "", err } @@ -169,17 +165,15 @@ func getAnyExistingShardConnectionStr( func runListShards( ctx context.Context, l *logger.Logger, - r *retry.Retryer, client *mongo.Client, ) (*mongo.SingleResult, error) { var res *mongo.SingleResult - err := r.Run( - ctx, - l, + err := retry.New().WithCallback( func(ctx context.Context, _ *retry.FuncInfo) error { res = client.Database("admin").RunCommand(ctx, bson.D{{"listShards", 1}}) return res.Err() }, - ) + "listing shards", + ).Run(ctx, l) return res, err } diff --git a/internal/verifier/recheck.go b/internal/verifier/recheck.go index 0ecbe20e..e4191655 100644 --- a/internal/verifier/recheck.go +++ b/internal/verifier/recheck.go @@ -128,10 +128,8 @@ func (verifier *Verifier) insertRecheckDocs( SetUpsert(true) } - retryer := retry.New(retry.DefaultDurationLimit) - err := retryer.Run( - groupCtx, - verifier.logger, + retryer := retry.New() + err := retryer.WithCallback( func(retryCtx context.Context, _ *retry.FuncInfo) error { _, err := verifier.verificationDatabase().Collection(recheckQueue).BulkWrite( retryCtx, @@ -141,7 +139,9 @@ func (verifier *Verifier) insertRecheckDocs( return err }, - ) + "persisting %d recheck(s)", + len(models), + ).Run(groupCtx, verifier.logger) return errors.Wrapf(err, "failed to persist %d recheck(s) for generation %d", len(models), generation) }) @@ -177,9 +177,7 @@ func (verifier *Verifier) ClearRecheckDocsWhileLocked(ctx context.Context) error Int("previousGeneration", prevGeneration). Msg("Deleting previous generation's enqueued rechecks.") - return retry.Retry( - ctx, - verifier.logger, + return retry.New().WithCallback( func(ctx context.Context, i *retry.FuncInfo) error { _, err := verifier.verificationDatabase().Collection(recheckQueue).DeleteMany( ctx, @@ -188,7 +186,9 @@ func (verifier *Verifier) ClearRecheckDocsWhileLocked(ctx context.Context) error return err }, - ) + "deleting generation %d's enqueued rechecks", + prevGeneration, + ).Run(ctx, verifier.logger) } func (verifier *Verifier) getPreviousGenerationWhileLocked() int { diff --git a/internal/verifier/verification_task.go b/internal/verifier/verification_task.go index c2b7bad4..3eeb4d29 100644 --- a/internal/verifier/verification_task.go +++ b/internal/verifier/verification_task.go @@ -138,14 +138,14 @@ func (verifier *Verifier) insertCollectionVerificationTask( }, } - err := retry.Retry( - ctx, - verifier.logger, + err := retry.New().WithCallback( func(ctx context.Context, _ *retry.FuncInfo) error { _, err := verifier.verificationTaskCollection().InsertOne(ctx, verificationTask) return err }, - ) + "persisting namespace %#q's verification task", + srcNamespace, + ).Run(ctx, verifier.logger) return &verificationTask, err } @@ -185,15 +185,17 @@ func (verifier *Verifier) InsertPartitionVerificationTask( }, } - err := retry.Retry( - ctx, - verifier.logger, + err := retry.New().WithCallback( func(ctx context.Context, _ *retry.FuncInfo) error { _, err := verifier.verificationTaskCollection().InsertOne(ctx, &task) return err }, - ) + "persisting partition verification task for %#q (%v to %v)", + task.QueryFilter.Namespace, + task.QueryFilter.Partition.Key.Lower, + task.QueryFilter.Partition.Upper, + ).Run(ctx, verifier.logger) return &task, err } @@ -227,11 +229,16 @@ func (verifier *Verifier) InsertDocumentRecheckTask( SourceByteCount: dataSize, } - err := retry.Retry(ctx, verifier.logger, func(ctx context.Context, _ *retry.FuncInfo) error { - _, err := verifier.verificationTaskCollection().InsertOne(ctx, &task) + err := retry.New().WithCallback( + func(ctx context.Context, _ *retry.FuncInfo) error { + _, err := verifier.verificationTaskCollection().InsertOne(ctx, &task) - return err - }) + return err + }, + "persisting recheck task for namespace %#q (%d document(s))", + task.QueryFilter.Namespace, + len(ids), + ).Run(ctx, verifier.logger) return &task, err } @@ -241,9 +248,7 @@ func (verifier *Verifier) FindNextVerifyTaskAndUpdate( ) (option.Option[VerificationTask], error) { task := &VerificationTask{} - err := retry.Retry( - ctx, - verifier.logger, + err := retry.New().WithCallback( func(ctx context.Context, _ *retry.FuncInfo) error { err := verifier.verificationTaskCollection().FindOneAndUpdate( @@ -282,15 +287,15 @@ func (verifier *Verifier) FindNextVerifyTaskAndUpdate( return err }, - ) + "finding next task to do in generation %d", + verifier.generation, + ).Run(ctx, verifier.logger) return option.FromPointer(task), err } func (verifier *Verifier) UpdateVerificationTask(ctx context.Context, task *VerificationTask) error { - return retry.Retry( - ctx, - verifier.logger, + return retry.New().WithCallback( func(ctx context.Context, _ *retry.FuncInfo) error { result, err := verifier.verificationTaskCollection().UpdateOne( ctx, @@ -317,7 +322,10 @@ func (verifier *Verifier) UpdateVerificationTask(ctx context.Context, task *Veri return err }, - ) + "updating task %v (namespace %#q)", + task.PrimaryKey, + task.QueryFilter.Namespace, + ).Run(ctx, verifier.logger) } func (verifier *Verifier) CreatePrimaryTaskIfNeeded(ctx context.Context) (bool, error) { @@ -325,9 +333,7 @@ func (verifier *Verifier) CreatePrimaryTaskIfNeeded(ctx context.Context) (bool, var created bool - err := retry.Retry( - ctx, - verifier.logger, + err := retry.New().WithCallback( func(ctx context.Context, _ *retry.FuncInfo) error { result, err := verifier.verificationTaskCollection().UpdateOne( ctx, @@ -349,15 +355,14 @@ func (verifier *Verifier) CreatePrimaryTaskIfNeeded(ctx context.Context) (bool, return nil }, - ) + "ensuring primary task's existence", + ).Run(ctx, verifier.logger) return created, err } func (verifier *Verifier) UpdatePrimaryTaskComplete(ctx context.Context) error { - return retry.Retry( - ctx, - verifier.logger, + return retry.New().WithCallback( func(ctx context.Context, _ *retry.FuncInfo) error { result, err := verifier.verificationTaskCollection().UpdateMany( ctx, @@ -380,5 +385,6 @@ func (verifier *Verifier) UpdatePrimaryTaskComplete(ctx context.Context) error { return nil }, - ) + "noting completion of primary task", + ).Run(ctx, verifier.logger) } diff --git a/msync/typed_atomic.go b/msync/typed_atomic.go new file mode 100644 index 00000000..219fa34f --- /dev/null +++ b/msync/typed_atomic.go @@ -0,0 +1,50 @@ +package msync + +import "sync/atomic" + +// TypedAtomic is a type-safe wrapper around the standard-library atomic.Value. +// TypedAtomic serves largely the same purpose as atomic.Pointer but stores +// the value itself rather than a pointer to it. This is often more ergonomic +// than an atomic.Pointer: it can be used to store constants directly (where +// taking a pointer is inconvenient), and it defaults to the type's zero value +// rather than a nil pointer. +type TypedAtomic[T any] struct { + v atomic.Value +} + +// NewTypedAtomic returns a new TypedAtomic, initialized to val. +func NewTypedAtomic[T any](val T) *TypedAtomic[T] { + var v atomic.Value + v.Store(val) + return &TypedAtomic[T]{v} +} + +// Load returns the value set by the most recent Store. It returns the zero +// value for the type if there has been no call to Store. +func (ta *TypedAtomic[T]) Load() T { + return orZero[T](ta.v.Load()) +} + +// Store sets the value TypedAtomic to val. Store(nil) panics. +func (ta *TypedAtomic[T]) Store(val T) { + ta.v.Store(val) +} + +// Swap stores newVal into the TypedAtomic and returns the previous value. It +// returns the zero value for the type if the value is empty. +func (ta *TypedAtomic[T]) Swap(newVal T) T { + return orZero[T](ta.v.Swap(newVal)) +} + +// CompareAndSwap executes the compare-and-swap operation for the TypedAtomic. +func (ta *TypedAtomic[T]) CompareAndSwap(oldVal, newVal T) bool { + return ta.v.CompareAndSwap(oldVal, newVal) +} + +func orZero[T any](val any) T { + if val == nil { + return *new(T) + } + + return val.(T) +} diff --git a/msync/typed_atomic_test.go b/msync/typed_atomic_test.go new file mode 100644 index 00000000..d4789137 --- /dev/null +++ b/msync/typed_atomic_test.go @@ -0,0 +1,59 @@ +package msync + +import ( + "sync" +) + +func (s *unitTestSuite) TestTypedAtomic() { + ta := NewTypedAtomic(42) + + s.Require().Equal(42, ta.Load()) + s.Require().False(ta.CompareAndSwap(17, 99)) + s.Require().True(ta.CompareAndSwap(42, 99)) + s.Require().Equal(99, ta.Load()) + s.Require().Equal(99, ta.Swap(42)) + s.Require().Equal(42, ta.Load()) + + ta.Store(17) + s.Require().Equal(17, ta.Load()) + + // This block is for race detection under -race. + var wg sync.WaitGroup + for i := range 100 { + wg.Add(1) + go func() { + defer wg.Done() + ta.Load() + ta.Store(i) + }() + } + wg.Wait() +} + +func (s *unitTestSuite) TestAtomicZeroValues() { + s.Run("string", func() { + var ta TypedAtomic[string] + s.Require().Equal("", ta.Load()) + s.Require().Equal("", ta.Swap("foo")) + s.Require().Equal("foo", ta.Load()) + }) + + s.Run("int", func() { + var ta TypedAtomic[int] + s.Require().Equal(0, ta.Load()) + s.Require().Equal(0, ta.Swap(42)) + s.Require().Equal(42, ta.Load()) + }) + + s.Run("arbitrary data", func() { + type data struct { + I int + S string + } + + var ta TypedAtomic[data] + s.Require().Equal(data{}, ta.Load()) + s.Require().Equal(data{}, ta.Swap(data{76, "trombones"})) + s.Require().Equal(data{76, "trombones"}, ta.Load()) + }) +}