Skip to content

Commit

Permalink
more
Browse files Browse the repository at this point in the history
Signed-off-by: Marc Handalian <[email protected]>
  • Loading branch information
mch2 committed Jan 28, 2025
1 parent 07b238d commit 7c9d9c5
Show file tree
Hide file tree
Showing 14 changed files with 103 additions and 152 deletions.
3 changes: 2 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -64,4 +64,5 @@ testfixtures_shared/
.ci/jobs/

# build files generated
doc-tools/missing-doclet/bin/
doc-tools/missing-doclet/bin/
/libs/datafusion/jni/target/
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,9 @@

import java.util.Set;

/**
*
*/
public interface PartitionedStreamProducer extends StreamProducer {
Set<StreamTicket> partitions();
void setRootTicket(StreamTicket ticket);
}
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,11 @@ public interface StreamProducer extends Closeable {
*/
BatchedJob createJob(BufferAllocator allocator);

default Set<StreamTicket> partitions() {
return Collections.emptySet();
}


/**
* Provides an estimate of the total number of rows that will be produced.
*
Expand Down
5 changes: 0 additions & 5 deletions libs/datafusion/build.gradle
Original file line number Diff line number Diff line change
Expand Up @@ -15,11 +15,6 @@ dependencies {
base {
archivesName = 'opensearch-datafusion'
}
// logging
implementation "org.apache.logging.log4j:log4j-api:${versions.log4j}"
implementation "org.apache.logging.log4j:log4j-core:${versions.log4j}"
implementation "org.apache.logging.log4j:log4j-slf4j-impl:${versions.log4j}"

// testing
testImplementation "com.carrotsearch.randomizedtesting:randomizedtesting-runner:${versions.randomizedrunner}"
testImplementation "junit:junit:${versions.junit}"
Expand Down
70 changes: 10 additions & 60 deletions libs/datafusion/jni/src/provider.rs
Original file line number Diff line number Diff line change
@@ -1,12 +1,9 @@
use arrow_flight::FlightEndpoint;
use arrow_flight::{flight_service_client::FlightServiceClient, FlightDescriptor, FlightInfo};
use arrow_flight::{flight_service_client::FlightServiceClient, FlightDescriptor};
use bytes::Bytes;
use datafusion::catalog::TableProvider;
use datafusion::common::JoinType;
use datafusion::common::Result;
use datafusion::error::DataFusionError;
use datafusion::functions_aggregate::expr_fn::sum;
use datafusion::prelude::Expr;
use futures::TryFutureExt;
use std::{collections::HashMap, sync::Arc};

Expand All @@ -16,7 +13,6 @@ use datafusion_table_providers::flight::{
FlightDriver, FlightMetadata, FlightProperties, FlightTableFactory, FlightTable
};
use datafusion::logical_expr::test::function_stub::count;
use futures::future::try_join_all;
use tonic::async_trait;
use tonic::transport::Channel;
mod test;
Expand Down Expand Up @@ -63,20 +59,6 @@ pub async fn join(
println!("LEFT FRAME:");
// left_df.clone().show().await;
println!("RIGHT FRAME:");
// right_df.clone().show().await;

// // select all the cols returned by right but alias the join field
// let select_cols: Vec<Expr> = right_df.schema()
// .fields()
// .iter()
// .map(|field| {
// if field.name() == &join_field {
// col(&join_field).alias("right.join_field")
// } else {
// col(field.name())
// }
// })
// .collect();

return left_df
.join(
Expand All @@ -95,8 +77,9 @@ pub async fn join(
pub async fn aggregate(
ctx: SessionContext,
ticket: Bytes,
entry_point: String
) -> datafusion::common::Result<DataFrame> {
let df = dataframe_for_index(&ctx, "theIndex".to_owned(), ticket).await?;
let df = dataframe_for_index(&ctx, "theIndex".to_owned(), ticket, entry_point).await?;
df.aggregate(vec![col("")], vec![count(col("a"))])
.map_err(|e| DataFusionError::Execution(format!("Failed to sort DataFrame: {}", e)))
}
Expand All @@ -105,66 +88,33 @@ pub async fn aggregate(
async fn dataframe_for_index(
ctx: &SessionContext,
prefix: String,
ticket: Bytes
ticket: Bytes,
entry_point: String
) -> Result<DataFrame> {
let table_name = format!("{}-s", prefix);
get_dataframe_for_tickets(ctx, table_name, ticket.clone()).await
get_dataframe_for_tickets(ctx, table_name, ticket.clone(), entry_point.clone()).await
}

// Return a single dataframe for an entire index.
// Each ticket in tickets represents a single shard.
// async fn dataframe_for_index(
// ctx: &SessionContext,
// prefix: String,
// tickets: Vec<Bytes>,
// ) -> Result<DataFrame> {
// println!("UNION");
// let inner_futures = tickets
// .into_iter()
// .enumerate()
// .map(|(j, bytes)| {
// let table_name = format!("{}-s-{}", prefix, j);
// get_dataframe_for_tickets(ctx, table_name, vec![bytes.clone()])
// })
// .collect::<Vec<_>>();

// let frames = try_join_all(inner_futures)
// .await
// .map_err(|e| DataFusionError::Execution(format!("Failed to join futures: {}", e)))?;
// frames[0].clone().show().await?;
// union_df(frames)
// }

// // Union a list of DataFrames
// fn union_df(frames: Vec<DataFrame>) -> Result<DataFrame> {
// Ok(frames
// .into_iter()
// .reduce(|acc, df| match acc.union(df) {
// Ok(unioned_df) => unioned_df,
// Err(e) => panic!("Failed to union DataFrames: {}", e),
// })
// .ok_or_else(|| DataFusionError::Execution("No frames to union".to_string()))?)
// }

// registers a single table from the list of given tickets, then reads it immediately returning a dataframe.
// intended to be used to register and get a df for a single shard.
async fn get_dataframe_for_tickets(
ctx: &SessionContext,
name: String,
ticket: Bytes,
entry_point: String
) -> Result<DataFrame> {
println!("Register");
register_table(ctx, name.clone(), ticket)
register_table(ctx, name.clone(), ticket, entry_point.clone())
.and_then(|_| ctx.table(&name))
.await
}
// registers a single table with datafusion using DataFusion TableProviders.
// Uses a TicketedFlightDriver to register the table with the list of given tickets.
async fn register_table(ctx: &SessionContext, name: String, ticket: Bytes) -> Result<()> {
async fn register_table(ctx: &SessionContext, name: String, ticket: Bytes, entry_point: String) -> Result<()> {
let driver: TicketedFlightDriver = TicketedFlightDriver { ticket };
let table_factory: FlightTableFactory = FlightTableFactory::new(Arc::new(driver));
let table: FlightTable = table_factory
.open_table(format!("http://localhost:{}", "8815"), HashMap::new())
.open_table(entry_point, HashMap::new())
.await
.map_err(|e| DataFusionError::Execution(format!("Error creating table: {}", e)))?;
println!("Registering table {:?}", table);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,8 @@
import java.util.Map;
import java.util.Set;
import java.util.concurrent.CompletableFuture;
import java.util.function.BiConsumer;
import java.util.function.Consumer;
import java.util.function.Function;
import java.util.function.Supplier;

Expand All @@ -40,10 +42,11 @@ public class DataFrameStreamProducer implements PartitionedStreamProducer {
private StreamTicket rootTicket;
private Set<StreamTicket> partitions;

public DataFrameStreamProducer(Set<StreamTicket> partitions, Function<StreamTicket, CompletableFuture<DataFrame>> frameSupplier) {
public DataFrameStreamProducer(Function<StreamProducer, StreamTicket> streamRegistrar, Set<StreamTicket> partitions, Function<StreamTicket, CompletableFuture<DataFrame>> frameSupplier) {
logger.info("Constructed DataFrameFlightProducer");
this.frameSupplier = frameSupplier;
this.partitions = partitions;
this.rootTicket = streamRegistrar.apply(this);
}

@Override
Expand All @@ -59,7 +62,6 @@ public VectorSchemaRoot createRoot(BufferAllocator allocator) {

@Override
public BatchedJob createJob(BufferAllocator allocator) {
assert rootTicket != null;
return new BatchedJob() {

private DataFrame df;
Expand All @@ -68,6 +70,7 @@ public BatchedJob createJob(BufferAllocator allocator) {
@Override
public void run(VectorSchemaRoot root, FlushSignal flushSignal) {
try {
assert rootTicket != null;
df = frameSupplier.apply(rootTicket).join();
recordBatchStream = df.getStream(allocator).get();
while (recordBatchStream.loadNextBatch().join()) {
Expand Down Expand Up @@ -125,8 +128,7 @@ public Set<StreamTicket> partitions() {
return partitions;
}

@Override
public void setRootTicket(StreamTicket ticket) {
this.rootTicket = ticket;
public StreamTicket getRootTicket() {
return rootTicket;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
* Main DataFusion Entrypoint.
*/
public class DataFusion {

public static Logger logger = LogManager.getLogger(DataFusion.class);
static {
System.loadLibrary("datafusion_jni");
}
Expand All @@ -29,8 +29,6 @@ public class DataFusion {

static native void agg(long runtime, long ctx, byte[] ticket, ObjectResultCallback callback);

static native void join(long runtime, long ctx, String joinField, byte[] left, byte[] right, ObjectResultCallback callback);

// collect the DataFrame
static native void collect(long runtime, long df, BiConsumer<String, byte[]> callback);

Expand All @@ -50,8 +48,6 @@ public static CompletableFuture<DataFrame> query(byte[] ticket) {
return future;
}

public static Logger logger = LogManager.getLogger(DataFusion.class);

public static CompletableFuture<DataFrame> agg(byte[] ticket) {
SessionContext ctx = new SessionContext();
CompletableFuture<DataFrame> future = new CompletableFuture<>();
Expand All @@ -66,18 +62,4 @@ public static CompletableFuture<DataFrame> agg(byte[] ticket) {
});
return future;
}

public static CompletableFuture<DataFrame> join(byte[] left, byte[] right, String joinField) {
SessionContext ctx = new SessionContext();
CompletableFuture<DataFrame> future = new CompletableFuture<>();
DataFusion.join(ctx.getRuntime(), ctx.getPointer(), joinField, left, right, (err, ptr) -> {
if (err != null) {
future.completeExceptionally(new RuntimeException(err));
} else {
DataFrame df = new DataFrame(ctx, ptr);
future.complete(df);
}
});
return future;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,8 @@ public class SessionContext implements AutoCloseable {

static native long createSessionContext();

// static native long createTable();

static native long createRuntime();

static native void destroyRuntime(long pointer);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -47,9 +47,9 @@ public class BaseFlightProducer extends NoOpFlightProducer {
* Constructs a new BaseFlightProducer.
*
* @param flightClientManager The FlightClientManager to handle client connections.
* @param streamManager The StreamManager to handle stream operations, including
* retrieving and removing streams based on tickets.
* @param allocator The BufferAllocator for memory management in Arrow operations.
* @param streamManager The StreamManager to handle stream operations, including
* retrieving and removing streams based on tickets.
* @param allocator The BufferAllocator for memory management in Arrow operations.
*/
public BaseFlightProducer(FlightClientManager flightClientManager, FlightStreamManager streamManager, BufferAllocator allocator) {
this.flightClientManager = flightClientManager;
Expand All @@ -62,8 +62,8 @@ public BaseFlightProducer(FlightClientManager flightClientManager, FlightStreamM
* This method orchestrates the entire process of setting up the stream,
* managing backpressure, and handling data flow to the client.
*
* @param context The call context (unused in this implementation)
* @param ticket The ticket containing stream information
* @param context The call context (unused in this implementation)
* @param ticket The ticket containing stream information
* @param listener The server stream listener to handle the data flow
*/
@Override
Expand Down Expand Up @@ -127,7 +127,7 @@ public void getStream(CallContext context, Ticket ticket, ServerStreamListener l
/**
* Retrieves FlightInfo for the given FlightDescriptor, handling both local and remote cases.
*
* @param context The call context
* @param context The call context
* @param descriptor The FlightDescriptor containing stream information
* @return FlightInfo for the requested stream
*/
Expand All @@ -143,17 +143,16 @@ public FlightInfo getFlightInfo(CallContext context, FlightDescriptor descriptor
throw CallStatus.NOT_FOUND.withDescription("FlightInfo not found").toRuntimeException();
}
StreamProducer producer = streamProducerHolder.getProducer();
if (producer instanceof PartitionedStreamProducer) {
Set<StreamTicket> partitions = ((PartitionedStreamProducer) producer).partitions();
for (StreamTicket partition : partitions) {
Location location = flightClientManager.getFlightClientLocation(streamTicket.getNodeId());
if (location == null) {
throw CallStatus.UNAVAILABLE.withDescription("Internal error while determining location information from ticket.")
.toRuntimeException();
}
endpoints.add(new FlightEndpoint(new Ticket(partition.toBytes()), location));
Set<StreamTicket> partitions = producer.partitions();
for (StreamTicket partition : partitions) {
Location location = flightClientManager.getFlightClientLocation(streamTicket.getNodeId());
if (location == null) {
throw CallStatus.UNAVAILABLE.withDescription("Internal error while determining location information from ticket.")
.toRuntimeException();
}
endpoints.add(new FlightEndpoint(new Ticket(partition.toBytes()), location));
}

// Location location = flightClientManager.getFlightClientLocation(streamTicket.getNodeId());
// FlightEndpoint endpoint = new FlightEndpoint(new Ticket(descriptor.getCommand()), location);
FlightInfo.Builder infoBuilder = FlightInfo.builder(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -79,9 +79,6 @@ public void setClientManager(FlightClientManager clientManager) {
public StreamTicket registerStream(StreamProducer provider, TaskId parentTaskId) {
StreamTicket ticket = ticketFactory.newTicket();
streamProducers.put(ticket.getTicketId(), new StreamProducerHolder(provider, allocatorSupplier.get()));
if (provider instanceof PartitionedStreamProducer) {
((PartitionedStreamProducer) provider).setRootTicket(ticket);
}
return ticket;
}

Expand Down
Loading

0 comments on commit 7c9d9c5

Please sign in to comment.