Skip to content

Commit

Permalink
refactor: Simplify hive predicate handling in NEW_MULTIFILE (#20730)
Browse files Browse the repository at this point in the history
  • Loading branch information
coastalwhite authored Jan 16, 2025
1 parent 8cef5c0 commit cae6788
Show file tree
Hide file tree
Showing 5 changed files with 72 additions and 123 deletions.
37 changes: 0 additions & 37 deletions crates/polars-expr/src/expressions/apply.rs
Original file line number Diff line number Diff line change
Expand Up @@ -428,43 +428,6 @@ impl PhysicalExpr for ApplyExpr {
i.collect_live_columns(lv);
}
}
fn replace_elementwise_const_columns(
&self,
const_columns: &PlHashMap<PlSmallStr, AnyValue<'static>>,
) -> Option<Arc<dyn PhysicalExpr>> {
if self.collect_groups == ApplyOptions::ElementWise {
let mut new_inputs = Vec::new();
for i in 0..self.inputs.len() {
match self.inputs[i].replace_elementwise_const_columns(const_columns) {
None => continue,
Some(new) => {
new_inputs.reserve(self.inputs.len());
new_inputs.extend(self.inputs[..i].iter().cloned());
new_inputs.push(new);
break;
},
}
}

// Only copy inputs if it is actually needed
if new_inputs.is_empty() {
return None;
}

new_inputs.extend(self.inputs[new_inputs.len()..].iter().map(|i| {
match i.replace_elementwise_const_columns(const_columns) {
None => i.clone(),
Some(new) => new,
}
}));

let mut slf = self.clone();
slf.inputs = new_inputs;
return Some(Arc::new(slf));
}

None
}

