Skip to content

Commit

Permalink
made predicate pushdowns inexact and removed row filtering logic
Browse files Browse the repository at this point in the history
  • Loading branch information
maximedion2 committed Apr 26, 2024
1 parent 766316b commit bbd4c57
Show file tree
Hide file tree
Showing 7 changed files with 121 additions and 155 deletions.
42 changes: 31 additions & 11 deletions src/async_reader/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -238,7 +238,6 @@ pub struct ZarrRecordBatchStream<T: ZarrStream> {
meta: ZarrStoreMetadata,
filter: Option<ZarrChunkFilter>,
state: ZarrStreamState<T>,
mask: Option<BooleanArray>,

// an option so that we can "take" the wrapper and bundle it
// in a future when polling the stream.
Expand Down Expand Up @@ -266,7 +265,6 @@ impl<T: ZarrStream> ZarrRecordBatchStream<T> {
predicate_store_wrapper,
store_wrapper: Some(ZarrStoreWrapper::new(zarr_store)),
state: ZarrStreamState::Init,
mask: None,
}
}
}
Expand Down Expand Up @@ -355,7 +353,6 @@ where
.skip_next_chunk();
self.state = ZarrStreamState::Init;
} else {
self.mask = Some(mask);
let wrapper = self.store_wrapper.take().expect(LOST_STORE_ERR);
let fut = wrapper.get_next().boxed();
self.state = ZarrStreamState::Reading(fut);
Expand All @@ -373,16 +370,13 @@ where

let chunk = chunk?;
let container = ZarrInMemoryChunkContainer::new(chunk);
let mut zarr_reader = ZarrRecordBatchReader::new(
let zarr_reader = ZarrRecordBatchReader::new(
self.meta.clone(),
Some(container),
None,
None,
);

if self.mask.is_some() {
zarr_reader = zarr_reader.with_row_mask(self.mask.take().unwrap());
}
self.state = ZarrStreamState::Decoding(zarr_reader);
} else {
// if store returns none, it's the end and it's time to return
Expand Down Expand Up @@ -840,11 +834,37 @@ mod zarr_async_reader_tests {
("float_data".to_string(), DataType::Float64),
]);

let rec = &records[1];
// check the values in a chunk. the predicate pushdown only takes care of
// skipping whole chunks, so there is no guarantee that the values in the
// record batch fully satisfy the predicate, here we are only checking that
// the first chunk that was read is the first one with some values that
// satisfy the predicate.
let rec = &records[0];
validate_names_and_types(&target_types, rec);
validate_primitive_column::<Float64Type, f64>("lat", rec, &[38.8, 38.9, 39.0]);
validate_primitive_column::<Float64Type, f64>("lon", rec, &[-109.7, -109.7, -109.7]);
validate_primitive_column::<Float64Type, f64>("float_data", rec, &[1042.0, 1043.0, 1044.0]);
validate_primitive_column::<Float64Type, f64>(
"lat",
rec,
&[
38.4, 38.5, 38.6, 38.7, 38.4, 38.5, 38.6, 38.7, 38.4, 38.5, 38.6, 38.7, 38.4, 38.5,
38.6, 38.7,
],
);
validate_primitive_column::<Float64Type, f64>(
"lon",
rec,
&[
-110.0, -110.0, -110.0, -110.0, -109.9, -109.9, -109.9, -109.9, -109.8, -109.8,
-109.8, -109.8, -109.7, -109.7, -109.7, -109.7,
],
);
validate_primitive_column::<Float64Type, f64>(
"float_data",
rec,
&[
1005.0, 1006.0, 1007.0, 1008.0, 1016.0, 1017.0, 1018.0, 1019.0, 1027.0, 1028.0,
1029.0, 1030.0, 1038.0, 1039.0, 1040.0, 1041.0,
],
);
}

#[tokio::test]
Expand Down
17 changes: 17 additions & 0 deletions src/datafusion/helpers.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,20 @@
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.

