Skip to content

Commit

Permalink
Prevent progress reversal in cluster membership (#8673)
Browse files Browse the repository at this point in the history
* Prevent progress reversal in cluster membership

* Use more consistent version manipulation in cluster membership tests
  • Loading branch information
ReubenBond authored Oct 25, 2023
1 parent b23e3d8 commit c1ae3ba
Show file tree
Hide file tree
Showing 14 changed files with 153 additions and 162 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,7 @@ public async Task<bool> InsertRow(MembershipEntry entry, TableVersion tableVersi
if (logger.IsEnabled(LogLevel.Debug)) logger.LogDebug("AdoNetClusteringTable.InsertRow aborted due to null check. MembershipEntry is null.");
throw new ArgumentNullException(nameof(entry));
}
if (tableVersion == null)
if (tableVersion is null)
{
if (logger.IsEnabled(LogLevel.Debug)) logger.LogDebug("AdoNetClusteringTable.InsertRow aborted due to null check. TableVersion is null ");
throw new ArgumentNullException(nameof(tableVersion));
Expand Down Expand Up @@ -132,7 +132,7 @@ public async Task<bool> UpdateRow(MembershipEntry entry, string etag, TableVersi
if (logger.IsEnabled(LogLevel.Debug)) logger.LogDebug("AdoNetClusteringTable.UpdateRow aborted due to null check. MembershipEntry is null.");
throw new ArgumentNullException(nameof(entry));
}
if (tableVersion == null)
if (tableVersion is null)
{
if (logger.IsEnabled(LogLevel.Debug)) logger.LogDebug("AdoNetClusteringTable.UpdateRow aborted due to null check. TableVersion is null");
throw new ArgumentNullException(nameof(tableVersion));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -121,7 +121,7 @@ public async Task<bool> InsertRow(MembershipEntry entry, TableVersion tableVersi
{
logger.LogWarning((int)TableStorageErrorCode.AzureTable_23,
exc,
"Intermediate error inserting entry {Data} tableVersion {TableVersion} to the table {TableName}.", entry.ToString(), tableVersion == null ? "null" : tableVersion.ToString(), tableManager.TableName);
"Intermediate error inserting entry {Data} tableVersion {TableVersion} to the table {TableName}.", entry.ToString(), tableVersion is null ? "null" : tableVersion.ToString(), tableManager.TableName);
throw;
}
}
Expand All @@ -148,7 +148,7 @@ public async Task<bool> UpdateRow(MembershipEntry entry, string etag, TableVersi
{
logger.LogWarning((int)TableStorageErrorCode.AzureTable_25,
exc,
"Intermediate error updating entry {Data} tableVersion {TableVersion} to the table {TableName}.", entry.ToString(), tableVersion == null ? "null" : tableVersion.ToString(), tableManager.TableName);
"Intermediate error updating entry {Data} tableVersion {TableVersion} to the table {TableName}.", entry.ToString(), tableVersion is null ? "null" : tableVersion.ToString(), tableManager.TableName);
throw;
}
}
Expand Down
7 changes: 6 additions & 1 deletion src/Orleans.Core.Abstractions/Manifest/MajorMinorVersion.cs
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,12 @@ public MajorMinorVersion(long majorVersion, long minorVersion)
/// <summary>
/// Gets the zero value.
/// </summary>
public static MajorMinorVersion Zero => new MajorMinorVersion(0, 0);
public static MajorMinorVersion Zero => new(0, 0);

/// <summary>
/// Gets the minimum value.
/// </summary>
public static MajorMinorVersion MinValue => new(long.MinValue, long.MinValue);

/// <summary>
/// Gets the most significant version component.
Expand Down
10 changes: 4 additions & 6 deletions src/Orleans.Core/Manifest/ClientClusterManifestProvider.cs
Original file line number Diff line number Diff line change
Expand Up @@ -41,13 +41,11 @@ public ClientClusterManifestProvider(
_services = services;
_gatewayManager = gatewayManager;
this.LocalGrainManifest = clientManifestProvider.ClientManifest;
_current = new ClusterManifest(MajorMinorVersion.Zero, ImmutableDictionary<SiloAddress, GrainManifest>.Empty, ImmutableArray.Create(this.LocalGrainManifest));
_current = new ClusterManifest(MajorMinorVersion.MinValue, ImmutableDictionary<SiloAddress, GrainManifest>.Empty, ImmutableArray.Create(this.LocalGrainManifest));
_updates = new AsyncEnumerable<ClusterManifest>(
(previous, proposed) => previous is null || proposed.Version == MajorMinorVersion.Zero || proposed.Version > previous.Version,
_current)
{
OnPublished = update => Interlocked.Exchange(ref _current, update)
};
initialValue: _current,
updateValidator: (previous, proposed) => previous is null || proposed.Version > previous.Version,
onPublished: update => Interlocked.Exchange(ref _current, update));
}

