From 766316b3aef6fd68955b8cdb1a6d413d1d93e0d0 Mon Sep 17 00:00:00 2001 From: Maxime Dion Date: Thu, 25 Apr 2024 00:02:11 -0500 Subject: [PATCH 1/2] first pass at implementing predicate pushdown, seems to work --- Cargo.toml | 10 +- src/async_reader/mod.rs | 2 +- src/datafusion/file_opener.rs | 32 +++-- src/datafusion/helpers.rs | 228 +++++++++++++++++++++++++++++++ src/datafusion/mod.rs | 1 + src/datafusion/scanner.rs | 14 +- src/datafusion/table_factory.rs | 132 +++++++++++++++++- src/datafusion/table_provider.rs | 37 ++++- src/reader/filters.rs | 7 +- src/reader/mod.rs | 4 +- src/reader/zarr_read.rs | 87 ++++++++---- 11 files changed, 499 insertions(+), 55 deletions(-) create mode 100644 src/datafusion/helpers.rs diff --git a/Cargo.toml b/Cargo.toml index e2821f4..350f78d 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -30,9 +30,17 @@ arrow-cast = { version = "50.0.0" } arrow-schema = { version = "50.0.0" } arrow-data = { version = "50.0.0" } datafusion = { version = "36.0", optional = true } +datafusion-expr = { version = "36.0", optional = true } +datafusion-common = { version = "36.0", optional = true } +datafusion-physical-expr = { version = "36.0", optional = true } [features] -datafusion = ["dep:datafusion"] +datafusion = [ + "dep:datafusion", + "dep:datafusion-physical-expr", + "dep:datafusion-expr", + "dep:datafusion-common", +] all = ["datafusion"] [dev-dependencies] diff --git a/src/async_reader/mod.rs b/src/async_reader/mod.rs index 03fefe9..cc19016 100644 --- a/src/async_reader/mod.rs +++ b/src/async_reader/mod.rs @@ -469,7 +469,7 @@ impl ZarrReadAsync<'a> + Clone + Unpin + Send + 'static> let mut predicate_stream: Option> = None; if let Some(filter) = &self.filter { - let predicate_proj = filter.get_all_projections(); + let predicate_proj = filter.get_all_projections()?; predicate_stream = Some( ZarrStoreAsync::new( self.zarr_reader_async.clone(), diff --git a/src/datafusion/file_opener.rs b/src/datafusion/file_opener.rs index 1663a62..8faec30 100644 --- a/src/datafusion/file_opener.rs +++ b/src/datafusion/file_opener.rs @@ -17,22 +17,26 @@ use arrow_schema::ArrowError; use datafusion::{datasource::physical_plan::FileOpener, error::DataFusionError}; +use datafusion_physical_expr::PhysicalExpr; use futures::{StreamExt, TryStreamExt}; +use std::sync::Arc; use crate::{ - async_reader::{ZarrPath, ZarrRecordBatchStreamBuilder}, + async_reader::{ZarrPath, ZarrReadAsync, ZarrRecordBatchStreamBuilder}, reader::ZarrProjection, }; use super::config::ZarrConfig; +use super::helpers::build_row_filter; pub struct ZarrFileOpener { config: ZarrConfig, + filters: Option>, } impl ZarrFileOpener { - pub fn new(config: ZarrConfig) -> Self { - Self { config } + pub fn new(config: ZarrConfig, filters: Option>) -> Self { + Self { config, filters } } } @@ -43,15 +47,27 @@ impl FileOpener for ZarrFileOpener { ) -> datafusion::error::Result { let config = self.config.clone(); + let filters_to_pushdown = self.filters.clone(); Ok(Box::pin(async move { let zarr_path = ZarrPath::new(config.object_store, file_meta.object_meta.location); - let rng = file_meta.range.map(|r| (r.start as usize, r.end as usize)); - let projection = ZarrProjection::from(config.projection.as_ref()); - let batch_reader = ZarrRecordBatchStreamBuilder::new(zarr_path) - .with_projection(projection) + let mut batch_reader_builder = + ZarrRecordBatchStreamBuilder::new(zarr_path.clone()).with_projection(projection); + if let Some(filters) = filters_to_pushdown { + let schema = zarr_path + .get_zarr_metadata() + .await + .map_err(|e| DataFusionError::External(Box::new(e)))? + .arrow_schema() + .map_err(|e| DataFusionError::External(Box::new(e)))?; + let filters = build_row_filter(&filters, &schema)?; + if let Some(filters) = filters { + batch_reader_builder = batch_reader_builder.with_filter(filters); + } + } + let batch_reader = batch_reader_builder .build_partial_reader(rng) .await .map_err(|e| DataFusionError::External(Box::new(e)))?; @@ -81,7 +97,7 @@ mod tests { let test_data = get_test_v2_data_path("lat_lon_example.zarr".to_string()); let config = ZarrConfig::new(Arc::new(local_fs)); - let opener = ZarrFileOpener::new(config); + let opener = ZarrFileOpener::new(config, None); let file_meta = FileMeta { object_meta: ObjectMeta { diff --git a/src/datafusion/helpers.rs b/src/datafusion/helpers.rs new file mode 100644 index 0000000..770c34d --- /dev/null +++ b/src/datafusion/helpers.rs @@ -0,0 +1,228 @@ +use crate::reader::{ZarrArrowPredicate, ZarrChunkFilter, ZarrProjection}; +use arrow::array::BooleanArray; +use arrow::error::ArrowError; +use arrow::record_batch::RecordBatch; +use arrow_schema::Schema; +use datafusion_common::cast::as_boolean_array; +use datafusion_common::tree_node::{RewriteRecursion, TreeNode, TreeNodeRewriter, VisitRecursion}; +use datafusion_common::Result as DataFusionResult; +use datafusion_common::{internal_err, DataFusionError}; +use datafusion_expr::{Expr, ScalarFunctionDefinition, Volatility}; +use datafusion_physical_expr::expressions::Column; +use datafusion_physical_expr::utils::reassign_predicate_columns; +use datafusion_physical_expr::{split_conjunction, PhysicalExpr}; +use std::collections::BTreeSet; +use std::sync::Arc; + +// Checks whether the given expression can be resolved using only the columns `col_names`. +// Copied from datafusion, because it's not accessible from the outside. +pub fn expr_applicable_for_cols(col_names: &[String], expr: &Expr) -> bool { + let mut is_applicable = true; + expr.apply(&mut |expr| match expr { + Expr::Column(datafusion_common::Column { ref name, .. }) => { + is_applicable &= col_names.contains(name); + if is_applicable { + Ok(VisitRecursion::Skip) + } else { + Ok(VisitRecursion::Stop) + } + } + Expr::Literal(_) + | Expr::Alias(_) + | Expr::OuterReferenceColumn(_, _) + | Expr::ScalarVariable(_, _) + | Expr::Not(_) + | Expr::IsNotNull(_) + | Expr::IsNull(_) + | Expr::IsTrue(_) + | Expr::IsFalse(_) + | Expr::IsUnknown(_) + | Expr::IsNotTrue(_) + | Expr::IsNotFalse(_) + | Expr::IsNotUnknown(_) + | Expr::Negative(_) + | Expr::Cast { .. } + | Expr::TryCast { .. } + | Expr::BinaryExpr { .. } + | Expr::Between { .. } + | Expr::Like { .. } + | Expr::SimilarTo { .. } + | Expr::InList { .. } + | Expr::Exists { .. } + | Expr::InSubquery(_) + | Expr::ScalarSubquery(_) + | Expr::GetIndexedField { .. } + | Expr::GroupingSet(_) + | Expr::Case { .. } => Ok(VisitRecursion::Continue), + + Expr::ScalarFunction(scalar_function) => match &scalar_function.func_def { + ScalarFunctionDefinition::BuiltIn(fun) => match fun.volatility() { + Volatility::Immutable => Ok(VisitRecursion::Continue), + Volatility::Stable | Volatility::Volatile => { + is_applicable = false; + Ok(VisitRecursion::Stop) + } + }, + ScalarFunctionDefinition::UDF(fun) => match fun.signature().volatility { + Volatility::Immutable => Ok(VisitRecursion::Continue), + Volatility::Stable | Volatility::Volatile => { + is_applicable = false; + Ok(VisitRecursion::Stop) + } + }, + ScalarFunctionDefinition::Name(_) => { + internal_err!("Function `Expr` with name should be resolved.") + } + }, + + Expr::AggregateFunction { .. } + | Expr::Sort { .. } + | Expr::WindowFunction { .. } + | Expr::Wildcard { .. } + | Expr::Unnest { .. } + | Expr::Placeholder(_) => { + is_applicable = false; + Ok(VisitRecursion::Stop) + } + }) + .unwrap(); + is_applicable +} + +// Below is all the logic necessary (I think) to convert a PhysicalExpr into a ZarrChunkFilter. +// The logic is mostly copied from datafusion, and is simplified here for the zarr use case. +pub struct ZarrFilterCandidate { + expr: Arc, + projection: Vec, +} + +struct ZarrFilterCandidateBuilder<'a> { + expr: Arc, + file_schema: &'a Schema, + required_column_indices: BTreeSet, +} + +impl<'a> ZarrFilterCandidateBuilder<'a> { + pub fn new(expr: Arc, file_schema: &'a Schema) -> Self { + Self { + expr, + file_schema, + required_column_indices: BTreeSet::default(), + } + } + + pub fn build(mut self) -> DataFusionResult> { + let expr = self.expr.clone().rewrite(&mut self)?; + + Ok(Some(ZarrFilterCandidate { + expr, + projection: self.required_column_indices.into_iter().collect(), + })) + } +} + +impl<'a> TreeNodeRewriter for ZarrFilterCandidateBuilder<'a> { + type N = Arc; + + fn pre_visit(&mut self, node: &Arc) -> DataFusionResult { + if let Some(column) = node.as_any().downcast_ref::() { + if let Ok(idx) = self.file_schema.index_of(column.name()) { + self.required_column_indices.insert(idx); + } + } + + Ok(RewriteRecursion::Continue) + } + + fn mutate(&mut self, expr: Arc) -> DataFusionResult> { + Ok(expr) + } +} + +#[derive(Clone)] +pub struct ZarrDatafusionArrowPredicate { + physical_expr: Arc, + projection_mask: ZarrProjection, + projection: Vec, +} + +impl ZarrDatafusionArrowPredicate { + pub fn new(candidate: ZarrFilterCandidate, schema: &Schema) -> DataFusionResult { + let cols: Vec<_> = candidate + .projection + .iter() + .map(|idx| schema.field(*idx).name().to_string()) + .collect(); + + let schema = Arc::new(schema.project(&candidate.projection)?); + let physical_expr = reassign_predicate_columns(candidate.expr, &schema, true)?; + + Ok(Self { + physical_expr, + projection_mask: ZarrProjection::keep(cols.clone()), + projection: cols, + }) + } +} + +impl ZarrArrowPredicate for ZarrDatafusionArrowPredicate { + fn projection(&self) -> &ZarrProjection { + &self.projection_mask + } + + fn evaluate(&mut self, batch: &RecordBatch) -> Result { + let index_projection = self + .projection + .iter() + .map(|col| batch.schema().index_of(col)) + .collect::, _>>()?; + let batch = batch.project(&index_projection[..])?; + + match self + .physical_expr + .evaluate(&batch) + .and_then(|v| v.into_array(batch.num_rows())) + { + Ok(array) => { + let bool_arr = as_boolean_array(&array)?.clone(); + Ok(bool_arr) + } + Err(e) => Err(ArrowError::ComputeError(format!( + "Error evaluating filter predicate: {e:?}" + ))), + } + } +} + +pub(crate) fn build_row_filter( + expr: &Arc, + file_schema: &Schema, +) -> DataFusionResult> { + let predicates = split_conjunction(expr); + let candidates: Vec = predicates + .into_iter() + .flat_map(|expr| { + if let Ok(candidate) = + ZarrFilterCandidateBuilder::new(expr.clone(), file_schema).build() + { + candidate + } else { + None + } + }) + .collect(); + + if candidates.is_empty() { + Ok(None) + } else { + let mut filters: Vec> = vec![]; + for candidate in candidates { + let filter = ZarrDatafusionArrowPredicate::new(candidate, file_schema)?; + filters.push(Box::new(filter)); + } + + let chunk_filter = ZarrChunkFilter::new(filters); + + Ok(Some(chunk_filter)) + } +} diff --git a/src/datafusion/mod.rs b/src/datafusion/mod.rs index 0fa784f..c0390ff 100644 --- a/src/datafusion/mod.rs +++ b/src/datafusion/mod.rs @@ -17,6 +17,7 @@ pub mod config; pub mod file_opener; +mod helpers; pub mod scanner; pub mod table_factory; pub mod table_provider; diff --git a/src/datafusion/scanner.rs b/src/datafusion/scanner.rs index bb0dc6b..a4356f3 100644 --- a/src/datafusion/scanner.rs +++ b/src/datafusion/scanner.rs @@ -23,7 +23,7 @@ use datafusion::{ datasource::physical_plan::{FileScanConfig, FileStream}, physical_plan::{ metrics::ExecutionPlanMetricsSet, DisplayAs, DisplayFormatType, ExecutionPlan, - Partitioning, SendableRecordBatchStream, + Partitioning, PhysicalExpr, SendableRecordBatchStream, }, }; @@ -43,11 +43,14 @@ pub struct ZarrScan { /// The statistics for the scan. statistics: Statistics, + + /// Filters that will be pushed down to the Zarr stream reader. + filters: Option>, } impl ZarrScan { /// Create a new Zarr scan. - pub fn new(base_config: FileScanConfig) -> Self { + pub fn new(base_config: FileScanConfig, filters: Option>) -> Self { let (projected_schema, statistics, _lex_sorting) = base_config.project(); Self { @@ -55,6 +58,7 @@ impl ZarrScan { projected_schema, metrics: ExecutionPlanMetricsSet::new(), statistics, + filters, } } } @@ -100,9 +104,7 @@ impl ExecutionPlan for ZarrScan { let config = ZarrConfig::new(object_store).with_projection(self.base_config.projection.clone()); - - let opener = ZarrFileOpener::new(config); - + let opener = ZarrFileOpener::new(config, self.filters.clone()); let stream = FileStream::new(&self.base_config, partition, opener, &self.metrics)?; Ok(Box::pin(stream) as SendableRecordBatchStream) @@ -169,7 +171,7 @@ mod tests { output_ordering: vec![], }; - let scanner = ZarrScan::new(scan_config); + let scanner = ZarrScan::new(scan_config, None); let session = datafusion::execution::context::SessionContext::new(); diff --git a/src/datafusion/table_factory.rs b/src/datafusion/table_factory.rs index 40271c2..edc4b39 100644 --- a/src/datafusion/table_factory.rs +++ b/src/datafusion/table_factory.rs @@ -59,15 +59,39 @@ impl TableProviderFactory for ZarrListingTableFactory { #[cfg(test)] mod tests { - use std::sync::Arc; + use crate::reader::{ZarrError, ZarrResult}; + use crate::tests::get_test_v2_data_path; + use arrow::record_batch::RecordBatch; + use arrow_array::cast::AsArray; + use arrow_array::types::*; + use arrow_buffer::ScalarBuffer; use datafusion::execution::{ config::SessionConfig, context::{SessionContext, SessionState}, runtime_env::RuntimeEnv, }; + use itertools::enumerate; + use std::sync::Arc; - use crate::tests::get_test_v2_data_path; + fn extract_col( + col_name: &str, + rec_batch: &RecordBatch, + ) -> ZarrResult> + where + T: ArrowPrimitiveType, + { + for (idx, col) in enumerate(rec_batch.schema().fields.iter()) { + if col.name().as_str() == col_name { + let values = rec_batch.column(idx).as_primitive::().values(); + return Ok(values.clone()); + } + } + + Err(ZarrError::InvalidMetadata( + "column name not found".to_string(), + )) + } #[tokio::test] async fn test_create() -> Result<(), Box> { @@ -103,4 +127,108 @@ mod tests { Ok(()) } + + #[tokio::test] + async fn test_predicates() -> Result<(), Box> { + let mut state = SessionState::new_with_config_rt( + SessionConfig::default(), + Arc::new(RuntimeEnv::default()), + ); + + state + .table_factories_mut() + .insert("ZARR".into(), Arc::new(super::ZarrListingTableFactory {})); + + let test_data = get_test_v2_data_path("lat_lon_example.zarr".to_string()); + + let sql = format!( + "CREATE EXTERNAL TABLE zarr_table STORED AS ZARR LOCATION '{}'", + test_data.display(), + ); + + let session = SessionContext::new_with_state(state); + session.sql(&sql).await?; + + // apply one predicate on one column. + let sql = "SELECT lat, lon FROM zarr_table WHERE lat > 38.21"; + let df = session.sql(sql).await?; + + let batches = df.collect().await?; + for batch in batches { + let values = extract_col::("lat", &batch)?; + assert!(values.iter().all(|v| *v > 38.21)); + } + + // apply 2 predicates, each on one column. + let sql = "SELECT lat, lon FROM zarr_table WHERE lat > 38.21 AND lon > -109.59"; + let df = session.sql(sql).await?; + + let batches = df.collect().await?; + for batch in batches { + let lat_values = extract_col::("lat", &batch)?; + let lon_values = extract_col::("lon", &batch)?; + assert!(lat_values + .iter() + .zip(lon_values.iter()) + .all(|(lat, lon)| *lat > 38.21 && *lon > -109.59)); + } + + // same as above, but flip the column order in the predicates. + let sql = "SELECT lat, lon FROM zarr_table WHERE lon > -109.59 AND lat > 38.21"; + let df = session.sql(sql).await?; + + let batches = df.collect().await?; + for batch in batches { + let lat_values = extract_col::("lat", &batch)?; + let lon_values = extract_col::("lon", &batch)?; + assert!(lat_values + .iter() + .zip(lon_values.iter()) + .all(|(lat, lon)| *lat > 38.21 && *lon > -109.59)); + } + + // apply one predicate that includes 2 columns + let sql = "SELECT lat, lon FROM zarr_table WHERE lat + lon > -71.39"; + let df = session.sql(sql).await?; + + let batches = df.collect().await?; + for batch in batches { + let lat_values = extract_col::("lat", &batch)?; + let lon_values = extract_col::("lon", &batch)?; + assert!(lat_values + .iter() + .zip(lon_values.iter()) + .all(|(lat, lon)| *lat + *lon > -71.39)); + } + + // same as above, but flip the column order in the predicates. + let sql = "SELECT lat, lon FROM zarr_table WHERE lon + lat > -71.39"; + let df = session.sql(sql).await?; + + let batches = df.collect().await?; + for batch in batches { + let lat_values = extract_col::("lat", &batch)?; + let lon_values = extract_col::("lon", &batch)?; + assert!(lat_values + .iter() + .zip(lon_values.iter()) + .all(|(lat, lon)| *lat + *lon > -71.39)); + } + + // apply 3 predicates, 2 on one column and one on 2 columns. + let sql = "SELECT lat, lon FROM zarr_table WHERE lat > 38.21 AND lon > -109.59 AND lat + lon > -71.09"; + let df = session.sql(sql).await?; + + let batches = df.collect().await?; + for batch in batches { + let lat_values = extract_col::("lat", &batch)?; + let lon_values = extract_col::("lon", &batch)?; + assert!(lat_values + .iter() + .zip(lon_values.iter()) + .all(|(lat, lon)| *lat > 38.21 && *lon > -109.59 && *lat + *lon > -71.09)); + } + + Ok(()) + } } diff --git a/src/datafusion/table_provider.rs b/src/datafusion/table_provider.rs index b264004..246dceb 100644 --- a/src/datafusion/table_provider.rs +++ b/src/datafusion/table_provider.rs @@ -20,22 +20,24 @@ use std::sync::Arc; use arrow_schema::{Schema, SchemaRef}; use async_trait::async_trait; use datafusion::{ - common::Statistics, + common::{Statistics, ToDFSchema}, datasource::{ listing::{ListingTableUrl, PartitionedFile}, physical_plan::FileScanConfig, TableProvider, TableType, }, execution::context::SessionState, - logical_expr::{Expr, TableProviderFilterPushDown}, + logical_expr::{utils::conjunction, Expr, TableProviderFilterPushDown}, physical_plan::ExecutionPlan, }; +use datafusion_physical_expr::create_physical_expr; use crate::{ async_reader::{ZarrPath, ZarrReadAsync}, reader::ZarrResult, }; +use super::helpers::expr_applicable_for_cols; use super::scanner::ZarrScan; pub struct ListingZarrTableOptions {} @@ -99,18 +101,31 @@ impl TableProvider for ZarrTableProvider { &self, filters: &[&Expr], ) -> datafusion::error::Result> { - // TODO: which filters can we push down? Ok(filters .iter() - .map(|_| TableProviderFilterPushDown::Unsupported) + .map(|filter| { + if expr_applicable_for_cols( + &self + .table_schema + .fields + .iter() + .map(|field| field.name().to_string()) + .collect::>(), + filter, + ) { + TableProviderFilterPushDown::Exact + } else { + TableProviderFilterPushDown::Unsupported + } + }) .collect()) } async fn scan( &self, - _state: &SessionState, + state: &SessionState, projection: Option<&Vec>, - _filters: &[Expr], + filters: &[Expr], limit: Option, ) -> datafusion::error::Result> { let object_store_url = self.config.table_path.object_store(); @@ -118,6 +133,14 @@ impl TableProvider for ZarrTableProvider { let pf = PartitionedFile::new(self.config.table_path.prefix().clone(), 0); let file_groups = vec![vec![pf]]; + let filters = if let Some(expr) = conjunction(filters.to_vec()) { + let table_df_schema = self.table_schema.clone().to_dfschema()?; + let filters = create_physical_expr(&expr, &table_df_schema, state.execution_props())?; + Some(filters) + } else { + None + }; + let file_scan_config = FileScanConfig { object_store_url, file_schema: Arc::new(self.table_schema.clone()), // TODO differentiate between file and table schema @@ -129,7 +152,7 @@ impl TableProvider for ZarrTableProvider { output_ordering: vec![], }; - let scanner = ZarrScan::new(file_scan_config); + let scanner = ZarrScan::new(file_scan_config, filters); Ok(Arc::new(scanner)) } diff --git a/src/reader/filters.rs b/src/reader/filters.rs index 9dac4e0..a28fadb 100644 --- a/src/reader/filters.rs +++ b/src/reader/filters.rs @@ -20,6 +20,7 @@ use arrow_schema::ArrowError; use dyn_clone::{clone_trait_object, DynClone}; use crate::reader::ZarrProjection; +use crate::reader::ZarrResult; /// A predicate operating on [`RecordBatch`]. Here we have the [`DynClone`] trait /// bound because when dealing with the async zarr reader, it's useful to be able @@ -100,13 +101,13 @@ impl ZarrChunkFilter { } /// Get the combined projections for all the predicates in the filter. - pub fn get_all_projections(&self) -> ZarrProjection { + pub fn get_all_projections(&self) -> ZarrResult { let mut proj = ZarrProjection::all(); for pred in self.predicates.iter() { - proj.update(pred.projection().clone()); + proj.update(pred.projection().clone())?; } - proj + Ok(proj) } } diff --git a/src/reader/mod.rs b/src/reader/mod.rs index 904e308..070584b 100644 --- a/src/reader/mod.rs +++ b/src/reader/mod.rs @@ -394,7 +394,7 @@ impl ZarrRecordBatchReaderBuilder { let mut predicate_store: Option> = None; if let Some(filter) = &self.filter { - let predicate_proj = filter.get_all_projections(); + let predicate_proj = filter.get_all_projections()?; predicate_store = Some(ZarrStore::new( self.zarr_reader.clone(), chunk_pos.clone(), @@ -466,7 +466,7 @@ mod zarr_reader_tests { let mut matched = false; for (idx, col) in enumerate(rec.schema().fields.iter()) { if col.name().as_str() == col_name { - assert_eq!(rec.column(idx).as_primitive::().values(), targets,); + assert_eq!(rec.column(idx).as_primitive::().values(), targets); matched = true; } } diff --git a/src/reader/zarr_read.rs b/src/reader/zarr_read.rs index 19fdb5e..c402d4b 100644 --- a/src/reader/zarr_read.rs +++ b/src/reader/zarr_read.rs @@ -16,7 +16,7 @@ // under the License. use itertools::Itertools; -use std::collections::{HashMap, HashSet}; +use std::collections::HashMap; use std::fs::{read, read_to_string}; use std::path::PathBuf; @@ -82,7 +82,7 @@ impl ZarrInMemoryChunk { } } -#[derive(Clone, PartialEq)] +#[derive(Clone, PartialEq, Debug)] pub(crate) enum ProjectionType { Select, SelectByIndex, @@ -92,7 +92,7 @@ pub(crate) enum ProjectionType { /// A structure to handle skipping or selecting specific columns (zarr arrays) from /// a zarr store. -#[derive(Clone)] +#[derive(Clone, Debug)] pub struct ZarrProjection { projection_type: ProjectionType, col_names: Option>, @@ -184,30 +184,67 @@ impl ZarrProjection { } } - pub(crate) fn update(&mut self, other_proj: ZarrProjection) { - if other_proj.projection_type == ProjectionType::Null { - return; - } - - if self.projection_type == ProjectionType::Null { - self.projection_type = other_proj.projection_type; - self.col_names = other_proj.col_names; - return; - } - - let col_names = self.col_names.as_mut().unwrap(); - if other_proj.projection_type == self.projection_type { - let mut s: HashSet = HashSet::from_iter(col_names.clone()); - let other_cols = other_proj.col_names.unwrap(); - s.extend::>(HashSet::from_iter(other_cols)); - self.col_names = Some(s.into_iter().collect_vec()); - } else { - for col in other_proj.col_names.as_ref().unwrap() { - if let Some(index) = col_names.iter().position(|value| value == col) { - col_names.remove(index); + pub(crate) fn update(&mut self, other_proj: ZarrProjection) -> ZarrResult<()> { + match (&self.projection_type, &other_proj.projection_type) { + (_, ProjectionType::Null) => (), + (ProjectionType::Null, _) => { + self.projection_type = other_proj.projection_type; + self.col_names = other_proj.col_names; + self.col_indices = other_proj.col_indices; + } + (ProjectionType::SelectByIndex, ProjectionType::SelectByIndex) => { + let mut indices = self + .col_indices + .take() + .expect("ZarrProjection missing indices"); + for i in other_proj + .col_indices + .expect("ZarrProjection update missing indices") + { + if !indices.contains(&i) { + indices.push(i); + } } + self.col_indices = Some(indices); } - } + (ProjectionType::Select, ProjectionType::Select) => { + let mut col_names = self + .col_names + .take() + .expect("ZarrProjection missing col names"); + for col in other_proj + .col_names + .expect("ZarrProjection update missing col names") + { + if !col_names.contains(&col) { + col_names.push(col); + } + } + self.col_names = Some(col_names); + } + (ProjectionType::Skip, ProjectionType::Select) + | (ProjectionType::Select, ProjectionType::Skip) => { + let mut col_names = self + .col_names + .take() + .expect("ZarrProjection missing col names"); + for col in other_proj + .col_names + .expect("ZarrProjection update missing col names") + { + if let Some(index) = col_names.iter().position(|value| value == &col) { + col_names.remove(index); + } + } + } + _ => { + return Err(ZarrError::InvalidPredicate( + "Invalid ZarrProjection update".to_string(), + )) + } + }; + + Ok(()) } } From bbd4c5783bd82b59228dddf8fbaf294f6759795d Mon Sep 17 00:00:00 2001 From: Maxime Dion Date: Fri, 26 Apr 2024 00:47:46 -0500 Subject: [PATCH 2/2] made predicate pushdowns inexact and removed row filtering logic --- src/async_reader/mod.rs | 42 +++++++++++---- src/datafusion/helpers.rs | 17 ++++++ src/datafusion/table_factory.rs | 62 ++++++++++++---------- src/datafusion/table_provider.rs | 3 +- src/reader/codecs.rs | 19 ++----- src/reader/mod.rs | 91 ++++++-------------------------- src/reader/zarr_read.rs | 42 +++++++-------- 7 files changed, 121 insertions(+), 155 deletions(-) diff --git a/src/async_reader/mod.rs b/src/async_reader/mod.rs index cc19016..f25263a 100644 --- a/src/async_reader/mod.rs +++ b/src/async_reader/mod.rs @@ -238,7 +238,6 @@ pub struct ZarrRecordBatchStream { meta: ZarrStoreMetadata, filter: Option, state: ZarrStreamState, - mask: Option, // an option so that we can "take" the wrapper and bundle it // in a future when polling the stream. @@ -266,7 +265,6 @@ impl ZarrRecordBatchStream { predicate_store_wrapper, store_wrapper: Some(ZarrStoreWrapper::new(zarr_store)), state: ZarrStreamState::Init, - mask: None, } } } @@ -355,7 +353,6 @@ where .skip_next_chunk(); self.state = ZarrStreamState::Init; } else { - self.mask = Some(mask); let wrapper = self.store_wrapper.take().expect(LOST_STORE_ERR); let fut = wrapper.get_next().boxed(); self.state = ZarrStreamState::Reading(fut); @@ -373,16 +370,13 @@ where let chunk = chunk?; let container = ZarrInMemoryChunkContainer::new(chunk); - let mut zarr_reader = ZarrRecordBatchReader::new( + let zarr_reader = ZarrRecordBatchReader::new( self.meta.clone(), Some(container), None, None, ); - if self.mask.is_some() { - zarr_reader = zarr_reader.with_row_mask(self.mask.take().unwrap()); - } self.state = ZarrStreamState::Decoding(zarr_reader); } else { // if store returns none, it's the end and it's time to return @@ -840,11 +834,37 @@ mod zarr_async_reader_tests { ("float_data".to_string(), DataType::Float64), ]); - let rec = &records[1]; + // check the values in a chunk. the predicate pushdown only takes care of + // skipping whole chunks, so there is no guarantee that the values in the + // record batch fully satisfy the predicate, here we are only checking that + // the first chunk that was read is the first one with some values that + // satisfy the predicate. + let rec = &records[0]; validate_names_and_types(&target_types, rec); - validate_primitive_column::("lat", rec, &[38.8, 38.9, 39.0]); - validate_primitive_column::("lon", rec, &[-109.7, -109.7, -109.7]); - validate_primitive_column::("float_data", rec, &[1042.0, 1043.0, 1044.0]); + validate_primitive_column::( + "lat", + rec, + &[ + 38.4, 38.5, 38.6, 38.7, 38.4, 38.5, 38.6, 38.7, 38.4, 38.5, 38.6, 38.7, 38.4, 38.5, + 38.6, 38.7, + ], + ); + validate_primitive_column::( + "lon", + rec, + &[ + -110.0, -110.0, -110.0, -110.0, -109.9, -109.9, -109.9, -109.9, -109.8, -109.8, + -109.8, -109.8, -109.7, -109.7, -109.7, -109.7, + ], + ); + validate_primitive_column::( + "float_data", + rec, + &[ + 1005.0, 1006.0, 1007.0, 1008.0, 1016.0, 1017.0, 1018.0, 1019.0, 1027.0, 1028.0, + 1029.0, 1030.0, 1038.0, 1039.0, 1040.0, 1041.0, + ], + ); } #[tokio::test] diff --git a/src/datafusion/helpers.rs b/src/datafusion/helpers.rs index 770c34d..e30f991 100644 --- a/src/datafusion/helpers.rs +++ b/src/datafusion/helpers.rs @@ -1,3 +1,20 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + use crate::reader::{ZarrArrowPredicate, ZarrChunkFilter, ZarrProjection}; use arrow::array::BooleanArray; use arrow::error::ArrowError; diff --git a/src/datafusion/table_factory.rs b/src/datafusion/table_factory.rs index edc4b39..4ac2a6b 100644 --- a/src/datafusion/table_factory.rs +++ b/src/datafusion/table_factory.rs @@ -59,7 +59,6 @@ impl TableProviderFactory for ZarrListingTableFactory { #[cfg(test)] mod tests { - use crate::reader::{ZarrError, ZarrResult}; use crate::tests::get_test_v2_data_path; use arrow::record_batch::RecordBatch; use arrow_array::cast::AsArray; @@ -71,26 +70,18 @@ mod tests { context::{SessionContext, SessionState}, runtime_env::RuntimeEnv, }; - use itertools::enumerate; use std::sync::Arc; - fn extract_col( - col_name: &str, - rec_batch: &RecordBatch, - ) -> ZarrResult> + fn extract_col(col_name: &str, rec_batch: &RecordBatch) -> ScalarBuffer where T: ArrowPrimitiveType, { - for (idx, col) in enumerate(rec_batch.schema().fields.iter()) { - if col.name().as_str() == col_name { - let values = rec_batch.column(idx).as_primitive::().values(); - return Ok(values.clone()); - } - } - - Err(ZarrError::InvalidMetadata( - "column name not found".to_string(), - )) + rec_batch + .column_by_name(col_name) + .unwrap() + .as_primitive::() + .values() + .clone() } #[tokio::test] @@ -155,7 +146,7 @@ mod tests { let batches = df.collect().await?; for batch in batches { - let values = extract_col::("lat", &batch)?; + let values = extract_col::("lat", &batch); assert!(values.iter().all(|v| *v > 38.21)); } @@ -165,8 +156,8 @@ mod tests { let batches = df.collect().await?; for batch in batches { - let lat_values = extract_col::("lat", &batch)?; - let lon_values = extract_col::("lon", &batch)?; + let lat_values = extract_col::("lat", &batch); + let lon_values = extract_col::("lon", &batch); assert!(lat_values .iter() .zip(lon_values.iter()) @@ -179,8 +170,8 @@ mod tests { let batches = df.collect().await?; for batch in batches { - let lat_values = extract_col::("lat", &batch)?; - let lon_values = extract_col::("lon", &batch)?; + let lat_values = extract_col::("lat", &batch); + let lon_values = extract_col::("lon", &batch); assert!(lat_values .iter() .zip(lon_values.iter()) @@ -193,8 +184,8 @@ mod tests { let batches = df.collect().await?; for batch in batches { - let lat_values = extract_col::("lat", &batch)?; - let lon_values = extract_col::("lon", &batch)?; + let lat_values = extract_col::("lat", &batch); + let lon_values = extract_col::("lon", &batch); assert!(lat_values .iter() .zip(lon_values.iter()) @@ -207,8 +198,8 @@ mod tests { let batches = df.collect().await?; for batch in batches { - let lat_values = extract_col::("lat", &batch)?; - let lon_values = extract_col::("lon", &batch)?; + let lat_values = extract_col::("lat", &batch); + let lon_values = extract_col::("lon", &batch); assert!(lat_values .iter() .zip(lon_values.iter()) @@ -221,14 +212,31 @@ mod tests { let batches = df.collect().await?; for batch in batches { - let lat_values = extract_col::("lat", &batch)?; - let lon_values = extract_col::("lon", &batch)?; + let lat_values = extract_col::("lat", &batch); + let lon_values = extract_col::("lon", &batch); assert!(lat_values .iter() .zip(lon_values.iter()) .all(|(lat, lon)| *lat > 38.21 && *lon > -109.59 && *lat + *lon > -71.09)); } + // check a query that doesn't include the column needed in the predicate. the first query + // below is used to produce the reference values, and the second one is the one we're testing + // for, since it has a predicate on lon, but doesn't select lon. + let sql = "SELECT lat, lon FROM zarr_table WHERE lon > -109.59"; + let df = session.sql(sql).await?; + let lat_lon_batches = df.collect().await?; + + let sql = "SELECT lat FROM zarr_table WHERE lon > -109.59"; + let df = session.sql(sql).await?; + let lat_batches = df.collect().await?; + + for (lat_batch, lat_lon_batch) in lat_batches.iter().zip(lat_lon_batches.iter()) { + let lat_values_1 = extract_col::("lat", lat_batch); + let lat_values_2 = extract_col::("lat", lat_lon_batch); + assert_eq!(lat_values_1, lat_values_2); + } + Ok(()) } } diff --git a/src/datafusion/table_provider.rs b/src/datafusion/table_provider.rs index 246dceb..ecbc84f 100644 --- a/src/datafusion/table_provider.rs +++ b/src/datafusion/table_provider.rs @@ -101,6 +101,7 @@ impl TableProvider for ZarrTableProvider { &self, filters: &[&Expr], ) -> datafusion::error::Result> { + // TODO handle predicates on partition columns as Exact. Ok(filters .iter() .map(|filter| { @@ -113,7 +114,7 @@ impl TableProvider for ZarrTableProvider { .collect::>(), filter, ) { - TableProviderFilterPushDown::Exact + TableProviderFilterPushDown::Inexact } else { TableProviderFilterPushDown::Unsupported } diff --git a/src/reader/codecs.rs b/src/reader/codecs.rs index 821c77c..fcb7463 100644 --- a/src/reader/codecs.rs +++ b/src/reader/codecs.rs @@ -715,14 +715,10 @@ pub(crate) fn apply_codecs( data_type: &ZarrDataType, codecs: &Vec, sharding_params: Option, - final_indices: Option<&Vec>, ) -> ZarrResult<(ArrayRef, FieldRef)> { macro_rules! return_array { ($func_name: tt, $data_t: expr, $array_t: ty) => { - let mut data = $func_name(raw_data, &chunk_dims, &real_dims, &codecs, sharding_params)?; - if let Some(indices) = final_indices { - keep_indices(&mut data, &indices); - }; + let data = $func_name(raw_data, &chunk_dims, &real_dims, &codecs, sharding_params)?; let field = Field::new(col_name, $data_t, false); let arr: $array_t = data.into(); return Ok((Arc::new(arr), Arc::new(field))) @@ -732,10 +728,7 @@ pub(crate) fn apply_codecs( match data_type { ZarrDataType::Bool => { let data = decode_u8_chunk(raw_data, chunk_dims, real_dims, codecs, sharding_params)?; - let mut data: Vec = data.iter().map(|x| *x != 0).collect(); - if let Some(indices) = final_indices { - keep_indices(&mut data, indices); - }; + let data: Vec = data.iter().map(|x| *x != 0).collect(); let field = Field::new(col_name, DataType::Boolean, false); let arr: BooleanArray = data.into(); Ok((Arc::new(arr), Arc::new(field))) @@ -817,7 +810,7 @@ pub(crate) fn apply_codecs( pyunicode = true; str_len_adjustment = PY_UNICODE_SIZE; } - let mut data = decode_string_chunk( + let data = decode_string_chunk( raw_data, *s / str_len_adjustment, chunk_dims, @@ -826,9 +819,6 @@ pub(crate) fn apply_codecs( sharding_params, pyunicode, )?; - if let Some(indices) = final_indices { - keep_indices(&mut data, indices); - }; let field = Field::new(col_name, DataType::Utf8, false); let arr: StringArray = data.into(); @@ -874,7 +864,6 @@ mod zarr_codecs_tests { &data_type, &codecs, sharding_params, - None, ) .unwrap(); @@ -925,7 +914,6 @@ mod zarr_codecs_tests { &data_type, &codecs, sharding_params, - None, ) .unwrap(); @@ -977,7 +965,6 @@ mod zarr_codecs_tests { &data_type, &codecs, sharding_params, - None, ) .unwrap(); diff --git a/src/reader/mod.rs b/src/reader/mod.rs index 070584b..d188324 100644 --- a/src/reader/mod.rs +++ b/src/reader/mod.rs @@ -158,7 +158,6 @@ pub struct ZarrRecordBatchReader { zarr_store: Option, filter: Option, predicate_projection_store: Option, - row_mask: Option, } impl ZarrRecordBatchReader { @@ -173,22 +172,10 @@ impl ZarrRecordBatchReader { zarr_store, filter, predicate_projection_store, - row_mask: None, } } - pub(crate) fn with_row_mask(self, row_mask: BooleanArray) -> Self { - Self { - row_mask: Some(row_mask), - ..self - } - } - - pub(crate) fn unpack_chunk( - &self, - mut chunk: ZarrInMemoryChunk, - final_indices: Option<&Vec>, - ) -> ZarrResult { + pub(crate) fn unpack_chunk(&self, mut chunk: ZarrInMemoryChunk) -> ZarrResult { let mut arrs: Vec = Vec::with_capacity(self.meta.get_num_columns()); let mut fields: Vec = Vec::with_capacity(self.meta.get_num_columns()); @@ -200,7 +187,6 @@ impl ZarrRecordBatchReader { data, chunk.get_real_dims(), self.meta.get_chunk_dims(), - final_indices, )?; arrs.push(arr); fields.push(field); @@ -215,7 +201,6 @@ impl ZarrRecordBatchReader { arr_chnk: ZarrInMemoryArray, real_dims: &Vec, chunk_dims: &Vec, - final_indices: Option<&Vec>, ) -> ZarrResult<(ArrayRef, FieldRef)> { // get the metadata for the array let meta = self.meta.get_array_meta(&col_name)?; @@ -232,7 +217,6 @@ impl ZarrRecordBatchReader { meta.get_type(), meta.get_codecs(), meta.get_sharding_params(), - final_indices, )?; Ok((arr, field)) @@ -246,7 +230,6 @@ impl Iterator for ZarrRecordBatchReader { fn next(&mut self) -> Option { // handle filters first. - let mut bool_arr: Option = self.row_mask.clone(); if let Some(store) = self.predicate_projection_store.as_mut() { let predicate_proj_chunk = store.next_chunk(); @@ -254,9 +237,10 @@ impl Iterator for ZarrRecordBatchReader { let predicate_proj_chunk = unwrap_or_return!(predicate_proj_chunk.unwrap()); - let predicate_rec = self.unpack_chunk(predicate_proj_chunk, None); + let predicate_rec = self.unpack_chunk(predicate_proj_chunk); let predicate_rec = unwrap_or_return!(predicate_rec); + let mut bool_arr: Option = None; if let Some(filter) = self.filter.as_mut() { for predicate in filter.predicates.iter_mut() { let mask = predicate.evaluate(&predicate_rec); @@ -316,24 +300,9 @@ impl Iterator for ZarrRecordBatchReader { // main logic for the chunk let next_batch = self.zarr_store.as_mut().unwrap().next_chunk(); next_batch.as_ref()?; - let next_batch = unwrap_or_return!(next_batch.unwrap()); - let mut final_indices: Option> = None; - - // if we have a bool array to mask some values, we get the indices (the rows) - // that we need to keep. those are then applied across all zarr array un the chunk. - if let Some(mask) = bool_arr { - let mask = mask.values(); - final_indices = Some( - mask.iter() - .enumerate() - .filter(|x| x.1) - .map(|x| x.0) - .collect(), - ); - } - return Some(self.unpack_chunk(next_batch, final_indices.as_ref())); + Some(self.unpack_chunk(next_batch)) } } @@ -928,71 +897,41 @@ mod zarr_reader_tests { let reader = builder.build().unwrap(); let records: Vec = reader.map(|x| x.unwrap()).collect(); - // check the 4 chunks that have some in the specified lat/lon range - // center chunk let target_types = HashMap::from([ ("lat".to_string(), DataType::Float64), ("lon".to_string(), DataType::Float64), ("float_data".to_string(), DataType::Float64), ]); + // check the values in a chunk. the predicate pushdown only takes care of + // skipping whole chunks, so there is no guarantee that the values in the + // record batch fully satisfy the predicate, here we are only checking that + // the first chunk that was read is the first one with some values that + // satisfy the predicate. let rec = &records[0]; validate_names_and_types(&target_types, rec); - validate_primitive_column::("lat", rec, &[38.6, 38.7]); - validate_primitive_column::("lon", rec, &[-109.7, -109.7]); - validate_primitive_column::("float_data", rec, &[1040.0, 1041.0]); - - let rec = &records[1]; - validate_names_and_types(&target_types, rec); - validate_primitive_column::("lat", rec, &[38.8, 38.9, 39.0]); - validate_primitive_column::("lon", rec, &[-109.7, -109.7, -109.7]); - validate_primitive_column::("float_data", rec, &[1042.0, 1043.0, 1044.0]); - - let rec = &records[2]; - validate_names_and_types(&target_types, rec); - validate_primitive_column::( - "lat", - rec, - &[38.6, 38.7, 38.6, 38.7, 38.6, 38.7, 38.6, 38.7], - ); - validate_primitive_column::( - "lon", - rec, - &[ - -109.6, -109.6, -109.5, -109.5, -109.4, -109.4, -109.3, -109.3, - ], - ); - validate_primitive_column::( - "float_data", - rec, - &[ - 1051.0, 1052.0, 1062.0, 1063.0, 1073.0, 1074.0, 1084.0, 1085.0, - ], - ); - - let rec = &records[3]; - validate_names_and_types(&target_types, rec); validate_primitive_column::( "lat", rec, &[ - 38.8, 38.9, 39.0, 38.8, 38.9, 39.0, 38.8, 38.9, 39.0, 38.8, 38.9, 39.0, + 38.4, 38.5, 38.6, 38.7, 38.4, 38.5, 38.6, 38.7, 38.4, 38.5, 38.6, 38.7, 38.4, 38.5, + 38.6, 38.7, ], ); validate_primitive_column::( "lon", rec, &[ - -109.6, -109.6, -109.6, -109.5, -109.5, -109.5, -109.4, -109.4, -109.4, -109.3, - -109.3, -109.3, + -110.0, -110.0, -110.0, -110.0, -109.9, -109.9, -109.9, -109.9, -109.8, -109.8, + -109.8, -109.8, -109.7, -109.7, -109.7, -109.7, ], ); validate_primitive_column::( "float_data", rec, &[ - 1053.0, 1054.0, 1055.0, 1064.0, 1065.0, 1066.0, 1075.0, 1076.0, 1077.0, 1086.0, - 1087.0, 1088.0, + 1005.0, 1006.0, 1007.0, 1008.0, 1016.0, 1017.0, 1018.0, 1019.0, 1027.0, 1028.0, + 1029.0, 1030.0, 1038.0, 1039.0, 1040.0, 1041.0, ], ); } diff --git a/src/reader/zarr_read.rs b/src/reader/zarr_read.rs index c402d4b..cc9c763 100644 --- a/src/reader/zarr_read.rs +++ b/src/reader/zarr_read.rs @@ -193,14 +193,12 @@ impl ZarrProjection { self.col_indices = other_proj.col_indices; } (ProjectionType::SelectByIndex, ProjectionType::SelectByIndex) => { - let mut indices = self - .col_indices - .take() - .expect("ZarrProjection missing indices"); - for i in other_proj - .col_indices - .expect("ZarrProjection update missing indices") - { + let mut indices = self.col_indices.take().ok_or(ZarrError::InvalidPredicate( + "ZarrProjection missing indices".to_string(), + ))?; + for i in other_proj.col_indices.ok_or(ZarrError::InvalidPredicate( + "ZarrProjection update missing indices".to_string(), + ))? { if !indices.contains(&i) { indices.push(i); } @@ -208,14 +206,12 @@ impl ZarrProjection { self.col_indices = Some(indices); } (ProjectionType::Select, ProjectionType::Select) => { - let mut col_names = self - .col_names - .take() - .expect("ZarrProjection missing col names"); - for col in other_proj - .col_names - .expect("ZarrProjection update missing col names") - { + let mut col_names = self.col_names.take().ok_or(ZarrError::InvalidPredicate( + "ZarrProjection missing col_names".to_string(), + ))?; + for col in other_proj.col_names.ok_or(ZarrError::InvalidPredicate( + "ZarrProjection update missing col_names".to_string(), + ))? { if !col_names.contains(&col) { col_names.push(col); } @@ -224,14 +220,12 @@ impl ZarrProjection { } (ProjectionType::Skip, ProjectionType::Select) | (ProjectionType::Select, ProjectionType::Skip) => { - let mut col_names = self - .col_names - .take() - .expect("ZarrProjection missing col names"); - for col in other_proj - .col_names - .expect("ZarrProjection update missing col names") - { + let mut col_names = self.col_names.take().ok_or(ZarrError::InvalidPredicate( + "ZarrProjection missing col_names".to_string(), + ))?; + for col in other_proj.col_names.ok_or(ZarrError::InvalidPredicate( + "ZarrProjection update missing col_names".to_string(), + ))? { if let Some(index) = col_names.iter().position(|value| value == &col) { col_names.remove(index); }