From 881d43f91e1fd951566031d57cf1221d0f7a6f81 Mon Sep 17 00:00:00 2001 From: Martin Trieu Date: Tue, 22 Oct 2024 15:23:01 -0700 Subject: [PATCH] extract semaphore logic out of WeightBoundedQueue to allow a weigher to be injected and shared --- .../worker/StreamingDataflowWorker.java | 6 +- .../streaming/WeightedBoundedQueue.java | 37 +++++------ .../worker/streaming/WeightedSemaphore.java | 49 ++++++++++++++ .../StreamingApplianceWorkCommitter.java | 8 ++- .../commits/StreamingEngineWorkCommitter.java | 19 ++++-- .../streaming/WeightBoundedQueueTest.java | 65 ++++++++++++++++--- 6 files changed, 140 insertions(+), 44 deletions(-) create mode 100644 runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/streaming/WeightedSemaphore.java diff --git a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/StreamingDataflowWorker.java b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/StreamingDataflowWorker.java index 524906023722..5d50aba0a15e 100644 --- a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/StreamingDataflowWorker.java +++ b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/StreamingDataflowWorker.java @@ -118,19 +118,17 @@ public final class StreamingDataflowWorker { */ public static final int MAX_SINK_BYTES = 10_000_000; + public static final String STREAMING_ENGINE_USE_JOB_SETTINGS_FOR_HEARTBEAT_POOL = + "streaming_engine_use_job_settings_for_heartbeat_pool"; private static final Logger LOG = LoggerFactory.getLogger(StreamingDataflowWorker.class); - /** * Maximum number of threads for processing. Currently, each thread processes one key at a time. */ private static final int MAX_PROCESSING_THREADS = 300; - /** The idGenerator to generate unique id globally. */ private static final IdGenerator ID_GENERATOR = IdGenerators.decrementingLongs(); - /** Maximum size of the result of a GetWork request. */ private static final long MAX_GET_WORK_FETCH_BYTES = 64L << 20; // 64m - /** Maximum number of failure stacktraces to report in each update sent to backend. */ private static final int MAX_FAILURES_TO_REPORT_IN_UPDATE = 1000; diff --git a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/streaming/WeightedBoundedQueue.java b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/streaming/WeightedBoundedQueue.java index f2893f3e7191..0132138cda03 100644 --- a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/streaming/WeightedBoundedQueue.java +++ b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/streaming/WeightedBoundedQueue.java @@ -18,33 +18,24 @@ package org.apache.beam.runners.dataflow.worker.streaming; import java.util.concurrent.LinkedBlockingQueue; -import java.util.concurrent.Semaphore; import java.util.concurrent.TimeUnit; -import java.util.function.Function; +import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.annotations.VisibleForTesting; import org.checkerframework.checker.nullness.qual.Nullable; -/** Bounded set of queues, with a maximum total weight. */ +/** Queue bounded by a {@link WeightedSemaphore}. */ public final class WeightedBoundedQueue { private final LinkedBlockingQueue queue; - private final int maxWeight; - private final Semaphore limit; - private final Function weigher; + private final WeightedSemaphore weigher; private WeightedBoundedQueue( - LinkedBlockingQueue linkedBlockingQueue, - int maxWeight, - Semaphore limit, - Function weigher) { + LinkedBlockingQueue linkedBlockingQueue, WeightedSemaphore weigher) { this.queue = linkedBlockingQueue; - this.maxWeight = maxWeight; - this.limit = limit; this.weigher = weigher; } - public static WeightedBoundedQueue create(int maxWeight, Function weigherFn) { - return new WeightedBoundedQueue<>( - new LinkedBlockingQueue<>(), maxWeight, new Semaphore(maxWeight, true), weigherFn); + public static WeightedBoundedQueue create(WeightedSemaphore weigher) { + return new WeightedBoundedQueue<>(new LinkedBlockingQueue<>(), weigher); } /** @@ -52,7 +43,7 @@ public static WeightedBoundedQueue create(int maxWeight, Function { + private final int maxWeight; + private final Semaphore limit; + private final Function weigher; + + private WeightedSemaphore(int maxWeight, Semaphore limit, Function weigher) { + this.maxWeight = maxWeight; + this.limit = limit; + this.weigher = weigher; + } + + public static WeightedSemaphore create(int maxWeight, Function weigherFn) { + return new WeightedSemaphore<>(maxWeight, new Semaphore(maxWeight, true), weigherFn); + } + + void acquire(V value) { + limit.acquireUninterruptibly(weigher.apply(value)); + } + + void release(V value) { + limit.release(weigher.apply(value)); + } + + public int currentWeight() { + return maxWeight - limit.availablePermits(); + } +} diff --git a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/commits/StreamingApplianceWorkCommitter.java b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/commits/StreamingApplianceWorkCommitter.java index d092ebf53fc1..ad51daed748e 100644 --- a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/commits/StreamingApplianceWorkCommitter.java +++ b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/commits/StreamingApplianceWorkCommitter.java @@ -27,6 +27,7 @@ import org.apache.beam.runners.dataflow.worker.streaming.ComputationState; import org.apache.beam.runners.dataflow.worker.streaming.ShardedKey; import org.apache.beam.runners.dataflow.worker.streaming.WeightedBoundedQueue; +import org.apache.beam.runners.dataflow.worker.streaming.WeightedSemaphore; import org.apache.beam.runners.dataflow.worker.streaming.Work; import org.apache.beam.runners.dataflow.worker.streaming.WorkId; import org.apache.beam.runners.dataflow.worker.windmill.Windmill; @@ -55,7 +56,9 @@ private StreamingApplianceWorkCommitter( this.commitWorkFn = commitWorkFn; this.commitQueue = WeightedBoundedQueue.create( - MAX_COMMIT_QUEUE_BYTES, commit -> Math.min(MAX_COMMIT_QUEUE_BYTES, commit.getSize())); + WeightedSemaphore.create( + MAX_COMMIT_QUEUE_BYTES, + commit -> Math.min(MAX_COMMIT_QUEUE_BYTES, commit.getSize()))); this.commitWorkers = Executors.newSingleThreadScheduledExecutor( new ThreadFactoryBuilder() @@ -73,10 +76,9 @@ public static StreamingApplianceWorkCommitter create( } @Override - @SuppressWarnings("FutureReturnValueIgnored") public void start() { if (!commitWorkers.isShutdown()) { - commitWorkers.submit(this::commitLoop); + commitWorkers.execute(this::commitLoop); } } diff --git a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/commits/StreamingEngineWorkCommitter.java b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/commits/StreamingEngineWorkCommitter.java index bf1007bc4bfb..91a3f3b745af 100644 --- a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/commits/StreamingEngineWorkCommitter.java +++ b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/commits/StreamingEngineWorkCommitter.java @@ -28,6 +28,7 @@ import javax.annotation.Nullable; import javax.annotation.concurrent.ThreadSafe; import org.apache.beam.runners.dataflow.worker.streaming.WeightedBoundedQueue; +import org.apache.beam.runners.dataflow.worker.streaming.WeightedSemaphore; import org.apache.beam.runners.dataflow.worker.streaming.Work; import org.apache.beam.runners.dataflow.worker.windmill.client.CloseableStream; import org.apache.beam.runners.dataflow.worker.windmill.client.WindmillStream.CommitWorkStream; @@ -46,7 +47,7 @@ public final class StreamingEngineWorkCommitter implements WorkCommitter { private static final Logger LOG = LoggerFactory.getLogger(StreamingEngineWorkCommitter.class); private static final int TARGET_COMMIT_BATCH_KEYS = 5; - private static final int MAX_COMMIT_QUEUE_BYTES = 500 << 20; // 500MB + private static final int MAX_QUEUED_COMMITS_BYTES = 500 << 20; // 500MB private static final String NO_BACKEND_WORKER_TOKEN = ""; private final Supplier> commitWorkStreamFactory; @@ -61,11 +62,10 @@ public final class StreamingEngineWorkCommitter implements WorkCommitter { Supplier> commitWorkStreamFactory, int numCommitSenders, Consumer onCommitComplete, - String backendWorkerToken) { + String backendWorkerToken, + WeightedSemaphore weigher) { this.commitWorkStreamFactory = commitWorkStreamFactory; - this.commitQueue = - WeightedBoundedQueue.create( - MAX_COMMIT_QUEUE_BYTES, commit -> Math.min(MAX_COMMIT_QUEUE_BYTES, commit.getSize())); + this.commitQueue = WeightedBoundedQueue.create(weigher); this.commitSenders = Executors.newFixedThreadPool( numCommitSenders, @@ -86,16 +86,19 @@ public final class StreamingEngineWorkCommitter implements WorkCommitter { public static Builder builder() { return new AutoBuilder_StreamingEngineWorkCommitter_Builder() .setBackendWorkerToken(NO_BACKEND_WORKER_TOKEN) + .setWeigher( + WeightedSemaphore.create( + MAX_QUEUED_COMMITS_BYTES, + commit -> Math.min(MAX_QUEUED_COMMITS_BYTES, commit.getSize()))) .setNumCommitSenders(1); } @Override - @SuppressWarnings("FutureReturnValueIgnored") public void start() { Preconditions.checkState( isRunning.compareAndSet(false, true), "Multiple calls to WorkCommitter.start()."); for (int i = 0; i < numCommitSenders; i++) { - commitSenders.submit(this::streamingCommitLoop); + commitSenders.execute(this::streamingCommitLoop); } } @@ -258,6 +261,8 @@ public interface Builder { Builder setCommitWorkStreamFactory( Supplier> commitWorkStreamFactory); + Builder setWeigher(WeightedSemaphore weigher); + Builder setNumCommitSenders(int numCommitSenders); Builder setOnCommitComplete(Consumer onCommitComplete); diff --git a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/streaming/WeightBoundedQueueTest.java b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/streaming/WeightBoundedQueueTest.java index 4f035c88774c..01e90ef7d360 100644 --- a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/streaming/WeightBoundedQueueTest.java +++ b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/streaming/WeightBoundedQueueTest.java @@ -30,13 +30,14 @@ @RunWith(JUnit4.class) public class WeightBoundedQueueTest { - @Rule public transient Timeout globalTimeout = Timeout.seconds(600); private static final int MAX_WEIGHT = 10; + @Rule public transient Timeout globalTimeout = Timeout.seconds(600); @Test public void testPut_hasCapacity() { WeightedBoundedQueue queue = - WeightedBoundedQueue.create(MAX_WEIGHT, i -> Math.min(MAX_WEIGHT, i)); + WeightedBoundedQueue.create( + WeightedSemaphore.create(MAX_WEIGHT, i -> Math.min(MAX_WEIGHT, i))); int insertedValue = 1; @@ -50,7 +51,8 @@ public void testPut_hasCapacity() { @Test public void testPut_noCapacity() throws InterruptedException { WeightedBoundedQueue queue = - WeightedBoundedQueue.create(MAX_WEIGHT, i -> Math.min(MAX_WEIGHT, i)); + WeightedBoundedQueue.create( + WeightedSemaphore.create(MAX_WEIGHT, i -> Math.min(MAX_WEIGHT, i))); // Insert value that takes all the capacity into the queue. queue.put(MAX_WEIGHT); @@ -87,7 +89,8 @@ public void testPut_noCapacity() throws InterruptedException { @Test public void testPoll() { WeightedBoundedQueue queue = - WeightedBoundedQueue.create(MAX_WEIGHT, i -> Math.min(MAX_WEIGHT, i)); + WeightedBoundedQueue.create( + WeightedSemaphore.create(MAX_WEIGHT, i -> Math.min(MAX_WEIGHT, i))); int insertedValue1 = 1; int insertedValue2 = 2; @@ -104,7 +107,8 @@ public void testPoll() { @Test public void testPoll_withTimeout() throws InterruptedException { WeightedBoundedQueue queue = - WeightedBoundedQueue.create(MAX_WEIGHT, i -> Math.min(MAX_WEIGHT, i)); + WeightedBoundedQueue.create( + WeightedSemaphore.create(MAX_WEIGHT, i -> Math.min(MAX_WEIGHT, i))); int pollWaitTimeMillis = 10000; int insertedValue1 = 1; @@ -132,7 +136,8 @@ public void testPoll_withTimeout() throws InterruptedException { @Test public void testPoll_withTimeout_timesOut() throws InterruptedException { WeightedBoundedQueue queue = - WeightedBoundedQueue.create(MAX_WEIGHT, i -> Math.min(MAX_WEIGHT, i)); + WeightedBoundedQueue.create( + WeightedSemaphore.create(MAX_WEIGHT, i -> Math.min(MAX_WEIGHT, i))); int defaultPollResult = -10; int pollWaitTimeMillis = 100; int insertedValue1 = 1; @@ -164,7 +169,8 @@ public void testPoll_withTimeout_timesOut() throws InterruptedException { @Test public void testPoll_emptyQueue() { WeightedBoundedQueue queue = - WeightedBoundedQueue.create(MAX_WEIGHT, i -> Math.min(MAX_WEIGHT, i)); + WeightedBoundedQueue.create( + WeightedSemaphore.create(MAX_WEIGHT, i -> Math.min(MAX_WEIGHT, i))); assertNull(queue.poll()); } @@ -172,7 +178,8 @@ public void testPoll_emptyQueue() { @Test public void testTake() throws InterruptedException { WeightedBoundedQueue queue = - WeightedBoundedQueue.create(MAX_WEIGHT, i -> Math.min(MAX_WEIGHT, i)); + WeightedBoundedQueue.create( + WeightedSemaphore.create(MAX_WEIGHT, i -> Math.min(MAX_WEIGHT, i))); AtomicInteger value = new AtomicInteger(); // Should block until value is available @@ -194,4 +201,46 @@ public void testTake() throws InterruptedException { assertEquals(MAX_WEIGHT, value.get()); } + + @Test + public void testPut_sharedWeigher() throws InterruptedException { + WeightedSemaphore weigher = + WeightedSemaphore.create(MAX_WEIGHT, i -> Math.min(MAX_WEIGHT, i)); + WeightedBoundedQueue queue1 = WeightedBoundedQueue.create(weigher); + WeightedBoundedQueue queue2 = WeightedBoundedQueue.create(weigher); + + // Insert value that takes all the weight into the queue1. + queue1.put(MAX_WEIGHT); + + // Try to insert a value into the queue2. This will block since there is no capacity in the + // weigher. + Thread putThread = + new Thread( + () -> { + try { + Thread.sleep(100); + } catch (InterruptedException e) { + throw new RuntimeException(e); + } + queue2.put(MAX_WEIGHT); + }); + putThread.start(); + + // Should only see the first value in the queue, since the queue is at capacity. putThread + // should be blocked. The weight should be the same however, since queue1 and queue2 are sharing + // the weigher. + assertEquals(MAX_WEIGHT, queue1.queuedElementsWeight()); + assertEquals(1, queue1.size()); + assertEquals(MAX_WEIGHT, queue2.queuedElementsWeight()); + assertEquals(0, queue2.size()); + + // Poll queue1, pulling off the only value inside and freeing up the capacity in the weigher. + queue1.poll(); + + // Wait for the putThread which was previously blocked due to the weigher being at capacity. + putThread.join(); + + assertEquals(MAX_WEIGHT, queue2.queuedElementsWeight()); + assertEquals(1, queue2.size()); + } }