From 862b054498ff134533b14f967e38abbe4bddb223 Mon Sep 17 00:00:00 2001 From: Michel Davit Date: Thu, 25 Jan 2024 11:51:11 +0100 Subject: [PATCH] Make sure callback is completed --- .../spotify/scio/transforms/FutureHandlers.java | 17 ++++++++++++----- .../transforms/GuavaAsyncBatchLookupDoFn.java | 12 ++++++++++++ .../spotify/scio/transforms/GuavaAsyncDoFn.java | 14 +++++++++++++- .../scio/transforms/GuavaAsyncLookupDoFn.java | 12 ++++++++++++ .../scio/transforms/FutureHandlersTest.scala | 5 +++-- 5 files changed, 52 insertions(+), 8 deletions(-) diff --git a/scio-core/src/main/java/com/spotify/scio/transforms/FutureHandlers.java b/scio-core/src/main/java/com/spotify/scio/transforms/FutureHandlers.java index 61bac51f91..bec6b57489 100644 --- a/scio-core/src/main/java/com/spotify/scio/transforms/FutureHandlers.java +++ b/scio-core/src/main/java/com/spotify/scio/transforms/FutureHandlers.java @@ -19,10 +19,7 @@ import com.google.common.util.concurrent.*; import java.time.Duration; -import java.util.concurrent.CompletableFuture; -import java.util.concurrent.ExecutionException; -import java.util.concurrent.TimeUnit; -import java.util.concurrent.TimeoutException; +import java.util.concurrent.*; import java.util.function.Function; import java.util.stream.StreamSupport; import javax.annotation.Nullable; @@ -50,6 +47,7 @@ void waitForFutures(Iterable futures) /** A {@link Base} implementation for Guava {@link ListenableFuture}. */ public interface Guava extends Base, V> { + Executor getCallbackExecutor(); @Override default void waitForFutures(Iterable> futures) @@ -73,6 +71,15 @@ default ListenableFuture addCallback( // Futures#transform doesn't allow onFailure callback while Futures#addCallback doesn't // guarantee that callbacks are called before ListenableFuture#get() unblocks SettableFuture f = SettableFuture.create(); + // if executor rejects the callback, we have to fail the future + Executor rejectPropagationExecutor = + command -> { + try { + getCallbackExecutor().execute(command); + } catch (RejectedExecutionException e) { + f.setException(e); + } + }; Futures.addCallback( future, new FutureCallback() { @@ -103,7 +110,7 @@ public void onFailure(Throwable t) { } } }, - MoreExecutors.directExecutor()); + rejectPropagationExecutor); return f; } diff --git a/scio-core/src/main/java/com/spotify/scio/transforms/GuavaAsyncBatchLookupDoFn.java b/scio-core/src/main/java/com/spotify/scio/transforms/GuavaAsyncBatchLookupDoFn.java index 462e7084dc..9e9a05aaa1 100644 --- a/scio-core/src/main/java/com/spotify/scio/transforms/GuavaAsyncBatchLookupDoFn.java +++ b/scio-core/src/main/java/com/spotify/scio/transforms/GuavaAsyncBatchLookupDoFn.java @@ -20,6 +20,8 @@ import com.spotify.scio.transforms.BaseAsyncLookupDoFn.CacheSupplier; import com.spotify.scio.transforms.BaseAsyncLookupDoFn.Try; import java.util.List; +import java.util.concurrent.Executor; +import java.util.concurrent.Executors; import org.apache.beam.sdk.transforms.SerializableFunction; import org.apache.commons.lang3.tuple.Pair; @@ -42,6 +44,8 @@ public abstract class GuavaAsyncBatchLookupDoFn< Try> implements FutureHandlers.Guava { + private transient Executor executor; + public GuavaAsyncBatchLookupDoFn( int batchSize, SerializableFunction, BatchRequest> batchRequestFn, @@ -67,6 +71,14 @@ public GuavaAsyncBatchLookupDoFn( cacheSupplier); } + @Override + public Executor getCallbackExecutor() { + if (executor == null) { + executor = Executors.newSingleThreadExecutor(); + } + return executor; + } + @Override public Try success(Output output) { return new Try<>(output); diff --git a/scio-core/src/main/java/com/spotify/scio/transforms/GuavaAsyncDoFn.java b/scio-core/src/main/java/com/spotify/scio/transforms/GuavaAsyncDoFn.java index 1f93e916c3..e6048b11ff 100644 --- a/scio-core/src/main/java/com/spotify/scio/transforms/GuavaAsyncDoFn.java +++ b/scio-core/src/main/java/com/spotify/scio/transforms/GuavaAsyncDoFn.java @@ -18,6 +18,8 @@ package com.spotify.scio.transforms; import com.google.common.util.concurrent.ListenableFuture; +import java.util.concurrent.Executor; +import java.util.concurrent.Executors; import org.apache.beam.sdk.transforms.DoFn; /** @@ -26,4 +28,14 @@ */ public abstract class GuavaAsyncDoFn extends BaseAsyncDoFn> - implements FutureHandlers.Guava {} + implements FutureHandlers.Guava { + private transient Executor executor; + + @Override + public Executor getCallbackExecutor() { + if (executor == null) { + executor = Executors.newSingleThreadExecutor(); + } + return executor; + } +} diff --git a/scio-core/src/main/java/com/spotify/scio/transforms/GuavaAsyncLookupDoFn.java b/scio-core/src/main/java/com/spotify/scio/transforms/GuavaAsyncLookupDoFn.java index 2a84627d7f..83ac092857 100644 --- a/scio-core/src/main/java/com/spotify/scio/transforms/GuavaAsyncLookupDoFn.java +++ b/scio-core/src/main/java/com/spotify/scio/transforms/GuavaAsyncLookupDoFn.java @@ -18,6 +18,8 @@ package com.spotify.scio.transforms; import com.google.common.util.concurrent.ListenableFuture; +import java.util.concurrent.Executor; +import java.util.concurrent.Executors; import org.apache.beam.sdk.transforms.DoFn; /** @@ -32,6 +34,8 @@ public abstract class GuavaAsyncLookupDoFn extends BaseAsyncLookupDoFn, BaseAsyncLookupDoFn.Try> implements FutureHandlers.Guava { + private transient Executor executor; + /** Create a {@link GuavaAsyncLookupDoFn} instance. */ public GuavaAsyncLookupDoFn() { super(); @@ -75,6 +79,14 @@ public GuavaAsyncLookupDoFn( super(maxPendingRequests, deduplicate, cacheSupplier); } + @Override + public Executor getCallbackExecutor() { + if (executor == null) { + executor = Executors.newSingleThreadExecutor(); + } + return executor; + } + @Override public BaseAsyncLookupDoFn.Try success(B output) { return new Try<>(output); diff --git a/scio-test/src/test/scala/com/spotify/scio/transforms/FutureHandlersTest.scala b/scio-test/src/test/scala/com/spotify/scio/transforms/FutureHandlersTest.scala index 42a60f4516..d07fede7e2 100644 --- a/scio-test/src/test/scala/com/spotify/scio/transforms/FutureHandlersTest.scala +++ b/scio-test/src/test/scala/com/spotify/scio/transforms/FutureHandlersTest.scala @@ -16,12 +16,12 @@ package com.spotify.scio.transforms -import com.google.common.util.concurrent.{ListenableFuture, SettableFuture} +import com.google.common.util.concurrent.{ListenableFuture, MoreExecutors, SettableFuture} import org.scalatest.flatspec.AnyFlatSpec import org.scalatest.matchers.should.Matchers import java.time.{Duration => JDuration} -import java.util.concurrent.CompletableFuture +import java.util.concurrent.{CompletableFuture, Executor} import scala.concurrent.ExecutionContext.Implicits.global import scala.concurrent.duration._ import scala.concurrent._ @@ -29,6 +29,7 @@ import scala.jdk.CollectionConverters._ import scala.util.{Failure, Success, Try} class GuavaFutureHandler extends FutureHandlers.Guava[String] { + override def getCallbackExecutor: Executor = MoreExecutors.directExecutor() override def getTimeout: JDuration = JDuration.ofMillis(500) } class JavaFutureHandler extends FutureHandlers.Java[String] {