Skip to content

Commit

Permalink
Make sure callback is completed
Browse files Browse the repository at this point in the history
  • Loading branch information
RustedBones committed Jan 25, 2024
1 parent 37b7b52 commit 862b054
Show file tree
Hide file tree
Showing 5 changed files with 52 additions and 8 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -19,10 +19,7 @@

import com.google.common.util.concurrent.*;
import java.time.Duration;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.TimeoutException;
import java.util.concurrent.*;
import java.util.function.Function;
import java.util.stream.StreamSupport;
import javax.annotation.Nullable;
Expand Down Expand Up @@ -50,6 +47,7 @@ void waitForFutures(Iterable<F> futures)

/** A {@link Base} implementation for Guava {@link ListenableFuture}. */
public interface Guava<V> extends Base<ListenableFuture<V>, V> {
Executor getCallbackExecutor();

@Override
default void waitForFutures(Iterable<ListenableFuture<V>> futures)
Expand All @@ -73,6 +71,15 @@ default ListenableFuture<V> addCallback(
// Futures#transform doesn't allow onFailure callback while Futures#addCallback doesn't
// guarantee that callbacks are called before ListenableFuture#get() unblocks
SettableFuture<V> f = SettableFuture.create();
// if executor rejects the callback, we have to fail the future
Executor rejectPropagationExecutor =
command -> {
try {
getCallbackExecutor().execute(command);
} catch (RejectedExecutionException e) {
f.setException(e);
}
};
Futures.addCallback(
future,
new FutureCallback<V>() {
Expand Down Expand Up @@ -103,7 +110,7 @@ public void onFailure(Throwable t) {
}
}
},
MoreExecutors.directExecutor());
rejectPropagationExecutor);

return f;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@
import com.spotify.scio.transforms.BaseAsyncLookupDoFn.CacheSupplier;
import com.spotify.scio.transforms.BaseAsyncLookupDoFn.Try;
import java.util.List;
import java.util.concurrent.Executor;
import java.util.concurrent.Executors;
import org.apache.beam.sdk.transforms.SerializableFunction;
import org.apache.commons.lang3.tuple.Pair;

Expand All @@ -42,6 +44,8 @@ public abstract class GuavaAsyncBatchLookupDoFn<
Try<Output>>
implements FutureHandlers.Guava<BatchResponse> {

private transient Executor executor;

public GuavaAsyncBatchLookupDoFn(
int batchSize,
SerializableFunction<List<Input>, BatchRequest> batchRequestFn,
Expand All @@ -67,6 +71,14 @@ public GuavaAsyncBatchLookupDoFn(
cacheSupplier);
}

@Override
public Executor getCallbackExecutor() {
if (executor == null) {
executor = Executors.newSingleThreadExecutor();
}
return executor;
}

@Override
public Try<Output> success(Output output) {
return new Try<>(output);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@
package com.spotify.scio.transforms;

import com.google.common.util.concurrent.ListenableFuture;
import java.util.concurrent.Executor;
import java.util.concurrent.Executors;
import org.apache.beam.sdk.transforms.DoFn;

/**
Expand All @@ -26,4 +28,14 @@
*/
public abstract class GuavaAsyncDoFn<InputT, OutputT, ResourceT>
extends BaseAsyncDoFn<InputT, OutputT, ResourceT, ListenableFuture<OutputT>>
implements FutureHandlers.Guava<OutputT> {}
implements FutureHandlers.Guava<OutputT> {
private transient Executor executor;

@Override
public Executor getCallbackExecutor() {
if (executor == null) {
executor = Executors.newSingleThreadExecutor();
}
return executor;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@
package com.spotify.scio.transforms;

import com.google.common.util.concurrent.ListenableFuture;
import java.util.concurrent.Executor;
import java.util.concurrent.Executors;
import org.apache.beam.sdk.transforms.DoFn;

/**
Expand All @@ -32,6 +34,8 @@ public abstract class GuavaAsyncLookupDoFn<A, B, C>
extends BaseAsyncLookupDoFn<A, B, C, ListenableFuture<B>, BaseAsyncLookupDoFn.Try<B>>
implements FutureHandlers.Guava<B> {

private transient Executor executor;

/** Create a {@link GuavaAsyncLookupDoFn} instance. */
public GuavaAsyncLookupDoFn() {
super();
Expand Down Expand Up @@ -75,6 +79,14 @@ public GuavaAsyncLookupDoFn(
super(maxPendingRequests, deduplicate, cacheSupplier);
}

@Override
public Executor getCallbackExecutor() {
if (executor == null) {
executor = Executors.newSingleThreadExecutor();
}
return executor;
}

@Override
public BaseAsyncLookupDoFn.Try<B> success(B output) {
return new Try<>(output);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,19 +16,20 @@

package com.spotify.scio.transforms

import com.google.common.util.concurrent.{ListenableFuture, SettableFuture}
import com.google.common.util.concurrent.{ListenableFuture, MoreExecutors, SettableFuture}
import org.scalatest.flatspec.AnyFlatSpec
import org.scalatest.matchers.should.Matchers

import java.time.{Duration => JDuration}
import java.util.concurrent.CompletableFuture
import java.util.concurrent.{CompletableFuture, Executor}
import scala.concurrent.ExecutionContext.Implicits.global
import scala.concurrent.duration._
import scala.concurrent._
import scala.jdk.CollectionConverters._
import scala.util.{Failure, Success, Try}

class GuavaFutureHandler extends FutureHandlers.Guava[String] {
override def getCallbackExecutor: Executor = MoreExecutors.directExecutor()
override def getTimeout: JDuration = JDuration.ofMillis(500)
}
class JavaFutureHandler extends FutureHandlers.Java[String] {
Expand Down

0 comments on commit 862b054

Please sign in to comment.