Skip to content

Commit

Permalink
Replace lock with SemaphoreSlim
Browse files Browse the repository at this point in the history
  • Loading branch information
lukebakken committed Apr 22, 2024
1 parent 7966dcb commit 2783dfb
Show file tree
Hide file tree
Showing 4 changed files with 121 additions and 73 deletions.
124 changes: 86 additions & 38 deletions projects/RabbitMQ.Client/client/impl/ChannelBase.cs
Original file line number Diff line number Diff line change
Expand Up @@ -59,8 +59,7 @@ internal abstract class ChannelBase : IChannel, IRecoverable
private readonly RpcContinuationQueue _continuationQueue = new RpcContinuationQueue();
private readonly ManualResetEventSlim _flowControlBlock = new ManualResetEventSlim(true);

// TODO replace with SemaphoreSlim
private object _confirmLock;
private SemaphoreSlim _confirmSemaphore;
private readonly LinkedList<ulong> _pendingDeliveryTags = new LinkedList<ulong>();

private bool _onlyAcksReceived = true;
Expand Down Expand Up @@ -420,7 +419,7 @@ internal void FinishClose()
m_connectionStartCell?.TrySetResult(null);
}

private bool ConfirmsAreEnabled => _confirmLock != null;
private bool ConfirmsAreEnabled => _confirmSemaphore != null;

private async Task HandleCommandAsync(IncomingCommand cmd, CancellationToken cancellationToken)
{
Expand Down Expand Up @@ -484,7 +483,8 @@ private void OnChannelShutdown(ShutdownEventArgs reason)

if (ConfirmsAreEnabled)
{
lock (_confirmLock)
_confirmSemaphore.Wait();
try
{
if (_confirmsTaskCompletionSources?.Count > 0)
{
Expand All @@ -497,6 +497,10 @@ private void OnChannelShutdown(ShutdownEventArgs reason)
_confirmsTaskCompletionSources.Clear();
}
}
finally
{
_confirmSemaphore.Release();
}
}

_flowControlBlock.Set();
Expand Down Expand Up @@ -542,6 +546,7 @@ protected virtual void Dispose(bool disposing)

ConsumerDispatcher.Dispose();
_rpcSemaphore.Dispose();
_confirmSemaphore?.Dispose();
}
}

Expand Down Expand Up @@ -596,7 +601,8 @@ protected void HandleAckNack(ulong deliveryTag, bool multiple, bool isNack)
if (ConfirmsAreEnabled)
{
// let's take a lock so we can assume that deliveryTags are unique, never duplicated and always sorted
lock (_confirmLock)
_confirmSemaphore.Wait();
try
{
// No need to do anything if there are no delivery tags in the list
if (_pendingDeliveryTags.Count > 0)
Expand Down Expand Up @@ -633,6 +639,10 @@ protected void HandleAckNack(ulong deliveryTag, bool multiple, bool isNack)
_onlyAcksReceived = true;
}
}
finally
{
_confirmSemaphore.Release();
}
}
}

