Skip to content

Commit

Permalink
Fix concurrency issues in FlightClientManager and FlightStreamManager
Browse files Browse the repository at this point in the history
Signed-off-by: Rishabh Maurya <[email protected]>
  • Loading branch information
rishabhmaurya committed Jan 30, 2025
1 parent 3ff3dd5 commit 2a6590f
Show file tree
Hide file tree
Showing 20 changed files with 635 additions and 452 deletions.
4 changes: 4 additions & 0 deletions libs/arrow-memory-shaded/build.gradle
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,10 @@ tasks.named("shadowJar").configure {
include(dependency('com.fasterxml.jackson.core:'))
include(dependency('commons-codec:'))
}
manifest {
attributes('Multi-Release': 'true')
}
exclude 'META-INF/maven/**'
relocate 'io.netty', 'org.opensearch.shaded.io.netty'
relocate 'org.checkerframework', 'org.opensearch.shaded.org.checkerframework'
relocate 'org.apache.commons.codec', 'org.opensearch.shaded.commons.codec'
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
import org.apache.arrow.memory.BufferAllocator;
import org.apache.arrow.vector.VectorSchemaRoot;
import org.opensearch.common.annotation.ExperimentalApi;
import org.opensearch.common.unit.TimeValue;
import org.opensearch.core.tasks.TaskId;

import java.io.Closeable;
Expand Down Expand Up @@ -86,7 +87,7 @@ public interface StreamProducer extends Closeable {
* @param allocator The allocator to use for creating vectors
* @return A new VectorSchemaRoot instance
*/
VectorSchemaRoot createRoot(BufferAllocator allocator);
VectorSchemaRoot createRoot(BufferAllocator allocator) throws Exception;

/**
* Creates a job that will produce the stream data in batches. The job will populate
Expand All @@ -97,6 +98,14 @@ public interface StreamProducer extends Closeable {
*/
BatchedJob createJob(BufferAllocator allocator);

/**
* Returns the deadline for the job execution.
* After this deadline, the job should be considered expired.
*
* @return TimeValue representing the job's deadline
*/
TimeValue getJobDeadline();

/**
* Provides an estimate of the total number of rows that will be produced.
*
Expand All @@ -122,7 +131,7 @@ interface BatchedJob {
* @param root The VectorSchemaRoot to populate with data
* @param flushSignal Signal to coordinate with consumers
*/
void run(VectorSchemaRoot root, FlushSignal flushSignal);
void run(VectorSchemaRoot root, FlushSignal flushSignal) throws Exception;

/**
* Called to signal producer when the job is canceled.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -43,14 +43,16 @@ public interface StreamReader extends Closeable {
* Blocking request to load next batch into root.
*
* @return true if more data was found, false if the stream is exhausted
* @throws Exception if an error occurs while loading the next batch
*/
boolean next();
boolean next() throws Exception;

/**
* Returns the VectorSchemaRoot associated with this iterator.
* The content of this root is updated with each successful call to next().
*
* @return the VectorSchemaRoot
* @throws Exception if an error occurs while retrieving the root
*/
VectorSchemaRoot getRoot();
VectorSchemaRoot getRoot() throws Exception;
}
1 change: 1 addition & 0 deletions libs/flight-core-shaded/build.gradle
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ shadowJar {
dependencies {
include(dependency('org.apache.arrow:'))
}
exclude 'META-INF/maven/**'
relocate 'io.grpc.netty', 'io.grpc.netty.shaded.io.grpc.netty'
relocate 'io.netty', 'io.grpc.netty.shaded.io.netty'
mergeServiceFiles()
Expand Down
4 changes: 4 additions & 0 deletions plugins/arrow-flight-rpc/build.gradle
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,10 @@ dependencies {
}
}

tasks.internalClusterTest {
jvmArgs += ["--add-opens", "java.base/java.nio=org.apache.arrow.memory.core,ALL-UNNAMED"]
}

