Skip to content

Commit

Permalink
Update secret functionality (#342)
Browse files Browse the repository at this point in the history
* add update secret functionality.
* Update the secrets for all the pool connections

Signed-off-by: Gabriele Santomaggio <[email protected]>
Co-authored-by: Gabriele Santomaggio <[email protected]>
  • Loading branch information
simone-fariselli and Gsantomaggio authored Feb 1, 2024
1 parent d70458a commit c3d8e55
Show file tree
Hide file tree
Showing 7 changed files with 107 additions and 11 deletions.
35 changes: 26 additions & 9 deletions RabbitMQ.Stream.Client/Client.cs
Original file line number Diff line number Diff line change
Expand Up @@ -117,7 +117,7 @@ public class Client : IClient

private uint correlationId = 0; // allow for some pre-amble

private Connection connection;
private Connection _connection;

private readonly ConcurrentDictionary<uint, IValueTaskSource> requests = new();

Expand Down Expand Up @@ -148,7 +148,7 @@ public class Client : IClient

public int ConfirmFrames => confirmFrames;

public int IncomingFrames => connection.NumFrames;
public int IncomingFrames => _connection.NumFrames;

//public int IncomingChannelCount => this.incoming.Reader.Count;
private static readonly object Obj = new();
Expand Down Expand Up @@ -176,7 +176,7 @@ public bool IsClosed
{
get
{
if (connection.IsClosed)
if (_connection.IsClosed)
{
isClosed = true;
}
Expand Down Expand Up @@ -208,10 +208,10 @@ private async Task OnConnectionClosed(string reason)
public static async Task<Client> Create(ClientParameters parameters, ILogger logger = null)
{
var client = new Client(parameters, logger);
client.connection = await Connection
client._connection = await Connection
.Create(parameters.Endpoint, client.HandleIncoming, client.HandleClosed, parameters.Ssl, logger)
.ConfigureAwait(false);
client.connection.ClientId = client.ClientId;
client._connection.ClientId = client.ClientId;
// exchange properties
var peerPropertiesResponse = await client.Request<PeerPropertiesRequest, PeerPropertiesResponse>(corr =>
new PeerPropertiesRequest(corr, parameters.Properties)).ConfigureAwait(false);
Expand Down Expand Up @@ -283,6 +283,23 @@ await client.Publish(new TuneRequest(0,
return client;
}

public async Task UpdateSecret(string newSecret)
{
var saslData = Encoding.UTF8.GetBytes($"\0{Parameters.UserName}\0{newSecret}");

var authResponse =
await Request<SaslAuthenticateRequest, SaslAuthenticateResponse>(corr =>
new SaslAuthenticateRequest(
corr,
Parameters.AuthMechanism.ToString().ToUpperInvariant(),
saslData))
.ConfigureAwait(false);

ClientExceptions.MaybeThrowException(
authResponse.ResponseCode,
"Error while updating secret: the secret will not be updated.");
}

public async ValueTask<bool> Publish(Publish publishMsg)
{
var publishTask = await Publish<Publish>(publishMsg).ConfigureAwait(false);
Expand All @@ -296,7 +313,7 @@ public ValueTask<bool> Publish<T>(T msg) where T : struct, ICommand
{
try
{
return connection.Write(msg);
return _connection.Write(msg);
}
catch (Exception e)
{
Expand Down Expand Up @@ -757,7 +774,7 @@ public async Task<CloseResponse> Close(string reason)
InternalClose();
try
{
connection.UpdateCloseStatus(ConnectionClosedReason.Normal);
_connection.UpdateCloseStatus(ConnectionClosedReason.Normal);
var result =
await Request<CloseRequest, CloseResponse>(corr => new CloseRequest(corr, reason),
TimeSpan.FromSeconds(10)).ConfigureAwait(false);
Expand All @@ -771,11 +788,11 @@ public async Task<CloseResponse> Close(string reason)
}
catch (Exception e)
{
_logger.LogError(e, "An error occurred while calling {CalledFunction}", nameof(connection.Dispose));
_logger.LogError(e, "An error occurred while calling {CalledFunction}", nameof(_connection.Dispose));
}
finally
{
connection.Dispose();
_connection.Dispose();
}

return new CloseResponse(0, ResponseCode.Ok);
Expand Down
23 changes: 21 additions & 2 deletions RabbitMQ.Stream.Client/ConnectionsPool.cs
Original file line number Diff line number Diff line change
Expand Up @@ -130,7 +130,7 @@ public ConnectionsPool(int maxConnections, byte idsPerConnection)
/// Value is the connection item
/// The Connections contains all the connections created by the pool
/// </summary>
internal ConcurrentDictionary<string, ConnectionItem> Connections { get; } = new();
private ConcurrentDictionary<string, ConnectionItem> Connections { get; } = new();

/// <summary>
/// GetOrCreateClient returns a client for the given brokerInfo.
Expand Down Expand Up @@ -162,7 +162,8 @@ internal async Task<IClient> GetOrCreateClient(string brokerInfo, Func<Task<ICli
// let's remove it from the pool
Connections.TryRemove(connectionItem.Client.ClientId, out _);
// let's create a new one
connectionItem = new ConnectionItem(brokerInfo, _idsPerConnection, await createClient().ConfigureAwait(false));
connectionItem = new ConnectionItem(brokerInfo, _idsPerConnection,
await createClient().ConfigureAwait(false));
Connections.TryAdd(connectionItem.Client.ClientId, connectionItem);

return connectionItem.Client;
Expand All @@ -185,6 +186,7 @@ internal async Task<IClient> GetOrCreateClient(string brokerInfo, Func<Task<ICli
_semaphoreSlim.Release();
}
}

public void Remove(string clientId)
{
_semaphoreSlim.Wait();
Expand All @@ -202,6 +204,23 @@ public void Remove(string clientId)
}
}

public async Task UpdateSecrets(string newSecret)
{
await _semaphoreSlim.WaitAsync().ConfigureAwait(false);
try
{
foreach (var connectionItem in Connections.Values)
{
await connectionItem.Client.UpdateSecret(newSecret).ConfigureAwait(false);
connectionItem.Client.Parameters.Password = newSecret;
}
}
finally
{
_semaphoreSlim.Release();
}
}

public void MaybeClose(string clientId, string reason)
{
_semaphoreSlim.Wait();
Expand Down
2 changes: 2 additions & 0 deletions RabbitMQ.Stream.Client/IClient.cs
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,8 @@ public interface IClient
IDictionary<byte, (string, (Action<ReadOnlyMemory<ulong>>, Action<(ulong, ResponseCode)[]>))> Publishers { get; }
IDictionary<byte, (string, ConsumerEvents)> Consumers { get; }

Task UpdateSecret(string newSecret);

public bool IsClosed { get; }
}
}
6 changes: 6 additions & 0 deletions RabbitMQ.Stream.Client/PublicAPI.Unshipped.txt
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ RabbitMQ.Stream.Client.Client.QueryRoute(string superStream, string routingKey)
RabbitMQ.Stream.Client.Client.StreamStats(string stream) -> System.Threading.Tasks.ValueTask<RabbitMQ.Stream.Client.StreamStatsResponse>
RabbitMQ.Stream.Client.Client.Subscribe(string stream, RabbitMQ.Stream.Client.IOffsetType offsetType, ushort initialCredit, System.Collections.Generic.Dictionary<string, string> properties, System.Func<RabbitMQ.Stream.Client.Deliver, System.Threading.Tasks.Task> deliverHandler, System.Func<bool, System.Threading.Tasks.Task<RabbitMQ.Stream.Client.IOffsetType>> consumerUpdateHandler = null, RabbitMQ.Stream.Client.ConnectionsPool pool = null) -> System.Threading.Tasks.Task<(byte, RabbitMQ.Stream.Client.SubscribeResponse)>
RabbitMQ.Stream.Client.Client.Unsubscribe(byte subscriptionId, bool ignoreIfAlreadyRemoved = false) -> System.Threading.Tasks.Task<RabbitMQ.Stream.Client.UnsubscribeResponse>
RabbitMQ.Stream.Client.Client.UpdateSecret(string newSecret) -> System.Threading.Tasks.Task
RabbitMQ.Stream.Client.ClientParameters.AuthMechanism.get -> RabbitMQ.Stream.Client.AuthMechanism
RabbitMQ.Stream.Client.ClientParameters.AuthMechanism.set -> void
RabbitMQ.Stream.Client.ClientParameters.MetadataUpdateHandler
Expand Down Expand Up @@ -68,6 +69,7 @@ RabbitMQ.Stream.Client.ConnectionsPool.MaybeClose(string clientId, string reason
RabbitMQ.Stream.Client.ConnectionsPool.Remove(string clientId) -> void
RabbitMQ.Stream.Client.ConnectionsPool.RemoveConsumerEntityFromStream(string clientId, byte id, string stream) -> void
RabbitMQ.Stream.Client.ConnectionsPool.RemoveProducerEntityFromStream(string clientId, byte id, string stream) -> void
RabbitMQ.Stream.Client.ConnectionsPool.UpdateSecrets(string newSecret) -> System.Threading.Tasks.Task
RabbitMQ.Stream.Client.ConsumerEvents
RabbitMQ.Stream.Client.ConsumerEvents.ConsumerEvents() -> void
RabbitMQ.Stream.Client.ConsumerEvents.ConsumerEvents(System.Func<RabbitMQ.Stream.Client.Deliver, System.Threading.Tasks.Task> deliverHandler, System.Func<bool, System.Threading.Tasks.Task<RabbitMQ.Stream.Client.IOffsetType>> consumerUpdateHandler) -> void
Expand Down Expand Up @@ -103,6 +105,7 @@ RabbitMQ.Stream.Client.IClient.ClientId.init -> void
RabbitMQ.Stream.Client.IClient.Consumers.get -> System.Collections.Generic.IDictionary<byte, (string, RabbitMQ.Stream.Client.ConsumerEvents)>
RabbitMQ.Stream.Client.IClient.IsClosed.get -> bool
RabbitMQ.Stream.Client.IClient.Publishers.get -> System.Collections.Generic.IDictionary<byte, (string, (System.Action<System.ReadOnlyMemory<ulong>>, System.Action<(ulong, RabbitMQ.Stream.Client.ResponseCode)[]>))>
RabbitMQ.Stream.Client.IClient.UpdateSecret(string newSecret) -> System.Threading.Tasks.Task
RabbitMQ.Stream.Client.IClosable
RabbitMQ.Stream.Client.IClosable.Close() -> System.Threading.Tasks.Task<RabbitMQ.Stream.Client.ResponseCode>
RabbitMQ.Stream.Client.IConsumer.Info.get -> RabbitMQ.Stream.Client.ConsumerInfo
Expand Down Expand Up @@ -276,6 +279,7 @@ RabbitMQ.Stream.Client.StreamSystem.CreateRawSuperStreamProducer(RabbitMQ.Stream
RabbitMQ.Stream.Client.StreamSystem.CreateSuperStreamConsumer(RabbitMQ.Stream.Client.RawSuperStreamConsumerConfig rawSuperStreamConsumerConfig, Microsoft.Extensions.Logging.ILogger logger = null) -> System.Threading.Tasks.Task<RabbitMQ.Stream.Client.ISuperStreamConsumer>
RabbitMQ.Stream.Client.StreamSystem.StreamInfo(string streamName) -> System.Threading.Tasks.Task<RabbitMQ.Stream.Client.StreamInfo>
RabbitMQ.Stream.Client.StreamSystem.StreamStats(string stream) -> System.Threading.Tasks.Task<RabbitMQ.Stream.Client.StreamStats>
RabbitMQ.Stream.Client.StreamSystem.UpdateSecret(string newSecret) -> System.Threading.Tasks.Task
RabbitMQ.Stream.Client.StreamSystemConfig.AuthMechanism.get -> RabbitMQ.Stream.Client.AuthMechanism
RabbitMQ.Stream.Client.StreamSystemConfig.AuthMechanism.set -> void
RabbitMQ.Stream.Client.StreamSystemConfig.ConnectionPoolConfig.get -> RabbitMQ.Stream.Client.ConnectionPoolConfig
Expand All @@ -286,6 +290,8 @@ RabbitMQ.Stream.Client.UnknownCommandException
RabbitMQ.Stream.Client.UnknownCommandException.UnknownCommandException(string s) -> void
RabbitMQ.Stream.Client.UnsupportedOperationException
RabbitMQ.Stream.Client.UnsupportedOperationException.UnsupportedOperationException(string s) -> void
RabbitMQ.Stream.Client.UpdateSecretFailureException
RabbitMQ.Stream.Client.UpdateSecretFailureException.UpdateSecretFailureException(string s) -> void
static RabbitMQ.Stream.Client.Connection.Create(System.Net.EndPoint endpoint, System.Func<System.Memory<byte>, System.Threading.Tasks.Task> commandCallback, System.Func<string, System.Threading.Tasks.Task> closedCallBack, RabbitMQ.Stream.Client.SslOption sslOption, Microsoft.Extensions.Logging.ILogger logger) -> System.Threading.Tasks.Task<RabbitMQ.Stream.Client.Connection>
static RabbitMQ.Stream.Client.Message.From(ref System.Buffers.ReadOnlySequence<byte> seq, uint len) -> RabbitMQ.Stream.Client.Message
static RabbitMQ.Stream.Client.RawConsumer.Create(RabbitMQ.Stream.Client.ClientParameters clientParameters, RabbitMQ.Stream.Client.RawConsumerConfig config, RabbitMQ.Stream.Client.StreamInfo metaStreamInfo, Microsoft.Extensions.Logging.ILogger logger = null) -> System.Threading.Tasks.Task<RabbitMQ.Stream.Client.IConsumer>
Expand Down
18 changes: 18 additions & 0 deletions RabbitMQ.Stream.Client/StreamSystem.cs
Original file line number Diff line number Diff line change
Expand Up @@ -153,6 +153,17 @@ private async Task MayBeReconnectLocator()
}
}

public async Task UpdateSecret(string newSecret)
{
if (_client.IsClosed)
throw new UpdateSecretFailureException("Cannot update a closed connection.");

await _client.UpdateSecret(newSecret).ConfigureAwait(false);
_clientParameters.Password = newSecret;
_client.Parameters.Password = newSecret;

}

public async Task<ISuperStreamProducer> CreateRawSuperStreamProducer(
RawSuperStreamProducerConfig rawSuperStreamProducerConfig, ILogger logger = null)
{
Expand Down Expand Up @@ -542,4 +553,11 @@ public StreamSystemInitialisationException(string error) : base(error)
{
}
}
public class UpdateSecretFailureException : ProtocolException
{
public UpdateSecretFailureException(string s)
: base(s)
{
}
}
}
32 changes: 32 additions & 0 deletions Tests/SystemTests.cs
Original file line number Diff line number Diff line change
Expand Up @@ -159,6 +159,38 @@ await Assert.ThrowsAsync<AuthenticationFailureException>(
);
}

[Fact]
public async void UpdateSecretWithValidSecretShouldNoRaiseExceptions()
{
var config = new StreamSystemConfig { UserName = "guest", Password = "guest" }; // specified for readability
var streamSystem = await StreamSystem.Create(config);

await streamSystem.UpdateSecret("guest");
}

[Fact]
public async void UpdateSecretWithInvalidSecretShouldThrowAuthenticationFailureException()
{
var config = new StreamSystemConfig { UserName = "guest", Password = "guest" }; // specified for readability
var streamSystem = await StreamSystem.Create(config);

await Assert.ThrowsAsync<AuthenticationFailureException>(
async () => { await streamSystem.UpdateSecret("not_valid_secret"); }
);
}

[Fact]
public async void UpdateSecretForClosedConnectionShouldThrowUpdateSecretFailureException()
{
var config = new StreamSystemConfig { UserName = "guest", Password = "guest" }; // specified for readability
var streamSystem = await StreamSystem.Create(config);

await streamSystem.Close();
await Assert.ThrowsAsync<UpdateSecretFailureException>(
async () => { await streamSystem.UpdateSecret("guest"); }
);
}

[Fact]
public async void CreateExistStreamIdempotentShouldNoRaiseExceptions()
{
Expand Down
2 changes: 2 additions & 0 deletions Tests/UnitTests.cs
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,8 @@ public Task<CloseResponse> Close(string reason)
}

public IDictionary<byte, (string, ConsumerEvents)> Consumers { get; }
public Task UpdateSecret(string newSecret) => throw new NotImplementedException();

public bool IsClosed { get; }

public FakeClient(ClientParameters clientParameters)
Expand Down

0 comments on commit c3d8e55

Please sign in to comment.