fn to_field(&self, input_schema: &Schema) -> PolarsResult<Field> {
self.expr.to_field(input_schema, Context::Default)
Expand Down
20 changes: 0 additions & 20 deletions crates/polars-expr/src/expressions/binary.rs
Original file line number Diff line number Diff line change
Expand Up @@ -272,26 +272,6 @@ impl PhysicalExpr for BinaryExpr {
self.left.collect_live_columns(lv);
self.right.collect_live_columns(lv);
}
fn replace_elementwise_const_columns(
&self,
const_columns: &PlHashMap<PlSmallStr, AnyValue<'static>>,
) -> Option<Arc<dyn PhysicalExpr>> {
let rcc_left = self.left.replace_elementwise_const_columns(const_columns);
let rcc_right = self.right.replace_elementwise_const_columns(const_columns);

if rcc_left.is_some() || rcc_right.is_some() {
let mut slf = self.clone();
if let Some(left) = rcc_left {
slf.left = left;
}
if let Some(right) = rcc_right {
slf.right = right;
}
return Some(Arc::new(slf));
}

None
}

fn to_field(&self, input_schema: &Schema) -> PolarsResult<Field> {
self.expr.to_field(input_schema, Context::Default)
Expand Down
12 changes: 0 additions & 12 deletions crates/polars-expr/src/expressions/column.rs
Original file line number Diff line number Diff line change
Expand Up @@ -182,18 +182,6 @@ impl PhysicalExpr for ColumnExpr {
fn collect_live_columns(&self, lv: &mut PlIndexSet<PlSmallStr>) {
lv.insert(self.name.clone());
}
fn replace_elementwise_const_columns(
&self,
const_columns: &PlHashMap<PlSmallStr, AnyValue<'static>>,
) -> Option<Arc<dyn PhysicalExpr>> {
if let Some(av) = const_columns.get(&self.name) {
let lv = LiteralValue::from(av.clone());
let le = LiteralExpr::new(lv, self.expr.clone());
return Some(Arc::new(le));
}

None
}

fn to_field(&self, input_schema: &Schema) -> PolarsResult<Field> {
input_schema.get_field(&self.name).ok_or_else(|| {
Expand Down
12 changes: 0 additions & 12 deletions crates/polars-expr/src/expressions/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -606,18 +606,6 @@ pub trait PhysicalExpr: Send + Sync {
/// This can contain duplicates.
fn collect_live_columns(&self, lv: &mut PlIndexSet<PlSmallStr>);

/// Replace columns that are known to be a constant value with their const value.
///
/// This should not replace values that are calculated non-elementwise e.g. col.max(),
/// col.std(), etc.
fn replace_elementwise_const_columns(
&self,
const_columns: &PlHashMap<PlSmallStr, AnyValue<'static>>,
) -> Option<Arc<dyn PhysicalExpr>> {
_ = const_columns;
None
}

/// Can take &dyn Statistics and determine of a file should be
/// read -> `true`
/// or not -> `false`
Expand Down
114 changes: 72 additions & 42 deletions crates/polars-mem-engine/src/executors/multi_file_scan.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,55 @@ use crate::executors::JsonExec;
use crate::executors::ParquetExec;
use crate::prelude::*;

pub struct PhysicalExprWithConstCols {
constants: Vec<(PlSmallStr, Scalar)>,
child: Arc<dyn PhysicalExpr>,
}

impl PhysicalExpr for PhysicalExprWithConstCols {
fn evaluate(&self, df: &DataFrame, state: &ExecutionState) -> PolarsResult<Column> {
let mut df = df.clone();
for (name, scalar) in &self.constants {
df.with_column(Column::new_scalar(
name.clone(),
scalar.clone(),
df.height(),
))?;
}

self.child.evaluate(&df, state)
}

fn evaluate_on_groups<'a>(
&self,
df: &DataFrame,
groups: &'a GroupPositions,
state: &ExecutionState,
) -> PolarsResult<AggregationContext<'a>> {
let mut df = df.clone();
for (name, scalar) in &self.constants {
df.with_column(Column::new_scalar(
name.clone(),
scalar.clone(),
df.height(),
))?;
}

self.child.evaluate_on_groups(&df, groups, state)
}

fn to_field(&self, input_schema: &Schema) -> PolarsResult<Field> {
self.child.to_field(input_schema)
}

fn collect_live_columns(&self, lv: &mut PlIndexSet<PlSmallStr>) {
self.child.collect_live_columns(lv)
}
fn is_scalar(&self) -> bool {
self.child.is_scalar()
}
}

/// An [`Executor`] that scans over some IO.
pub trait ScanExec {
/// Read the source.
Expand Down Expand Up @@ -287,8 +336,6 @@ impl MultiScanExec {
let verbose = config::verbose();
let mut dfs = Vec::with_capacity(self.sources.len());

let mut const_columns = PlHashMap::new();

// @TODO: This should be moved outside of the FileScan::Parquet
let use_statistics = match &self.scan_type {
#[cfg(feature = "parquet")]
Expand Down Expand Up @@ -334,51 +381,34 @@ impl MultiScanExec {
do_skip_file |= allow_slice_skip;
}

let mut file_predicate = predicate.clone();

// Insert the hive partition values into the predicate. This allows the predicate
// to function even when there is a combination of hive and non-hive columns being
// used.
let mut file_predicate = predicate.clone();
if has_live_hive_columns {
let hive_part = hive_part.unwrap();
let predicate = predicate.as_ref().unwrap();
const_columns.clear();
for (idx, column) in hive_column_set.iter().enumerate() {
let value = hive_part.get_statistics().column_stats()[idx]
.to_min()
.unwrap()
.get(0)
.unwrap()
.into_static();
const_columns.insert(column.clone(), value);
}
// @TODO: It would be nice to get this somehow.
// for (_, (missing_column, _)) in &missing_columns {
// const_columns.insert((*missing_column).clone(), AnyValue::Null);
// }

file_predicate = predicate.replace_elementwise_const_columns(&const_columns);

// @TODO: Set predicate to `None` if it's constant evaluated to true.

// At this point the file_predicate should not contain any references to the
// hive columns anymore.
//
// Note that, replace_elementwise_const_columns does not actually guarantee the
// replacement of all reference to the const columns. But any expression which
// does not guarantee this should not be pushed down as an IO predicate.
if cfg!(debug_assertions) {
let mut live_columns = PlIndexSet::new();
file_predicate
.as_ref()
.unwrap()
.collect_live_columns(&mut live_columns);
for hive_column in hive_part.get_statistics().column_stats() {
assert!(
!live_columns.contains(hive_column.field_name()),
"Predicate still contains hive column"
);
}
}
let child = file_predicate.unwrap();

file_predicate = Some(Arc::new(PhysicalExprWithConstCols {
constants: hive_column_set
.iter()
.enumerate()
.map(|(idx, column)| {
let series = hive_part.get_statistics().column_stats()[idx]
.to_min()
.unwrap();
(
column.clone(),
Scalar::new(
series.dtype().clone(),
series.get(0).unwrap().into_static(),
),
)
})
.collect(),
child,
}));
}

let stats_evaluator = file_predicate.as_ref().and_then(|p| p.as_stats_evaluator());
Expand Down

0 comments on commit cae6788

Please sign in to comment.