Skip to content

Commit

Permalink
extract semaphore logic out of WeightBoundedQueue to allow a weigher …
Browse files Browse the repository at this point in the history
…to be injected and shared
  • Loading branch information
m-trieu committed Oct 22, 2024
1 parent 4f4853e commit 881d43f
Show file tree
Hide file tree
Showing 6 changed files with 140 additions and 44 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -118,19 +118,17 @@ public final class StreamingDataflowWorker {
*/
public static final int MAX_SINK_BYTES = 10_000_000;

public static final String STREAMING_ENGINE_USE_JOB_SETTINGS_FOR_HEARTBEAT_POOL =
"streaming_engine_use_job_settings_for_heartbeat_pool";
private static final Logger LOG = LoggerFactory.getLogger(StreamingDataflowWorker.class);

/**
* Maximum number of threads for processing. Currently, each thread processes one key at a time.
*/
private static final int MAX_PROCESSING_THREADS = 300;

/** The idGenerator to generate unique id globally. */
private static final IdGenerator ID_GENERATOR = IdGenerators.decrementingLongs();

/** Maximum size of the result of a GetWork request. */
private static final long MAX_GET_WORK_FETCH_BYTES = 64L << 20; // 64m

/** Maximum number of failure stacktraces to report in each update sent to backend. */
private static final int MAX_FAILURES_TO_REPORT_IN_UPDATE = 1000;

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,49 +18,40 @@
package org.apache.beam.runners.dataflow.worker.streaming;

import java.util.concurrent.LinkedBlockingQueue;
import java.util.concurrent.Semaphore;
import java.util.concurrent.TimeUnit;
import java.util.function.Function;
import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.annotations.VisibleForTesting;
import org.checkerframework.checker.nullness.qual.Nullable;

/** Bounded set of queues, with a maximum total weight. */
/** Queue bounded by a {@link WeightedSemaphore}. */
public final class WeightedBoundedQueue<V> {

private final LinkedBlockingQueue<V> queue;
private final int maxWeight;
private final Semaphore limit;
private final Function<V, Integer> weigher;
private final WeightedSemaphore<V> weigher;

private WeightedBoundedQueue(
LinkedBlockingQueue<V> linkedBlockingQueue,
int maxWeight,
Semaphore limit,
Function<V, Integer> weigher) {
LinkedBlockingQueue<V> linkedBlockingQueue, WeightedSemaphore<V> weigher) {
this.queue = linkedBlockingQueue;
this.maxWeight = maxWeight;
this.limit = limit;
this.weigher = weigher;
}

public static <V> WeightedBoundedQueue<V> create(int maxWeight, Function<V, Integer> weigherFn) {
return new WeightedBoundedQueue<>(
new LinkedBlockingQueue<>(), maxWeight, new Semaphore(maxWeight, true), weigherFn);
public static <V> WeightedBoundedQueue<V> create(WeightedSemaphore<V> weigher) {
return new WeightedBoundedQueue<>(new LinkedBlockingQueue<>(), weigher);
}

/**
* Adds the value to the queue, blocking if this would cause the overall weight to exceed the
* limit.
*/
public void put(V value) {
limit.acquireUninterruptibly(weigher.apply(value));
weigher.acquire(value);
queue.add(value);
}

/** Returns and removes the next value, or null if there is no such value. */
public @Nullable V poll() {
V result = queue.poll();
if (result != null) {
limit.release(weigher.apply(result));
weigher.release(result);
}
return result;
}
Expand All @@ -78,24 +69,26 @@ public void put(V value) {
public @Nullable V poll(long timeout, TimeUnit unit) throws InterruptedException {
V result = queue.poll(timeout, unit);
if (result != null) {
limit.release(weigher.apply(result));
weigher.release(result);
}
return result;
}

/** Returns and removes the next value, or blocks until one is available. */
public @Nullable V take() throws InterruptedException {
V result = queue.take();
limit.release(weigher.apply(result));
weigher.release(result);
return result;
}

/** Returns the current weight of the queue. */
public int queuedElementsWeight() {
return maxWeight - limit.availablePermits();
@VisibleForTesting
int queuedElementsWeight() {
return weigher.currentWeight();
}

public int size() {
@VisibleForTesting
int size() {
return queue.size();
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.beam.runners.dataflow.worker.streaming;

import java.util.concurrent.Semaphore;
import java.util.function.Function;

public final class WeightedSemaphore<V> {
private final int maxWeight;
private final Semaphore limit;
private final Function<V, Integer> weigher;

private WeightedSemaphore(int maxWeight, Semaphore limit, Function<V, Integer> weigher) {
this.maxWeight = maxWeight;
this.limit = limit;
this.weigher = weigher;
}

public static <V> WeightedSemaphore<V> create(int maxWeight, Function<V, Integer> weigherFn) {
return new WeightedSemaphore<>(maxWeight, new Semaphore(maxWeight, true), weigherFn);
}

void acquire(V value) {
limit.acquireUninterruptibly(weigher.apply(value));
}

void release(V value) {
limit.release(weigher.apply(value));
}

public int currentWeight() {
return maxWeight - limit.availablePermits();
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
import org.apache.beam.runners.dataflow.worker.streaming.ComputationState;
import org.apache.beam.runners.dataflow.worker.streaming.ShardedKey;
import org.apache.beam.runners.dataflow.worker.streaming.WeightedBoundedQueue;
import org.apache.beam.runners.dataflow.worker.streaming.WeightedSemaphore;
import org.apache.beam.runners.dataflow.worker.streaming.Work;
import org.apache.beam.runners.dataflow.worker.streaming.WorkId;
import org.apache.beam.runners.dataflow.worker.windmill.Windmill;
Expand Down Expand Up @@ -55,7 +56,9 @@ private StreamingApplianceWorkCommitter(
this.commitWorkFn = commitWorkFn;
this.commitQueue =
WeightedBoundedQueue.create(
MAX_COMMIT_QUEUE_BYTES, commit -> Math.min(MAX_COMMIT_QUEUE_BYTES, commit.getSize()));
WeightedSemaphore.create(
MAX_COMMIT_QUEUE_BYTES,
commit -> Math.min(MAX_COMMIT_QUEUE_BYTES, commit.getSize())));
this.commitWorkers =
Executors.newSingleThreadScheduledExecutor(
new ThreadFactoryBuilder()
Expand All @@ -73,10 +76,9 @@ public static StreamingApplianceWorkCommitter create(
}

@Override
@SuppressWarnings("FutureReturnValueIgnored")
public void start() {
if (!commitWorkers.isShutdown()) {
commitWorkers.submit(this::commitLoop);
commitWorkers.execute(this::commitLoop);
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
import javax.annotation.Nullable;
import javax.annotation.concurrent.ThreadSafe;
import org.apache.beam.runners.dataflow.worker.streaming.WeightedBoundedQueue;
import org.apache.beam.runners.dataflow.worker.streaming.WeightedSemaphore;
import org.apache.beam.runners.dataflow.worker.streaming.Work;
import org.apache.beam.runners.dataflow.worker.windmill.client.CloseableStream;
import org.apache.beam.runners.dataflow.worker.windmill.client.WindmillStream.CommitWorkStream;
Expand All @@ -46,7 +47,7 @@
public final class StreamingEngineWorkCommitter implements WorkCommitter {
private static final Logger LOG = LoggerFactory.getLogger(StreamingEngineWorkCommitter.class);
private static final int TARGET_COMMIT_BATCH_KEYS = 5;
private static final int MAX_COMMIT_QUEUE_BYTES = 500 << 20; // 500MB
private static final int MAX_QUEUED_COMMITS_BYTES = 500 << 20; // 500MB
private static final String NO_BACKEND_WORKER_TOKEN = "";

private final Supplier<CloseableStream<CommitWorkStream>> commitWorkStreamFactory;
Expand All @@ -61,11 +62,10 @@ public final class StreamingEngineWorkCommitter implements WorkCommitter {
Supplier<CloseableStream<CommitWorkStream>> commitWorkStreamFactory,
int numCommitSenders,
Consumer<CompleteCommit> onCommitComplete,
String backendWorkerToken) {
String backendWorkerToken,
WeightedSemaphore<Commit> weigher) {
this.commitWorkStreamFactory = commitWorkStreamFactory;
this.commitQueue =
WeightedBoundedQueue.create(
MAX_COMMIT_QUEUE_BYTES, commit -> Math.min(MAX_COMMIT_QUEUE_BYTES, commit.getSize()));
this.commitQueue = WeightedBoundedQueue.create(weigher);
this.commitSenders =
Executors.newFixedThreadPool(
numCommitSenders,
Expand All @@ -86,16 +86,19 @@ public final class StreamingEngineWorkCommitter implements WorkCommitter {
public static Builder builder() {
return new AutoBuilder_StreamingEngineWorkCommitter_Builder()
.setBackendWorkerToken(NO_BACKEND_WORKER_TOKEN)
.setWeigher(
WeightedSemaphore.create(
MAX_QUEUED_COMMITS_BYTES,
commit -> Math.min(MAX_QUEUED_COMMITS_BYTES, commit.getSize())))
.setNumCommitSenders(1);
}

@Override
@SuppressWarnings("FutureReturnValueIgnored")
public void start() {
Preconditions.checkState(
isRunning.compareAndSet(false, true), "Multiple calls to WorkCommitter.start().");
for (int i = 0; i < numCommitSenders; i++) {
commitSenders.submit(this::streamingCommitLoop);
commitSenders.execute(this::streamingCommitLoop);
}
}

Expand Down Expand Up @@ -258,6 +261,8 @@ public interface Builder {
Builder setCommitWorkStreamFactory(
Supplier<CloseableStream<CommitWorkStream>> commitWorkStreamFactory);

Builder setWeigher(WeightedSemaphore<Commit> weigher);

Builder setNumCommitSenders(int numCommitSenders);

Builder setOnCommitComplete(Consumer<CompleteCommit> onCommitComplete);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,13 +30,14 @@

@RunWith(JUnit4.class)
public class WeightBoundedQueueTest {
@Rule public transient Timeout globalTimeout = Timeout.seconds(600);
private static final int MAX_WEIGHT = 10;
@Rule public transient Timeout globalTimeout = Timeout.seconds(600);

@Test
public void testPut_hasCapacity() {
WeightedBoundedQueue<Integer> queue =
WeightedBoundedQueue.create(MAX_WEIGHT, i -> Math.min(MAX_WEIGHT, i));
WeightedBoundedQueue.create(
WeightedSemaphore.create(MAX_WEIGHT, i -> Math.min(MAX_WEIGHT, i)));

int insertedValue = 1;

Expand All @@ -50,7 +51,8 @@ public void testPut_hasCapacity() {
@Test
public void testPut_noCapacity() throws InterruptedException {
WeightedBoundedQueue<Integer> queue =
WeightedBoundedQueue.create(MAX_WEIGHT, i -> Math.min(MAX_WEIGHT, i));
WeightedBoundedQueue.create(
WeightedSemaphore.create(MAX_WEIGHT, i -> Math.min(MAX_WEIGHT, i)));

// Insert value that takes all the capacity into the queue.
queue.put(MAX_WEIGHT);
Expand Down Expand Up @@ -87,7 +89,8 @@ public void testPut_noCapacity() throws InterruptedException {
@Test
public void testPoll() {
WeightedBoundedQueue<Integer> queue =
WeightedBoundedQueue.create(MAX_WEIGHT, i -> Math.min(MAX_WEIGHT, i));
WeightedBoundedQueue.create(
WeightedSemaphore.create(MAX_WEIGHT, i -> Math.min(MAX_WEIGHT, i)));

int insertedValue1 = 1;
int insertedValue2 = 2;
Expand All @@ -104,7 +107,8 @@ public void testPoll() {
@Test
public void testPoll_withTimeout() throws InterruptedException {
WeightedBoundedQueue<Integer> queue =
WeightedBoundedQueue.create(MAX_WEIGHT, i -> Math.min(MAX_WEIGHT, i));
WeightedBoundedQueue.create(
WeightedSemaphore.create(MAX_WEIGHT, i -> Math.min(MAX_WEIGHT, i)));
int pollWaitTimeMillis = 10000;
int insertedValue1 = 1;

Expand Down Expand Up @@ -132,7 +136,8 @@ public void testPoll_withTimeout() throws InterruptedException {
@Test
public void testPoll_withTimeout_timesOut() throws InterruptedException {
WeightedBoundedQueue<Integer> queue =
WeightedBoundedQueue.create(MAX_WEIGHT, i -> Math.min(MAX_WEIGHT, i));
WeightedBoundedQueue.create(
WeightedSemaphore.create(MAX_WEIGHT, i -> Math.min(MAX_WEIGHT, i)));
int defaultPollResult = -10;
int pollWaitTimeMillis = 100;
int insertedValue1 = 1;
Expand Down Expand Up @@ -164,15 +169,17 @@ public void testPoll_withTimeout_timesOut() throws InterruptedException {
@Test
public void testPoll_emptyQueue() {
WeightedBoundedQueue<Integer> queue =
WeightedBoundedQueue.create(MAX_WEIGHT, i -> Math.min(MAX_WEIGHT, i));
WeightedBoundedQueue.create(
WeightedSemaphore.create(MAX_WEIGHT, i -> Math.min(MAX_WEIGHT, i)));

assertNull(queue.poll());
}

@Test
public void testTake() throws InterruptedException {
WeightedBoundedQueue<Integer> queue =
WeightedBoundedQueue.create(MAX_WEIGHT, i -> Math.min(MAX_WEIGHT, i));
WeightedBoundedQueue.create(
WeightedSemaphore.create(MAX_WEIGHT, i -> Math.min(MAX_WEIGHT, i)));

AtomicInteger value = new AtomicInteger();
// Should block until value is available
Expand All @@ -194,4 +201,46 @@ public void testTake() throws InterruptedException {

assertEquals(MAX_WEIGHT, value.get());
}

@Test
public void testPut_sharedWeigher() throws InterruptedException {
WeightedSemaphore<Integer> weigher =
WeightedSemaphore.create(MAX_WEIGHT, i -> Math.min(MAX_WEIGHT, i));
WeightedBoundedQueue<Integer> queue1 = WeightedBoundedQueue.create(weigher);
WeightedBoundedQueue<Integer> queue2 = WeightedBoundedQueue.create(weigher);

// Insert value that takes all the weight into the queue1.
queue1.put(MAX_WEIGHT);

// Try to insert a value into the queue2. This will block since there is no capacity in the
// weigher.
Thread putThread =
new Thread(
() -> {
try {
Thread.sleep(100);
} catch (InterruptedException e) {
throw new RuntimeException(e);
}
queue2.put(MAX_WEIGHT);
});
putThread.start();

// Should only see the first value in the queue, since the queue is at capacity. putThread
// should be blocked. The weight should be the same however, since queue1 and queue2 are sharing
// the weigher.
assertEquals(MAX_WEIGHT, queue1.queuedElementsWeight());
assertEquals(1, queue1.size());
assertEquals(MAX_WEIGHT, queue2.queuedElementsWeight());
assertEquals(0, queue2.size());

// Poll queue1, pulling off the only value inside and freeing up the capacity in the weigher.
queue1.poll();

// Wait for the putThread which was previously blocked due to the weigher being at capacity.
putThread.join();

assertEquals(MAX_WEIGHT, queue2.queuedElementsWeight());
assertEquals(1, queue2.size());
}
}

0 comments on commit 881d43f

Please sign in to comment.