diff --git a/WatsonTcp/ClientMetadata.cs b/WatsonTcp/ClientMetadata.cs index cf9a2de..ec93a9f 100644 --- a/WatsonTcp/ClientMetadata.cs +++ b/WatsonTcp/ClientMetadata.cs @@ -34,7 +34,7 @@ public string IpPort public SemaphoreSlim ReadLock { get; set; } - public SemaphoreSlim SendLock { get; set; } + public SemaphoreSlim WriteLock { get; set; } #endregion @@ -58,7 +58,7 @@ public ClientMetadata(TcpClient tcp) _IpPort = tcp.Client.RemoteEndPoint.ToString(); ReadLock = new SemaphoreSlim(1); - SendLock = new SemaphoreSlim(1); + WriteLock = new SemaphoreSlim(1); } #endregion @@ -101,7 +101,7 @@ protected virtual void Dispose(bool disposing) } ReadLock.Dispose(); - SendLock.Dispose(); + WriteLock.Dispose(); _Disposed = true; } diff --git a/WatsonTcp/Message/WatsonMessage.cs b/WatsonTcp/Message/WatsonMessage.cs index 283f3db..341eca6 100644 --- a/WatsonTcp/Message/WatsonMessage.cs +++ b/WatsonTcp/Message/WatsonMessage.cs @@ -174,7 +174,7 @@ internal WatsonMessage(long contentLength, Stream stream, bool debug) } /// - /// Read from a TCP-based stream and construct a message. Call Build() to populate. + /// Read from a stream and construct a message. Call Build() to populate. /// /// NetworkStream. /// Enable or disable console debugging. diff --git a/WatsonTcp/WatsonTcp.csproj b/WatsonTcp/WatsonTcp.csproj index c34abc8..852cdf9 100644 --- a/WatsonTcp/WatsonTcp.csproj +++ b/WatsonTcp/WatsonTcp.csproj @@ -3,7 +3,7 @@ netstandard2.0;net452 true - 1.3.11 + 1.3.12 Joel Christner Joel Christner A simple C# async TCP server and client with integrated framing for reliable transmission and receipt of data @@ -13,7 +13,7 @@ https://github.com/jchristn/WatsonTcp Github https://github.com/jchristn/WatsonTcp/blob/master/LICENSE.TXT - More reliable and stable detection of client disconnect. + Reduce instances of calls to GetStream https://raw.githubusercontent.com/jchristn/watsontcp/master/assets/watson.ico diff --git a/WatsonTcp/WatsonTcpClient.cs b/WatsonTcp/WatsonTcpClient.cs index cc48dba..a5a6766 100644 --- a/WatsonTcp/WatsonTcpClient.cs +++ b/WatsonTcp/WatsonTcpClient.cs @@ -114,12 +114,15 @@ public int ReadStreamBufferSize private string _ServerIp; private int _ServerPort; private TcpClient _Client; + private NetworkStream _TcpStream; + private SslStream _SslStream; - private SslStream _Ssl; private X509Certificate2 _SslCertificate; private X509Certificate2Collection _SslCertificateCollection; - private SemaphoreSlim _SendLock; + private SemaphoreSlim _WriteLock; + private SemaphoreSlim _ReadLock; + private CancellationTokenSource _TokenSource; private CancellationToken _Token; @@ -142,8 +145,9 @@ public WatsonTcpClient( _Mode = Mode.Tcp; _ServerIp = serverIp; _ServerPort = serverPort; - _SendLock = new SemaphoreSlim(1); - _Ssl = null; + _WriteLock = new SemaphoreSlim(1); + _ReadLock = new SemaphoreSlim(1); + _SslStream = null; } /// @@ -165,7 +169,9 @@ public WatsonTcpClient( _Mode = Mode.Ssl; _ServerIp = serverIp; _ServerPort = serverPort; - _SendLock = new SemaphoreSlim(1); + _WriteLock = new SemaphoreSlim(1); + _ReadLock = new SemaphoreSlim(1); + _TcpStream = null; _SslCertificate = null; if (String.IsNullOrEmpty(pfxCertPass)) _SslCertificate = new X509Certificate2(pfxCertFile); else _SslCertificate = new X509Certificate2(pfxCertFile, pfxCertPass); @@ -220,6 +226,9 @@ public void Start() _SourceIp = ((IPEndPoint)_Client.Client.LocalEndPoint).Address.ToString(); _SourcePort = ((IPEndPoint)_Client.Client.LocalEndPoint).Port; + _TcpStream = _Client.GetStream(); + _SslStream = null; + Connected = true; } catch (Exception) @@ -258,31 +267,32 @@ public void Start() if (AcceptInvalidCertificates) { // accept invalid certs - _Ssl = new SslStream(_Client.GetStream(), false, new RemoteCertificateValidationCallback(AcceptCertificate)); + _SslStream = new SslStream(_Client.GetStream(), false, new RemoteCertificateValidationCallback(AcceptCertificate)); } else { // do not accept invalid SSL certificates - _Ssl = new SslStream(_Client.GetStream(), false); + _SslStream = new SslStream(_Client.GetStream(), false); } - _Ssl.AuthenticateAsClient(_ServerIp, _SslCertificateCollection, SslProtocols.Tls12, !AcceptInvalidCertificates); + _SslStream.AuthenticateAsClient(_ServerIp, _SslCertificateCollection, SslProtocols.Tls12, !AcceptInvalidCertificates); - if (!_Ssl.IsEncrypted) + if (!_SslStream.IsEncrypted) { throw new AuthenticationException("Stream is not encrypted"); } - if (!_Ssl.IsAuthenticated) + if (!_SslStream.IsAuthenticated) { throw new AuthenticationException("Stream is not authenticated"); } - if (MutuallyAuthenticate && !_Ssl.IsMutuallyAuthenticated) + if (MutuallyAuthenticate && !_SslStream.IsMutuallyAuthenticated) { throw new AuthenticationException("Mutual authentication failed"); } + Connected = true; } catch (Exception) @@ -384,32 +394,37 @@ protected virtual void Dispose(bool disposing) if (disposing) { - if (_Ssl != null) + if (_SslStream != null) { try { - _Ssl.Close(); + _WriteLock.Wait(1); + _ReadLock.Wait(1); + _SslStream.Close(); } catch (Exception) { } + finally + { + _WriteLock.Release(); + _ReadLock.Release(); + } } - if (_Client != null) - { - if (_Client.Connected) + if (_TcpStream != null) + { + try { - try - { - NetworkStream ns = _Client.GetStream(); - if (ns != null) ns.Close(); - } - catch (Exception) - { - - } + _WriteLock.Wait(1); + _ReadLock.Wait(1); + if (_TcpStream != null) _TcpStream.Close(); } + catch (Exception) + { + + } try { @@ -419,6 +434,11 @@ protected virtual void Dispose(bool disposing) { } + finally + { + _WriteLock.Release(); + _ReadLock.Release(); + } } _TokenSource.Cancel(); @@ -481,7 +501,7 @@ private async Task DataReceiver(CancellationToken? cancelToken=null) break; } - if (_Ssl != null && !_Ssl.CanRead) + if (_SslStream != null && !_SslStream.CanRead) { Log("*** DataReceiver cannot read from SSL stream"); break; @@ -493,31 +513,40 @@ private async Task DataReceiver(CancellationToken? cancelToken=null) WatsonMessage msg = null; - if (_Ssl != null) - { - msg = new WatsonMessage(_Ssl, Debug); + _ReadLock.Wait(1); - if (ReadDataStream) + try + { + if (_SslStream != null) { - await msg.Build(); + msg = new WatsonMessage(_SslStream, Debug); + + if (ReadDataStream) + { + await msg.Build(); + } + else + { + await msg.BuildStream(); + } } else { - await msg.BuildStream(); + msg = new WatsonMessage(_TcpStream, Debug); + + if (ReadDataStream) + { + await msg.Build(); + } + else + { + await msg.BuildStream(); + } } } - else + finally { - msg = new WatsonMessage(_Client.GetStream(), Debug); - - if (ReadDataStream) - { - await msg.Build(); - } - else - { - await msg.BuildStream(); - } + _ReadLock.Release(); } if (msg == null) @@ -614,22 +643,21 @@ private bool MessageWrite(WatsonMessage msg) byte[] headerBytes = msg.ToHeaderBytes(dataLen); - _SendLock.Wait(1); + _WriteLock.Wait(1); try { if (_Mode == Mode.Tcp) - { - NetworkStream ns = _Client.GetStream(); - ns.Write(headerBytes, 0, headerBytes.Length); - if (msg.Data != null && msg.Data.Length > 0) ns.Write(msg.Data, 0, msg.Data.Length); - ns.Flush(); + { + _TcpStream.Write(headerBytes, 0, headerBytes.Length); + if (msg.Data != null && msg.Data.Length > 0) _TcpStream.Write(msg.Data, 0, msg.Data.Length); + _TcpStream.Flush(); } else if (_Mode == Mode.Ssl) { - _Ssl.Write(headerBytes, 0, headerBytes.Length); - if (msg.Data != null && msg.Data.Length > 0) _Ssl.Write(msg.Data, 0, msg.Data.Length); - _Ssl.Flush(); + _SslStream.Write(headerBytes, 0, headerBytes.Length); + if (msg.Data != null && msg.Data.Length > 0) _SslStream.Write(msg.Data, 0, msg.Data.Length); + _SslStream.Flush(); } else { @@ -638,7 +666,7 @@ private bool MessageWrite(WatsonMessage msg) } finally { - _SendLock.Release(); + _WriteLock.Release(); } string logMessage = "MessageWrite sent " + Encoding.UTF8.GetString(headerBytes); @@ -728,14 +756,13 @@ private bool MessageWrite(long contentLength, Stream stream) long bytesRemaining = contentLength; byte[] buffer = new byte[_ReadStreamBufferSize]; - _SendLock.Wait(1); + _WriteLock.Wait(1); try { if (_Mode == Mode.Tcp) - { - NetworkStream ns = _Client.GetStream(); - ns.Write(headerBytes, 0, headerBytes.Length); + { + _TcpStream.Write(headerBytes, 0, headerBytes.Length); if (contentLength > 0) { @@ -744,17 +771,17 @@ private bool MessageWrite(long contentLength, Stream stream) bytesRead = stream.Read(buffer, 0, buffer.Length); if (bytesRead > 0) { - ns.Write(buffer, 0, bytesRead); + _TcpStream.Write(buffer, 0, bytesRead); bytesRemaining -= bytesRead; } } } - ns.Flush(); + _TcpStream.Flush(); } else if (_Mode == Mode.Ssl) { - _Ssl.Write(headerBytes, 0, headerBytes.Length); + _SslStream.Write(headerBytes, 0, headerBytes.Length); if (contentLength > 0) { @@ -763,13 +790,13 @@ private bool MessageWrite(long contentLength, Stream stream) bytesRead = stream.Read(buffer, 0, buffer.Length); if (bytesRead > 0) { - _Ssl.Write(buffer, 0, bytesRead); + _SslStream.Write(buffer, 0, bytesRead); bytesRemaining -= bytesRead; } } } - _Ssl.Flush(); + _SslStream.Flush(); } else { @@ -778,7 +805,7 @@ private bool MessageWrite(long contentLength, Stream stream) } finally { - _SendLock.Release(); + _WriteLock.Release(); } string logMessage = "MessageWrite sent " + Encoding.UTF8.GetString(headerBytes); @@ -868,14 +895,13 @@ private async Task MessageWriteAsync(long contentLength, Stream stream) long bytesRemaining = contentLength; byte[] buffer = new byte[_ReadStreamBufferSize]; - await _SendLock.WaitAsync(); + await _WriteLock.WaitAsync(); try { if (_Mode == Mode.Tcp) - { - NetworkStream ns = _Client.GetStream(); - await ns.WriteAsync(headerBytes, 0, headerBytes.Length); + { + await _TcpStream.WriteAsync(headerBytes, 0, headerBytes.Length); if (contentLength > 0) { @@ -884,17 +910,17 @@ private async Task MessageWriteAsync(long contentLength, Stream stream) bytesRead = await stream.ReadAsync(buffer, 0, buffer.Length); if (bytesRead > 0) { - await ns.WriteAsync(buffer, 0, bytesRead); + await _TcpStream.WriteAsync(buffer, 0, bytesRead); bytesRemaining -= bytesRead; } } } - await ns.FlushAsync(); + await _TcpStream.FlushAsync(); } else if (_Mode == Mode.Ssl) { - await _Ssl.WriteAsync(headerBytes, 0, headerBytes.Length); + await _SslStream.WriteAsync(headerBytes, 0, headerBytes.Length); if (contentLength > 0) { @@ -903,13 +929,13 @@ private async Task MessageWriteAsync(long contentLength, Stream stream) bytesRead = await stream.ReadAsync(buffer, 0, buffer.Length); if (bytesRead > 0) { - await _Ssl.WriteAsync(buffer, 0, bytesRead); + await _SslStream.WriteAsync(buffer, 0, bytesRead); bytesRemaining -= bytesRead; } } } - await _Ssl.FlushAsync(); + await _SslStream.FlushAsync(); } else { @@ -918,7 +944,7 @@ private async Task MessageWriteAsync(long contentLength, Stream stream) } finally { - _SendLock.Release(); + _WriteLock.Release(); } string logMessage = "MessageWriteAsync sent " + Encoding.UTF8.GetString(headerBytes); diff --git a/WatsonTcp/WatsonTcpServer.cs b/WatsonTcp/WatsonTcpServer.cs index b62d13f..fb9da1b 100644 --- a/WatsonTcp/WatsonTcpServer.cs +++ b/WatsonTcp/WatsonTcpServer.cs @@ -112,8 +112,7 @@ public int ReadStreamBufferSize private int _ActiveClients; private ConcurrentDictionary _Clients; private ConcurrentDictionary _UnauthenticatedClients; - - private readonly SemaphoreSlim _SendLock; + private CancellationTokenSource _TokenSource; private CancellationToken _Token; @@ -154,8 +153,7 @@ public WatsonTcpServer( _ActiveClients = 0; _Clients = new ConcurrentDictionary(); - _UnauthenticatedClients = new ConcurrentDictionary(); - _SendLock = new SemaphoreSlim(1); + _UnauthenticatedClients = new ConcurrentDictionary(); } /// @@ -203,8 +201,7 @@ public WatsonTcpServer( _Token = _TokenSource.Token; _ActiveClients = 0; _Clients = new ConcurrentDictionary(); - _UnauthenticatedClients = new ConcurrentDictionary(); - _SendLock = new SemaphoreSlim(1); + _UnauthenticatedClients = new ConcurrentDictionary(); } #endregion @@ -384,8 +381,7 @@ protected virtual void Dispose(bool disposing) } } } - - _SendLock.Dispose(); + _Disposed = true; } @@ -607,7 +603,7 @@ private bool IsConnected(ClientMetadata client) try { - client.SendLock.Wait(1); + client.WriteLock.Wait(1); sendLocked = true; client.TcpClient.Client.Send(tmp, 0, 0); success = true; @@ -631,7 +627,7 @@ private bool IsConnected(ClientMetadata client) } finally { - if (sendLocked) client.SendLock.Release(); + if (sendLocked) client.WriteLock.Release(); } if (success) return true; @@ -887,7 +883,7 @@ private bool MessageWrite(ClientMetadata client, WatsonMessage msg, long content long bytesRemaining = contentLength; byte[] buffer = new byte[_ReadStreamBufferSize]; - client.SendLock.Wait(1); + client.WriteLock.Wait(1); try { @@ -943,7 +939,7 @@ private bool MessageWrite(ClientMetadata client, WatsonMessage msg, long content } finally { - client.SendLock.Release(); + client.WriteLock.Release(); } } @@ -980,7 +976,7 @@ private async Task MessageWriteAsync(ClientMetadata client, WatsonMessage long bytesRemaining = contentLength; byte[] buffer = new byte[_ReadStreamBufferSize]; - client.SendLock.Wait(1); + client.WriteLock.Wait(1); try { @@ -1036,7 +1032,7 @@ private async Task MessageWriteAsync(ClientMetadata client, WatsonMessage } finally { - client.SendLock.Release(); + client.WriteLock.Release(); } }