Skip to content
This repository has been archived by the owner on Apr 30, 2024. It is now read-only.

Commit

Permalink
Merge pull request #612 from UWPCommunity/rewrite/main
Browse files Browse the repository at this point in the history
Alpha release
  • Loading branch information
matthew4850 authored May 24, 2022
2 parents 1cc0fa6 + 1911211 commit d807e16
Show file tree
Hide file tree
Showing 93 changed files with 1,880 additions and 360 deletions.
2 changes: 1 addition & 1 deletion src/API/Discord.API/Gateways/Gateway.Events.cs
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ internal partial class Gateway
private Action<Ready> Ready { get; }
private Action<Resumed> Resumed { get; }
private Action<InvalidSession> InvalidSession { get; }
private Action<Exception> GatewayClosed { get; }
private Action<GatewayStatus> GatewayStateChanged { get; }

private Action<JsonGuild> GuildCreated { get; }
private Action<JsonGuild> GuildUpdated { get; }
Expand Down
25 changes: 12 additions & 13 deletions src/API/Discord.API/Gateways/Gateway.Handshake.cs
Original file line number Diff line number Diff line change
Expand Up @@ -13,40 +13,39 @@ internal partial class Gateway

private bool OnHelloReceived(SocketFrame<Hello> frame)
{
SetupGateway(frame.Payload.HeartbeatInterval);
_ = SetupGateway(frame.Payload.HeartbeatInterval);
return true;
}

private async void SetupGateway(int interval)
private async Task SetupGateway(int interval)
{
switch (_gatewayStatus)
switch (GatewayStatus)
{
case GatewayStatus.Reconnecting:
case GatewayStatus.Connecting:
await IdentifySelfToGateway();
_gatewayStatus = GatewayStatus.Connected;
GatewayStatus = GatewayStatus.Connected;
break;
case GatewayStatus.Resuming:
await SendResumeRequestAsync();
_gatewayStatus = GatewayStatus.Connected;
GatewayStatus = GatewayStatus.Connected;
break;
default:
_gatewayStatus = GatewayStatus.Error;
GatewayStatus = GatewayStatus.Error;
return;
}

double jitter = (new Random()).NextDouble();
await Task.Delay((int)(interval * jitter));
await BeginHeartbeatAsync(interval);
_ = BeginHeartbeatAsync(interval);
}

