From 9c3fda2a5f97cec6374d885c229904efdbcc9500 Mon Sep 17 00:00:00 2001 From: Bob Vawter Date: Wed, 6 Nov 2024 10:51:11 -0500 Subject: [PATCH] lockset: Import package This commit extracts the lockset package from cockroachdb/replicator at commit ee8e2894. There are some API modifications to generalize the concept of a Task and to generalize metrics collection. --- lockset/events.go | 54 +++++ lockset/executor.go | 290 +++++++++++++++++++++++++ lockset/executor_test.go | 449 +++++++++++++++++++++++++++++++++++++++ lockset/lockset.go | 19 ++ lockset/queue.go | 276 ++++++++++++++++++++++++ lockset/queue_test.go | 181 ++++++++++++++++ lockset/retry.go | 45 ++++ lockset/runner.go | 40 ++++ lockset/status.go | 106 +++++++++ lockset/task.go | 48 +++++ 10 files changed, 1508 insertions(+) create mode 100644 lockset/events.go create mode 100644 lockset/executor.go create mode 100644 lockset/executor_test.go create mode 100644 lockset/lockset.go create mode 100644 lockset/queue.go create mode 100644 lockset/queue_test.go create mode 100644 lockset/retry.go create mode 100644 lockset/runner.go create mode 100644 lockset/status.go create mode 100644 lockset/task.go diff --git a/lockset/events.go b/lockset/events.go new file mode 100644 index 0000000..96f1aed --- /dev/null +++ b/lockset/events.go @@ -0,0 +1,54 @@ +// Copyright 2024 The Cockroach Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +// SPDX-License-Identifier: Apache-2.0 + +package lockset + +import "time" + +// Events provides an [Executor] with optional callbacks to monitor the +// performance of enqueued tasks. +// +// See [Executor.SetEvents]. +type Events[K any] struct { + OnComplete func(task Task[K], sinceScheduled time.Duration) + OnRetried func(task Task[K]) + OnSchedule func(task Task[K], deferred bool) + OnStarted func(task Task[K], sinceScheduled time.Duration) +} + +func (e *Events[K]) doComplete(task Task[K], sinceScheduled time.Duration) { + if e != nil && e.OnComplete != nil { + e.OnComplete(task, sinceScheduled) + } +} + +func (e *Events[K]) doRetried(task Task[K]) { + if e != nil && e.OnRetried != nil { + e.OnRetried(task) + } +} + +func (e *Events[K]) doSchedule(task Task[K], deferred bool) { + if e != nil && e.OnSchedule != nil { + e.OnSchedule(task, deferred) + } +} + +func (e *Events[K]) doStarted(task Task[K], sinceScheduled time.Duration) { + if e != nil && e.OnStarted != nil { + e.OnStarted(task, sinceScheduled) + } +} diff --git a/lockset/executor.go b/lockset/executor.go new file mode 100644 index 0000000..8ba91b0 --- /dev/null +++ b/lockset/executor.go @@ -0,0 +1,290 @@ +// Copyright 2024 The Cockroach Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +// SPDX-License-Identifier: Apache-2.0 + +package lockset + +// This file was extracted from cockroachdb/replicator at ee8e2894. + +import ( + "context" + "fmt" + "sync" + "time" + + "github.com/cockroachdb/field-eng-powertools/notify" +) + +// ErrScheduleCancel will be returned from [context.Cause] if a task's +// context was canceled via the function returned from +// [Executor.Schedule]. +var ErrScheduleCancel = fmt.Errorf("%w: Executor.Schedule cancel()", context.Canceled) + +// A waiter represents a request to acquire locks on some number of +// keys. Instances of this type should only be accessed while +// holding the parent [Executor.waiterMu] lock. +type waiter[K any] struct { + keys []K // Desired key set. + result notify.Var[*Status] // The outbox for the waiter. + scheduleStart time.Time // The time at which Schedule was called. + + mu struct { + sync.Mutex + cancel func() // Non-nil when the task is executing. + task Task[K] // nil if already executed. + } +} + +// Executor invokes callbacks based on an in-order admission [Queue] for +// potentially-overlapping sets of keys. +// +// An Executor is internally synchronized and is safe for concurrent +// use. An Executor should not be copied after it has been created. +type Executor[K comparable] struct { + events *Events[K] // Injectable callbacks. + queue *Queue[K, *waiter[K]] // Internally synchronized. + runner Runner // Executes callbacks. +} + +// NewExecutor construct an Executor that executes tasks using the given +// [Runner]. If runner is nil, tasks will be executed using +// [context.Background]. +// +// See [GoRunner] or +// [github.com/cockroachdb/field-eng-powertools/workgroup.Group]. +func NewExecutor[K comparable](runner Runner) *Executor[K] { + if runner == nil { + runner = GoRunner(context.Background()) + } + return &Executor[K]{ + queue: NewQueue[K, *waiter[K]](), + runner: runner, + } +} + +// Schedule executes the [Task] once all keys have been locked. The +// result from [Task.Call] is available through the returned [Outcome]. +// +// Tasks that need to be retried may return [RetryAtHead]. This will +// execute the task again when all other tasks scheduled before it have +// been completed. A retrying task will continue to hold its key locks +// until the retry has taken place. +// +// A task may return an empty key slice; the task will be executed +// immediately. +// +// Tasks must not schedule new tasks and proceed to wait upon them. This +// will lead to deadlocks. +// +// The cancel function may be called to asynchronously dequeue and +// cancel the task. If the task has already started executing, the +// cancel callback will cancel the task's context. +func (e *Executor[K]) Schedule(task Task[K]) (outcome Outcome, cancel func()) { + scheduleStart := time.Now() + keys := task.Keys() + + w := &waiter[K]{ + keys: keys, + scheduleStart: scheduleStart, + } + w.mu.task = task + w.result.Set(queued) + ready, err := e.queue.Enqueue(keys, w) + if err != nil { + w.result.Set(StatusFor(err)) + return &w.result, func() {} + } + if ready { + e.events.doSchedule(task, false) + e.dispose(w, false) + } else { + e.events.doSchedule(task, true) + } + return &w.result, func() { + // Swap the callback so that it does nothing. We want to guard + // against revivifying an already completed waiter, so we + // look at whether a function is still defined. + w.mu.Lock() + needsDispose := w.mu.task != nil + if needsDispose { + w.mu.task = &canceledTask[K]{} + } + if w.mu.cancel != nil { + w.mu.cancel() + } + w.mu.Unlock() + + // Async cleanup. + if needsDispose { + e.dispose(w, true) + } + } +} + +// SetEvents allows performance-monitoring callbacks to be injected into +// the Executor. This method should be called prior to any call to +// [Executor.Schedule]. +func (e *Executor[K]) SetEvents(events *Events[K]) { + e.events = events +} + +// dispose of the waiter callback in a separate goroutine. The waiter +// will be dequeued from the Executor, possibly leading to cascading +// callbacks. +func (e *Executor[K]) dispose(w *waiter[K], cancel bool) { + work := func(ctx context.Context) { + ctx, cancelCtx := context.WithCancelCause(ctx) + + // Clear the function reference to make the effects of dispose a + // one-shot. + w.mu.Lock() + w.mu.cancel = func() { cancelCtx(ErrScheduleCancel) } + task := w.mu.task + w.mu.task = nil + w.mu.Unlock() + startedAtHead := e.queue.IsHead(w) + + // Already executed and/or canceled. + if task == nil { + return + } + + // Executor canceled status or execute the callback. + var err error + if cancel { + err = ErrScheduleCancel + } else { + w.result.Set(executing) + e.events.doStarted(task, time.Since(w.scheduleStart)) + err = tryCall(ctx, task) + w.mu.Lock() + w.mu.cancel = nil + w.mu.Unlock() + e.events.doComplete(task, time.Since(w.scheduleStart)) + } + + // Once the waiter has been called, update its status and call + // dequeue to find any tasks that have been unblocked. + switch t := err.(type) { + case nil: + w.result.Set(success) + + case *RetryAtHeadErr: + // The callback requested to be retried later. + if startedAtHead { + // The waiter was already executing at the global head + // of the queue. Reject the request and execute any + // fallback handler that may have been provided. + if t.fallback != nil { + t.fallback() + } + retryErr := t.Unwrap() + if retryErr == nil { + w.result.Set(success) + } else { + w.result.Set(&Status{err: retryErr}) + } + } else { + e.events.doRetried(task) + + // Otherwise, re-enable the waiter. The status will be + // set to retryRequested for later re-dispatching by the + // dispose method. + w.mu.Lock() + w.mu.cancel = nil + w.mu.task = task + w.mu.Unlock() + w.result.Set(retryRequested) + endedAtHead := e.queue.IsHead(w) + + // It's possible that another task completed while this + // one was executing, which moved it to the head of the + // global queue. If this happens, we need to immediately + // queue up its retry. + if !startedAtHead && endedAtHead { + e.dispose(w, false) + } + + // We can't dequeue the waiter if it's going to retry at + // some later point in time. Since we know that the task + // was running somewhere in the middle of the global + // queue, there's nothing more that we need to do. + return + } + default: + w.result.Set(&Status{err: err}) + } + + // Remove the waiter's locks and get a slice of newly-unblocked + // tasks to kick off. + next, _ := e.queue.Dequeue(w) + // Calling dequeue also advances the global queue. If the + // element at the head of the queue wants to be retried, also + // add it to the list. + if head, ok := e.queue.PeekHead(); ok && head != nil { + if status, _ := head.result.Get(); status == retryRequested { + head.result.Set(retryQueued) + next = append(next, head) + } + } + for _, unblocked := range next { + e.dispose(unblocked, false) + } + } + + if err := e.runner.Go(work); err != nil { + w.result.Set(&Status{err: err}) + } +} + +// Wait returns the first non-nil error. +func Wait(ctx context.Context, outcomes []Outcome) error { +outcome: + for _, outcome := range outcomes { + for { + status, changed := outcome.Get() + if status.Success() { + continue outcome + } + if err := status.Err(); err != nil { + return err + } + select { + case <-changed: + case <-ctx.Done(): + return ctx.Err() + } + } + } + return nil +} + +// tryCall invokes the function with a panic handler. +func tryCall[K any](ctx context.Context, task Task[K]) (err error) { + // Install panic handler before executing user code. + defer func() { + x := recover() + switch t := x.(type) { + case nil: + // Success. + case error: + err = t + default: + err = fmt.Errorf("panic in task: %v", t) + } + }() + + return task.Call(ctx) +} diff --git a/lockset/executor_test.go b/lockset/executor_test.go new file mode 100644 index 0000000..d3a7a80 --- /dev/null +++ b/lockset/executor_test.go @@ -0,0 +1,449 @@ +// Copyright 2024 The Cockroach Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +// SPDX-License-Identifier: Apache-2.0 + +package lockset + +// This file was extracted from cockroachdb/replicator at ee8e2894. + +import ( + "context" + "errors" + "math" + "math/rand" + "runtime" + "sync" + "sync/atomic" + "testing" + "time" + + "github.com/cockroachdb/field-eng-powertools/notify" + "github.com/cockroachdb/field-eng-powertools/workgroup" + "github.com/stretchr/testify/require" + "golang.org/x/sync/errgroup" +) + +// Ensure serial ordering based on key. +func TestSerial(t *testing.T) { + const numWaiters = 1024 + r := require.New(t) + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + // We want to verify that we see execution order for a key match the + // scheduling order. + var resource atomic.Int32 + checker := func(expect int) Task[struct{}] { + return TaskFunc( + []struct{}{{}}, + func(context.Context, []struct{}) error { + current := resource.Add(1) - 1 + if expect != int(current) { + return errors.New("out of order execution") + } + return nil + }) + } + + e := NewExecutor[struct{}](GoRunner(ctx)) + + outcomes := make([]*notify.Var[*Status], numWaiters) + for i := 0; i < numWaiters; i++ { + outcomes[i], _ = e.Schedule(checker(i)) + } + + r.NoError(Wait(ctx, outcomes)) +} + +// Use random key sets to ensure that we don't see any collisions on the +// underlying resources and that execution occurs in the expected order. +func TestSmoke(t *testing.T) { + const numResources = 128 + const numWaiters = 10 * numResources + r := require.New(t) + ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) + defer cancel() + + // Verify that each resource and waiter are run in the expected + // order and the expeced number of times. + executionCounts := make([]int, numWaiters) + executionOrder := make([][]int, numResources) + + // The checker function will toggle the values between 0 and a nonce + // value to look for collisions. + resources := make([]atomic.Int64, numResources) + checker := func(keys []int, retry bool, waiter int) error { + if len(keys) == 0 { + return errors.New("no keys") + } + executionCounts[waiter]++ + for _, k := range keys { + executionOrder[k] = append(executionOrder[k], waiter) + } + fail := false + nonce := rand.Int63n(math.MaxInt64) + for _, k := range keys { + if !resources[k].CompareAndSwap(0, nonce) { + fail = true + } + } + // Create goroutine scheduling jitter. + runtime.Gosched() + for _, k := range keys { + if !resources[k].CompareAndSwap(nonce, 0) { + fail = true + } + } + if fail { + return errors.New("collision detected") + } + if retry { + return RetryAtHead(nil).Or(func() { + // If the task was at the head of the global queue + // already, this callback will be executed. We want to + // add a fake execution entry to make comparison below + // easy to think about. + if executionCounts[waiter] != 2 { + for _, k := range keys { + executionOrder[k] = append(executionOrder[k], waiter) + } + } + }) + } + return nil + } + + e := NewExecutor[int](workgroup.WithSize(ctx, numWaiters/2, numResources)) + + expectedOrder := make([][]int, numResources) + var expectedOrderMu sync.Mutex + + outcomes := make([]*notify.Var[*Status], numWaiters) + eg, _ := errgroup.WithContext(ctx) + for i := 0; i < numWaiters; i++ { + i := i // Capture + eg.Go(func() error { + // Pick a random set of keys, intentionally including duplicate + // key values. + count := rand.Intn(numResources) + 1 + keys := make([]int, count) + for idx := range keys { + key := rand.Intn(numResources) + keys[idx] = key + } + // We need to test against the same key deduplication that + // the scheduler will perform when computing expected execution order. + deduped := dedup(keys) + willRetry := i%10 == 0 + expectedOrderMu.Lock() + for _, key := range deduped { + expectedOrder[key] = append(expectedOrder[key], i) + if willRetry { + expectedOrder[key] = append(expectedOrder[key], i) + } + } + outcomes[i], _ = e.Schedule( + TaskFunc(keys, func(_ context.Context, keys []int) error { + return checker(keys, willRetry, i) + }), + ) + expectedOrderMu.Unlock() + return nil + }) + } + r.NoError(eg.Wait()) + + // Wait for each task to arrive at a successful state. + waitErr := Wait(ctx, outcomes) + for i := 0; i < numResources; i++ { + r.Equalf(expectedOrder[i], executionOrder[i], "key %d", i) + } + r.NoError(waitErr) +} + +func TestCancel(t *testing.T) { + r := require.New(t) + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + s := NewExecutor[int](GoRunner(ctx)) + + // Schedule a blocker first so we can control execution flow. + blockCh := make(chan struct{}) + blocker, _ := s.Schedule(TaskFunc([]int{0}, func(context.Context, []int) error { + <-blockCh + return nil + })) + + // Schedule a job to cancel. + canceled, cancel := s.Schedule(TaskFunc([]int{0}, func(context.Context, []int) error { + return errors.New("should not see this") + })) + status, _ := canceled.Get() + r.True(status.Queued()) // This should always be true. + cancel() // The effects of cancel() are asynchronous. + cancel() // Duplicate cancel is a no-op. + close(blockCh) // Allow the machinery to proceed. + + // The blocker should be successful. + r.NoError(Wait(ctx, []*notify.Var[*Status]{blocker})) + + for { + status, changed := canceled.Get() + // The cancel callback does set a trivial callback, so it's + // possible that we could execute a callback which just returns + // canceled. + r.False(status.Success()) + if status.Err() != nil { + r.ErrorIs(status.Err(), context.Canceled) + break + } + <-changed + } +} + +func TestCancelWithinTask(t *testing.T) { + r := require.New(t) + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + s := NewExecutor[int](GoRunner(ctx)) + + // Race-free handoff. + cancelTaskCh := make(chan func(), 1) + canceled, cancelTask := s.Schedule(TaskFunc([]int{0}, + func(ctx context.Context, _ []int) error { + r.NoError(ctx.Err()) + (<-cancelTaskCh)() + r.ErrorIs(ctx.Err(), context.Canceled) + r.ErrorIs(context.Cause(ctx), ErrScheduleCancel) + return ctx.Err() + })) + cancelTaskCh <- cancelTask + r.ErrorIs(Wait(ctx, []Outcome{canceled}), context.Canceled) +} + +func TestRunnerRejection(t *testing.T) { + r := require.New(t) + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + s := NewExecutor[int](workgroup.WithSize(ctx, 1, 0)) + + block := make(chan struct{}) + + // An empty key set will cause this to be executed immediately. + s.Schedule(TaskFunc(nil, func(ctx context.Context, keys []int) error { + select { + case <-block: + return nil + case <-ctx.Done(): + return ctx.Err() + } + })) + + rejectedStatus, _ := s.Schedule(TaskFunc(nil, func(context.Context, []int) error { + r.Fail("should not execute") + return nil + })) + rejected, _ := rejectedStatus.Get() + r.ErrorContains(rejected.Err(), "queue depth 0 exceeded") +} + +func TestPanic(t *testing.T) { + r := require.New(t) + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + s := NewExecutor[int](GoRunner(ctx)) + + outcome, _ := s.Schedule(TaskFunc(nil, func(context.Context, []int) error { + panic("boom") + })) + + for { + status, changed := outcome.Get() + if status.Err() != nil { + r.ErrorContains(status.Err(), "boom") + break + } + <-changed + } + + outcome, _ = s.Schedule(TaskFunc(nil, func(context.Context, []int) error { + panic(errors.New("boom")) + })) + + for { + status, changed := outcome.Get() + if status.Err() != nil { + r.ErrorContains(status.Err(), "boom") + break + } + <-changed + } +} + +func TestRetry(t *testing.T) { + r := require.New(t) + + ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second) + defer cancel() + + s := NewExecutor[int](GoRunner(ctx)) + + // This task will be at the head of the queue. + block := make(chan struct{}) + blocker, _ := s.Schedule(TaskFunc([]int{0}, func(ctx context.Context, _ []int) error { + select { + case <-block: + return nil + case <-ctx.Done(): + return ctx.Err() + } + })) + + // This task will retry itself and block the checker below. + var didRetry, expectRetry atomic.Bool + retried, _ := s.Schedule(TaskFunc([]int{42}, func(context.Context, []int) error { + if expectRetry.CompareAndSwap(false, true) { + // This error should never be seen. + return RetryAtHead(errors.New("masked")) + } + + if didRetry.CompareAndSwap(false, true) { + // Retrying on a retry returns the error. + return RetryAtHead(errors.New("should see this")) + } + + r.Fail("called too many times") + return nil + })) + + // Set up a task that depends upon the retried task. It shouldn't + // execute until the retry has taken place. + checker, _ := s.Schedule(TaskFunc([]int{42}, func(context.Context, []int) error { + r.True(didRetry.Load()) + return nil + })) + + for { + // Check that the blocker hasn't yet completed. + status, _ := blocker.Get() + r.False(status.Completed(), + "expected uncompleted task, had %s", status) + + // Once we see the retry being requested, unblock the blocker. + status, changed := retried.Get() + if status.Retrying() { + close(block) + break + } + select { + case <-changed: + case <-ctx.Done(): + r.NoError(ctx.Err()) + } + } + + // Wait for all tasks to complete. + r.EqualError( + Wait(ctx, []*notify.Var[*Status]{blocker, checker, retried}), + "should see this") + + // Ensure that other tasks can still proceed. + simple, _ := s.Schedule(TaskFunc([]int{42}, func(context.Context, []int) error { + return nil + })) + r.NoError(Wait(ctx, []*notify.Var[*Status]{simple})) +} + +// This tests a case where a task requests rescheduling after another +// task promotes it to the head of the global queue. +func TestRetryAfterPromotion(t *testing.T) { + r := require.New(t) + + ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second) + defer cancel() + + s := NewExecutor[int](GoRunner(ctx)) + + blockPromoter := make(chan struct{}) + promoterOutcome, _ := s.Schedule(TaskFunc(nil, func(context.Context, []int) error { + <-blockPromoter + return nil + })) + + promoterWaiter, ok := s.queue.PeekHead() + r.True(ok) + r.NotNil(promoterWaiter) + + blockRetry := make(chan struct{}) + var retryCount atomic.Int32 + var retryRan atomic.Bool + var retryWaiter *waiter[int] + retryOutcome, _ := s.Schedule(TaskFunc(nil, func(context.Context, []int) error { + if retryCount.Add(1) == 1 { + <-blockRetry + } + return RetryAtHead(nil).Or(func() { + // Ensure the tail was promoted. + h, ok := s.queue.PeekHead() + r.True(ok) + r.Same(retryWaiter, h) + retryRan.Store(true) + }) + })) + + retryWaiter, ok = s.queue.PeekTail() + r.True(ok) + r.NotNil(retryWaiter) + r.NotSame(promoterWaiter, retryWaiter) + + close(blockPromoter) + r.NoError(Wait(ctx, []Outcome{promoterOutcome})) + close(blockRetry) + + r.NoError(Wait(ctx, []Outcome{retryOutcome})) + r.True(retryRan.Load()) +} + +func TestStatusFor(t *testing.T) { + r := require.New(t) + + r.True(StatusFor(nil).Success()) + r.False(StatusFor(context.Canceled).Success()) + r.ErrorIs(StatusFor(context.Canceled).Err(), context.Canceled) +} + +func TestFakeOutcome(t *testing.T) { + r := require.New(t) + + status, _ := NewOutcome().Get() + r.True(status.Executing()) +} + +func TestDedup(t *testing.T) { + r := require.New(t) + + src := []int{0, 5, 4, 3, 2, 1, 0, 1, 2, 3, 4, 5, 0} + cpy := append([]int(nil), src...) + expected := []int{0, 5, 4, 3, 2, 1} + + r.Equal(expected, dedup(src)) + // Ensure that the source was not modified. + r.Equal(src, cpy) +} diff --git a/lockset/lockset.go b/lockset/lockset.go new file mode 100644 index 0000000..51031c8 --- /dev/null +++ b/lockset/lockset.go @@ -0,0 +1,19 @@ +// Copyright 2024 The Cockroach Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +// SPDX-License-Identifier: Apache-2.0 + +// Package lockset contains utilities for ordering access to +// potentially-overlapping resources. +package lockset diff --git a/lockset/queue.go b/lockset/queue.go new file mode 100644 index 0000000..2fa9f29 --- /dev/null +++ b/lockset/queue.go @@ -0,0 +1,276 @@ +// Copyright 2024 The Cockroach Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +// SPDX-License-Identifier: Apache-2.0 + +package lockset + +// This file was extracted from cockroachdb/replicator at ee8e2894. + +import ( + "fmt" + "sync" +) + +type entry[K, V any] struct { + headCount int + elt V + keys []K + next *entry[K, V] + valid bool +} + +func (q *entry[K, V]) invalidate() { + q.elt = *new(V) + q.keys = nil + q.valid = false +} + +// A Queue implements an in-order admission queue for arbitrary values +// associated with a set of potentially-overlapping keys. A Queue also +// maintains a "global" queue of values, based on the order in which the +// [Queue.Enqueue] method is called. +// +// Deadlocks between values are avoided since the relative order of +// enqueued values is maintained. That is, if [Queue.Enqueue] is called +// with V1 and then V2, the first value will be ahead of the second in +// all key queues that they have in common. +// +// A Queue is internally synchronized and is safe for concurrent use. A +// Queue should not be copied after it has been created. +type Queue[K, V comparable] struct { + mu struct { + sync.RWMutex + + // These waiters are used to maintain a global ordering. + head *entry[K, V] + tail *entry[K, V] + + backRef map[V]*entry[K, V] + queues map[K][]*entry[K, V] + } +} + +// NewQueue constructs a [Queue]. +func NewQueue[K, V comparable]() *Queue[K, V] { + q := &Queue[K, V]{} + q.mu.backRef = make(map[V]*entry[K, V]) + q.mu.queues = make(map[K][]*entry[K, V]) + return q +} + +// Dequeue removes the value from the queue and returns any +// newly-unblocked values. The bool return value indicates whether the +// value was in the queue. +func (q *Queue[K, V]) Dequeue(val V) ([]V, bool) { + q.mu.Lock() + defer q.mu.Unlock() + + e, ok := q.mu.backRef[val] + // Not in the queue, so a no-op. Let the caller detemine if this is + // an incorrect use-case or not. + if !ok { + return nil, false + } + delete(q.mu.backRef, val) + + var ret []V + + // Remove the waiter from each key's queue. + for _, k := range e.keys { + entries := q.mu.queues[k] + + // Search for the waiter in the queue. It's always going to + // be the first element in the slice, except in the + // cancellation case. + var idx int + for idx = range entries { + if entries[idx] == e { + break + } + } + + if idx == len(entries) { + panic(fmt.Sprintf("waiter not found in queue: %d", idx)) + } + + // If the waiter was the first in the queue (likely), + // promote the next waiter, possibly making it eligible to + // be run. + if idx == 0 { + entries = entries[1:] + if len(entries) == 0 { + // The waiter was the only element of the queue, so + // we'll just delete the slice from the map. + delete(q.mu.queues, k) + continue + } + + // Promote the next waiter. If the waiter is now at the + // head of its queues, it can be started. + head := entries[0] + head.headCount++ + if head.headCount == len(head.keys) { + ret = append(ret, head.elt) + } else if head.headCount > len(head.keys) { + panic("over counted") + } + } else { + // The (canceled) waiter was in the middle of the queue, + // just remove it from the slice. + entries = append(entries[:idx], entries[idx+1:]...) + } + + // Put the shortened queue back in the map. + q.mu.queues[k] = entries + } + + // Make eligible for cleanup and remove key references. + e.invalidate() + + // Clean up the global queue. + head := q.mu.head + for head != nil { + if head.valid { + break + } + head = head.next + } + q.mu.head = head + if q.mu.head == nil { + q.mu.tail = nil + } + + return ret, true +} + +// Enqueue returns true if the value is at the head of its key queues. +// It is an error to enqueue a value if it is already enqueued. +func (q *Queue[K, V]) Enqueue(keys []K, val V) (atHead bool, err error) { + q.mu.Lock() + defer q.mu.Unlock() + + if _, dup := q.mu.backRef[val]; dup { + return false, fmt.Errorf("the value %v is already enqueued", val) + } + + e := &entry[K, V]{ + elt: val, + keys: dedup(keys), + valid: true, + } + q.mu.backRef[val] = e + + // Insert the waiter into the global queue. + if q.mu.tail == nil { + q.mu.head = e + } else { + q.mu.tail.next = e + } + q.mu.tail = e + + // Add the waiter to each key queue. If it's the only waiter for + // that key, also increment its headCount. + for _, k := range e.keys { + entries := q.mu.queues[k] + entries = append(entries, e) + q.mu.queues[k] = entries + if len(entries) == 1 { + e.headCount++ + } + } + + // This will also be satisfied if the waiter has an empty key set. + return e.headCount == len(e.keys), nil +} + +// IsEmpty returns true if there are no elements in the queue. +func (q *Queue[K, V]) IsEmpty() bool { + q.mu.RLock() + defer q.mu.RUnlock() + return q.mu.head == nil +} + +// IsHead returns true if the value is at the head of the global queue. +func (q *Queue[K, V]) IsHead(val V) bool { + q.mu.RLock() + defer q.mu.RUnlock() + head := q.mu.head + return head != nil && head.elt == val +} + +// IsQueuedKey returns true if the key is present in the queue. +func (q *Queue[K, V]) IsQueuedKey(key K) bool { + q.mu.RLock() + defer q.mu.RUnlock() + return len(q.mu.queues[key]) > 0 +} + +// IsQueuedValue returns true if the value is present in the queue. +func (q *Queue[K, V]) IsQueuedValue(val V) bool { + q.mu.RLock() + defer q.mu.RUnlock() + _, ok := q.mu.backRef[val] + return ok +} + +// IsTail returns true if the value is at the tail of the global queue. +func (q *Queue[K, V]) IsTail(val V) bool { + q.mu.RLock() + defer q.mu.RUnlock() + tail := q.mu.tail + return tail != nil && tail.elt == val +} + +// PeekHead returns the value at the head of the global queue. It +// returns false if the queue is empty. +func (q *Queue[K, V]) PeekHead() (V, bool) { + q.mu.RLock() + defer q.mu.RUnlock() + h := q.mu.head + if h == nil { + return *new(V), false + } + return h.elt, true +} + +// PeekTail returns the value at the head of the global queue. It +// returns false if the queue is empty. +func (q *Queue[K, V]) PeekTail() (V, bool) { + q.mu.RLock() + defer q.mu.RUnlock() + t := q.mu.tail + if t == nil { + return *new(V), false + } + return t.elt, true +} + +// Make a copy of the key slice and deduplicate it. +func dedup[K comparable](keys []K) []K { + keys = append([]K(nil), keys...) + seen := make(map[K]struct{}, len(keys)) + idx := 0 + for _, key := range keys { + if _, dup := seen[key]; dup { + continue + } + seen[key] = struct{}{} + + keys[idx] = key + idx++ + } + keys = keys[:idx] + return keys +} diff --git a/lockset/queue_test.go b/lockset/queue_test.go new file mode 100644 index 0000000..995ea71 --- /dev/null +++ b/lockset/queue_test.go @@ -0,0 +1,181 @@ +// Copyright 2024 The Cockroach Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +// SPDX-License-Identifier: Apache-2.0 + +package lockset + +// This file was extracted from cockroachdb/replicator at ee8e2894. + +import ( + "fmt" + "strconv" + "testing" + + "github.com/stretchr/testify/require" +) + +func TestQueueSmoke(t *testing.T) { + r := require.New(t) + + q := NewQueue[int, string]() + keys := []int{1} + + r.True(q.IsEmpty()) + r.False(q.IsHead("1a")) + r.False(q.IsTail("1a")) + r.False(q.IsQueuedValue("1a")) + r.False(q.IsQueuedKey(keys[0])) + + r.True(q.Enqueue(keys, "1a")) + r.False(q.IsEmpty()) + r.True(q.IsHead("1a")) + r.True(q.IsTail("1a")) + r.True(q.IsQueuedValue("1a")) + r.True(q.IsQueuedKey(keys[0])) + peek, ok := q.PeekHead() + r.True(ok) + r.Equal("1a", peek) + peek, ok = q.PeekTail() + r.True(ok) + r.Equal("1a", peek) + + // Not allowed to double-enqueue + _, err := q.Enqueue(keys, "1a") + r.EqualError(err, "the value 1a is already enqueued") + + r.False(q.IsHead("1b")) + r.False(q.IsTail("1b")) + r.False(q.Enqueue(keys, "1b")) + + r.True(q.IsHead("1a")) + r.False(q.IsTail("1a")) + r.False(q.IsHead("1b")) + r.True(q.IsTail("1b")) + peek, ok = q.PeekHead() + r.True(ok) + r.Equal("1a", peek) + peek, ok = q.PeekTail() + r.True(ok) + r.Equal("1b", peek) + + next, ok := q.Dequeue("1a") + r.True(ok) + r.Equal([]string{"1b"}, next) + r.True(q.IsHead("1b")) + + peek, ok = q.PeekHead() + r.True(ok) + r.Equal("1b", peek) + peek, ok = q.PeekTail() + r.True(ok) + r.Equal("1b", peek) + + // It's not an error to repeatedly dequeue. + next, ok = q.Dequeue("1a") + r.False(ok) + r.Nil(next) + + next, ok = q.Dequeue("1b") + r.True(ok) + r.Nil(next) + + r.True(q.IsEmpty()) + + peek, ok = q.PeekHead() + r.False(ok) + r.Equal("", peek) + peek, ok = q.PeekTail() + r.False(ok) + r.Equal("", peek) +} + +func TestQueueMultipleKeys(t *testing.T) { + r := require.New(t) + + q := NewQueue[int, string]() + + r.True(q.Enqueue([]int{1, 2, 3, 4, 5}, "one")) + r.False(q.Enqueue([]int{2, 3, 4, 5}, "two")) + r.False(q.Enqueue([]int{3, 4, 5}, "three")) + r.False(q.Enqueue([]int{4}, "four-only")) + r.False(q.Enqueue([]int{5}, "five-only")) + + r.Nil(q.Dequeue("two")) + + next, ok := q.Dequeue("one") + r.True(ok) + r.Equal([]string{"three"}, next) + + next, ok = q.Dequeue("three") + r.True(ok) + r.Equal([]string{"four-only", "five-only"}, next) +} + +func ExampleQueue() { + q := NewQueue[int, string]() + // Returns true since this entry is unblocked. + fmt.Println(q.Enqueue([]int{1, 2, 3, 4, 5}, "one")) + // These next entries have keys that overlap and will be blocked. + fmt.Println(q.Enqueue([]int{2, 3, 4, 5}, "two")) + fmt.Println(q.Enqueue([]int{3, 4, 5}, "three")) + fmt.Println(q.Enqueue([]int{4}, "four")) + fmt.Println(q.Enqueue([]int{5}, "five")) + // This immediately returns true, since no keys are blocking. + fmt.Println(q.Enqueue([]int{6}, "six")) + + // Unlocks nothing, since one is still in the queue. + fmt.Println(q.Dequeue("two")) + // Unlocks three. + fmt.Println(q.Dequeue("one")) + // Unlocks both four and five. + fmt.Println(q.Dequeue("three")) + + // Output: + // true + // false + // false + // false + // false + // true + // [] true + // [three] true + // [four five] true +} + +func TestQueueManyWaiters(t *testing.T) { + r := require.New(t) + + q := NewQueue[int, string]() + + for i := range 100 { + atHead, err := q.Enqueue([]int{1, 2}, strconv.Itoa(i)) + r.NoError(err) + if i == 0 { + r.True(atHead) + } else { + r.False(atHead) + } + } + + for i := range 100 { + next, ok := q.Dequeue(strconv.Itoa(i)) + r.True(ok) + if i == 99 { + r.Nil(next) + } else { + r.Equal([]string{strconv.Itoa(i + 1)}, next) + } + } +} diff --git a/lockset/retry.go b/lockset/retry.go new file mode 100644 index 0000000..649b5e1 --- /dev/null +++ b/lockset/retry.go @@ -0,0 +1,45 @@ +// Copyright 2024 The Cockroach Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +// SPDX-License-Identifier: Apache-2.0 + +package lockset + +// This file was extracted from cockroachdb/replicator at ee8e2894. + +// RetryAtHead returns an error that tasks can use to be retried later, +// once all preceding tasks have completed. If this error is returned +// when there are no preceding tasks, the causal error will be emitted +// from [Executor.Schedule]. +func RetryAtHead(cause error) *RetryAtHeadErr { + return &RetryAtHeadErr{cause, nil} +} + +// RetryAtHeadErr is returned by [RetryAtHead]. +type RetryAtHeadErr struct { + cause error + fallback func() +} + +// Error returns a message. +func (e *RetryAtHeadErr) Error() string { return "callback requested a retry" } + +// Or sets a fallback function to invoke if the task was already +// at the head of the global queue. This is used if a cleanup task +// must be run if the task is not going to be retried. The receiver +// is returned. +func (e *RetryAtHeadErr) Or(fn func()) *RetryAtHeadErr { e.fallback = fn; return e } + +// Unwrap returns the causal error passed to [RetryAtHead]. +func (e *RetryAtHeadErr) Unwrap() error { return e.cause } diff --git a/lockset/runner.go b/lockset/runner.go new file mode 100644 index 0000000..bc042a6 --- /dev/null +++ b/lockset/runner.go @@ -0,0 +1,40 @@ +// Copyright 2024 The Cockroach Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +// SPDX-License-Identifier: Apache-2.0 + +package lockset + +// This file was extracted from cockroachdb/replicator at ee8e2894. + +import "context" + +// A Runner is passed to [NewExecutor] to begin the execution of tasks. +type Runner interface { + // Go should execute the function in a non-blocking fashion. + Go(func(context.Context)) error +} + +// GoRunner returns a Runner that executes tasks using the go keyword +// and the specified context. +func GoRunner(ctx context.Context) Runner { return &goRunner{ctx} } + +type goRunner struct { + ctx context.Context +} + +func (r *goRunner) Go(fn func(context.Context)) error { + go fn(r.ctx) + return nil +} diff --git a/lockset/status.go b/lockset/status.go new file mode 100644 index 0000000..aa8c91d --- /dev/null +++ b/lockset/status.go @@ -0,0 +1,106 @@ +// Copyright 2024 The Cockroach Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +// SPDX-License-Identifier: Apache-2.0 + +package lockset + +// This file was extracted from cockroachdb/replicator at ee8e2894. + +import ( + "fmt" + + "github.com/cockroachdb/field-eng-powertools/notify" +) + +// Outcome is a convenience type alias. +type Outcome = *notify.Var[*Status] + +// NewOutcome is a convenience method to allocate an Outcome. +func NewOutcome() Outcome { + return notify.VarOf(executing) +} + +// Status is returned by [Executor.Schedule]. +type Status struct { + err error +} + +// StatusFor constructs a successful status if err is null. Otherwise, +// it returns a new Status object that returns the error. +func StatusFor(err error) *Status { + if err == nil { + return success + } + return &Status{err: err} +} + +// Sentinel instances of Status. +var ( + executing = &Status{} + queued = &Status{} + retryQueued = &Status{} + retryRequested = &Status{} + success = &Status{} +) + +// Completed returns true if the callback has been called. +// See also [Status.Success]. +func (s *Status) Completed() bool { + return s == success || s.err != nil +} + +// Err returns any error returned by the Task. +func (s *Status) Err() error { + return s.err +} + +// Executing returns true if the Task is currently executing. +func (s *Status) Executing() bool { + return s == executing +} + +// Queued returns true if the Task has not been executed yet. +func (s *Status) Queued() bool { + return s == queued +} + +// Retrying returns true if the callback returned [RetryAtHead] and it +// has not yet been re-attempted. +func (s *Status) Retrying() bool { + return s == retryRequested || s == retryQueued +} + +// Success returns true if the Status represents the successful +// completion of a scheduled waiter. +func (s *Status) Success() bool { + return s == success +} + +func (s *Status) String() string { + switch s { + case executing: + return "executing" + case queued: + return "queued" + case retryQueued: + return "retryQueued" + case retryRequested: + return "retryRequested" + case success: + return "success" + default: + return fmt.Sprintf("error: %v", s.err) + } +} diff --git a/lockset/task.go b/lockset/task.go new file mode 100644 index 0000000..95d2a7e --- /dev/null +++ b/lockset/task.go @@ -0,0 +1,48 @@ +// Copyright 2024 The Cockroach Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +// SPDX-License-Identifier: Apache-2.0 + +package lockset + +import "context" + +// A Task is provided to [Executor.Schedule]. +type Task[K any] interface { + // Call contains the logic associated with the task. + Call(ctx context.Context) error + // Keys returns the set of keys that the Task depends upon. + Keys() []K +} + +// TaskFunc returns a [Task] that acquires locks on the given keys and +// then invokes the function callback. +func TaskFunc[K comparable](keys []K, fn func(ctx context.Context, keys []K) error) Task[K] { + return &taskFunc[K]{fn, dedup(keys)} +} + +// canceledTask is used internally for tasks that are canceled before +// being executed. +type canceledTask[K any] struct{} + +func (t *canceledTask[K]) Call(context.Context) error { return ErrScheduleCancel } +func (t *canceledTask[K]) Keys() []K { return nil } + +type taskFunc[K any] struct { + fn func(ctx context.Context, keys []K) error + keys []K +} + +func (t *taskFunc[K]) Call(ctx context.Context) error { return t.fn(ctx, t.keys) } +func (t *taskFunc[K]) Keys() []K { return t.keys }