use crate::reader::{ZarrArrowPredicate, ZarrChunkFilter, ZarrProjection};
use arrow::array::BooleanArray;
use arrow::error::ArrowError;
Expand Down
62 changes: 35 additions & 27 deletions src/datafusion/table_factory.rs
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,6 @@ impl TableProviderFactory for ZarrListingTableFactory {

#[cfg(test)]
mod tests {
use crate::reader::{ZarrError, ZarrResult};
use crate::tests::get_test_v2_data_path;
use arrow::record_batch::RecordBatch;
use arrow_array::cast::AsArray;
Expand All @@ -71,26 +70,18 @@ mod tests {
context::{SessionContext, SessionState},
runtime_env::RuntimeEnv,
};
use itertools::enumerate;
use std::sync::Arc;

fn extract_col<T>(
col_name: &str,
rec_batch: &RecordBatch,
) -> ZarrResult<ScalarBuffer<T::Native>>
fn extract_col<T>(col_name: &str, rec_batch: &RecordBatch) -> ScalarBuffer<T::Native>
where
T: ArrowPrimitiveType,
{
for (idx, col) in enumerate(rec_batch.schema().fields.iter()) {
if col.name().as_str() == col_name {
let values = rec_batch.column(idx).as_primitive::<T>().values();
return Ok(values.clone());
}
}

Err(ZarrError::InvalidMetadata(
"column name not found".to_string(),
))
rec_batch
.column_by_name(col_name)
.unwrap()
.as_primitive::<T>()
.values()
.clone()
}

#[tokio::test]
Expand Down Expand Up @@ -155,7 +146,7 @@ mod tests {

let batches = df.collect().await?;
for batch in batches {
let values = extract_col::<Float64Type>("lat", &batch)?;
let values = extract_col::<Float64Type>("lat", &batch);
assert!(values.iter().all(|v| *v > 38.21));
}

Expand All @@ -165,8 +156,8 @@ mod tests {

let batches = df.collect().await?;
for batch in batches {
let lat_values = extract_col::<Float64Type>("lat", &batch)?;
let lon_values = extract_col::<Float64Type>("lon", &batch)?;
let lat_values = extract_col::<Float64Type>("lat", &batch);
let lon_values = extract_col::<Float64Type>("lon", &batch);
assert!(lat_values
.iter()
.zip(lon_values.iter())
Expand All @@ -179,8 +170,8 @@ mod tests {

let batches = df.collect().await?;
for batch in batches {
let lat_values = extract_col::<Float64Type>("lat", &batch)?;
let lon_values = extract_col::<Float64Type>("lon", &batch)?;
let lat_values = extract_col::<Float64Type>("lat", &batch);
let lon_values = extract_col::<Float64Type>("lon", &batch);
assert!(lat_values
.iter()
.zip(lon_values.iter())
Expand All @@ -193,8 +184,8 @@ mod tests {

let batches = df.collect().await?;
for batch in batches {
let lat_values = extract_col::<Float64Type>("lat", &batch)?;
let lon_values = extract_col::<Float64Type>("lon", &batch)?;
let lat_values = extract_col::<Float64Type>("lat", &batch);
let lon_values = extract_col::<Float64Type>("lon", &batch);
assert!(lat_values
.iter()
.zip(lon_values.iter())
Expand All @@ -207,8 +198,8 @@ mod tests {

let batches = df.collect().await?;
for batch in batches {
let lat_values = extract_col::<Float64Type>("lat", &batch)?;
let lon_values = extract_col::<Float64Type>("lon", &batch)?;
let lat_values = extract_col::<Float64Type>("lat", &batch);
let lon_values = extract_col::<Float64Type>("lon", &batch);
assert!(lat_values
.iter()
.zip(lon_values.iter())
Expand All @@ -221,14 +212,31 @@ mod tests {

let batches = df.collect().await?;
for batch in batches {
let lat_values = extract_col::<Float64Type>("lat", &batch)?;
let lon_values = extract_col::<Float64Type>("lon", &batch)?;
let lat_values = extract_col::<Float64Type>("lat", &batch);
let lon_values = extract_col::<Float64Type>("lon", &batch);
assert!(lat_values
.iter()
.zip(lon_values.iter())
.all(|(lat, lon)| *lat > 38.21 && *lon > -109.59 && *lat + *lon > -71.09));
}

// check a query that doesn't include the column needed in the predicate. the first query
// below is used to produce the reference values, and the second one is the one we're testing
// for, since it has a predicate on lon, but doesn't select lon.
let sql = "SELECT lat, lon FROM zarr_table WHERE lon > -109.59";
let df = session.sql(sql).await?;
let lat_lon_batches = df.collect().await?;

let sql = "SELECT lat FROM zarr_table WHERE lon > -109.59";
let df = session.sql(sql).await?;
let lat_batches = df.collect().await?;

for (lat_batch, lat_lon_batch) in lat_batches.iter().zip(lat_lon_batches.iter()) {
let lat_values_1 = extract_col::<Float64Type>("lat", lat_batch);
let lat_values_2 = extract_col::<Float64Type>("lat", lat_lon_batch);
assert_eq!(lat_values_1, lat_values_2);
}

Ok(())
}
}
3 changes: 2 additions & 1 deletion src/datafusion/table_provider.rs
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,7 @@ impl TableProvider for ZarrTableProvider {
&self,
filters: &[&Expr],
) -> datafusion::error::Result<Vec<TableProviderFilterPushDown>> {
// TODO handle predicates on partition columns as Exact.
Ok(filters
.iter()
.map(|filter| {
Expand All @@ -113,7 +114,7 @@ impl TableProvider for ZarrTableProvider {
.collect::<Vec<_>>(),
filter,
) {
TableProviderFilterPushDown::Exact
TableProviderFilterPushDown::Inexact
} else {
TableProviderFilterPushDown::Unsupported
}
Expand Down
19 changes: 3 additions & 16 deletions src/reader/codecs.rs
Original file line number Diff line number Diff line change
Expand Up @@ -715,14 +715,10 @@ pub(crate) fn apply_codecs(
data_type: &ZarrDataType,
codecs: &Vec<ZarrCodec>,
sharding_params: Option<ShardingOptions>,
final_indices: Option<&Vec<usize>>,
) -> ZarrResult<(ArrayRef, FieldRef)> {
macro_rules! return_array {
($func_name: tt, $data_t: expr, $array_t: ty) => {
let mut data = $func_name(raw_data, &chunk_dims, &real_dims, &codecs, sharding_params)?;
if let Some(indices) = final_indices {
keep_indices(&mut data, &indices);
};
let data = $func_name(raw_data, &chunk_dims, &real_dims, &codecs, sharding_params)?;
let field = Field::new(col_name, $data_t, false);
let arr: $array_t = data.into();
return Ok((Arc::new(arr), Arc::new(field)))
Expand All @@ -732,10 +728,7 @@ pub(crate) fn apply_codecs(
match data_type {
ZarrDataType::Bool => {
let data = decode_u8_chunk(raw_data, chunk_dims, real_dims, codecs, sharding_params)?;
let mut data: Vec<bool> = data.iter().map(|x| *x != 0).collect();
if let Some(indices) = final_indices {
keep_indices(&mut data, indices);
};
let data: Vec<bool> = data.iter().map(|x| *x != 0).collect();
let field = Field::new(col_name, DataType::Boolean, false);
let arr: BooleanArray = data.into();
Ok((Arc::new(arr), Arc::new(field)))
Expand Down Expand Up @@ -817,7 +810,7 @@ pub(crate) fn apply_codecs(
pyunicode = true;
str_len_adjustment = PY_UNICODE_SIZE;
}
let mut data = decode_string_chunk(
let data = decode_string_chunk(
raw_data,
*s / str_len_adjustment,
chunk_dims,
Expand All @@ -826,9 +819,6 @@ pub(crate) fn apply_codecs(
sharding_params,
pyunicode,
)?;
if let Some(indices) = final_indices {
keep_indices(&mut data, indices);
};
let field = Field::new(col_name, DataType::Utf8, false);
let arr: StringArray = data.into();

Expand Down Expand Up @@ -874,7 +864,6 @@ mod zarr_codecs_tests {
&data_type,
&codecs,
sharding_params,
None,
)
.unwrap();

Expand Down Expand Up @@ -925,7 +914,6 @@ mod zarr_codecs_tests {
&data_type,
&codecs,
sharding_params,
None,
)
.unwrap();

Expand Down Expand Up @@ -977,7 +965,6 @@ mod zarr_codecs_tests {
&data_type,
&codecs,
sharding_params,
None,
)
.unwrap();

Expand Down
Loading

0 comments on commit bbd4c57

Please sign in to comment.