private bool OnInvalidSession(SocketFrame frame)
{
switch (_gatewayStatus)
switch (GatewayStatus)
{
case GatewayStatus.InvalidSession:
Guard.IsNotNull(_connectionUrl, nameof(_connectionUrl));

_ = ConnectAsync(_connectionUrl);
_ = ReconnectAsync();
break;
case GatewayStatus.Reconnecting:
FireEvent(frame, InvalidSession);
Expand All @@ -64,14 +63,14 @@ private bool OnHeartbeatAck()

private async Task BeginHeartbeatAsync(int interval)
{
while (_gatewayStatus == GatewayStatus.Connected)
while (GatewayStatus == GatewayStatus.Connected)
{
await SendHeartbeatAsync();
_recievedAck = false;
await Task.Delay(interval);
if (!_recievedAck)
{
_gatewayStatus = GatewayStatus.Disconnected;
GatewayStatus = GatewayStatus.Disconnected;
await CloseSocket();
await ResumeAsync();
}
Expand Down
190 changes: 141 additions & 49 deletions src/API/Discord.API/Gateways/Gateway.Sockets.cs
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
using System.Net.WebSockets;
using System.Text;
using System.Text.Json;
using System.Threading;
using System.Threading.Tasks;

namespace Discord.API.Gateways
Expand All @@ -17,26 +18,135 @@ internal partial class Gateway
{
private readonly JsonSerializerOptions _serialiseOptions;
private readonly JsonSerializerOptions _deserialiseOptions;
private WebSocketClient _socket;
private ClientWebSocket? _socket;
private Task? _task;
private CancellationTokenSource _tokenSource = new CancellationTokenSource();
private DeflateStream? _decompressor;
private MemoryStream? _decompressionBuffer;

private WebSocketClient CreateSocket()
{
_socket?.Dispose();
_socket = new WebSocketClient();
_socket.TextMessage += HandleTextMessage;
_socket.BinaryMessage += HandleBinaryMessage;
_socket.Closed += HandleClosed;
return _socket;
}


private void SetupCompression()
{
_decompressionBuffer = new MemoryStream();
_decompressor = new DeflateStream(_decompressionBuffer, CompressionMode.Decompress);
}


/// <summary>
/// Sets up a connection to the gateway.
/// </summary>
/// <exception cref="Exception">An exception will be thrown when connection fails, but not when the handshake fails.</exception>
public async Task Connect(string token)
{
_token = token;
await ConnectAsync();
}

public async Task ConnectAsync()
{
GatewayStatus = GatewayStatus == GatewayStatus.Initialized ? GatewayStatus.Connecting : GatewayStatus.Reconnecting;
await ConnectAsync(_gatewayConfig.GetFullGatewayUrl("json", "9", "&compress=zlib-stream"));
_task = Task.Run(async () =>
{
await ListenOnSocket();
_socket = null;
});
}

/// <summary>
/// Resumes a connection to the gateway.
/// </summary>
/// <exception cref="Exception">An exception will be thrown when connection fails, but not when the handshake fails.</exception>
public async Task ResumeAsync()
{
GatewayStatus = GatewayStatus.Resuming;
await ConnectAsync(_gatewayConfig.GetFullGatewayUrl("json", "9", "&compress=zlib-stream"));
_task = Task.Run(ListenOnSocket);
}

private async Task ListenOnSocket()
{
var buffer = new ArraySegment<byte>(new byte[16 * 1024]);
while (_tokenSource.IsCancellationRequested && _socket!.State == WebSocketState.Open)
{
WebSocketReceiveResult socketResult = await _socket.ReceiveAsync(buffer, _tokenSource.Token).ConfigureAwait(false);
if (socketResult.MessageType == WebSocketMessageType.Close)
{
switch (socketResult.CloseStatus)
{
case (WebSocketCloseStatus)4000:
case (WebSocketCloseStatus)4001:
case (WebSocketCloseStatus)4002:
case (WebSocketCloseStatus)4003:
case (WebSocketCloseStatus)4005:
case (WebSocketCloseStatus)4007:
case (WebSocketCloseStatus)4008:
case (WebSocketCloseStatus)4009:
GatewayStatus = GatewayStatus.Reconnecting;
_ = ConnectAsync();
return;

case (WebSocketCloseStatus)4004:
default:
GatewayStatus = GatewayStatus.Disconnected;
return;

}
}

byte[] bytes = buffer.Array;
int length = socketResult.Count;

if (!socketResult.EndOfMessage)
{
// This is a large message (likely just READY), lets create a temporary expandable stream
var stream = new MemoryStream();
await stream.WriteAsync(buffer.Array, 0, socketResult.Count).ConfigureAwait(false);
do
{
if (_tokenSource.Token.IsCancellationRequested)
{
return;
}
socketResult = await _socket.ReceiveAsync(buffer, _tokenSource.Token).ConfigureAwait(false);
await stream.WriteAsync(buffer.Array, 0, socketResult.Count).ConfigureAwait(false);
}
while (!socketResult.EndOfMessage);

bytes = stream.GetBuffer();
length = (int)stream.Length;
}

if (socketResult.MessageType == WebSocketMessageType.Text)
{
HandleTextMessage(bytes);
}
else
{
HandleBinaryMessage(bytes, length);
}
}
}

private async Task ReconnectAsync()
{
await CloseSocket();
await ConnectAsync(_connectionUrl!);
}
private async Task ConnectAsync(string connectionUrl)
{
_connectionUrl = connectionUrl;
SetupCompression();
_tokenSource = new CancellationTokenSource();
_socket ??= new ClientWebSocket();


if (_socket.State is WebSocketState.Connecting or WebSocketState.Open)
{
throw new Exception("Tried to connect to socket while already connected");
}
await _socket.ConnectAsync(new Uri(connectionUrl), CancellationToken.None);
}

private async Task SendMessageAsync<T>(SocketFrame<T> frame)
{
var stream = new MemoryStream();
Expand All @@ -46,56 +156,42 @@ private async Task SendMessageAsync<T>(SocketFrame<T> frame)

private async Task SendMessageAsync(MemoryStream stream)
{
try
{
await _socket.SendAsync(stream.GetBuffer(), 0, (int)stream.Length, true);
}
catch (WebSocketClosedException exception)
{
GatewayClosed(exception);
}
await _socket!.SendAsync(new ArraySegment<byte>(stream.GetBuffer(), 0, (int)stream.Length), WebSocketMessageType.Text, true, _tokenSource.Token);
}

private void HandleTextMessage(string message)
private void HandleTextMessage(byte[] buffer)
{
using var reader = new StreamReader(new MemoryStream(Encoding.ASCII.GetBytes(message)));
HandleMessage(reader);
HandleMessage(new MemoryStream(buffer));
}

private void HandleBinaryMessage(byte[] bytes, int _, int count)
private async void HandleBinaryMessage(byte[] buffer, int count)
{
Guard.IsNotNull(_decompressor, nameof(_decompressor));
Guard.IsNotNull(_decompressionBuffer, nameof(_decompressionBuffer));

using var ms = new MemoryStream(bytes);
ms.Position = 0;
byte[] data = new byte[count];
ms.Read(data, 0, count);
int index = 0;

using var decompressed = new MemoryStream();
if (data[0] == 0x78)

if (buffer[0] == 0x78)
{
_decompressionBuffer.Write(data, index + 2, count - 2);
await _decompressionBuffer.WriteAsync(buffer, 2, count - 2);
_decompressionBuffer.SetLength(count - 2);
}
else
{
_decompressionBuffer.Write(data, index, count);
await _decompressionBuffer.WriteAsync(buffer, 0, count);
_decompressionBuffer.SetLength(count);
}

_decompressionBuffer.Position = 0;
_decompressor.CopyTo(decompressed);
await _decompressor.CopyToAsync(decompressed);
_decompressionBuffer.Position = 0;
decompressed.Position = 0;

using var reader = new StreamReader(decompressed);
HandleMessage(reader);

HandleMessage(decompressed);
}

private async void HandleMessage(TextReader reader)
private async void HandleMessage(Stream stream)
{
Stream stream = ((StreamReader)reader).BaseStream;
SocketFrame? frame = await ParseFrame(stream);
if (frame is null) return;

Expand All @@ -107,18 +203,14 @@ private async void HandleMessage(TextReader reader)
ProcessEvents(frame);
}

private void HandleClosed(Exception exception)
{
GatewayClosed(exception);
}

private async Task CloseSocket()
{
if (_socket != null)
{
await _socket.DisconnectAsync((WebSocketCloseStatus)4000);
await _socket.DisconnectAsync();
}
if(_socket is { State: WebSocketState.Open })
await _socket.CloseAsync((WebSocketCloseStatus)4000, string.Empty, CancellationToken.None);
_tokenSource.Cancel();
if(_task != null)
await _task;
_task = null;
}

private async Task<SocketFrame?> ParseFrame(Stream stream)
Expand Down
Loading

0 comments on commit d807e16

Please sign in to comment.