/// <inheritdoc />
Expand Down
13 changes: 8 additions & 5 deletions src/Orleans.Core/SystemTargetInterfaces/IMembershipTable.cs
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,7 @@ public interface IMembershipTableSystemTarget : IMembershipTable, ISystemTarget
}

[Serializable, GenerateSerializer, Immutable]
public sealed class TableVersion : ISpanFormattable
public sealed class TableVersion : ISpanFormattable, IEquatable<TableVersion>
{
/// <summary>
/// The version part of this TableVersion. Monotonically increasing number.
Expand All @@ -129,16 +129,19 @@ public TableVersion(int version, string eTag)
VersionEtag = eTag;
}

public TableVersion Next()
{
return new TableVersion(Version + 1, VersionEtag);
}
public TableVersion Next() => new (Version + 1, VersionEtag);

public override string ToString() => $"<{Version}, {VersionEtag}>";
string IFormattable.ToString(string format, IFormatProvider formatProvider) => ToString();

bool ISpanFormattable.TryFormat(Span<char> destination, out int charsWritten, ReadOnlySpan<char> format, IFormatProvider provider)
=> destination.TryWrite($"<{Version}, {VersionEtag}>", out charsWritten);

public override bool Equals(object obj) => Equals(obj as TableVersion);
public override int GetHashCode() => HashCode.Combine(Version, VersionEtag);
public bool Equals(TableVersion other) => other is not null && Version == other.Version && VersionEtag == other.VersionEtag;
public static bool operator ==(TableVersion left, TableVersion right) => EqualityComparer<TableVersion>.Default.Equals(left, right);
public static bool operator !=(TableVersion left, TableVersion right) => !(left == right);
}

[Serializable]
Expand Down
151 changes: 79 additions & 72 deletions src/Orleans.Core/Utils/AsyncEnumerable.cs
Original file line number Diff line number Diff line change
@@ -1,43 +1,37 @@
using System;
using System.Collections.Generic;
using System.Diagnostics.CodeAnalysis;
using System.Threading;
using System.Threading.Tasks;
using Orleans.Internal;

