diff --git a/scio-core/src/main/java/com/spotify/scio/transforms/BaseAsyncBatchLookupDoFn.java b/scio-core/src/main/java/com/spotify/scio/transforms/BaseAsyncBatchLookupDoFn.java index b3f34b25e8..a281edc876 100644 --- a/scio-core/src/main/java/com/spotify/scio/transforms/BaseAsyncBatchLookupDoFn.java +++ b/scio-core/src/main/java/com/spotify/scio/transforms/BaseAsyncBatchLookupDoFn.java @@ -20,9 +20,7 @@ import com.google.common.cache.Cache; import com.google.common.collect.ImmutableList; -import com.google.common.collect.MoreCollectors; import com.spotify.scio.transforms.BaseAsyncLookupDoFn.CacheSupplier; - import java.util.*; import java.util.concurrent.ConcurrentHashMap; import java.util.concurrent.ConcurrentLinkedQueue; @@ -31,7 +29,6 @@ import java.util.concurrent.Semaphore; import java.util.function.Consumer; import java.util.stream.Collectors; - import org.apache.beam.sdk.transforms.DoFn; import org.apache.beam.sdk.transforms.SerializableFunction; import org.apache.beam.sdk.transforms.windowing.BoundedWindow; @@ -48,322 +45,320 @@ * A {@link DoFn} that performs asynchronous lookup using the provided client. Lookup requests may * be deduplicated. * - * @param input element type. - * @param batched input element type + * @param input element type. + * @param batched input element type * @param batched output element type - * @param client lookup value type. - * @param client type. - * @param future type. - * @param client lookup value type wrapped in a Try. + * @param client lookup value type. + * @param client type. + * @param future type. + * @param client lookup value type wrapped in a Try. */ public abstract class BaseAsyncBatchLookupDoFn< Input, BatchRequest, BatchResponse, Output, ClientType, FutureType, TryWrapper> - extends DoFnWithResource, Pair>> - implements FutureHandlers.Base { - - private static final Logger LOG = LoggerFactory.getLogger(BaseAsyncBatchLookupDoFn.class); - - // Data structures for handling async requests - private final int batchSize; - private final SerializableFunction, BatchRequest> batchRequestFn; - private final SerializableFunction>> batchResponseFn; - private final SerializableFunction idExtractorFn; - private final int maxPendingRequests; - private final CacheSupplier cacheSupplier; - - private final Semaphore semaphore; - private final ConcurrentMap futures = new ConcurrentHashMap<>(); - private final ConcurrentMap>> inputs = - new ConcurrentHashMap<>(); - - private final Queue batch = new ArrayDeque<>(); - private final ConcurrentLinkedQueue>>>> - results = new ConcurrentLinkedQueue<>(); - private long inputCount; - private long outputCount; - - public BaseAsyncBatchLookupDoFn( - int batchSize, - SerializableFunction, BatchRequest> batchRequestFn, - SerializableFunction>> batchResponseFn, - SerializableFunction idExtractorFn, - int maxPendingRequests) { - this( - batchSize, - batchRequestFn, - batchResponseFn, - idExtractorFn, - maxPendingRequests, - new BaseAsyncLookupDoFn.NoOpCacheSupplier<>()); - } - - public BaseAsyncBatchLookupDoFn( - int batchSize, - SerializableFunction, BatchRequest> batchRequestFn, - SerializableFunction>> batchResponseFn, - SerializableFunction idExtractorFn, - int maxPendingRequests, - CacheSupplier cacheSupplier) { - this.batchSize = batchSize; - this.batchRequestFn = batchRequestFn; - this.batchResponseFn = batchResponseFn; - this.idExtractorFn = idExtractorFn; - this.maxPendingRequests = maxPendingRequests; - this.semaphore = new Semaphore(maxPendingRequests); - this.cacheSupplier = cacheSupplier; - } - - protected abstract ClientType newClient(); - - public abstract FutureType asyncLookup(ClientType client, BatchRequest input); - - public abstract TryWrapper success(Iterable output); - - public abstract TryWrapper failure(Throwable throwable); - - @Override - public Pair> createResource() { - return Pair.of(newClient(), cacheSupplier.get()); - } - - @Override - public void closeResource(Pair> resource) throws Exception { - final ClientType client = resource.getLeft(); - if (client instanceof AutoCloseable) { - ((AutoCloseable) client).close(); - } + extends DoFnWithResource, Pair>> + implements FutureHandlers.Base { + + private static final Logger LOG = LoggerFactory.getLogger(BaseAsyncBatchLookupDoFn.class); + + // Data structures for handling async requests + private final int batchSize; + private final SerializableFunction, BatchRequest> batchRequestFn; + private final SerializableFunction>> batchResponseFn; + private final SerializableFunction idExtractorFn; + private final int maxPendingRequests; + private final CacheSupplier cacheSupplier; + + private final Semaphore semaphore; + private final ConcurrentMap futures = new ConcurrentHashMap<>(); + private final ConcurrentMap>> inputs = + new ConcurrentHashMap<>(); + + private final Queue batch = new ArrayDeque<>(); + private final ConcurrentLinkedQueue>>>> + results = new ConcurrentLinkedQueue<>(); + private long inputCount; + private long outputCount; + + public BaseAsyncBatchLookupDoFn( + int batchSize, + SerializableFunction, BatchRequest> batchRequestFn, + SerializableFunction>> batchResponseFn, + SerializableFunction idExtractorFn, + int maxPendingRequests) { + this( + batchSize, + batchRequestFn, + batchResponseFn, + idExtractorFn, + maxPendingRequests, + new BaseAsyncLookupDoFn.NoOpCacheSupplier<>()); + } + + public BaseAsyncBatchLookupDoFn( + int batchSize, + SerializableFunction, BatchRequest> batchRequestFn, + SerializableFunction>> batchResponseFn, + SerializableFunction idExtractorFn, + int maxPendingRequests, + CacheSupplier cacheSupplier) { + this.batchSize = batchSize; + this.batchRequestFn = batchRequestFn; + this.batchResponseFn = batchResponseFn; + this.idExtractorFn = idExtractorFn; + this.maxPendingRequests = maxPendingRequests; + this.semaphore = new Semaphore(maxPendingRequests); + this.cacheSupplier = cacheSupplier; + } + + protected abstract ClientType newClient(); + + public abstract FutureType asyncLookup(ClientType client, BatchRequest input); + + public abstract TryWrapper success(Iterable output); + + public abstract TryWrapper failure(Throwable throwable); + + @Override + public Pair> createResource() { + return Pair.of(newClient(), cacheSupplier.get()); + } + + @Override + public void closeResource(Pair> resource) throws Exception { + final ClientType client = resource.getLeft(); + if (client instanceof AutoCloseable) { + ((AutoCloseable) client).close(); } - - public ClientType getResourceClient() { - return getResource().getLeft(); - } - - public Cache getResourceCache() { - return getResource().getRight(); + } + + public ClientType getResourceClient() { + return getResource().getLeft(); + } + + public Cache getResourceCache() { + return getResource().getRight(); + } + + @StartBundle + public void startBundle(StartBundleContext context) { + futures.clear(); + results.clear(); + inputs.clear(); + batch.clear(); + inputCount = 0; + outputCount = 0; + semaphore.drainPermits(); + semaphore.release(maxPendingRequests); + } + + // kept for binary compatibility. Must not be used + // TODO: remove in 0.15.0 + @Deprecated + public void processElement( + Input input, + Instant timestamp, + OutputReceiver> out, + BoundedWindow window) { + processElement(input, timestamp, window, null, out); + } + + @ProcessElement + public void processElement( + @Element Input input, + @Timestamp Instant timestamp, + BoundedWindow window, + PaneInfo pane, + OutputReceiver> out) { + inputCount++; + flush( + r -> { + final KV io = r.getValue(); + final Instant ts = r.getTimestamp(); + final Collection ws = Collections.singleton(r.getWindow()); + final PaneInfo p = r.getPane(); + out.outputWindowedValue(io, ts, ws, p); + }); + final Cache cache = getResourceCache(); + + try { + final String id = this.idExtractorFn.apply(input); + requireNonNull(id, "idExtractorFn returned null"); + + final Output cached = cache.getIfPresent(id); + + if (cached != null) { + // found in cache + out.output(KV.of(input, success(ImmutableList.of(cached)))); + outputCount++; + } else { + inputs.compute( + id, + (k, v) -> { + if (v == null) { + v = new LinkedList<>(); + batch.add(input); + } + v.add(ValueInSingleWindow.of(input, timestamp, window, pane)); + return v; + }); + } + + if (batch.size() >= batchSize) { + createRequest(); + } + + } catch (InterruptedException e) { + LOG.error("Failed to acquire semaphore", e); + throw new RuntimeException("Failed to acquire semaphore", e); + } catch (Exception e) { + LOG.error("Failed to process element", e); + throw e; } - - @StartBundle - public void startBundle(StartBundleContext context) { - futures.clear(); - results.clear(); - inputs.clear(); - batch.clear(); - inputCount = 0; - outputCount = 0; - semaphore.drainPermits(); - semaphore.release(maxPendingRequests); + } + + @FinishBundle + public void finishBundle(FinishBundleContext context) { + + // send remaining + try { + if (!batch.isEmpty()) { + createRequest(); + } + if (!futures.isEmpty()) { + // Block until all pending futures are complete + waitForFutures(futures.values()); + } + } catch (InterruptedException e) { + Thread.currentThread().interrupt(); + LOG.error("Failed to process futures", e); + throw new RuntimeException("Failed to process futures", e); + } catch (ExecutionException e) { + LOG.error("Failed to process futures", e); + throw new RuntimeException("Failed to process futures", e); } - - // kept for binary compatibility. Must not be used - // TODO: remove in 0.15.0 - @Deprecated - public void processElement( - Input input, - Instant timestamp, - OutputReceiver> out, - BoundedWindow window) { - processElement(input, timestamp, window, null, out); - } - - @ProcessElement - public void processElement( - @Element Input input, - @Timestamp Instant timestamp, - BoundedWindow window, - PaneInfo pane, - OutputReceiver> out) { - inputCount++; - flush( - r -> { - final KV io = r.getValue(); - final Instant ts = r.getTimestamp(); - final Collection ws = Collections.singleton(r.getWindow()); - final PaneInfo p = r.getPane(); - out.outputWindowedValue(io, ts, ws, p); - }); - final Cache cache = getResourceCache(); - - try { - final String id = this.idExtractorFn.apply(input); - requireNonNull(id, "idExtractorFn returned null"); - - final Output cached = cache.getIfPresent(id); - - if (cached != null) { - // found in cache - out.output(KV.of(input, success(ImmutableList.of(cached)))); - outputCount++; - } else { - inputs.compute( - id, - (k, v) -> { - if (v == null) { - v = new LinkedList<>(); - batch.add(input); - } - v.add(ValueInSingleWindow.of(input, timestamp, window, pane)); - return v; - }); - } - - if (batch.size() >= batchSize) { - createRequest(); - } - - } catch (InterruptedException e) { - LOG.error("Failed to acquire semaphore", e); - throw new RuntimeException("Failed to acquire semaphore", e); - } catch (Exception e) { - LOG.error("Failed to process element", e); - throw e; - } - } - - @FinishBundle - public void finishBundle(FinishBundleContext context) { - - // send remaining - try { - if (!batch.isEmpty()) { - createRequest(); - } - if (!futures.isEmpty()) { - // Block until all pending futures are complete - waitForFutures(futures.values()); - } - } catch (InterruptedException e) { - Thread.currentThread().interrupt(); - LOG.error("Failed to process futures", e); - throw new RuntimeException("Failed to process futures", e); - } catch (ExecutionException e) { - LOG.error("Failed to process futures", e); - throw new RuntimeException("Failed to process futures", e); - } - flush(r -> context.output(r.getValue(), r.getTimestamp(), r.getWindow())); - - // Make sure all requests are processed - Preconditions.checkState( - inputCount == outputCount, - "Expected requestCount == responseCount, but %s != %s", - inputCount, - outputCount); - } - - private void createRequest() throws InterruptedException { - final ClientType client = getResourceClient(); - final Cache cache = getResourceCache(); - final UUID key = UUID.randomUUID(); - final List elems = new ArrayList<>(batch); - final BatchRequest request = batchRequestFn.apply(elems); - - // semaphore release is not performed on exception. - // let beam retry the bundle. startBundle will reset the semaphore to the - // maxPendingRequests permits. - semaphore.acquire(); - final FutureType future = asyncLookup(client, request); - // handle cache in fire & forget way - handleCache(future, cache); - // make sure semaphore are released when waiting for futures in finishBundle - final FutureType unlockedFuture = handleSemaphore(future); - - futures.put(key, handleOutput(unlockedFuture, elems, key)); - batch.clear(); - } - - private FutureType handleOutput(FutureType future, List batchInput, UUID key) { - return addCallback( - future, - response -> { - final Map> responses = batchResponseFn - .apply(response) - .stream() - .collect( - Collectors.groupingBy(Pair::getKey, - Collectors.mapping(Pair::getValue, Collectors.toList())) - ); - batchInput - .forEach( - element -> { - final String id = idExtractorFn.apply(element); - final List>> batchResult = - inputs.remove(id).stream() - .map(processInput -> { - final List output = - responses.getOrDefault(id, Collections.emptyList()); - final Input i = processInput.getValue(); - final TryWrapper o = success(output); - final Instant ts = processInput.getTimestamp(); - final BoundedWindow w = processInput.getWindow(); - final PaneInfo p = processInput.getPane(); - return ValueInSingleWindow.of(KV.of(i, o), ts, w, p); - }).collect(Collectors.toList()); - results.add(Pair.of(key, batchResult)); - } - ); - return null; - }, - throwable -> { - batchInput.forEach( - element -> { - final String id = idExtractorFn.apply(element); - final List>> batchResult = - inputs.remove(id).stream() - .map( - processInput -> { - final Input i = processInput.getValue(); - final TryWrapper o = failure(throwable); - final Instant ts = processInput.getTimestamp(); - final BoundedWindow w = processInput.getWindow(); - final PaneInfo p = processInput.getPane(); - return ValueInSingleWindow.of(KV.of(i, o), ts, w, p); - }) - .collect(Collectors.toList()); - results.add(Pair.of(key, batchResult)); - }); - return null; - }); - } - - private FutureType handleSemaphore(FutureType future) { - return addCallback( - future, - ouput -> { - semaphore.release(); - return null; - }, - throwable -> { - semaphore.release(); - return null; - }); - } - - private FutureType handleCache(FutureType future, Cache cache) { - return addCallback( - future, - response -> { - batchResponseFn - .apply(response) - .forEach( - pair -> { - final String id = pair.getLeft(); - final Output output = pair.getRight(); - cache.put(id, output); - }); - return null; - }, - throwable -> null); - } - - // Flush pending elements errors and results - private void flush(Consumer>> outputFn) { - Pair>>> r = results.poll(); - while (r != null) { - final UUID key = r.getKey(); - final List>> batchResult = r.getValue(); - batchResult.forEach(outputFn); - outputCount += batchResult.size(); - futures.remove(key); - r = results.poll(); - } + flush(r -> context.output(r.getValue(), r.getTimestamp(), r.getWindow())); + + // Make sure all requests are processed + Preconditions.checkState( + inputCount == outputCount, + "Expected requestCount == responseCount, but %s != %s", + inputCount, + outputCount); + } + + private void createRequest() throws InterruptedException { + final ClientType client = getResourceClient(); + final Cache cache = getResourceCache(); + final UUID key = UUID.randomUUID(); + final List elems = new ArrayList<>(batch); + final BatchRequest request = batchRequestFn.apply(elems); + + // semaphore release is not performed on exception. + // let beam retry the bundle. startBundle will reset the semaphore to the + // maxPendingRequests permits. + semaphore.acquire(); + final FutureType future = asyncLookup(client, request); + // handle cache in fire & forget way + handleCache(future, cache); + // make sure semaphore are released when waiting for futures in finishBundle + final FutureType unlockedFuture = handleSemaphore(future); + + futures.put(key, handleOutput(unlockedFuture, elems, key)); + batch.clear(); + } + + private FutureType handleOutput(FutureType future, List batchInput, UUID key) { + return addCallback( + future, + response -> { + final Map> responses = + batchResponseFn.apply(response).stream() + .collect( + Collectors.groupingBy( + Pair::getKey, Collectors.mapping(Pair::getValue, Collectors.toList()))); + batchInput.forEach( + element -> { + final String id = idExtractorFn.apply(element); + final List>> batchResult = + inputs.remove(id).stream() + .map( + processInput -> { + final List output = + responses.getOrDefault(id, Collections.emptyList()); + final Input i = processInput.getValue(); + final TryWrapper o = success(output); + final Instant ts = processInput.getTimestamp(); + final BoundedWindow w = processInput.getWindow(); + final PaneInfo p = processInput.getPane(); + return ValueInSingleWindow.of(KV.of(i, o), ts, w, p); + }) + .collect(Collectors.toList()); + results.add(Pair.of(key, batchResult)); + }); + return null; + }, + throwable -> { + batchInput.forEach( + element -> { + final String id = idExtractorFn.apply(element); + final List>> batchResult = + inputs.remove(id).stream() + .map( + processInput -> { + final Input i = processInput.getValue(); + final TryWrapper o = failure(throwable); + final Instant ts = processInput.getTimestamp(); + final BoundedWindow w = processInput.getWindow(); + final PaneInfo p = processInput.getPane(); + return ValueInSingleWindow.of(KV.of(i, o), ts, w, p); + }) + .collect(Collectors.toList()); + results.add(Pair.of(key, batchResult)); + }); + return null; + }); + } + + private FutureType handleSemaphore(FutureType future) { + return addCallback( + future, + ouput -> { + semaphore.release(); + return null; + }, + throwable -> { + semaphore.release(); + return null; + }); + } + + private FutureType handleCache(FutureType future, Cache cache) { + return addCallback( + future, + response -> { + batchResponseFn + .apply(response) + .forEach( + pair -> { + final String id = pair.getLeft(); + final Output output = pair.getRight(); + cache.put(id, output); + }); + return null; + }, + throwable -> null); + } + + // Flush pending elements errors and results + private void flush(Consumer>> outputFn) { + Pair>>> r = results.poll(); + while (r != null) { + final UUID key = r.getKey(); + final List>> batchResult = r.getValue(); + batchResult.forEach(outputFn); + outputCount += batchResult.size(); + futures.remove(key); + r = results.poll(); } + } } diff --git a/scio-core/src/main/java/com/spotify/scio/transforms/GuavaAsyncBatchLookupDoFn.java b/scio-core/src/main/java/com/spotify/scio/transforms/GuavaAsyncBatchLookupDoFn.java index 66da02411a..3dba9166c6 100644 --- a/scio-core/src/main/java/com/spotify/scio/transforms/GuavaAsyncBatchLookupDoFn.java +++ b/scio-core/src/main/java/com/spotify/scio/transforms/GuavaAsyncBatchLookupDoFn.java @@ -20,7 +20,6 @@ import com.spotify.scio.transforms.BaseAsyncLookupDoFn.CacheSupplier; import com.spotify.scio.transforms.BaseAsyncLookupDoFn.Try; import java.util.List; - import org.apache.beam.sdk.transforms.SerializableFunction; import org.apache.commons.lang3.tuple.Pair; diff --git a/scio-core/src/main/java/com/spotify/scio/transforms/JavaAsyncBatchLookupDoFn.java b/scio-core/src/main/java/com/spotify/scio/transforms/JavaAsyncBatchLookupDoFn.java index 6b5144a2f9..d0d48e91f1 100644 --- a/scio-core/src/main/java/com/spotify/scio/transforms/JavaAsyncBatchLookupDoFn.java +++ b/scio-core/src/main/java/com/spotify/scio/transforms/JavaAsyncBatchLookupDoFn.java @@ -18,10 +18,7 @@ import com.spotify.scio.transforms.BaseAsyncLookupDoFn.CacheSupplier; import com.spotify.scio.transforms.BaseAsyncLookupDoFn.Try; - -import java.util.Iterator; import java.util.List; -import java.util.Optional; import java.util.concurrent.CompletableFuture; import org.apache.beam.sdk.transforms.SerializableFunction; import org.apache.commons.lang3.tuple.Pair; diff --git a/scio-core/src/main/scala/com/spotify/scio/transforms/ScalaAsyncBatchLookupDoFn.scala b/scio-core/src/main/scala/com/spotify/scio/transforms/ScalaAsyncBatchLookupDoFn.scala index ce50d10c77..4319eb22f1 100644 --- a/scio-core/src/main/scala/com/spotify/scio/transforms/ScalaAsyncBatchLookupDoFn.scala +++ b/scio-core/src/main/scala/com/spotify/scio/transforms/ScalaAsyncBatchLookupDoFn.scala @@ -66,7 +66,9 @@ abstract class ScalaAsyncBatchLookupDoFn[Input, BatchRequest, BatchResponse, Out ) with ScalaFutureHandlers[BatchResponse] { - override def success(output: java.lang.Iterable[Output]): Try[Option[Output]] = Success(output.asScala.headOption) + override def success(output: java.lang.Iterable[Output]): Try[Option[Output]] = Success( + output.asScala.headOption + ) override def failure(throwable: Throwable): Try[Option[Output]] = Failure(throwable) }