diff --git a/src/CommonLib/ConnectionPoolManager.cs b/src/CommonLib/ConnectionPoolManager.cs index 5e001dec..5364f3fd 100644 --- a/src/CommonLib/ConnectionPoolManager.cs +++ b/src/CommonLib/ConnectionPoolManager.cs @@ -73,14 +73,14 @@ public void ReleaseConnection(LdapConnectionWrapper connectionWrapper, bool conn } private bool GetPool(string identifier, out LdapConnectionPool pool) { - if (identifier == null) { + if (string.IsNullOrWhiteSpace(identifier)) { pool = default; return false; } var resolved = ResolveIdentifier(identifier); if (!_pools.TryGetValue(resolved, out pool)) { - pool = new LdapConnectionPool(identifier, resolved, _ldapConfig,scanner: _portScanner); + pool = new LdapConnectionPool(identifier, resolved, _ldapConfig, scanner: _portScanner); _pools.TryAdd(resolved, pool); } @@ -96,6 +96,7 @@ private bool GetPool(string identifier, out LdapConnectionPool pool) { if (globalCatalog) { return await pool.GetGlobalCatalogConnectionAsync(); } + return await pool.GetConnectionAsync(); } diff --git a/src/CommonLib/LdapConnectionPool.cs b/src/CommonLib/LdapConnectionPool.cs index f3e8735d..1073f960 100644 --- a/src/CommonLib/LdapConnectionPool.cs +++ b/src/CommonLib/LdapConnectionPool.cs @@ -32,8 +32,10 @@ internal class LdapConnectionPool : IDisposable { private const int MaxRetries = 3; private static readonly ConcurrentDictionary DCInfoCache = new(); - public LdapConnectionPool(string identifier, string poolIdentifier, LdapConfig config, - PortScanner scanner = null, NativeMethods nativeMethods = null, ILogger log = null) { + // Tracks domains we know we've determined we shouldn't try to connect to + private static readonly ConcurrentHashSet _excludedDomains = new(); + + public LdapConnectionPool(string identifier, string poolIdentifier, LdapConfig config, PortScanner scanner = null, NativeMethods nativeMethods = null, ILogger log = null) { _connections = new ConcurrentBag(); _globalCatalogConnection = new ConcurrentBag(); //TODO: Re-enable this once we track down the semaphore deadlock @@ -621,8 +623,11 @@ private bool CallDsGetDcName(string domainName, out NetAPIStructs.DomainControll return true; } - public async Task<(bool Success, LdapConnectionWrapper ConnectionWrapper, string Message)> - GetConnectionAsync() { + public async Task<(bool Success, LdapConnectionWrapper ConnectionWrapper, string Message)> GetConnectionAsync() { + if (_excludedDomains.Contains(_identifier)) { + return (false, null, $"Identifier {_identifier} excluded for connection attempt"); + } + if (!_connections.TryTake(out var connectionWrapper)) { var (success, connection, message) = await CreateNewConnection(); if (!success) { @@ -640,8 +645,11 @@ private bool CallDsGetDcName(string domainName, out NetAPIStructs.DomainControll return CreateNewConnectionForServer(server, globalCatalog); } - public async Task<(bool Success, LdapConnectionWrapper ConnectionWrapper, string Message)> - GetGlobalCatalogConnectionAsync() { + public async Task<(bool Success, LdapConnectionWrapper ConnectionWrapper, string Message)> GetGlobalCatalogConnectionAsync() { + if (_excludedDomains.Contains(_identifier)) { + return (false, null, $"Identifier {_identifier} excluded for connection attempt"); + } + if (!_globalCatalogConnection.TryTake(out var connectionWrapper)) { var (success, connection, message) = await CreateNewConnection(true); if (!success) { @@ -721,6 +729,7 @@ public void Dispose() { _log.LogDebug( "Could not get domain object from GetDomain, unable to create ldap connection for domain {Domain}", _identifier); + _excludedDomains.Add(_identifier); return (false, null, "Unable to get domain object for further strategies"); } @@ -755,8 +764,8 @@ public void Dispose() { } } } catch (Exception e) { - _log.LogInformation(e, "We will not be able to connect to domain {Domain} by any strategy, leaving it.", - _identifier); + _log.LogInformation(e, "We will not be able to connect to domain {Domain} by any strategy, leaving it.", _identifier); + _excludedDomains.Add(_identifier); } return (false, null, "All attempted connections failed");