tasks.named('test').configure {
jacoco {
excludes = ['org/apache/arrow/flight/**']
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
import org.apache.arrow.flight.CallOptions;
import org.apache.arrow.flight.FlightDescriptor;
import org.apache.arrow.flight.FlightInfo;
import org.apache.arrow.flight.FlightRuntimeException;
import org.apache.arrow.flight.OSFlightClient;
import org.apache.arrow.memory.BufferAllocator;
import org.apache.arrow.vector.FieldVector;
Expand All @@ -20,26 +21,26 @@
import org.opensearch.arrow.flight.bootstrap.FlightService;
import org.opensearch.arrow.spi.StreamManager;
import org.opensearch.arrow.spi.StreamProducer;
import org.opensearch.arrow.spi.StreamReader;
import org.opensearch.arrow.spi.StreamTicket;
import org.opensearch.cluster.node.DiscoveryNode;
import org.opensearch.common.unit.TimeValue;
import org.opensearch.common.util.FeatureFlags;
import org.opensearch.plugins.Plugin;
import org.opensearch.test.FeatureFlagSetter;
import org.opensearch.test.OpenSearchIntegTestCase;
import org.junit.BeforeClass;

import java.io.IOException;
import java.util.Arrays;
import java.util.Collection;
import java.util.Collections;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicReference;

@OpenSearchIntegTestCase.ClusterScope(scope = OpenSearchIntegTestCase.Scope.SUITE, numDataNodes = 5)
@OpenSearchIntegTestCase.ClusterScope(scope = OpenSearchIntegTestCase.Scope.SUITE, numDataNodes = 3)
public class ArrowFlightServerIT extends OpenSearchIntegTestCase {

private FlightClientManager flightClientManager;
private FlightService flightService;

@BeforeClass
public static void setupFeatureFlags() {
FeatureFlagSetter.set(FeatureFlags.ARROW_STREAMS_SETTING.getKey());
Expand All @@ -55,80 +56,208 @@ public void setUp() throws Exception {
super.setUp();
ensureGreen();
Thread.sleep(1000);
flightService = internalCluster().getInstance(FlightService.class);
flightClientManager = flightService.getFlightClientManager();
}

public void testArrowFlightEndpoint() throws Exception {
public void testArrowFlightEndpoint() {
for (DiscoveryNode node : getClusterState().nodes()) {
OSFlightClient flightClient = flightClientManager.getFlightClient(node.getId());
FlightService flightService = internalCluster().getInstance(FlightService.class, node.getName());
FlightClientManager flightClientManager = flightService.getFlightClientManager();
OSFlightClient flightClient = flightClientManager.getFlightClient(node.getId()).get();
assertNotNull(flightClient);
flightClient.handshake(CallOptions.timeout(5000L, TimeUnit.MILLISECONDS));
}
}

public void testFlightStreamReader() throws Exception {
for (DiscoveryNode node : getClusterState().nodes()) {
StreamManager streamManagerRandomNode = getStreamManagerRandomNode();
StreamTicket ticket = streamManagerRandomNode.registerStream(getStreamProducer(), null);
StreamManager streamManagerCurrentNode = getStreamManager(node.getName());
// reader should be accessible from any node in the cluster due to the use ProxyStreamProducer
try (StreamReader reader = streamManagerCurrentNode.getStreamReader(ticket)) {
int totalBatches = 0;
while (reader.next()) {
IntVector docIDVector = (IntVector) reader.getRoot().getVector("docID");
assertEquals(10, docIDVector.getValueCount());
for (int i = 0; i < 10; i++) {
assertEquals(docIDVector.toString(), i + (totalBatches * 10L), docIDVector.get(i));
}
totalBatches++;
}
assertEquals(10, totalBatches);
}
}
}

public void testEarlyCancel() throws Exception {
DiscoveryNode previousNode = null;
for (DiscoveryNode node : getClusterState().nodes()) {
if (previousNode == null) {
previousNode = node;
continue;
}
StreamManager streamManagerServer = getStreamManager(node.getName());
TestStreamProducer streamProducer = getStreamProducer();
StreamTicket ticket = streamManagerServer.registerStream(streamProducer, null);
StreamManager streamManagerClient = getStreamManager(previousNode.getName());

CountDownLatch readerComplete = new CountDownLatch(1);
AtomicReference<Exception> readerException = new AtomicReference<>();
AtomicReference<StreamReader> readerRef = new AtomicReference<>();

// Start reader thread
Thread readerThread = new Thread(() -> {
try {
StreamReader reader = streamManagerClient.getStreamReader(ticket);
readerRef.set(reader);
assertNotNull(reader.getRoot());
IntVector docIDVector = (IntVector) reader.getRoot().getVector("docID");
assertNotNull(docIDVector);

// Read first batch
reader.next();
assertEquals(10, docIDVector.getValueCount());
for (int i = 0; i < 10; i++) {
assertEquals(docIDVector.toString(), i, docIDVector.get(i));
}
reader.close();
} catch (Exception e) {
readerException.set(e);
} finally {
readerComplete.countDown();
}
}, "flight-reader-thread");

readerThread.start();
assertTrue("Reader thread did not complete in time", readerComplete.await(1, TimeUnit.SECONDS));

// Check for any exceptions in reader thread
if (readerException.get() != null) {
throw readerException.get();
}

StreamReader reader = readerRef.get();

try {
reader.next();
fail("Expected FlightRuntimeException");
} catch (FlightRuntimeException e) {
assertEquals("CANCELLED", e.status().code().name());
assertEquals("Stream closed before end", e.getMessage());
reader.close();
}

// Wait for onCancel to complete
// Due to https://github.com/grpc/grpc-java/issues/5882, there is a logic in FlightStream.java
// where it exhausts the stream on the server side before it is actually cancelled.
assertTrue(
"Timeout waiting for stream cancellation on server [" + node.getName() + "]",
streamProducer.waitForClose(2, TimeUnit.SECONDS)
);
previousNode = node;
}
}

public void testFlightGetInfo() throws Exception {
StreamManager streamManager = flightService.getStreamManager();
StreamTicket ticket = streamManager.registerStream(getStreamProducer(), null);
StreamTicket ticket = null;
for (DiscoveryNode node : getClusterState().nodes()) {
OSFlightClient flightClient = flightClientManager.getFlightClient(node.getId());
FlightService flightService = internalCluster().getInstance(FlightService.class, node.getName());
StreamManager streamManager = flightService.getStreamManager();
if (ticket == null) {
ticket = streamManager.registerStream(getStreamProducer(), null);
}
FlightClientManager flightClientManager = flightService.getFlightClientManager();
OSFlightClient flightClient = flightClientManager.getFlightClient(node.getId()).get();
assertNotNull(flightClient);
FlightDescriptor flightDescriptor = FlightDescriptor.command(ticket.toBytes());
FlightInfo flightInfo = flightClient.getInfo(flightDescriptor, CallOptions.timeout(5000L, TimeUnit.MILLISECONDS));
assertNotNull(flightInfo);
assertEquals(100, flightInfo.getRecords());
}
}

private StreamProducer getStreamProducer() {
return new StreamProducer() {
@Override
public VectorSchemaRoot createRoot(BufferAllocator allocator) {
IntVector docIDVector = new IntVector("docID", allocator);
FieldVector[] vectors = new FieldVector[] { docIDVector };
return new VectorSchemaRoot(Arrays.asList(vectors));
}
private StreamManager getStreamManager(String nodeName) {
FlightService flightService = internalCluster().getInstance(FlightService.class, nodeName);
return flightService.getStreamManager();
}

private StreamManager getStreamManagerRandomNode() {
FlightService flightService = internalCluster().getInstance(FlightService.class);
return flightService.getStreamManager();
}

private TestStreamProducer getStreamProducer() {
return new TestStreamProducer();
}

private static class TestStreamProducer implements StreamProducer {
volatile boolean isClosed = false;
private final CountDownLatch closeLatch = new CountDownLatch(1);

VectorSchemaRoot root;

@Override
public BatchedJob createJob(BufferAllocator allocator) {
return new BatchedJob() {
@Override
public void run(VectorSchemaRoot root, FlushSignal flushSignal) {
IntVector docIDVector = (IntVector) root.getVector("docID");
for (int i = 0; i < 100; i++) {
docIDVector.setSafe(i % 10, i);
if (i >= 10) {
root.setRowCount(10);
flushSignal.awaitConsumption(1000);
}
@Override
public VectorSchemaRoot createRoot(BufferAllocator allocator) {
IntVector docIDVector = new IntVector("docID", allocator);
FieldVector[] vectors = new FieldVector[] { docIDVector };
root = new VectorSchemaRoot(Arrays.asList(vectors));
return root;
}

@Override
public BatchedJob createJob(BufferAllocator allocator) {
return new BatchedJob() {
@Override
public void run(VectorSchemaRoot root, FlushSignal flushSignal) {
IntVector docIDVector = (IntVector) root.getVector("docID");
root.setRowCount(10);
for (int i = 0; i < 100; i++) {
docIDVector.setSafe(i % 10, i);
if ((i + 1) % 10 == 0) {
flushSignal.awaitConsumption(1000);
docIDVector.clear();
root.setRowCount(10);
}
}
}

@Override
public void onCancel() {
@Override
public void onCancel() {
root.close();
isClosed = true;
}

}
@Override
public boolean isCancelled() {
return isClosed;
}
};
}

@Override
public boolean isCancelled() {
return false;
}
};
}
@Override
public TimeValue getJobDeadline() {
return TimeValue.timeValueSeconds(5);
}

@Override
public int estimatedRowCount() {
return 100;
}
@Override
public int estimatedRowCount() {
return 100;
}

@Override
public String getAction() {
return "";
}
@Override
public String getAction() {
return "";
}

@Override
public void close() throws IOException {
@Override
public void close() {
root.close();
closeLatch.countDown();
isClosed = true;
}

}
};
public boolean waitForClose(long timeout, TimeUnit unit) throws InterruptedException {
return closeLatch.await(timeout, unit);
}
}
}
Loading

0 comments on commit 2a6590f

Please sign in to comment.