diff --git a/src/main/java/io/lettuce/core/internal/AsyncConnectionProvider.java b/src/main/java/io/lettuce/core/internal/AsyncConnectionProvider.java index 335450d5eb..0591186598 100644 --- a/src/main/java/io/lettuce/core/internal/AsyncConnectionProvider.java +++ b/src/main/java/io/lettuce/core/internal/AsyncConnectionProvider.java @@ -142,11 +142,14 @@ public CompletableFuture close() { List> futures = new ArrayList<>(); - forEach((connectionKey, closeable) -> { - - futures.add(closeable.closeAsync()); - connections.remove(connectionKey); - }); + for (K k : connections.keySet()) { + Sync remove = connections.remove(k); + if (remove != null) { + CompletionStage closeFuture = remove.future.thenAccept(AsyncCloseable::closeAsync); + // always synchronously add the future, made it immutably in Futures.allOf() + futures.add(closeFuture.toCompletableFuture()); + } + } return Futures.allOf(futures); } @@ -160,9 +163,8 @@ public void close(K key) { LettuceAssert.notNull(key, "ConnectionKey must not be null!"); - Sync sync = connections.get(key); + Sync sync = connections.remove(key); if (sync != null) { - connections.remove(key); sync.doWithConnection(AsyncCloseable::closeAsync); } } @@ -217,7 +219,6 @@ static class Sync> { @SuppressWarnings("unchecked") public Sync(K key, F future) { - this.key = key; this.future = (F) future.whenComplete((connection, throwable) -> { diff --git a/src/test/java/io/lettuce/core/cluster/AsyncConnectionProviderIntegrationTests.java b/src/test/java/io/lettuce/core/cluster/AsyncConnectionProviderIntegrationTests.java index c2352a7898..8b531ac787 100644 --- a/src/test/java/io/lettuce/core/cluster/AsyncConnectionProviderIntegrationTests.java +++ b/src/test/java/io/lettuce/core/cluster/AsyncConnectionProviderIntegrationTests.java @@ -32,7 +32,9 @@ import javax.inject.Inject; import org.apache.commons.lang3.time.StopWatch; +import org.junit.Assert; import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.Assertions; import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Test; import org.junit.jupiter.api.extension.ExtendWith; @@ -125,7 +127,7 @@ void shouldCloseConnections() throws IOException { ConnectionKey connectionKey = new ConnectionKey(ConnectionIntent.READ, TestSettings.host(), TestSettings.port()); sut.getConnection(connectionKey); - TestFutures.awaitOrTimeout(sut.close()); + assertThatThrownBy(() -> TestFutures.awaitOrTimeout(sut.close())).isInstanceOf(IllegalStateException.class); assertThat(sut.getConnectionCount()).isEqualTo(0); TestFutures.awaitOrTimeout(sut.close()); diff --git a/src/test/java/io/lettuce/core/internal/AsyncConnectionProviderTest.java b/src/test/java/io/lettuce/core/internal/AsyncConnectionProviderTest.java new file mode 100644 index 0000000000..80e7f9ee10 --- /dev/null +++ b/src/test/java/io/lettuce/core/internal/AsyncConnectionProviderTest.java @@ -0,0 +1,99 @@ +package io.lettuce.core.internal; + +import org.junit.Assert; +import org.junit.Test; + +import java.util.concurrent.CompletableFuture; +import java.util.concurrent.CountDownLatch; +import java.util.concurrent.ExecutionException; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.TimeoutException; + +public class AsyncConnectionProviderTest { + + @Test + public void testFutureListLength() throws InterruptedException, ExecutionException, TimeoutException { + + CountDownLatch slowCreate = new CountDownLatch(1); + CountDownLatch slowShutdown = new CountDownLatch(1); + + // create a provider with a slow connection creation + AsyncConnectionProvider> provider = new AsyncConnectionProvider<>( + key -> { + return countDownFuture(slowCreate, new io.lettuce.core.api.AsyncCloseable() { + + @Override + public CompletableFuture closeAsync() { + return CompletableFuture.completedFuture(null); + } + + }); + }); + + // add slow shutdown connection first + SlowCloseFuture slowCloseFuture = new SlowCloseFuture(slowShutdown); + provider.register("slowShutdown", new io.lettuce.core.api.AsyncCloseable() { + + @Override + public CompletableFuture closeAsync() { + return slowCloseFuture; + } + + }); + + // add slow creation connection + CompletableFuture createFuture = provider.getConnection("slowCreate"); + + // close the connection. + CompletableFuture closeFuture = provider.close(); + + // the connection has not been created yet, so the close futures array always has 1 element + // we block the iterator on the slowCloseFuture + // then we count down the creation, the close future will be added to the list + slowCreate.countDown(); + + // the close future is added to the list, we unblock the iterator + slowShutdown.countDown(); + + // assert close future is completed, and no exceptions are thrown + closeFuture.get(10, TimeUnit.SECONDS); + Assert.assertTrue(createFuture.isDone()); + } + + private CompletableFuture countDownFuture(CountDownLatch countDownLatch, T value) { + return CompletableFuture.runAsync(() -> { + try { + countDownLatch.await(1, TimeUnit.SECONDS); + } catch (InterruptedException e) { + throw new RuntimeException(e); + } + }).thenApply(v -> value); + } + + static class SlowCloseFuture extends CompletableFuture { + + private final CountDownLatch countDownLatch; + + SlowCloseFuture(CountDownLatch countDownLatch) { + this.countDownLatch = countDownLatch; + } + + @Override + public CompletableFuture toCompletableFuture() { + // we block the iterator on here + try { + countDownLatch.await(1, TimeUnit.SECONDS); + } catch (InterruptedException e) { + throw new RuntimeException(e); + } + return super.toCompletableFuture(); + } + + @Override + public Void get() { + return null; + } + + } + +}