Skip to content

Commit

Permalink
Allow timeout for async DoFn
Browse files Browse the repository at this point in the history
Add unit testing

Fit unit testing

Make sure callback is completed

Remove old branch relicate
  • Loading branch information
RustedBones committed Dec 11, 2024
1 parent cc1e584 commit 4b7ca53
Show file tree
Hide file tree
Showing 7 changed files with 81 additions and 21 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
import java.util.concurrent.ConcurrentMap;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.Semaphore;
import java.util.concurrent.TimeoutException;
import java.util.function.Consumer;
import java.util.stream.Collectors;
import org.apache.beam.sdk.transforms.DoFn;
Expand Down Expand Up @@ -238,7 +239,7 @@ public void finishBundle(FinishBundleContext context) {
Thread.currentThread().interrupt();
LOG.error("Failed to process futures", e);
throw new RuntimeException("Failed to process futures", e);
} catch (ExecutionException e) {
} catch (ExecutionException | TimeoutException e) {
LOG.error("Failed to process futures", e);
throw new RuntimeException("Failed to process futures", e);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
import java.util.concurrent.ConcurrentLinkedQueue;
import java.util.concurrent.ConcurrentMap;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.TimeoutException;
import java.util.function.Consumer;
import org.apache.beam.sdk.transforms.DoFn;
import org.apache.beam.sdk.transforms.windowing.BoundedWindow;
Expand Down Expand Up @@ -64,7 +65,7 @@ public void finishBundle(FinishBundleContext context) {
Thread.currentThread().interrupt();
LOG.error("Failed to process futures", e);
throw new RuntimeException("Failed to process futures", e);
} catch (ExecutionException e) {
} catch (ExecutionException | TimeoutException e) {
LOG.error("Failed to process futures", e);
throw new RuntimeException("Failed to process futures", e);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,11 +24,7 @@
import java.util.Collections;
import java.util.Objects;
import java.util.UUID;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.ConcurrentLinkedQueue;
import java.util.concurrent.ConcurrentMap;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.Semaphore;
import java.util.concurrent.*;
import java.util.function.Consumer;
import java.util.function.Supplier;
import javax.annotation.CheckForNull;
Expand Down Expand Up @@ -233,7 +229,7 @@ public void finishBundle(FinishBundleContext context) {
Thread.currentThread().interrupt();
LOG.error("Failed to process futures", e);
throw new RuntimeException("Failed to process futures", e);
} catch (ExecutionException e) {
} catch (ExecutionException | TimeoutException e) {
LOG.error("Failed to process futures", e);
throw new RuntimeException("Failed to process futures", e);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,14 @@
package com.spotify.scio.transforms;

import com.google.common.util.concurrent.*;
import java.util.concurrent.*;
import java.time.Duration;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.Executor;
import java.util.concurrent.ForkJoinPool;
import java.util.concurrent.RejectedExecutionException;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.TimeoutException;
import java.util.function.Function;
import java.util.stream.StreamSupport;
import javax.annotation.Nullable;
Expand All @@ -33,7 +40,13 @@ public class FutureHandlers {
* @param <V> value type.
*/
public interface Base<F, V> {
void waitForFutures(Iterable<F> futures) throws InterruptedException, ExecutionException;

default Duration getTimeout() {
return null;
}

void waitForFutures(Iterable<F> futures)
throws InterruptedException, ExecutionException, TimeoutException;

F addCallback(F future, Function<V, Void> onSuccess, Function<Throwable, Void> onFailure);
}
Expand All @@ -53,10 +66,16 @@ default Executor getCallbackExecutor() {

@Override
default void waitForFutures(Iterable<ListenableFuture<V>> futures)
throws InterruptedException, ExecutionException {
throws InterruptedException, ExecutionException, TimeoutException {
// use Future#successfulAsList instead of Futures#allAsList which only works if all
// futures succeed
Futures.successfulAsList(futures).get();
ListenableFuture<?> f = Futures.successfulAsList(futures);
Duration timeout = getTimeout();
if (timeout != null) {
f.get(timeout.toMillis(), TimeUnit.MILLISECONDS);
} else {
f.get();
}
}

@Override
Expand Down Expand Up @@ -116,10 +135,16 @@ public void onFailure(Throwable t) {
public interface Java<V> extends Base<CompletableFuture<V>, V> {
@Override
default void waitForFutures(Iterable<CompletableFuture<V>> futures)
throws InterruptedException, ExecutionException {
throws InterruptedException, ExecutionException, TimeoutException {
CompletableFuture[] array =
StreamSupport.stream(futures.spliterator(), false).toArray(CompletableFuture[]::new);
CompletableFuture.allOf(array).exceptionally(t -> null).get();
CompletableFuture<?> f = CompletableFuture.allOf(array).exceptionally(t -> null);
Duration timeout = getTimeout();
if (timeout != null) {
f.get(timeout.toMillis(), TimeUnit.MILLISECONDS);
} else {
f.get();
}
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,16 +20,20 @@ import com.google.common.util.concurrent.{ListenableFuture, SettableFuture}
import org.scalatest.flatspec.AnyFlatSpec
import org.scalatest.matchers.should.Matchers

import java.time.{Duration => JDuration}
import java.util.concurrent.{CompletableFuture, Executor, RejectedExecutionException}
import scala.concurrent.ExecutionContext.Implicits.global
import scala.concurrent.duration._
import scala.concurrent._
import scala.jdk.CollectionConverters._
import scala.util.{Failure, Success, Try}

class GuavaFutureHandler extends FutureHandlers.Guava[String]

class JavaFutureHandler extends FutureHandlers.Java[String]
class GuavaFutureHandler extends FutureHandlers.Guava[String] {
override def getTimeout: JDuration = JDuration.ofMillis(500)
}
class JavaFutureHandler extends FutureHandlers.Java[String] {
override def getTimeout: JDuration = JDuration.ofMillis(500)
}

class RejectFutureHandler extends FutureHandlers.Guava[String] {
override def getCallbackExecutor: Executor = _ => throw new RejectedExecutionException("Rejected")
Expand Down Expand Up @@ -185,14 +189,34 @@ class FutureHandlersTest extends AnyFlatSpec with Matchers {
}
cause.getSuppressed.headOption.map(_.getMessage) shouldBe expectedSuppressed
}

it should "wait for futures to complete" in {
import scala.concurrent.ExecutionContext.Implicits.global
val successFuture = create()
val failureFuture = create()
val cancelFuture = create()
Future {
Thread.sleep(100)
complete(successFuture)("success")
fail(failureFuture)(new Exception("failure"))
cancel(cancelFuture)
}
handler.waitForFutures(Iterable[F](successFuture, failureFuture, cancelFuture).asJava)
}

it should "throw a timeout exception " in {
val f = create()
a[TimeoutException] shouldBe thrownBy(handler.waitForFutures(Iterable[F](f).asJava))
}

}

"Guava handler" should behave like futureHandler[
ListenableFuture[String],
SettableFuture[String]
](
new GuavaFutureHandler,
SettableFuture.create[String],
() => SettableFuture.create[String](),
_.set,
_.setException,
_.cancel(true),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
import com.spotify.scio.transforms.BaseAsyncLookupDoFn;
import com.spotify.scio.transforms.GuavaAsyncLookupDoFn;
import java.io.IOException;
import java.time.Duration;
import org.apache.beam.sdk.transforms.DoFn;

/**
Expand Down Expand Up @@ -99,6 +100,11 @@ public ResourceType getResourceType() {
return ResourceType.PER_INSTANCE;
}

@Override
public Duration getTimeout() {
return Duration.ofMillis(options.getCallOptionsConfig().getMutateRpcTimeoutMs());
}

protected BigtableSession newClient() {
try {
return new BigtableSession(options);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@

package com.spotify.scio.bigtable

import com.google.cloud.bigtable.config.BigtableOptions

import java.util.concurrent.ConcurrentLinkedQueue
import com.google.cloud.bigtable.grpc.BigtableSession
import com.google.common.cache.{Cache, CacheBuilder}
Expand Down Expand Up @@ -66,21 +68,26 @@ object BigtableDoFnTest {
val queue: ConcurrentLinkedQueue[Int] = new ConcurrentLinkedQueue[Int]()
}

class TestBigtableDoFn extends BigtableDoFn[Int, String](null) {
class TestBigtableDoFn extends BigtableDoFn[Int, String](BigtableOptions.getDefaultOptions) {
override def newClient(): BigtableSession = null
override def asyncLookup(session: BigtableSession, input: Int): ListenableFuture[String] =
Futures.immediateFuture(input.toString)
}

class TestCachingBigtableDoFn extends BigtableDoFn[Int, String](null, 100, new TestCacheSupplier) {
class TestCachingBigtableDoFn
extends BigtableDoFn[Int, String](
BigtableOptions.getDefaultOptions,
100,
new TestCacheSupplier
) {
override def newClient(): BigtableSession = null
override def asyncLookup(session: BigtableSession, input: Int): ListenableFuture[String] = {
BigtableDoFnTest.queue.add(input)
Futures.immediateFuture(input.toString)
}
}

class TestFailingBigtableDoFn extends BigtableDoFn[Int, String](null) {
class TestFailingBigtableDoFn extends BigtableDoFn[Int, String](BigtableOptions.getDefaultOptions) {
override def newClient(): BigtableSession = null
override def asyncLookup(session: BigtableSession, input: Int): ListenableFuture[String] =
if (input % 2 == 0) {
Expand Down

0 comments on commit 4b7ca53

Please sign in to comment.