From cae6788cee192b1dded31f4df54490b2c5021be8 Mon Sep 17 00:00:00 2001 From: Gijs Burghoorn Date: Thu, 16 Jan 2025 10:31:04 +0100 Subject: [PATCH] refactor: Simplify hive predicate handling in `NEW_MULTIFILE` (#20730) --- crates/polars-expr/src/expressions/apply.rs | 37 ------ crates/polars-expr/src/expressions/binary.rs | 20 --- crates/polars-expr/src/expressions/column.rs | 12 -- crates/polars-expr/src/expressions/mod.rs | 12 -- .../src/executors/multi_file_scan.rs | 114 +++++++++++------- 5 files changed, 72 insertions(+), 123 deletions(-) diff --git a/crates/polars-expr/src/expressions/apply.rs b/crates/polars-expr/src/expressions/apply.rs index 7b95a19c18b9..d3cd542e046e 100644 --- a/crates/polars-expr/src/expressions/apply.rs +++ b/crates/polars-expr/src/expressions/apply.rs @@ -428,43 +428,6 @@ impl PhysicalExpr for ApplyExpr { i.collect_live_columns(lv); } } - fn replace_elementwise_const_columns( - &self, - const_columns: &PlHashMap>, - ) -> Option> { - if self.collect_groups == ApplyOptions::ElementWise { - let mut new_inputs = Vec::new(); - for i in 0..self.inputs.len() { - match self.inputs[i].replace_elementwise_const_columns(const_columns) { - None => continue, - Some(new) => { - new_inputs.reserve(self.inputs.len()); - new_inputs.extend(self.inputs[..i].iter().cloned()); - new_inputs.push(new); - break; - }, - } - } - - // Only copy inputs if it is actually needed - if new_inputs.is_empty() { - return None; - } - - new_inputs.extend(self.inputs[new_inputs.len()..].iter().map(|i| { - match i.replace_elementwise_const_columns(const_columns) { - None => i.clone(), - Some(new) => new, - } - })); - - let mut slf = self.clone(); - slf.inputs = new_inputs; - return Some(Arc::new(slf)); - } - - None - } fn to_field(&self, input_schema: &Schema) -> PolarsResult { self.expr.to_field(input_schema, Context::Default) diff --git a/crates/polars-expr/src/expressions/binary.rs b/crates/polars-expr/src/expressions/binary.rs index 694afd0bbaff..bd23a893c388 100644 --- a/crates/polars-expr/src/expressions/binary.rs +++ b/crates/polars-expr/src/expressions/binary.rs @@ -272,26 +272,6 @@ impl PhysicalExpr for BinaryExpr { self.left.collect_live_columns(lv); self.right.collect_live_columns(lv); } - fn replace_elementwise_const_columns( - &self, - const_columns: &PlHashMap>, - ) -> Option> { - let rcc_left = self.left.replace_elementwise_const_columns(const_columns); - let rcc_right = self.right.replace_elementwise_const_columns(const_columns); - - if rcc_left.is_some() || rcc_right.is_some() { - let mut slf = self.clone(); - if let Some(left) = rcc_left { - slf.left = left; - } - if let Some(right) = rcc_right { - slf.right = right; - } - return Some(Arc::new(slf)); - } - - None - } fn to_field(&self, input_schema: &Schema) -> PolarsResult { self.expr.to_field(input_schema, Context::Default) diff --git a/crates/polars-expr/src/expressions/column.rs b/crates/polars-expr/src/expressions/column.rs index 4e31b9a557cd..3f288cd4bdb9 100644 --- a/crates/polars-expr/src/expressions/column.rs +++ b/crates/polars-expr/src/expressions/column.rs @@ -182,18 +182,6 @@ impl PhysicalExpr for ColumnExpr { fn collect_live_columns(&self, lv: &mut PlIndexSet) { lv.insert(self.name.clone()); } - fn replace_elementwise_const_columns( - &self, - const_columns: &PlHashMap>, - ) -> Option> { - if let Some(av) = const_columns.get(&self.name) { - let lv = LiteralValue::from(av.clone()); - let le = LiteralExpr::new(lv, self.expr.clone()); - return Some(Arc::new(le)); - } - - None - } fn to_field(&self, input_schema: &Schema) -> PolarsResult { input_schema.get_field(&self.name).ok_or_else(|| { diff --git a/crates/polars-expr/src/expressions/mod.rs b/crates/polars-expr/src/expressions/mod.rs index 4fc9dab8c043..1e46ee167575 100644 --- a/crates/polars-expr/src/expressions/mod.rs +++ b/crates/polars-expr/src/expressions/mod.rs @@ -606,18 +606,6 @@ pub trait PhysicalExpr: Send + Sync { /// This can contain duplicates. fn collect_live_columns(&self, lv: &mut PlIndexSet); - /// Replace columns that are known to be a constant value with their const value. - /// - /// This should not replace values that are calculated non-elementwise e.g. col.max(), - /// col.std(), etc. - fn replace_elementwise_const_columns( - &self, - const_columns: &PlHashMap>, - ) -> Option> { - _ = const_columns; - None - } - /// Can take &dyn Statistics and determine of a file should be /// read -> `true` /// or not -> `false` diff --git a/crates/polars-mem-engine/src/executors/multi_file_scan.rs b/crates/polars-mem-engine/src/executors/multi_file_scan.rs index 679a822d59de..873b8cbbbaae 100644 --- a/crates/polars-mem-engine/src/executors/multi_file_scan.rs +++ b/crates/polars-mem-engine/src/executors/multi_file_scan.rs @@ -20,6 +20,55 @@ use crate::executors::JsonExec; use crate::executors::ParquetExec; use crate::prelude::*; +pub struct PhysicalExprWithConstCols { + constants: Vec<(PlSmallStr, Scalar)>, + child: Arc, +} + +impl PhysicalExpr for PhysicalExprWithConstCols { + fn evaluate(&self, df: &DataFrame, state: &ExecutionState) -> PolarsResult { + let mut df = df.clone(); + for (name, scalar) in &self.constants { + df.with_column(Column::new_scalar( + name.clone(), + scalar.clone(), + df.height(), + ))?; + } + + self.child.evaluate(&df, state) + } + + fn evaluate_on_groups<'a>( + &self, + df: &DataFrame, + groups: &'a GroupPositions, + state: &ExecutionState, + ) -> PolarsResult> { + let mut df = df.clone(); + for (name, scalar) in &self.constants { + df.with_column(Column::new_scalar( + name.clone(), + scalar.clone(), + df.height(), + ))?; + } + + self.child.evaluate_on_groups(&df, groups, state) + } + + fn to_field(&self, input_schema: &Schema) -> PolarsResult { + self.child.to_field(input_schema) + } + + fn collect_live_columns(&self, lv: &mut PlIndexSet) { + self.child.collect_live_columns(lv) + } + fn is_scalar(&self) -> bool { + self.child.is_scalar() + } +} + /// An [`Executor`] that scans over some IO. pub trait ScanExec { /// Read the source. @@ -287,8 +336,6 @@ impl MultiScanExec { let verbose = config::verbose(); let mut dfs = Vec::with_capacity(self.sources.len()); - let mut const_columns = PlHashMap::new(); - // @TODO: This should be moved outside of the FileScan::Parquet let use_statistics = match &self.scan_type { #[cfg(feature = "parquet")] @@ -334,51 +381,34 @@ impl MultiScanExec { do_skip_file |= allow_slice_skip; } + let mut file_predicate = predicate.clone(); + // Insert the hive partition values into the predicate. This allows the predicate // to function even when there is a combination of hive and non-hive columns being // used. - let mut file_predicate = predicate.clone(); if has_live_hive_columns { let hive_part = hive_part.unwrap(); - let predicate = predicate.as_ref().unwrap(); - const_columns.clear(); - for (idx, column) in hive_column_set.iter().enumerate() { - let value = hive_part.get_statistics().column_stats()[idx] - .to_min() - .unwrap() - .get(0) - .unwrap() - .into_static(); - const_columns.insert(column.clone(), value); - } - // @TODO: It would be nice to get this somehow. - // for (_, (missing_column, _)) in &missing_columns { - // const_columns.insert((*missing_column).clone(), AnyValue::Null); - // } - - file_predicate = predicate.replace_elementwise_const_columns(&const_columns); - - // @TODO: Set predicate to `None` if it's constant evaluated to true. - - // At this point the file_predicate should not contain any references to the - // hive columns anymore. - // - // Note that, replace_elementwise_const_columns does not actually guarantee the - // replacement of all reference to the const columns. But any expression which - // does not guarantee this should not be pushed down as an IO predicate. - if cfg!(debug_assertions) { - let mut live_columns = PlIndexSet::new(); - file_predicate - .as_ref() - .unwrap() - .collect_live_columns(&mut live_columns); - for hive_column in hive_part.get_statistics().column_stats() { - assert!( - !live_columns.contains(hive_column.field_name()), - "Predicate still contains hive column" - ); - } - } + let child = file_predicate.unwrap(); + + file_predicate = Some(Arc::new(PhysicalExprWithConstCols { + constants: hive_column_set + .iter() + .enumerate() + .map(|(idx, column)| { + let series = hive_part.get_statistics().column_stats()[idx] + .to_min() + .unwrap(); + ( + column.clone(), + Scalar::new( + series.dtype().clone(), + series.get(0).unwrap().into_static(), + ), + ) + }) + .collect(), + child, + })); } let stats_evaluator = file_predicate.as_ref().and_then(|p| p.as_stats_evaluator());