namespace Orleans.Runtime.Utilities
{
internal static class AsyncEnumerable
{
internal static readonly object InitialValue = new object();
internal static readonly object DisposedValue = new object();
internal static readonly object InitialValue = new();
internal static readonly object DisposedValue = new();
}

internal sealed class AsyncEnumerable<T> : IAsyncEnumerable<T>
{
private enum PublishResult
{
Success,
InvalidUpdate,
Disposed
}

private readonly object updateLock = new object();
private readonly Func<T, T, bool> updateValidator;
private Element current;
private readonly object _updateLock = new();
private readonly Func<T, T, bool> _updateValidator;
private readonly Action<T> _onPublished;
private Element _current;

public AsyncEnumerable(Func<T, T, bool> updateValidator, T initial)
public AsyncEnumerable(T initialValue, Func<T, T, bool> updateValidator, Action<T> onPublished)
{
this.updateValidator = updateValidator;
this.current = new Element(initial);
_updateValidator = updateValidator;
_current = new Element(initialValue);
_onPublished = onPublished;
onPublished(initialValue);
}

public Action<T> OnPublished { get; set; }

public bool TryPublish(T value) => this.TryPublish(new Element(value)) == PublishResult.Success;
public bool TryPublish(T value) => TryPublish(new Element(value)) == PublishResult.Success;

public void Publish(T value)
{
switch (this.TryPublish(new Element(value)))
switch (TryPublish(new Element(value)))
{
case PublishResult.Success:
return;
Expand All @@ -52,20 +46,20 @@ public void Publish(T value)

private PublishResult TryPublish(Element newItem)
{
if (this.current.IsDisposed) return PublishResult.Disposed;
if (_current.IsDisposed) return PublishResult.Disposed;

lock (this.updateLock)
lock (_updateLock)
{
if (this.current.IsDisposed) return PublishResult.Disposed;
if (_current.IsDisposed) return PublishResult.Disposed;

if (this.current.IsValid && newItem.IsValid && !this.updateValidator(this.current.Value, newItem.Value))
if (_current.IsValid && newItem.IsValid && !_updateValidator(_current.Value, newItem.Value))
{
return PublishResult.InvalidUpdate;
}

var curr = this.current;
Interlocked.Exchange(ref this.current, newItem);
if (newItem.IsValid) this.OnPublished?.Invoke(newItem.Value);
var curr = _current;
Interlocked.Exchange(ref _current, newItem);
if (newItem.IsValid) _onPublished(newItem.Value);
curr.SetNext(newItem);

return PublishResult.Success;
Expand All @@ -74,81 +68,100 @@ private PublishResult TryPublish(Element newItem)

public void Dispose()
{
if (this.current.IsDisposed) return;
if (_current.IsDisposed) return;

lock (this.updateLock)
lock (_updateLock)
{
if (this.current.IsDisposed) return;
if (_current.IsDisposed) return;

this.TryPublish(Element.CreateDisposed());
TryPublish(Element.CreateDisposed());
}
}

private void ThrowInvalidUpdate() => throw new ArgumentException("The value was not valid");
[DoesNotReturn]
private static void ThrowInvalidUpdate() => throw new ArgumentException("The value was not valid.");

[DoesNotReturn]
private static void ThrowDisposed() => throw new ObjectDisposedException("This instance has been disposed.");

private void ThrowDisposed() => throw new ObjectDisposedException("This instance has been disposed");
public IAsyncEnumerator<T> GetAsyncEnumerator(CancellationToken cancellationToken = default) => new AsyncEnumerator(_current, cancellationToken);

public IAsyncEnumerator<T> GetAsyncEnumerator(CancellationToken cancellationToken = default)
private enum PublishResult
{
return new AsyncEnumerator(this.current, cancellationToken);
Success,
InvalidUpdate,
Disposed
}

private sealed class AsyncEnumerator : IAsyncEnumerator<T>
{
private readonly Task cancellation;
private Element current;
private readonly TaskCompletionSource _cancellation = new(TaskCreationOptions.RunContinuationsAsynchronously);
private readonly CancellationTokenRegistration _registration;
private Element _current;

public AsyncEnumerator(Element initial, CancellationToken cancellation)
{
if (!initial.IsValid) this.current = initial;
if (!initial.IsValid)
{
_current = initial;
}
else
{
var result = Element.CreateInitial();
result.SetNext(initial);
this.current = result;
_current = result;
}

if (cancellation != default)
if (cancellation.CanBeCanceled)
{
this.cancellation = cancellation.WhenCancelled();
_registration = cancellation.Register(() => _cancellation.TrySetResult());
}
}

T IAsyncEnumerator<T>.Current => this.current.Value;
T IAsyncEnumerator<T>.Current => _current.Value;

async ValueTask<bool> IAsyncEnumerator<T>.MoveNextAsync()
{
Task<Element> next;
if (this.cancellation != default)
if (_current.IsDisposed || _cancellation.Task.IsCompleted)
{
next = this.current.NextAsync();
var result = await Task.WhenAny(this.cancellation, next);
if (ReferenceEquals(result, this.cancellation)) return false;
return false;
}
else

var next = _current.NextAsync();
var cancellationTask = _cancellation.Task;
var result = await Task.WhenAny(cancellationTask, next);
if (ReferenceEquals(result, cancellationTask))
{
next = this.current.NextAsync();
return false;
}

this.current = await next;
return this.current.IsValid;
_current = await next;
return _current.IsValid;
}

ValueTask IAsyncDisposable.DisposeAsync() => default;
async ValueTask IAsyncDisposable.DisposeAsync()
{
_cancellation.TrySetResult();
await _registration.DisposeAsync();
}
}

private sealed class Element
{
private readonly TaskCompletionSource<Element> next;
private readonly object value;
private readonly TaskCompletionSource<Element> _next;
private readonly object _value;

public Element(T value)
public Element(T value) : this(value, new TaskCompletionSource<Element>(TaskCreationOptions.RunContinuationsAsynchronously))
{
this.value = value;
this.next = new TaskCompletionSource<Element>(TaskCreationOptions.RunContinuationsAsynchronously);
}

public static Element CreateInitial() => new Element(
private Element(object value, TaskCompletionSource<Element> next)
{
_value = value;
_next = next;
}

public static Element CreateInitial() => new(
AsyncEnumerable.InitialValue,
new TaskCompletionSource<Element>(TaskCreationOptions.RunContinuationsAsynchronously));

Expand All @@ -159,33 +172,27 @@ public static Element CreateDisposed()
return new Element(AsyncEnumerable.DisposedValue, tcs);
}

private Element(object value, TaskCompletionSource<Element> next)
{
this.value = value;
this.next = next;
}

public bool IsValid => !this.IsInitial && !this.IsDisposed;
public bool IsValid => !IsInitial && !IsDisposed;

public T Value
{
get
{
if (this.IsInitial) ThrowInvalidInstance();
if (IsInitial) ThrowInvalidInstance();
ObjectDisposedException.ThrowIf(IsDisposed, this);
if (this.value is T typedValue) return typedValue;
if (_value is T typedValue) return typedValue;
return default;
}
}

public bool IsInitial => ReferenceEquals(this.value, AsyncEnumerable.InitialValue);
public bool IsDisposed => ReferenceEquals(this.value, AsyncEnumerable.DisposedValue);
public bool IsInitial => ReferenceEquals(_value, AsyncEnumerable.InitialValue);
public bool IsDisposed => ReferenceEquals(_value, AsyncEnumerable.DisposedValue);

public Task<Element> NextAsync() => this.next.Task;
public Task<Element> NextAsync() => _next.Task;

public void SetNext(Element next) => this.next.SetResult(next);
public void SetNext(Element next) => _next.SetResult(next);

private void ThrowInvalidInstance() => throw new InvalidOperationException("This instance does not have a value set.");
private static void ThrowInvalidInstance() => throw new InvalidOperationException("This instance does not have a value set.");
}
}
}
8 changes: 3 additions & 5 deletions src/Orleans.Runtime/Manifest/ClusterManifestProvider.cs
Original file line number Diff line number Diff line change
Expand Up @@ -42,11 +42,9 @@ public ClusterManifestProvider(
ImmutableDictionary.CreateRange(new[] { new KeyValuePair<SiloAddress, GrainManifest>(localSiloDetails.SiloAddress, this.LocalGrainManifest) }),
ImmutableArray.Create(this.LocalGrainManifest));
_updates = new AsyncEnumerable<ClusterManifest>(
(previous, proposed) => previous.Version <= MajorMinorVersion.Zero || proposed.Version > previous.Version,
_current)
{
OnPublished = update => Interlocked.Exchange(ref _current, update)
};
initialValue: _current,
updateValidator: (previous, proposed) => proposed.Version > previous.Version,
onPublished: update => Interlocked.Exchange(ref _current, update));
}

public ClusterManifest Current => _current;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,11 +25,9 @@ public ClusterMembershipService(
{
this.snapshot = membershipTableManager.MembershipTableSnapshot.CreateClusterMembershipSnapshot();
this.updates = new AsyncEnumerable<ClusterMembershipSnapshot>(
(previous, proposed) => proposed.Version == MembershipVersion.MinValue || proposed.Version > previous.Version,
this.snapshot)
{
OnPublished = update => Interlocked.Exchange(ref this.snapshot, update)
};
initialValue: this.snapshot,
updateValidator: (previous, proposed) => proposed.Version > previous.Version,
onPublished: update => Interlocked.Exchange(ref this.snapshot, update));
this.membershipTableManager = membershipTableManager;
this.log = log;
this.fatalErrorHandler = fatalErrorHandler;
Expand Down
Loading

0 comments on commit c1ae3ba

Please sign in to comment.