Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

first pass at implementing predicate pushdown, seems to work #16

Merged
merged 2 commits into from
Apr 29, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 9 additions & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -30,9 +30,17 @@ arrow-cast = { version = "50.0.0" }
arrow-schema = { version = "50.0.0" }
arrow-data = { version = "50.0.0" }
datafusion = { version = "36.0", optional = true }
datafusion-expr = { version = "36.0", optional = true }
datafusion-common = { version = "36.0", optional = true }
datafusion-physical-expr = { version = "36.0", optional = true }

[features]
datafusion = ["dep:datafusion"]
datafusion = [
"dep:datafusion",
"dep:datafusion-physical-expr",
"dep:datafusion-expr",
"dep:datafusion-common",
]
all = ["datafusion"]

[dev-dependencies]
Expand Down
2 changes: 1 addition & 1 deletion src/async_reader/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -469,7 +469,7 @@ impl<T: for<'a> ZarrReadAsync<'a> + Clone + Unpin + Send + 'static>

let mut predicate_stream: Option<ZarrStoreAsync<T>> = None;
if let Some(filter) = &self.filter {
let predicate_proj = filter.get_all_projections();
let predicate_proj = filter.get_all_projections()?;
predicate_stream = Some(
ZarrStoreAsync::new(
self.zarr_reader_async.clone(),
Expand Down
32 changes: 24 additions & 8 deletions src/datafusion/file_opener.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,22 +17,26 @@

use arrow_schema::ArrowError;
use datafusion::{datasource::physical_plan::FileOpener, error::DataFusionError};
use datafusion_physical_expr::PhysicalExpr;
use futures::{StreamExt, TryStreamExt};
use std::sync::Arc;

use crate::{
async_reader::{ZarrPath, ZarrRecordBatchStreamBuilder},
async_reader::{ZarrPath, ZarrReadAsync, ZarrRecordBatchStreamBuilder},
reader::ZarrProjection,
};

use super::config::ZarrConfig;
use super::helpers::build_row_filter;

pub struct ZarrFileOpener {
config: ZarrConfig,
filters: Option<Arc<dyn PhysicalExpr>>,
}

impl ZarrFileOpener {
pub fn new(config: ZarrConfig) -> Self {
Self { config }
pub fn new(config: ZarrConfig, filters: Option<Arc<dyn PhysicalExpr>>) -> Self {
Self { config, filters }
}
}

Expand All @@ -43,15 +47,27 @@ impl FileOpener for ZarrFileOpener {
) -> datafusion::error::Result<datafusion::datasource::physical_plan::FileOpenFuture> {
let config = self.config.clone();

let filters_to_pushdown = self.filters.clone();
Ok(Box::pin(async move {
let zarr_path = ZarrPath::new(config.object_store, file_meta.object_meta.location);

let rng = file_meta.range.map(|r| (r.start as usize, r.end as usize));

let projection = ZarrProjection::from(config.projection.as_ref());

let batch_reader = ZarrRecordBatchStreamBuilder::new(zarr_path)
.with_projection(projection)
let mut batch_reader_builder =
ZarrRecordBatchStreamBuilder::new(zarr_path.clone()).with_projection(projection);
if let Some(filters) = filters_to_pushdown {
let schema = zarr_path
.get_zarr_metadata()
.await
.map_err(|e| DataFusionError::External(Box::new(e)))?
.arrow_schema()
.map_err(|e| DataFusionError::External(Box::new(e)))?;
let filters = build_row_filter(&filters, &schema)?;
if let Some(filters) = filters {
batch_reader_builder = batch_reader_builder.with_filter(filters);
}
}
let batch_reader = batch_reader_builder
.build_partial_reader(rng)
.await
.map_err(|e| DataFusionError::External(Box::new(e)))?;
Expand Down Expand Up @@ -81,7 +97,7 @@ mod tests {
let test_data = get_test_v2_data_path("lat_lon_example.zarr".to_string());

let config = ZarrConfig::new(Arc::new(local_fs));
let opener = ZarrFileOpener::new(config);
let opener = ZarrFileOpener::new(config, None);

let file_meta = FileMeta {
object_meta: ObjectMeta {
Expand Down
228 changes: 228 additions & 0 deletions src/datafusion/helpers.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,228 @@
use crate::reader::{ZarrArrowPredicate, ZarrChunkFilter, ZarrProjection};
maximedion2 marked this conversation as resolved.
Show resolved Hide resolved
use arrow::array::BooleanArray;
use arrow::error::ArrowError;
use arrow::record_batch::RecordBatch;
use arrow_schema::Schema;
use datafusion_common::cast::as_boolean_array;
use datafusion_common::tree_node::{RewriteRecursion, TreeNode, TreeNodeRewriter, VisitRecursion};
use datafusion_common::Result as DataFusionResult;
use datafusion_common::{internal_err, DataFusionError};
use datafusion_expr::{Expr, ScalarFunctionDefinition, Volatility};
use datafusion_physical_expr::expressions::Column;
use datafusion_physical_expr::utils::reassign_predicate_columns;
use datafusion_physical_expr::{split_conjunction, PhysicalExpr};
use std::collections::BTreeSet;
use std::sync::Arc;

// Checks whether the given expression can be resolved using only the columns `col_names`.
// Copied from datafusion, because it's not accessible from the outside.
pub fn expr_applicable_for_cols(col_names: &[String], expr: &Expr) -> bool {
let mut is_applicable = true;
expr.apply(&mut |expr| match expr {
Expr::Column(datafusion_common::Column { ref name, .. }) => {
is_applicable &= col_names.contains(name);
if is_applicable {
Ok(VisitRecursion::Skip)
} else {
Ok(VisitRecursion::Stop)
}
}
Expr::Literal(_)
| Expr::Alias(_)
| Expr::OuterReferenceColumn(_, _)
| Expr::ScalarVariable(_, _)
| Expr::Not(_)
| Expr::IsNotNull(_)
| Expr::IsNull(_)
| Expr::IsTrue(_)
| Expr::IsFalse(_)
| Expr::IsUnknown(_)
| Expr::IsNotTrue(_)
| Expr::IsNotFalse(_)
| Expr::IsNotUnknown(_)
| Expr::Negative(_)
| Expr::Cast { .. }
| Expr::TryCast { .. }
| Expr::BinaryExpr { .. }
| Expr::Between { .. }
| Expr::Like { .. }
| Expr::SimilarTo { .. }
| Expr::InList { .. }
| Expr::Exists { .. }
| Expr::InSubquery(_)
| Expr::ScalarSubquery(_)
| Expr::GetIndexedField { .. }
| Expr::GroupingSet(_)
| Expr::Case { .. } => Ok(VisitRecursion::Continue),

Expr::ScalarFunction(scalar_function) => match &scalar_function.func_def {
ScalarFunctionDefinition::BuiltIn(fun) => match fun.volatility() {
Volatility::Immutable => Ok(VisitRecursion::Continue),
Volatility::Stable | Volatility::Volatile => {
is_applicable = false;
Ok(VisitRecursion::Stop)
}
},
ScalarFunctionDefinition::UDF(fun) => match fun.signature().volatility {
Volatility::Immutable => Ok(VisitRecursion::Continue),
Volatility::Stable | Volatility::Volatile => {
is_applicable = false;
Ok(VisitRecursion::Stop)
}
},
ScalarFunctionDefinition::Name(_) => {
internal_err!("Function `Expr` with name should be resolved.")
}
},

Expr::AggregateFunction { .. }
| Expr::Sort { .. }
| Expr::WindowFunction { .. }
| Expr::Wildcard { .. }
| Expr::Unnest { .. }
| Expr::Placeholder(_) => {
is_applicable = false;
Ok(VisitRecursion::Stop)
}
})
.unwrap();
is_applicable
}

// Below is all the logic necessary (I think) to convert a PhysicalExpr into a ZarrChunkFilter.
// The logic is mostly copied from datafusion, and is simplified here for the zarr use case.
pub struct ZarrFilterCandidate {
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So in the end, this is still mostly index based. I'll make a few comments below to clarify the logic (I should actually write those comments in code, I'll do that before we merge), but column names are basically just used as an intermediate step.

expr: Arc<dyn PhysicalExpr>,
projection: Vec<usize>,
}

struct ZarrFilterCandidateBuilder<'a> {
expr: Arc<dyn PhysicalExpr>,
file_schema: &'a Schema,
required_column_indices: BTreeSet<usize>,
}

impl<'a> ZarrFilterCandidateBuilder<'a> {
pub fn new(expr: Arc<dyn PhysicalExpr>, file_schema: &'a Schema) -> Self {
Self {
expr,
file_schema,
required_column_indices: BTreeSet::default(),
}
}

pub fn build(mut self) -> DataFusionResult<Option<ZarrFilterCandidate>> {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since there aren't any builder methods on this, maybe TryFrom? Also not super clear to my why it's an Option?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good question. So I think we need a builder struct because the way the columns indices get "extracted" from the predicate is through the call to rewrite, which takes a mut reference to self, on the below line. That function requires the TreeNodeRewriter trait on its argument, and you can see that as pre_visit is called, the indices get progressively filled. To be honest, I didn't dig all the way down to how this works, I just followed the steps they follow for parquet since I didn't want to risk breaking something.

Regarding the Option, you're right that it's not clear from the code here, I believe it's like that because of the code here, https://github.com/datafusion-contrib/arrow-zarr/pull/16/files#diff-d61c0a121604c7680df3d272638903a3fc21fee9ac3381e34b5285c02b9deaf0R202-R213, specifically because the else statement returns None. Since the type of candidates is Vec<ZarrFilterCandidate>, I think the call to collect coerces options into the inner type (or skips the value if it's None)? And that means the type of candidate must be Option<...>, so that the if and else statements return types match. Again, I mostly followed the parquet implementation.

I know that just following someone else's code and replicating it somewhat naively is not the best excuse haha, but like I said I wanted to minimize the risk of messing things up here, since I'm not yet comfortable with the code base. Overall does this all make sense?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also, when we start handling hive style partitions, and the logic gets more complicated, we might need to return a Ok(None) from build in some situations, I'm following the parquet logic but also simplifying it a lot (for now), so that might lead to code that looks a bit weird, temporarily.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Cool cool, yeah that all sounds good to me to get started with. If we notice some perf issues w/ the cloning + rewriting we can reassess later.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, I haven't been paying too much attention to everything that could impact performance so far, I'm thinking I'll revisit later when we have something fully functional.

let expr = self.expr.clone().rewrite(&mut self)?;

Ok(Some(ZarrFilterCandidate {
expr,
projection: self.required_column_indices.into_iter().collect(),
}))
}
}

impl<'a> TreeNodeRewriter for ZarrFilterCandidateBuilder<'a> {
type N = Arc<dyn PhysicalExpr>;

fn pre_visit(&mut self, node: &Arc<dyn PhysicalExpr>) -> DataFusionResult<RewriteRecursion> {
if let Some(column) = node.as_any().downcast_ref::<Column>() {
if let Ok(idx) = self.file_schema.index_of(column.name()) {
self.required_column_indices.insert(idx);
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So first, we accumulate indices of columns required (by a given predicate). These indices represent the position of the column in the file schema (which will eventually be the table schema, for now we don't have that distinction), e.g. if the predicate requires lat, lon and the file schema is float_data, lat, lon, we will end up setting the projection to [1, 2]. Since the set is ordered, I think even if in the predicate the order was lon, lat, we'd end up with [1, 2] as the projection.

}
}

Ok(RewriteRecursion::Continue)
}

fn mutate(&mut self, expr: Arc<dyn PhysicalExpr>) -> DataFusionResult<Arc<dyn PhysicalExpr>> {
Ok(expr)
}
}

#[derive(Clone)]
pub struct ZarrDatafusionArrowPredicate {
physical_expr: Arc<dyn PhysicalExpr>,
projection_mask: ZarrProjection,
projection: Vec<String>,
}

impl ZarrDatafusionArrowPredicate {
pub fn new(candidate: ZarrFilterCandidate, schema: &Schema) -> DataFusionResult<Self> {
let cols: Vec<_> = candidate
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is where we convert the indices to the columns names, e.g. [1, 2] -> [lat, lon]. See below for how that's used.

.projection
.iter()
.map(|idx| schema.field(*idx).name().to_string())
.collect();

let schema = Arc::new(schema.project(&candidate.projection)?);
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Here we go from the file schema to the predicate schema, e.g. float_data, lat, lon -> lat, lon.

let physical_expr = reassign_predicate_columns(candidate.expr, &schema, true)?;
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If I understand correctly, the physical expression has the name of each column as well as an index for each. Since it was first created off of a Expr, using the file schema, the indices for each column don't necessarily match what they will be in the record batch we pass to the physical expression. Assuming we will pass the physical expression a record batch that only contains the columns it needs, we need to remap indices to columns, e.g. we go from (lat, 1), (lon, 2) to (lat, 0), (lon, 1).


Ok(Self {
physical_expr,
projection_mask: ZarrProjection::keep(cols.clone()),
projection: cols,
})
}
}

impl ZarrArrowPredicate for ZarrDatafusionArrowPredicate {
fn projection(&self) -> &ZarrProjection {
&self.projection_mask
}

fn evaluate(&mut self, batch: &RecordBatch) -> Result<BooleanArray, ArrowError> {
let index_projection = self
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is the bit that depends on the column names. Here, the incoming record batch can have any number of columns, it doesn't matter, as long as it contains at least the columns the predicate needs. In the parquet implementation, again if I understood correctly, it's expected to come in with only the required columns, but by using column names here, that's not required, we figure out the indices in the record batch based on the column names, and re-project it before passing it to the physical expression. The re-projection does still happen in the parquet implementation, I think to handle different column orderings, but here we use it to also drop unnecessary columns, that way, for example if the predicate only requires the lon column, we can re-use a record batch that contains lat, lon.

.projection
.iter()
.map(|col| batch.schema().index_of(col))
.collect::<Result<Vec<_>, _>>()?;
let batch = batch.project(&index_projection[..])?;

match self
.physical_expr
.evaluate(&batch)
.and_then(|v| v.into_array(batch.num_rows()))
{
Ok(array) => {
let bool_arr = as_boolean_array(&array)?.clone();
Ok(bool_arr)
}
Err(e) => Err(ArrowError::ComputeError(format!(
"Error evaluating filter predicate: {e:?}"
))),
}
}
}

pub(crate) fn build_row_filter(
expr: &Arc<dyn PhysicalExpr>,
file_schema: &Schema,
) -> DataFusionResult<Option<ZarrChunkFilter>> {
let predicates = split_conjunction(expr);
let candidates: Vec<ZarrFilterCandidate> = predicates
.into_iter()
.flat_map(|expr| {
if let Ok(candidate) =
ZarrFilterCandidateBuilder::new(expr.clone(), file_schema).build()
{
candidate
} else {
None
}
})
.collect();

if candidates.is_empty() {
Ok(None)
} else {
let mut filters: Vec<Box<dyn ZarrArrowPredicate>> = vec![];
for candidate in candidates {
let filter = ZarrDatafusionArrowPredicate::new(candidate, file_schema)?;
filters.push(Box::new(filter));
}

let chunk_filter = ZarrChunkFilter::new(filters);

Ok(Some(chunk_filter))
}
}
1 change: 1 addition & 0 deletions src/datafusion/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@

pub mod config;
pub mod file_opener;
mod helpers;
pub mod scanner;
pub mod table_factory;
pub mod table_provider;
14 changes: 8 additions & 6 deletions src/datafusion/scanner.rs
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ use datafusion::{
datasource::physical_plan::{FileScanConfig, FileStream},
physical_plan::{
metrics::ExecutionPlanMetricsSet, DisplayAs, DisplayFormatType, ExecutionPlan,
Partitioning, SendableRecordBatchStream,
Partitioning, PhysicalExpr, SendableRecordBatchStream,
},
};

Expand All @@ -43,18 +43,22 @@ pub struct ZarrScan {

/// The statistics for the scan.
statistics: Statistics,

/// Filters that will be pushed down to the Zarr stream reader.
filters: Option<Arc<dyn PhysicalExpr>>,
}

impl ZarrScan {
/// Create a new Zarr scan.
pub fn new(base_config: FileScanConfig) -> Self {
pub fn new(base_config: FileScanConfig, filters: Option<Arc<dyn PhysicalExpr>>) -> Self {
let (projected_schema, statistics, _lex_sorting) = base_config.project();

Self {
base_config,
projected_schema,
metrics: ExecutionPlanMetricsSet::new(),
statistics,
filters,
}
}
}
Expand Down Expand Up @@ -100,9 +104,7 @@ impl ExecutionPlan for ZarrScan {

let config =
ZarrConfig::new(object_store).with_projection(self.base_config.projection.clone());

let opener = ZarrFileOpener::new(config);

let opener = ZarrFileOpener::new(config, self.filters.clone());
let stream = FileStream::new(&self.base_config, partition, opener, &self.metrics)?;

Ok(Box::pin(stream) as SendableRecordBatchStream)
Expand Down Expand Up @@ -169,7 +171,7 @@ mod tests {
output_ordering: vec![],
};

let scanner = ZarrScan::new(scan_config);
let scanner = ZarrScan::new(scan_config, None);

let session = datafusion::execution::context::SessionContext::new();

Expand Down
Loading