Expand Down Expand Up @@ -1054,10 +1064,16 @@ public async ValueTask BasicPublishAsync<TProperties>(string exchange, string ro
{
if (ConfirmsAreEnabled)
{
lock (_confirmLock)
await _confirmSemaphore.WaitAsync(cancellationToken)
.ConfigureAwait(false);
try
{
_pendingDeliveryTags.AddLast(NextPublishSeqNo++);
}
finally
{
_confirmSemaphore.Release();
}
}

try
Expand All @@ -1084,11 +1100,17 @@ await ModelSendAsync(in cmd, in basicProperties, body, cancellationToken)
{
if (ConfirmsAreEnabled)
{
lock (_confirmLock)
await _confirmSemaphore.WaitAsync(cancellationToken)
.ConfigureAwait(false);
try
{
NextPublishSeqNo--;
_pendingDeliveryTags.RemoveLast();
}
finally
{
_confirmSemaphore.Release();
}
}

throw;
Expand All @@ -1102,10 +1124,16 @@ public async ValueTask BasicPublishAsync<TProperties>(CachedString exchange, Cac
{
if (ConfirmsAreEnabled)
{
lock (_confirmLock)
await _confirmSemaphore.WaitAsync(cancellationToken)
.ConfigureAwait(false);
try
{
_pendingDeliveryTags.AddLast(NextPublishSeqNo++);
}
finally
{
_confirmSemaphore.Release();
}
}

try
Expand Down Expand Up @@ -1133,11 +1161,17 @@ await ModelSendAsync(in cmd, in basicProperties, body, cancellationToken)
{
if (ConfirmsAreEnabled)
{
lock (_confirmLock)
await _confirmSemaphore.WaitAsync(cancellationToken)
.ConfigureAwait(false);
try
{
NextPublishSeqNo--;
_pendingDeliveryTags.RemoveLast();
}
finally
{
_confirmSemaphore.Release();
}
}

throw;
Expand Down Expand Up @@ -1242,7 +1276,7 @@ await ModelSendAsync(method, k.CancellationToken)

// Note:
// Non-null means confirms are enabled
_confirmLock = new object();
_confirmSemaphore = new SemaphoreSlim(1, 1);

return;
}
Expand Down Expand Up @@ -1742,63 +1776,50 @@ await ModelSendAsync(method, k.CancellationToken)

private List<TaskCompletionSource<bool>> _confirmsTaskCompletionSources;

public Task<bool> WaitForConfirmsAsync(CancellationToken token = default)
public async Task<bool> WaitForConfirmsAsync(CancellationToken cancellationToken = default)
{
if (false == ConfirmsAreEnabled)
{
throw new InvalidOperationException("Confirms not selected");
}

TaskCompletionSource<bool> tcs;
lock (_confirmLock)
await _confirmSemaphore.WaitAsync(cancellationToken)
.ConfigureAwait(false);
try
{
if (_pendingDeliveryTags.Count == 0)
{
if (_onlyAcksReceived == false)
{
_onlyAcksReceived = true;
return Task.FromResult(false);
return false;
}

return Task.FromResult(true);
return true;
}

tcs = new TaskCompletionSource<bool>(TaskCreationOptions.RunContinuationsAsynchronously);
_confirmsTaskCompletionSources.Add(tcs);
}

if (!token.CanBeCanceled)
finally
{
return tcs.Task;
_confirmSemaphore.Release();
}

return WaitForConfirmsWithTokenAsync(tcs, token);
}
bool rv;

private async Task<bool> WaitForConfirmsWithTokenAsync(TaskCompletionSource<bool> tcs, CancellationToken token)
{
CancellationTokenRegistration tokenRegistration =
#if NET6_0_OR_GREATER
token.UnsafeRegister(
state => ((TaskCompletionSource<bool>)state).TrySetCanceled(), tcs);
#else
token.Register(
state => ((TaskCompletionSource<bool>)state).TrySetCanceled(),
state: tcs, useSynchronizationContext: false);
#endif
try
if (false == cancellationToken.CanBeCanceled)
{
return await tcs.Task.ConfigureAwait(false);
rv = await tcs.Task.ConfigureAwait(false);
}
finally
else
{
#if NET6_0_OR_GREATER
await tokenRegistration.DisposeAsync()
rv = await WaitForConfirmsWithTokenAsync(tcs, cancellationToken)
.ConfigureAwait(false);
#else
tokenRegistration.Dispose();
#endif
}

return rv;
}

public async Task WaitForConfirmsOrDieAsync(CancellationToken token = default)
Expand Down Expand Up @@ -1830,6 +1851,33 @@ await CloseAsync(ea, false, token)
}
}

private async Task<bool> WaitForConfirmsWithTokenAsync(TaskCompletionSource<bool> tcs,
CancellationToken cancellationToken)
{
CancellationTokenRegistration tokenRegistration =
#if NET6_0_OR_GREATER
cancellationToken.UnsafeRegister(
state => ((TaskCompletionSource<bool>)state).TrySetCanceled(), tcs);
#else
cancellationToken.Register(
state => ((TaskCompletionSource<bool>)state).TrySetCanceled(),
state: tcs, useSynchronizationContext: false);
#endif
try
{
return await tcs.Task.ConfigureAwait(false);
}
finally
{
#if NET6_0_OR_GREATER
await tokenRegistration.DisposeAsync()
.ConfigureAwait(false);
#else
tokenRegistration.Dispose();
#endif
}
}

private static BasicProperties PopulateActivityAndPropagateTraceId<TProperties>(TProperties basicProperties,
Activity sendActivity) where TProperties : IReadOnlyBasicProperties, IAmqpHeader
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -204,7 +204,7 @@ private Task NotifyCredentialRefreshed(bool succesfully)
if (succesfully)
{
return UpdateSecretAsync(_config.CredentialsProvider.Password, "Token refresh",
CancellationToken.None); // TODO
CancellationToken.None); // TODO cancellation token
}
else
{
Expand Down
60 changes: 30 additions & 30 deletions projects/Test/Common/IntegrationFixture.cs
Original file line number Diff line number Diff line change
Expand Up @@ -144,7 +144,10 @@ public virtual async Task InitializeAsync()
_channel = await _conn.CreateChannelAsync();
}

AddCallbackHandlers();
if (IsVerbose)
{
AddCallbackHandlers();
}
}

if (_connFactory.AutomaticRecoveryEnabled)
Expand Down Expand Up @@ -182,43 +185,40 @@ public virtual async Task DisposeAsync()

protected virtual void AddCallbackHandlers()
{
if (IsVerbose)
if (_conn != null)
{
if (_conn != null)
_conn.CallbackException += (o, ea) =>
{
_conn.CallbackException += (o, ea) =>
{
_output.WriteLine("{0} connection callback exception: {1}",
_testDisplayName, ea.Exception);
};
_output.WriteLine("{0} connection callback exception: {1}",
_testDisplayName, ea.Exception);
};

_conn.ConnectionShutdown += (o, ea) =>
_conn.ConnectionShutdown += (o, ea) =>
{
HandleConnectionShutdown(_conn, ea, (args) =>
{
HandleConnectionShutdown(_conn, ea, (args) =>
{
_output.WriteLine("{0} connection shutdown, args: {1}",
_testDisplayName, args);
});
};
}
_output.WriteLine("{0} connection shutdown, args: {1}",
_testDisplayName, args);
});
};
}

