diff --git a/src/WebSocket4Net/WebSocket.cs b/src/WebSocket4Net/WebSocket.cs index 8582d8b..60dce1f 100644 --- a/src/WebSocket4Net/WebSocket.cs +++ b/src/WebSocket4Net/WebSocket.cs @@ -24,6 +24,10 @@ public class WebSocket : EasyClient, IWebSocket private static readonly Encoding _utf8Encoding = new UTF8Encoding(false); private const string _magic = "258EAFA5-E914-47DA-95CA-C5AB0DC85B11"; + private TaskCompletionSource _closePackageReceivedTaskSource; + + private bool _packageHandlerMode = false; + public Uri Uri { get; private set; } public bool AutoPingEnabled { get; set; } @@ -220,7 +224,11 @@ private void WriteHandshakeRequest(PipeWriter writer, string secKey) writer.Write("\r\n", _asciiEncoding); } - public new void StartReceive() => base.StartReceive(); + public new void StartReceive() + { + base.StartReceive(); + _packageHandlerMode = true; + } public new async ValueTask ReceiveAsync() => await ReceiveAsync( @@ -229,6 +237,11 @@ private void WriteHandshakeRequest(PipeWriter writer, string secKey) internal async ValueTask ReceiveAsync(bool handleControlPackage, bool returnControlPackage) { + if (_packageHandlerMode) + { + throw new InvalidOperationException($"You cannot call the method {nameof(ReceiveAsync)} if you already setup the client to process packages by PackageHandler."); + } + var package = await base.ReceiveAsync(); if (package == null) @@ -254,6 +267,15 @@ protected override async ValueTask OnPackageReceived(WebSocketPackage package) { if (package.OpCode != OpCode.Binary && package.OpCode != OpCode.Text) { + if (package.OpCode == OpCode.Close && _closePackageReceivedTaskSource is TaskCompletionSource closePackageReceivedTaskSource) + { + if (Interlocked.CompareExchange(ref _closePackageReceivedTaskSource, null, closePackageReceivedTaskSource) == closePackageReceivedTaskSource) + { + closePackageReceivedTaskSource.SetResult(package); + return; + } + } + await HandleControlPackage(package); return; } @@ -349,14 +371,21 @@ public async ValueTask CloseAsync(CloseReason closeReason, string message = null Reason = closeReason, ReasonText = message }; + + var closePackageReceivedTaskSource = default(TaskCompletionSource); + + if (_packageHandlerMode) + { + closePackageReceivedTaskSource = _closePackageReceivedTaskSource = new (); + } await SendAsync(_packageEncoder, package); State = WebSocketState.CloseSent; - var closeHandshakeResponse = await ReceiveAsync( - handleControlPackage: false, - returnControlPackage: true); + var closeHandshakeResponse = closePackageReceivedTaskSource != null + ? await closePackageReceivedTaskSource.Task + : await ReceiveAsync(handleControlPackage: false, returnControlPackage: true); if (closeHandshakeResponse.OpCode != OpCode.Close) {