From 2783dfb8a5cdd50b33a64d923743384ae8720d9b Mon Sep 17 00:00:00 2001 From: Luke Bakken Date: Sun, 21 Apr 2024 08:37:32 -0700 Subject: [PATCH] Replace `lock` with `SemaphoreSlim` --- .../client/impl/ChannelBase.cs | 124 ++++++++++++------ .../client/impl/Connection.Commands.cs | 2 +- projects/Test/Common/IntegrationFixture.cs | 60 ++++----- .../Test/Integration/TestAsyncConsumer.cs | 8 +- 4 files changed, 121 insertions(+), 73 deletions(-) diff --git a/projects/RabbitMQ.Client/client/impl/ChannelBase.cs b/projects/RabbitMQ.Client/client/impl/ChannelBase.cs index 49683f3415..91ff79c9f4 100644 --- a/projects/RabbitMQ.Client/client/impl/ChannelBase.cs +++ b/projects/RabbitMQ.Client/client/impl/ChannelBase.cs @@ -59,8 +59,7 @@ internal abstract class ChannelBase : IChannel, IRecoverable private readonly RpcContinuationQueue _continuationQueue = new RpcContinuationQueue(); private readonly ManualResetEventSlim _flowControlBlock = new ManualResetEventSlim(true); - // TODO replace with SemaphoreSlim - private object _confirmLock; + private SemaphoreSlim _confirmSemaphore; private readonly LinkedList _pendingDeliveryTags = new LinkedList(); private bool _onlyAcksReceived = true; @@ -420,7 +419,7 @@ internal void FinishClose() m_connectionStartCell?.TrySetResult(null); } - private bool ConfirmsAreEnabled => _confirmLock != null; + private bool ConfirmsAreEnabled => _confirmSemaphore != null; private async Task HandleCommandAsync(IncomingCommand cmd, CancellationToken cancellationToken) { @@ -484,7 +483,8 @@ private void OnChannelShutdown(ShutdownEventArgs reason) if (ConfirmsAreEnabled) { - lock (_confirmLock) + _confirmSemaphore.Wait(); + try { if (_confirmsTaskCompletionSources?.Count > 0) { @@ -497,6 +497,10 @@ private void OnChannelShutdown(ShutdownEventArgs reason) _confirmsTaskCompletionSources.Clear(); } } + finally + { + _confirmSemaphore.Release(); + } } _flowControlBlock.Set(); @@ -542,6 +546,7 @@ protected virtual void Dispose(bool disposing) ConsumerDispatcher.Dispose(); _rpcSemaphore.Dispose(); + _confirmSemaphore?.Dispose(); } } @@ -596,7 +601,8 @@ protected void HandleAckNack(ulong deliveryTag, bool multiple, bool isNack) if (ConfirmsAreEnabled) { // let's take a lock so we can assume that deliveryTags are unique, never duplicated and always sorted - lock (_confirmLock) + _confirmSemaphore.Wait(); + try { // No need to do anything if there are no delivery tags in the list if (_pendingDeliveryTags.Count > 0) @@ -633,6 +639,10 @@ protected void HandleAckNack(ulong deliveryTag, bool multiple, bool isNack) _onlyAcksReceived = true; } } + finally + { + _confirmSemaphore.Release(); + } } } @@ -1054,10 +1064,16 @@ public async ValueTask BasicPublishAsync(string exchange, string ro { if (ConfirmsAreEnabled) { - lock (_confirmLock) + await _confirmSemaphore.WaitAsync(cancellationToken) + .ConfigureAwait(false); + try { _pendingDeliveryTags.AddLast(NextPublishSeqNo++); } + finally + { + _confirmSemaphore.Release(); + } } try @@ -1084,11 +1100,17 @@ await ModelSendAsync(in cmd, in basicProperties, body, cancellationToken) { if (ConfirmsAreEnabled) { - lock (_confirmLock) + await _confirmSemaphore.WaitAsync(cancellationToken) + .ConfigureAwait(false); + try { NextPublishSeqNo--; _pendingDeliveryTags.RemoveLast(); } + finally + { + _confirmSemaphore.Release(); + } } throw; @@ -1102,10 +1124,16 @@ public async ValueTask BasicPublishAsync(CachedString exchange, Cac { if (ConfirmsAreEnabled) { - lock (_confirmLock) + await _confirmSemaphore.WaitAsync(cancellationToken) + .ConfigureAwait(false); + try { _pendingDeliveryTags.AddLast(NextPublishSeqNo++); } + finally + { + _confirmSemaphore.Release(); + } } try @@ -1133,11 +1161,17 @@ await ModelSendAsync(in cmd, in basicProperties, body, cancellationToken) { if (ConfirmsAreEnabled) { - lock (_confirmLock) + await _confirmSemaphore.WaitAsync(cancellationToken) + .ConfigureAwait(false); + try { NextPublishSeqNo--; _pendingDeliveryTags.RemoveLast(); } + finally + { + _confirmSemaphore.Release(); + } } throw; @@ -1242,7 +1276,7 @@ await ModelSendAsync(method, k.CancellationToken) // Note: // Non-null means confirms are enabled - _confirmLock = new object(); + _confirmSemaphore = new SemaphoreSlim(1, 1); return; } @@ -1742,7 +1776,7 @@ await ModelSendAsync(method, k.CancellationToken) private List> _confirmsTaskCompletionSources; - public Task WaitForConfirmsAsync(CancellationToken token = default) + public async Task WaitForConfirmsAsync(CancellationToken cancellationToken = default) { if (false == ConfirmsAreEnabled) { @@ -1750,55 +1784,42 @@ public Task WaitForConfirmsAsync(CancellationToken token = default) } TaskCompletionSource tcs; - lock (_confirmLock) + await _confirmSemaphore.WaitAsync(cancellationToken) + .ConfigureAwait(false); + try { if (_pendingDeliveryTags.Count == 0) { if (_onlyAcksReceived == false) { _onlyAcksReceived = true; - return Task.FromResult(false); + return false; } - return Task.FromResult(true); + return true; } tcs = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously); _confirmsTaskCompletionSources.Add(tcs); } - - if (!token.CanBeCanceled) + finally { - return tcs.Task; + _confirmSemaphore.Release(); } - return WaitForConfirmsWithTokenAsync(tcs, token); - } + bool rv; - private async Task WaitForConfirmsWithTokenAsync(TaskCompletionSource tcs, CancellationToken token) - { - CancellationTokenRegistration tokenRegistration = -#if NET6_0_OR_GREATER - token.UnsafeRegister( - state => ((TaskCompletionSource)state).TrySetCanceled(), tcs); -#else - token.Register( - state => ((TaskCompletionSource)state).TrySetCanceled(), - state: tcs, useSynchronizationContext: false); -#endif - try + if (false == cancellationToken.CanBeCanceled) { - return await tcs.Task.ConfigureAwait(false); + rv = await tcs.Task.ConfigureAwait(false); } - finally + else { -#if NET6_0_OR_GREATER - await tokenRegistration.DisposeAsync() + rv = await WaitForConfirmsWithTokenAsync(tcs, cancellationToken) .ConfigureAwait(false); -#else - tokenRegistration.Dispose(); -#endif } + + return rv; } public async Task WaitForConfirmsOrDieAsync(CancellationToken token = default) @@ -1830,6 +1851,33 @@ await CloseAsync(ea, false, token) } } + private async Task WaitForConfirmsWithTokenAsync(TaskCompletionSource tcs, + CancellationToken cancellationToken) + { + CancellationTokenRegistration tokenRegistration = +#if NET6_0_OR_GREATER + cancellationToken.UnsafeRegister( + state => ((TaskCompletionSource)state).TrySetCanceled(), tcs); +#else + cancellationToken.Register( + state => ((TaskCompletionSource)state).TrySetCanceled(), + state: tcs, useSynchronizationContext: false); +#endif + try + { + return await tcs.Task.ConfigureAwait(false); + } + finally + { +#if NET6_0_OR_GREATER + await tokenRegistration.DisposeAsync() + .ConfigureAwait(false); +#else + tokenRegistration.Dispose(); +#endif + } + } + private static BasicProperties PopulateActivityAndPropagateTraceId(TProperties basicProperties, Activity sendActivity) where TProperties : IReadOnlyBasicProperties, IAmqpHeader { diff --git a/projects/RabbitMQ.Client/client/impl/Connection.Commands.cs b/projects/RabbitMQ.Client/client/impl/Connection.Commands.cs index 69e7376f28..9809f1634c 100644 --- a/projects/RabbitMQ.Client/client/impl/Connection.Commands.cs +++ b/projects/RabbitMQ.Client/client/impl/Connection.Commands.cs @@ -204,7 +204,7 @@ private Task NotifyCredentialRefreshed(bool succesfully) if (succesfully) { return UpdateSecretAsync(_config.CredentialsProvider.Password, "Token refresh", - CancellationToken.None); // TODO + CancellationToken.None); // TODO cancellation token } else { diff --git a/projects/Test/Common/IntegrationFixture.cs b/projects/Test/Common/IntegrationFixture.cs index 325e227cba..700078a16c 100644 --- a/projects/Test/Common/IntegrationFixture.cs +++ b/projects/Test/Common/IntegrationFixture.cs @@ -144,7 +144,10 @@ public virtual async Task InitializeAsync() _channel = await _conn.CreateChannelAsync(); } - AddCallbackHandlers(); + if (IsVerbose) + { + AddCallbackHandlers(); + } } if (_connFactory.AutomaticRecoveryEnabled) @@ -182,43 +185,40 @@ public virtual async Task DisposeAsync() protected virtual void AddCallbackHandlers() { - if (IsVerbose) + if (_conn != null) { - if (_conn != null) + _conn.CallbackException += (o, ea) => { - _conn.CallbackException += (o, ea) => - { - _output.WriteLine("{0} connection callback exception: {1}", - _testDisplayName, ea.Exception); - }; + _output.WriteLine("{0} connection callback exception: {1}", + _testDisplayName, ea.Exception); + }; - _conn.ConnectionShutdown += (o, ea) => + _conn.ConnectionShutdown += (o, ea) => + { + HandleConnectionShutdown(_conn, ea, (args) => { - HandleConnectionShutdown(_conn, ea, (args) => - { - _output.WriteLine("{0} connection shutdown, args: {1}", - _testDisplayName, args); - }); - }; - } + _output.WriteLine("{0} connection shutdown, args: {1}", + _testDisplayName, args); + }); + }; + } - if (_channel != null) + if (_channel != null) + { + _channel.CallbackException += (o, ea) => { - _channel.CallbackException += (o, ea) => - { - _output.WriteLine("{0} channel callback exception: {1}", - _testDisplayName, ea.Exception); - }; + _output.WriteLine("{0} channel callback exception: {1}", + _testDisplayName, ea.Exception); + }; - _channel.ChannelShutdown += (o, ea) => + _channel.ChannelShutdown += (o, ea) => + { + HandleChannelShutdown(_channel, ea, (args) => { - HandleChannelShutdown(_channel, ea, (args) => - { - _output.WriteLine("{0} channel shutdown, args: {1}", - _testDisplayName, args); - }); - }; - } + _output.WriteLine("{0} channel shutdown, args: {1}", + _testDisplayName, args); + }); + }; } } diff --git a/projects/Test/Integration/TestAsyncConsumer.cs b/projects/Test/Integration/TestAsyncConsumer.cs index 8a44cdd8bf..551c033f3d 100644 --- a/projects/Test/Integration/TestAsyncConsumer.cs +++ b/projects/Test/Integration/TestAsyncConsumer.cs @@ -357,6 +357,8 @@ await _channel.BasicConsumeAsync(queue: queueName, autoAck: false, [Fact] public async Task TestBasicAckAsync() { + AddCallbackHandlers(); + const int messageCount = 1024; int messagesReceived = 0; @@ -392,7 +394,7 @@ public async Task TestBasicAckAsync() var c = sender as AsyncEventingBasicConsumer; Assert.NotNull(c); await _channel.BasicAckAsync(args.DeliveryTag, false); - messagesReceived++; + Interlocked.Increment(ref messagesReceived); if (messagesReceived == messageCount) { publishSyncSource.SetResult(true); @@ -413,14 +415,12 @@ await _channel.BasicConsumeAsync(queue: queueName, autoAck: false, { byte[] _body = _encoding.GetBytes(Guid.NewGuid().ToString()); await _channel.BasicPublishAsync(string.Empty, queueName, _body); + await _channel.WaitForConfirmsOrDieAsync(); } }); - await _channel.WaitForConfirmsOrDieAsync(); Assert.True(await publishSyncSource.Task); - Assert.Equal(messageCount, messagesReceived); - await _channel.CloseAsync(_closeArgs, false, CancellationToken.None); }