if (_channel != null)
if (_channel != null)
{
_channel.CallbackException += (o, ea) =>
{
_channel.CallbackException += (o, ea) =>
{
_output.WriteLine("{0} channel callback exception: {1}",
_testDisplayName, ea.Exception);
};
_output.WriteLine("{0} channel callback exception: {1}",
_testDisplayName, ea.Exception);
};

_channel.ChannelShutdown += (o, ea) =>
_channel.ChannelShutdown += (o, ea) =>
{
HandleChannelShutdown(_channel, ea, (args) =>
{
HandleChannelShutdown(_channel, ea, (args) =>
{
_output.WriteLine("{0} channel shutdown, args: {1}",
_testDisplayName, args);
});
};
}
_output.WriteLine("{0} channel shutdown, args: {1}",
_testDisplayName, args);
});
};
}
}

Expand Down
8 changes: 4 additions & 4 deletions projects/Test/Integration/TestAsyncConsumer.cs
Original file line number Diff line number Diff line change
Expand Up @@ -357,6 +357,8 @@ await _channel.BasicConsumeAsync(queue: queueName, autoAck: false,
[Fact]
public async Task TestBasicAckAsync()
{
AddCallbackHandlers();

const int messageCount = 1024;
int messagesReceived = 0;

Expand Down Expand Up @@ -392,7 +394,7 @@ public async Task TestBasicAckAsync()
var c = sender as AsyncEventingBasicConsumer;
Assert.NotNull(c);
await _channel.BasicAckAsync(args.DeliveryTag, false);
messagesReceived++;
Interlocked.Increment(ref messagesReceived);
if (messagesReceived == messageCount)
{
publishSyncSource.SetResult(true);
Expand All @@ -413,14 +415,12 @@ await _channel.BasicConsumeAsync(queue: queueName, autoAck: false,
{
byte[] _body = _encoding.GetBytes(Guid.NewGuid().ToString());
await _channel.BasicPublishAsync(string.Empty, queueName, _body);
await _channel.WaitForConfirmsOrDieAsync();
}
});

await _channel.WaitForConfirmsOrDieAsync();
Assert.True(await publishSyncSource.Task);

Assert.Equal(messageCount, messagesReceived);

await _channel.CloseAsync(_closeArgs, false, CancellationToken.None);
}

Expand Down

0 comments on commit 2783dfb

Please sign in to comment.