From 59fa75c9f3d0b453e67002f976c64a7c8bf3eae4 Mon Sep 17 00:00:00 2001 From: Luke Bakken Date: Thu, 26 Dec 2024 15:57:18 -0800 Subject: [PATCH] Fix cancellation of RPC methods Fixes #1750 * Start by adding a test that demonstrates the error. Give a 5ms cancellation to `BasicConsumeAsync`, with a much longer delay via a hacked RabbitMQ. If running in debug mode, you will see the same `task canceled` exception, but it does not propagate to the test itself. * Set cancellation correctly for TaskCompletionSource in AsyncRpcContinuation * Handle continuation cancellation and timeouts correctly * Refactor repeated code * Add wait on `RegisteredAsync` to see how the changes in #1750 affect the order of operations. * Check to ensure no connection shutdown in `TestBasicConsumeCancellation_GH1750` * @danielmarbach noticed an improvement. Do not elide `await` here. * Extend GHA timeout for Windows builds, yuck (see https://github.com/actions/runner-images/issues/7320) --- .github/workflows/build-test.yaml | 4 +- .../ConsumerDispatcherChannelBase.cs | 56 +++--- .../Impl/AsyncRpcContinuations.cs | 174 ++++++++++++------ projects/RabbitMQ.Client/Impl/Channel.cs | 131 +++++-------- projects/Test/Common/IntegrationFixture.cs | 21 ++- .../Test/Integration/GH/TestGitHubIssues.cs | 121 ++++++++++++ .../Test/Integration/TestAsyncConsumer.cs | 16 ++ 7 files changed, 351 insertions(+), 172 deletions(-) create mode 100644 projects/Test/Integration/GH/TestGitHubIssues.cs diff --git a/.github/workflows/build-test.yaml b/.github/workflows/build-test.yaml index 04b9b3abd..af72a803d 100644 --- a/.github/workflows/build-test.yaml +++ b/.github/workflows/build-test.yaml @@ -67,7 +67,7 @@ jobs: id: install-start-rabbitmq run: ${{ github.workspace }}\.ci\windows\gha-setup.ps1 - name: Integration Tests - timeout-minutes: 25 + timeout-minutes: 45 run: | Start-Job -Verbose -ScriptBlock { & "${{ github.workspace }}\.ci\windows\toxiproxy\toxiproxy-server.exe" | Out-File -LiteralPath $env:APPDATA\RabbitMQ\log\toxiproxy-log.txt }; ` dotnet test ` @@ -113,7 +113,7 @@ jobs: id: install-start-rabbitmq run: ${{ github.workspace }}\.ci\windows\gha-setup.ps1 - name: Sequential Integration Tests - timeout-minutes: 25 + timeout-minutes: 45 run: dotnet test ` --environment 'RABBITMQ_LONG_RUNNING_TESTS=true' ` --environment "RABBITMQ_RABBITMQCTL_PATH=${{ steps.install-start-rabbitmq.outputs.path }}" ` diff --git a/projects/RabbitMQ.Client/ConsumerDispatching/ConsumerDispatcherChannelBase.cs b/projects/RabbitMQ.Client/ConsumerDispatching/ConsumerDispatcherChannelBase.cs index 963fced57..b2ff4f2bc 100644 --- a/projects/RabbitMQ.Client/ConsumerDispatching/ConsumerDispatcherChannelBase.cs +++ b/projects/RabbitMQ.Client/ConsumerDispatching/ConsumerDispatcherChannelBase.cs @@ -83,61 +83,65 @@ internal ConsumerDispatcherChannelBase(Impl.Channel channel, ushort concurrency) public ushort Concurrency => _concurrency; - public ValueTask HandleBasicConsumeOkAsync(IAsyncBasicConsumer consumer, string consumerTag, CancellationToken cancellationToken) + public async ValueTask HandleBasicConsumeOkAsync(IAsyncBasicConsumer consumer, string consumerTag, CancellationToken cancellationToken) { + cancellationToken.ThrowIfCancellationRequested(); + if (false == _disposed && false == _quiesce) { - AddConsumer(consumer, consumerTag); - WorkStruct work = WorkStruct.CreateConsumeOk(consumer, consumerTag); - return _writer.WriteAsync(work, cancellationToken); - } - else - { - return default; + try + { + AddConsumer(consumer, consumerTag); + WorkStruct work = WorkStruct.CreateConsumeOk(consumer, consumerTag); + await _writer.WriteAsync(work, cancellationToken) + .ConfigureAwait(false); + } + catch + { + _ = GetAndRemoveConsumer(consumerTag); + throw; + } } } - public ValueTask HandleBasicDeliverAsync(string consumerTag, ulong deliveryTag, bool redelivered, + public async ValueTask HandleBasicDeliverAsync(string consumerTag, ulong deliveryTag, bool redelivered, string exchange, string routingKey, IReadOnlyBasicProperties basicProperties, RentedMemory body, CancellationToken cancellationToken) { + cancellationToken.ThrowIfCancellationRequested(); + if (false == _disposed && false == _quiesce) { IAsyncBasicConsumer consumer = GetConsumerOrDefault(consumerTag); var work = WorkStruct.CreateDeliver(consumer, consumerTag, deliveryTag, redelivered, exchange, routingKey, basicProperties, body); - return _writer.WriteAsync(work, cancellationToken); - } - else - { - return default; + await _writer.WriteAsync(work, cancellationToken) + .ConfigureAwait(false); } } - public ValueTask HandleBasicCancelOkAsync(string consumerTag, CancellationToken cancellationToken) + public async ValueTask HandleBasicCancelOkAsync(string consumerTag, CancellationToken cancellationToken) { + cancellationToken.ThrowIfCancellationRequested(); + if (false == _disposed && false == _quiesce) { IAsyncBasicConsumer consumer = GetAndRemoveConsumer(consumerTag); WorkStruct work = WorkStruct.CreateCancelOk(consumer, consumerTag); - return _writer.WriteAsync(work, cancellationToken); - } - else - { - return default; + await _writer.WriteAsync(work, cancellationToken) + .ConfigureAwait(false); } } - public ValueTask HandleBasicCancelAsync(string consumerTag, CancellationToken cancellationToken) + public async ValueTask HandleBasicCancelAsync(string consumerTag, CancellationToken cancellationToken) { + cancellationToken.ThrowIfCancellationRequested(); + if (false == _disposed && false == _quiesce) { IAsyncBasicConsumer consumer = GetAndRemoveConsumer(consumerTag); WorkStruct work = WorkStruct.CreateCancel(consumer, consumerTag); - return _writer.WriteAsync(work, cancellationToken); - } - else - { - return default; + await _writer.WriteAsync(work, cancellationToken) + .ConfigureAwait(false); } } diff --git a/projects/RabbitMQ.Client/Impl/AsyncRpcContinuations.cs b/projects/RabbitMQ.Client/Impl/AsyncRpcContinuations.cs index 38ba391df..8781453b9 100644 --- a/projects/RabbitMQ.Client/Impl/AsyncRpcContinuations.cs +++ b/projects/RabbitMQ.Client/Impl/AsyncRpcContinuations.cs @@ -43,6 +43,9 @@ namespace RabbitMQ.Client.Impl { internal abstract class AsyncRpcContinuation : IRpcContinuation { + private readonly TimeSpan _continuationTimeout; + private readonly CancellationToken _rpcCancellationToken; + private readonly CancellationToken _continuationTimeoutCancellationToken; private readonly CancellationTokenSource _continuationTimeoutCancellationTokenSource; private readonly CancellationTokenRegistration _continuationTimeoutCancellationTokenRegistration; private readonly CancellationTokenSource _linkedCancellationTokenSource; @@ -51,45 +54,33 @@ internal abstract class AsyncRpcContinuation : IRpcContinuation private bool _disposedValue; - public AsyncRpcContinuation(TimeSpan continuationTimeout, CancellationToken cancellationToken) + public AsyncRpcContinuation(TimeSpan continuationTimeout, CancellationToken rpcCancellationToken) { + _continuationTimeout = continuationTimeout; + _rpcCancellationToken = rpcCancellationToken; + /* * Note: we can't use an ObjectPool for these because the netstandard2.0 * version of CancellationTokenSource can't be reset prior to checking * in to the ObjectPool */ _continuationTimeoutCancellationTokenSource = new CancellationTokenSource(continuationTimeout); + _continuationTimeoutCancellationToken = _continuationTimeoutCancellationTokenSource.Token; #if NET - _continuationTimeoutCancellationTokenRegistration = _continuationTimeoutCancellationTokenSource.Token.UnsafeRegister((object? state) => - { - var tcs = (TaskCompletionSource)state!; - if (tcs.TrySetCanceled()) - { - // Cancellation was successful, does this mean we set a TimeoutException - // in the same manner as BlockingCell used to - string msg = $"operation '{GetType().FullName}' timed out after {continuationTimeout}"; - tcs.TrySetException(new TimeoutException(msg)); - } - }, _tcs); + _continuationTimeoutCancellationTokenRegistration = + _continuationTimeoutCancellationToken.UnsafeRegister( + callback: HandleContinuationTimeout, state: _tcs); #else - _continuationTimeoutCancellationTokenRegistration = _continuationTimeoutCancellationTokenSource.Token.Register((object state) => - { - var tcs = (TaskCompletionSource)state; - if (tcs.TrySetCanceled()) - { - // Cancellation was successful, does this mean we set a TimeoutException - // in the same manner as BlockingCell used to - string msg = $"operation '{GetType().FullName}' timed out after {continuationTimeout}"; - tcs.TrySetException(new TimeoutException(msg)); - } - }, state: _tcs, useSynchronizationContext: false); + _continuationTimeoutCancellationTokenRegistration = + _continuationTimeoutCancellationToken.Register( + callback: HandleContinuationTimeout, state: _tcs, useSynchronizationContext: false); #endif _tcsConfiguredTaskAwaitable = _tcs.Task.ConfigureAwait(false); _linkedCancellationTokenSource = CancellationTokenSource.CreateLinkedTokenSource( - _continuationTimeoutCancellationTokenSource.Token, cancellationToken); + _continuationTimeoutCancellationTokenSource.Token, rpcCancellationToken); } public CancellationToken CancellationToken @@ -105,13 +96,61 @@ public ConfiguredTaskAwaitable.ConfiguredTaskAwaiter GetAwaiter() return _tcsConfiguredTaskAwaitable.GetAwaiter(); } - public abstract Task HandleCommandAsync(IncomingCommand cmd); + public async Task HandleCommandAsync(IncomingCommand cmd) + { + try + { + await DoHandleCommandAsync(cmd) + .ConfigureAwait(false); + } + catch (OperationCanceledException) + { + if (_rpcCancellationToken.IsCancellationRequested) + { +#if NET + _tcs.TrySetCanceled(_rpcCancellationToken); +#else + _tcs.TrySetCanceled(); +#endif + } + else if (_continuationTimeoutCancellationToken.IsCancellationRequested) + { +#if NET + if (_tcs.TrySetCanceled(_continuationTimeoutCancellationToken)) +#else + if (_tcs.TrySetCanceled()) +#endif + { + // Cancellation was successful, does this mean we set a TimeoutException + // in the same manner as BlockingCell used to + _tcs.TrySetException(GetTimeoutException()); + } + } + else + { + throw; + } + } + } public virtual void HandleChannelShutdown(ShutdownEventArgs reason) { _tcs.TrySetException(new OperationInterruptedException(reason)); } + public void Dispose() + { + Dispose(disposing: true); + GC.SuppressFinalize(this); + } + + protected abstract Task DoHandleCommandAsync(IncomingCommand cmd); + + protected void HandleUnexpectedCommand(IncomingCommand cmd) + { + _tcs.SetException(new InvalidOperationException($"Received unexpected command of type {cmd.CommandId}!")); + } + protected virtual void Dispose(bool disposing) { if (!_disposedValue) @@ -127,10 +166,33 @@ protected virtual void Dispose(bool disposing) } } - public void Dispose() +#if NET + private void HandleContinuationTimeout(object? state, CancellationToken cancellationToken) { - Dispose(disposing: true); - GC.SuppressFinalize(this); + var tcs = (TaskCompletionSource)state!; + if (tcs.TrySetCanceled(cancellationToken)) + { + tcs.TrySetException(GetTimeoutException()); + } + } +#else + private void HandleContinuationTimeout(object state) + { + var tcs = (TaskCompletionSource)state; + if (tcs.TrySetCanceled()) + { + tcs.TrySetException(GetTimeoutException()); + } + } +#endif + + private TimeoutException GetTimeoutException() + { + // TODO + // Cancellation was successful, does this mean we set a TimeoutException + // in the same manner as BlockingCell used to + string msg = $"operation '{GetType().FullName}' timed out after {_continuationTimeout}"; + return new TimeoutException(msg); } } @@ -141,17 +203,17 @@ public ConnectionSecureOrTuneAsyncRpcContinuation(TimeSpan continuationTimeout, { } - public override Task HandleCommandAsync(IncomingCommand cmd) + protected override Task DoHandleCommandAsync(IncomingCommand cmd) { if (cmd.CommandId == ProtocolCommandId.ConnectionSecure) { var secure = new ConnectionSecure(cmd.MethodSpan); - _tcs.TrySetResult(new ConnectionSecureOrTune(secure._challenge, default)); + _tcs.SetResult(new ConnectionSecureOrTune(secure._challenge, default)); } else if (cmd.CommandId == ProtocolCommandId.ConnectionTune) { var tune = new ConnectionTune(cmd.MethodSpan); - _tcs.TrySetResult(new ConnectionSecureOrTune(default, new ConnectionTuneDetails + _tcs.SetResult(new ConnectionSecureOrTune(default, new ConnectionTuneDetails { m_channelMax = tune._channelMax, m_frameMax = tune._frameMax, @@ -160,7 +222,7 @@ public override Task HandleCommandAsync(IncomingCommand cmd) } else { - _tcs.SetException(new InvalidOperationException($"Received unexpected command of type {cmd.CommandId}!")); + HandleUnexpectedCommand(cmd); } return Task.CompletedTask; @@ -178,15 +240,15 @@ public SimpleAsyncRpcContinuation(ProtocolCommandId expectedCommandId, TimeSpan _expectedCommandId = expectedCommandId; } - public override Task HandleCommandAsync(IncomingCommand cmd) + protected override Task DoHandleCommandAsync(IncomingCommand cmd) { if (cmd.CommandId == _expectedCommandId) { - _tcs.TrySetResult(true); + _tcs.SetResult(true); } else { - _tcs.SetException(new InvalidOperationException($"Received unexpected command of type {cmd.CommandId}!")); + HandleUnexpectedCommand(cmd); } return Task.CompletedTask; @@ -206,18 +268,18 @@ public BasicCancelAsyncRpcContinuation(string consumerTag, IConsumerDispatcher c _consumerDispatcher = consumerDispatcher; } - public override async Task HandleCommandAsync(IncomingCommand cmd) + protected override async Task DoHandleCommandAsync(IncomingCommand cmd) { if (cmd.CommandId == ProtocolCommandId.BasicCancelOk) { - _tcs.TrySetResult(true); Debug.Assert(_consumerTag == new BasicCancelOk(cmd.MethodSpan)._consumerTag); await _consumerDispatcher.HandleBasicCancelOkAsync(_consumerTag, CancellationToken) .ConfigureAwait(false); + _tcs.SetResult(true); } else { - _tcs.SetException(new InvalidOperationException($"Received unexpected command of type {cmd.CommandId}!")); + HandleUnexpectedCommand(cmd); } } } @@ -235,18 +297,20 @@ public BasicConsumeAsyncRpcContinuation(IAsyncBasicConsumer consumer, IConsumerD _consumerDispatcher = consumerDispatcher; } - public override async Task HandleCommandAsync(IncomingCommand cmd) + protected override async Task DoHandleCommandAsync(IncomingCommand cmd) { if (cmd.CommandId == ProtocolCommandId.BasicConsumeOk) { var method = new BasicConsumeOk(cmd.MethodSpan); - _tcs.TrySetResult(method._consumerTag); + await _consumerDispatcher.HandleBasicConsumeOkAsync(_consumer, method._consumerTag, CancellationToken) .ConfigureAwait(false); + + _tcs.SetResult(method._consumerTag); } else { - _tcs.SetException(new InvalidOperationException($"Received unexpected command of type {cmd.CommandId}!")); + HandleUnexpectedCommand(cmd); } } } @@ -264,7 +328,7 @@ public BasicGetAsyncRpcContinuation(Func adjustDeliveryTag, internal DateTime StartTime { get; } = DateTime.UtcNow; - public override Task HandleCommandAsync(IncomingCommand cmd) + protected override Task DoHandleCommandAsync(IncomingCommand cmd) { if (cmd.CommandId == ProtocolCommandId.BasicGetOk) { @@ -280,15 +344,15 @@ public override Task HandleCommandAsync(IncomingCommand cmd) header, cmd.Body.ToArray()); - _tcs.TrySetResult(result); + _tcs.SetResult(result); } else if (cmd.CommandId == ProtocolCommandId.BasicGetEmpty) { - _tcs.TrySetResult(null); + _tcs.SetResult(null); } else { - _tcs.SetException(new InvalidOperationException($"Received unexpected command of type {cmd.CommandId}!")); + HandleUnexpectedCommand(cmd); } return Task.CompletedTask; @@ -325,7 +389,7 @@ public override void HandleChannelShutdown(ShutdownEventArgs reason) public Task OnConnectionShutdownAsync(object? sender, ShutdownEventArgs reason) { - _tcs.TrySetResult(true); + _tcs.SetResult(true); return Task.CompletedTask; } } @@ -377,17 +441,17 @@ public QueueDeclareAsyncRpcContinuation(TimeSpan continuationTimeout, Cancellati { } - public override Task HandleCommandAsync(IncomingCommand cmd) + protected override Task DoHandleCommandAsync(IncomingCommand cmd) { if (cmd.CommandId == ProtocolCommandId.QueueDeclareOk) { var method = new Client.Framing.QueueDeclareOk(cmd.MethodSpan); var result = new QueueDeclareOk(method._queue, method._messageCount, method._consumerCount); - _tcs.TrySetResult(result); + _tcs.SetResult(result); } else { - _tcs.SetException(new InvalidOperationException($"Received unexpected command of type {cmd.CommandId}!")); + HandleUnexpectedCommand(cmd); } return Task.CompletedTask; @@ -417,16 +481,16 @@ public QueueDeleteAsyncRpcContinuation(TimeSpan continuationTimeout, Cancellatio { } - public override Task HandleCommandAsync(IncomingCommand cmd) + protected override Task DoHandleCommandAsync(IncomingCommand cmd) { if (cmd.CommandId == ProtocolCommandId.QueueDeleteOk) { var method = new QueueDeleteOk(cmd.MethodSpan); - _tcs.TrySetResult(method._messageCount); + _tcs.SetResult(method._messageCount); } else { - _tcs.SetException(new InvalidOperationException($"Received unexpected command of type {cmd.CommandId}!")); + HandleUnexpectedCommand(cmd); } return Task.CompletedTask; @@ -440,16 +504,16 @@ public QueuePurgeAsyncRpcContinuation(TimeSpan continuationTimeout, Cancellation { } - public override Task HandleCommandAsync(IncomingCommand cmd) + protected override Task DoHandleCommandAsync(IncomingCommand cmd) { if (cmd.CommandId == ProtocolCommandId.QueuePurgeOk) { var method = new QueuePurgeOk(cmd.MethodSpan); - _tcs.TrySetResult(method._messageCount); + _tcs.SetResult(method._messageCount); } else { - _tcs.SetException(new InvalidOperationException($"Received unexpected command of type {cmd.CommandId}!")); + HandleUnexpectedCommand(cmd); } return Task.CompletedTask; diff --git a/projects/RabbitMQ.Client/Impl/Channel.cs b/projects/RabbitMQ.Client/Impl/Channel.cs index 20b4175d5..0538107f2 100644 --- a/projects/RabbitMQ.Client/Impl/Channel.cs +++ b/projects/RabbitMQ.Client/Impl/Channel.cs @@ -256,10 +256,7 @@ await ConsumerDispatcher.WaitForShutdownAsync() } finally { - if (false == enqueued) - { - k.Dispose(); - } + MaybeDisposeContinuation(enqueued, k); _rpcSemaphore.Release(); ChannelShutdownAsync -= k.OnConnectionShutdownAsync; } @@ -303,10 +300,7 @@ await ModelSendAsync(in method, k.CancellationToken) } finally { - if (false == enqueued) - { - k.Dispose(); - } + MaybeDisposeContinuation(enqueued, k); _rpcSemaphore.Release(); } } @@ -342,10 +336,7 @@ await ModelSendAsync(in method, k.CancellationToken) } finally { - if (false == enqueued) - { - k.Dispose(); - } + MaybeDisposeContinuation(enqueued, k); _rpcSemaphore.Release(); } } @@ -392,10 +383,7 @@ await MaybeConfirmSelect(cancellationToken) } finally { - if (false == enqueued) - { - k.Dispose(); - } + MaybeDisposeContinuation(enqueued, k); _rpcSemaphore.Release(); } @@ -935,10 +923,7 @@ await ModelSendAsync(in method, k.CancellationToken) } finally { - if (false == enqueued) - { - k.Dispose(); - } + MaybeDisposeContinuation(enqueued, k); _rpcSemaphore.Release(); } } @@ -967,10 +952,7 @@ await ModelSendAsync(in method, k.CancellationToken) } finally { - if (false == enqueued) - { - k.Dispose(); - } + MaybeDisposeContinuation(enqueued, k); _rpcSemaphore.Release(); } } @@ -1006,10 +988,7 @@ await ModelSendAsync(in method, k.CancellationToken) } finally { - if (false == enqueued) - { - k.Dispose(); - } + MaybeDisposeContinuation(enqueued, k); _rpcSemaphore.Release(); } } @@ -1048,10 +1027,7 @@ await ModelSendAsync(in method, k.CancellationToken) } finally { - if (false == enqueued) - { - k.Dispose(); - } + MaybeDisposeContinuation(enqueued, k); _rpcSemaphore.Release(); } } @@ -1078,10 +1054,7 @@ await ModelSendAsync(in method, k.CancellationToken) } finally { - if (false == enqueued) - { - k.Dispose(); - } + MaybeDisposeContinuation(enqueued, k); _rpcSemaphore.Release(); } } @@ -1119,10 +1092,7 @@ await ModelSendAsync(in method, k.CancellationToken) } finally { - if (false == enqueued) - { - k.Dispose(); - } + MaybeDisposeContinuation(enqueued, k); _rpcSemaphore.Release(); } } @@ -1166,10 +1136,7 @@ await ModelSendAsync(in method, k.CancellationToken) } finally { - if (false == enqueued) - { - k.Dispose(); - } + MaybeDisposeContinuation(enqueued, k); _rpcSemaphore.Release(); } } @@ -1206,10 +1173,7 @@ await ModelSendAsync(in method, k.CancellationToken) } finally { - if (false == enqueued) - { - k.Dispose(); - } + MaybeDisposeContinuation(enqueued, k); _rpcSemaphore.Release(); } } @@ -1247,10 +1211,7 @@ await ModelSendAsync(in method, k.CancellationToken) } finally { - if (false == enqueued) - { - k.Dispose(); - } + MaybeDisposeContinuation(enqueued, k); _rpcSemaphore.Release(); } } @@ -1320,10 +1281,7 @@ await ModelSendAsync(in method, k.CancellationToken) } finally { - if (false == enqueued) - { - k.Dispose(); - } + MaybeDisposeContinuation(enqueued, k); _rpcSemaphore.Release(); } } @@ -1361,10 +1319,7 @@ await ModelSendAsync(in method, k.CancellationToken) } finally { - if (false == enqueued) - { - k.Dispose(); - } + MaybeDisposeContinuation(enqueued, k); _rpcSemaphore.Release(); } } @@ -1416,10 +1371,7 @@ await ModelSendAsync(in method, k.CancellationToken) } finally { - if (false == enqueued) - { - k.Dispose(); - } + MaybeDisposeContinuation(enqueued, k); _rpcSemaphore.Release(); } } @@ -1444,10 +1396,7 @@ await ModelSendAsync(in method, k.CancellationToken) } finally { - if (false == enqueued) - { - k.Dispose(); - } + MaybeDisposeContinuation(enqueued, k); _rpcSemaphore.Release(); } } @@ -1475,10 +1424,7 @@ await ModelSendAsync(in method, k.CancellationToken) } finally { - if (false == enqueued) - { - k.Dispose(); - } + MaybeDisposeContinuation(enqueued, k); _rpcSemaphore.Release(); } } @@ -1504,10 +1450,7 @@ await ModelSendAsync(in method, k.CancellationToken) } finally { - if (false == enqueued) - { - k.Dispose(); - } + MaybeDisposeContinuation(enqueued, k); _rpcSemaphore.Release(); } } @@ -1533,10 +1476,7 @@ await ModelSendAsync(in method, k.CancellationToken) } finally { - if (false == enqueued) - { - k.Dispose(); - } + MaybeDisposeContinuation(enqueued, k); _rpcSemaphore.Release(); } } @@ -1562,10 +1502,7 @@ await ModelSendAsync(in method, k.CancellationToken) } finally { - if (false == enqueued) - { - k.Dispose(); - } + MaybeDisposeContinuation(enqueued, k); _rpcSemaphore.Release(); } } @@ -1577,6 +1514,32 @@ internal static Task CreateAndOpenAsync(CreateChannelOptions createCha return channel.OpenAsync(createChannelOptions, cancellationToken); } + private void MaybeDisposeContinuation(bool enqueued, IRpcContinuation continuation) + { + try + { + if (enqueued) + { + if (_continuationQueue.TryPeek(out IRpcContinuation? enqueuedContinuation)) + { + if (object.ReferenceEquals(continuation, enqueuedContinuation)) + { + IRpcContinuation dequeuedContinuation = _continuationQueue.Next(); + dequeuedContinuation.Dispose(); + } + } + } + else + { + continuation.Dispose(); + } + } + catch + { + // TODO low-level debug logging + } + } + /// /// Returning true from this method means that the command was server-originated, /// and handled already. diff --git a/projects/Test/Common/IntegrationFixture.cs b/projects/Test/Common/IntegrationFixture.cs index 3251fd7f7..8f790cc5c 100644 --- a/projects/Test/Common/IntegrationFixture.cs +++ b/projects/Test/Common/IntegrationFixture.cs @@ -185,16 +185,27 @@ public virtual async Task DisposeAsync() await _conn.CloseAsync(); } } + catch (Exception ex) + { + _output.WriteLine("[WARNING] IntegrationFixture.CloseAsync() exception: {0}", ex); + } finally { - _eventListener?.Dispose(); - if (_channel is not null) + try { - await _channel.DisposeAsync(); + _eventListener?.Dispose(); + if (_channel is not null) + { + await _channel.DisposeAsync(); + } + if (_conn is not null) + { + await _conn.DisposeAsync(); + } } - if (_conn is not null) + catch (Exception ex) { - await _conn.DisposeAsync(); + _output.WriteLine("[WARNING] IntegrationFixture.DisposeAsync() exception: {0}", ex); } _channel = null; _conn = null; diff --git a/projects/Test/Integration/GH/TestGitHubIssues.cs b/projects/Test/Integration/GH/TestGitHubIssues.cs new file mode 100644 index 000000000..13f2be7fb --- /dev/null +++ b/projects/Test/Integration/GH/TestGitHubIssues.cs @@ -0,0 +1,121 @@ +// This source code is dual-licensed under the Apache License, version +// 2.0, and the Mozilla Public License, version 2.0. +// +// The APL v2.0: +// +//--------------------------------------------------------------------------- +// Copyright (c) 2007-2024 Broadcom. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +//--------------------------------------------------------------------------- +// +// The MPL v2.0: +// +//--------------------------------------------------------------------------- +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this +// file, You can obtain one at https://mozilla.org/MPL/2.0/. +// +// Copyright (c) 2007-2024 Broadcom. All Rights Reserved. +//--------------------------------------------------------------------------- + +using System; +using System.Threading; +using System.Threading.Tasks; +using RabbitMQ.Client; +using RabbitMQ.Client.Events; +using Xunit; +using Xunit.Abstractions; + +#nullable enable + +namespace Test.Integration.GH +{ + public class TestGitHubIssues : IntegrationFixture + { + public TestGitHubIssues(ITestOutputHelper output) : base(output) + { + } + + public override Task InitializeAsync() + { + // NB: nothing to do here since each test creates its own factory, + // connections and channels + Assert.Null(_connFactory); + Assert.Null(_conn); + Assert.Null(_channel); + return Task.CompletedTask; + } + + [Fact] + public async Task TestBasicConsumeCancellation_GH1750() + { + /* + * Note: + * Testing that the task is actually canceled requires a hacked RabbitMQ server. + * Modify deps/rabbit/src/rabbit_channel.erl, handle_cast for basic.consume_ok + * Before send/2, add timer:sleep(1000), then `make run-broker` + * + * The _output line at the end of the test will print TaskCanceledException + */ + Assert.Null(_connFactory); + Assert.Null(_conn); + Assert.Null(_channel); + + _connFactory = CreateConnectionFactory(); + _connFactory.NetworkRecoveryInterval = TimeSpan.FromMilliseconds(250); + _connFactory.AutomaticRecoveryEnabled = true; + _connFactory.TopologyRecoveryEnabled = true; + + _conn = await _connFactory.CreateConnectionAsync(); + _channel = await _conn.CreateChannelAsync(); + + QueueDeclareOk q = await _channel.QueueDeclareAsync(); + + var consumer = new AsyncEventingBasicConsumer(_channel); + consumer.ReceivedAsync += (o, a) => + { + return Task.CompletedTask; + }; + + bool sawConnectionShutdown = false; + _conn.ConnectionShutdownAsync += (o, ea) => + { + sawConnectionShutdown = true; + return Task.CompletedTask; + }; + + try + { + // Note: use this to test timeout via the passed-in RPC token + /* + using var cts = new CancellationTokenSource(TimeSpan.FromMilliseconds(5)); + await _channel.BasicConsumeAsync(q.QueueName, true, consumer, cts.Token); + */ + + // Note: use these to test timeout of the continuation RPC operation + using var cts = new CancellationTokenSource(TimeSpan.FromMinutes(5)); + _channel.ContinuationTimeout = TimeSpan.FromMilliseconds(5); + await _channel.BasicConsumeAsync(q.QueueName, true, consumer, cts.Token); + } + catch (Exception ex) + { + _output.WriteLine("ex: {0}", ex); + } + + await Task.Delay(500); + + Assert.False(sawConnectionShutdown); + } + } +} diff --git a/projects/Test/Integration/TestAsyncConsumer.cs b/projects/Test/Integration/TestAsyncConsumer.cs index 6210509b3..e8e684997 100644 --- a/projects/Test/Integration/TestAsyncConsumer.cs +++ b/projects/Test/Integration/TestAsyncConsumer.cs @@ -71,6 +71,7 @@ public async Task TestBasicRoundtripConcurrent() var consumer = new AsyncEventingBasicConsumer(_channel); + var consumerRegisteredTcs = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously); var publish1SyncSource = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously); var publish2SyncSource = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously); @@ -105,6 +106,20 @@ public async Task TestBasicRoundtripConcurrent() return Task.CompletedTask; }; + consumer.RegisteredAsync += (object sender, ConsumerEventArgs ea) => + { + if (ReferenceEquals(consumer, sender)) + { + consumerRegisteredTcs.SetResult(true); + } + else + { + var ex = Xunit.Sdk.EqualException.ForMismatchedValues(consumer, sender); + consumerRegisteredTcs.SetException(ex); + } + return Task.CompletedTask; + }; + consumer.ReceivedAsync += (o, a) => { if (ByteArraysEqual(a.Body.Span, body1)) @@ -126,6 +141,7 @@ public async Task TestBasicRoundtripConcurrent() }; await _channel.BasicConsumeAsync(q.QueueName, true, string.Empty, false, false, null, consumer); + await consumerRegisteredTcs.Task.WaitAsync(WaitSpan); try {