Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Add new-streaming first/last aggregations #20716

Merged
merged 1 commit into from
Jan 15, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions crates/polars-expr/src/reduce/convert.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ use polars_plan::prelude::*;
use polars_utils::arena::{Arena, Node};

use super::*;
use crate::reduce::first_last::{new_first_reduction, new_last_reduction};
use crate::reduce::len::LenReduce;
use crate::reduce::mean::new_mean_reduction;
use crate::reduce::min_max::{new_max_reduction, new_min_reduction};
Expand Down Expand Up @@ -39,6 +40,8 @@ pub fn into_reduction(
IRAggExpr::Std(input, ddof) => {
(new_var_std_reduction(get_dt(*input)?, true, *ddof), *input)
},
IRAggExpr::First(input) => (new_first_reduction(get_dt(*input)?), *input),
IRAggExpr::Last(input) => (new_last_reduction(get_dt(*input)?), *input),
_ => todo!(),
},
AExpr::Len => {
Expand Down
202 changes: 202 additions & 0 deletions crates/polars-expr/src/reduce/first_last.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,202 @@
use std::marker::PhantomData;

use polars_core::frame::row::AnyValueBufferTrusted;

use super::*;

pub fn new_first_reduction(dtype: DataType) -> Box<dyn GroupedReduction> {
new_reduction_with_policy::<First>(dtype)
}

pub fn new_last_reduction(dtype: DataType) -> Box<dyn GroupedReduction> {
new_reduction_with_policy::<Last>(dtype)
}

fn new_reduction_with_policy<P: Policy + 'static>(dtype: DataType) -> Box<dyn GroupedReduction> {
Box::new(GenericFirstLastGroupedReduction::<P>::new(dtype))
}

trait Policy {
fn index(len: usize) -> usize;
fn should_replace(new: u64, old: u64) -> bool;
}

struct First;
impl Policy for First {
fn index(_len: usize) -> usize {
0
}

fn should_replace(new: u64, old: u64) -> bool {
// Subtracting 1 with wrapping leaves all order unchanged, except it
// makes 0 (no value) the largest possible.
new.wrapping_sub(1) < old.wrapping_sub(1)
}
}

struct Last;
impl Policy for Last {
fn index(len: usize) -> usize {
len - 1
}

fn should_replace(new: u64, old: u64) -> bool {
new > old
}
}

#[expect(dead_code)]
struct Arbitrary;
impl Policy for Arbitrary {
fn index(_len: usize) -> usize {
0
}

fn should_replace(_new: u64, old: u64) -> bool {
old == 0
}
}

pub struct GenericFirstLastGroupedReduction<P> {
dtype: DataType,
values: Vec<AnyValue<'static>>,
seqs: Vec<u64>,
policy: PhantomData<fn() -> P>,
}

impl<P> GenericFirstLastGroupedReduction<P> {
fn new(dtype: DataType) -> Self {
Self {
dtype,
values: Vec::new(),
seqs: Vec::new(),
policy: PhantomData,
}
}
}

impl<P: Policy + 'static> GroupedReduction for GenericFirstLastGroupedReduction<P> {
fn new_empty(&self) -> Box<dyn GroupedReduction> {
Box::new(Self {
dtype: self.dtype.clone(),
values: Vec::new(),
seqs: Vec::new(),
policy: PhantomData,
})
}

fn reserve(&mut self, additional: usize) {
self.values.reserve(additional);
self.seqs.reserve(additional);
}

fn resize(&mut self, num_groups: IdxSize) {
self.values.resize(num_groups as usize, AnyValue::Null);
self.seqs.resize(num_groups as usize, 0);
}

fn update_group(
&mut self,
values: &Series,
group_idx: IdxSize,
seq_id: u64,
) -> PolarsResult<()> {
if values.len() > 0 {
let seq_id = seq_id + 1; // We use 0 for 'no value'.
if P::should_replace(seq_id, self.seqs[group_idx as usize]) {
self.values[group_idx as usize] = values.get(P::index(values.len()))?.into_static();
self.seqs[group_idx as usize] = seq_id;
}
}
Ok(())
}

unsafe fn update_groups(
&mut self,
values: &Series,
group_idxs: &[IdxSize],
seq_id: u64,
) -> PolarsResult<()> {
let seq_id = seq_id + 1; // We use 0 for 'no value'.
for (i, g) in group_idxs.iter().enumerate() {
if P::should_replace(seq_id, *self.seqs.get_unchecked(*g as usize)) {
*self.values.get_unchecked_mut(*g as usize) = values.get_unchecked(i).into_static();
*self.seqs.get_unchecked_mut(*g as usize) = seq_id;
}
}
Ok(())
}

unsafe fn combine(
&mut self,
other: &dyn GroupedReduction,
group_idxs: &[IdxSize],
) -> PolarsResult<()> {
let other = other.as_any().downcast_ref::<Self>().unwrap();
for (i, g) in group_idxs.iter().enumerate() {
if P::should_replace(
*other.seqs.get_unchecked(i),
*self.seqs.get_unchecked(*g as usize),
) {
*self.values.get_unchecked_mut(*g as usize) = other.values.get_unchecked(i).clone();
*self.seqs.get_unchecked_mut(*g as usize) = *other.seqs.get_unchecked(i);
}
}
Ok(())
}

unsafe fn gather_combine(
&mut self,
other: &dyn GroupedReduction,
subset: &[IdxSize],
group_idxs: &[IdxSize],
) -> PolarsResult<()> {
let other = other.as_any().downcast_ref::<Self>().unwrap();
for (i, g) in group_idxs.iter().enumerate() {
let si = *subset.get_unchecked(i) as usize;
if P::should_replace(
*other.seqs.get_unchecked(si),
*self.seqs.get_unchecked(*g as usize),
) {
*self.values.get_unchecked_mut(*g as usize) =
other.values.get_unchecked(si).clone();
*self.seqs.get_unchecked_mut(*g as usize) = *other.seqs.get_unchecked(si);
}
}
Ok(())
}

unsafe fn partition(
self: Box<Self>,
partition_sizes: &[IdxSize],
partition_idxs: &[IdxSize],
) -> Vec<Box<dyn GroupedReduction>> {
let values = partition::partition_vec(self.values, partition_sizes, partition_idxs);
let seqs = partition::partition_vec(self.seqs, partition_sizes, partition_idxs);
std::iter::zip(values, seqs)
.map(|(values, seqs)| {
Box::new(Self {
dtype: self.dtype.clone(),
values,
seqs,
policy: PhantomData,
}) as _
})
.collect()
}

fn finalize(&mut self) -> PolarsResult<Series> {
self.seqs.clear();
unsafe {
let mut buf = AnyValueBufferTrusted::new(&self.dtype, self.values.len());
for v in core::mem::take(&mut self.values) {
buf.add_unchecked_owned_physical(&v);
}
Ok(buf.into_series())
}
}

fn as_any(&self) -> &dyn Any {
self
}
}
8 changes: 7 additions & 1 deletion crates/polars-expr/src/reduce/len.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,12 @@ impl GroupedReduction for LenReduce {
self.groups.resize(num_groups as usize, 0);
}

fn update_group(&mut self, values: &Series, group_idx: IdxSize) -> PolarsResult<()> {
fn update_group(
&mut self,
values: &Series,
group_idx: IdxSize,
_seq_id: u64,
) -> PolarsResult<()> {
self.groups[group_idx as usize] += values.len() as u64;
Ok(())
}
Expand All @@ -30,6 +35,7 @@ impl GroupedReduction for LenReduce {
&mut self,
values: &Series,
group_idxs: &[IdxSize],
_seq_id: u64,
) -> PolarsResult<()> {
assert!(values.len() == group_idxs.len());
unsafe {
Expand Down
6 changes: 5 additions & 1 deletion crates/polars-expr/src/reduce/mean.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,11 @@ pub fn new_mean_reduction(dtype: DataType) -> Box<dyn GroupedReduction> {
},
#[cfg(feature = "dtype-decimal")]
Decimal(_, _) => Box::new(VGR::new(dtype, NumMeanReducer::<Int128Type>(PhantomData))),
_ => unimplemented!(),

// For compatibility with the current engine, should probably be an error.
String | Binary => Box::new(super::NullGroupedReduction::new(dtype)),

_ => unimplemented!("{dtype:?} is not supported by mean reduction"),
}
}

Expand Down
16 changes: 14 additions & 2 deletions crates/polars-expr/src/reduce/min_max.rs
Original file line number Diff line number Diff line change
Expand Up @@ -294,7 +294,12 @@ impl GroupedReduction for BoolMinGroupedReduction {
self.mask.resize(num_groups as usize, false);
}

fn update_group(&mut self, values: &Series, group_idx: IdxSize) -> PolarsResult<()> {
fn update_group(
&mut self,
values: &Series,
group_idx: IdxSize,
_seq_id: u64,
) -> PolarsResult<()> {
// TODO: we should really implement a sum-as-other-type operation instead
// of doing this materialized cast.
assert!(values.dtype() == &DataType::Boolean);
Expand All @@ -312,6 +317,7 @@ impl GroupedReduction for BoolMinGroupedReduction {
&mut self,
values: &Series,
group_idxs: &[IdxSize],
_seq_id: u64,
) -> PolarsResult<()> {
// TODO: we should really implement a sum-as-other-type operation instead
// of doing this materialized cast.
Expand Down Expand Up @@ -430,7 +436,12 @@ impl GroupedReduction for BoolMaxGroupedReduction {
self.mask.resize(num_groups as usize, false);
}

fn update_group(&mut self, values: &Series, group_idx: IdxSize) -> PolarsResult<()> {
fn update_group(
&mut self,
values: &Series,
group_idx: IdxSize,
_seq_id: u64,
) -> PolarsResult<()> {
// TODO: we should really implement a sum-as-other-type operation instead
// of doing this materialized cast.
assert!(values.dtype() == &DataType::Boolean);
Expand All @@ -448,6 +459,7 @@ impl GroupedReduction for BoolMaxGroupedReduction {
&mut self,
values: &Series,
group_idxs: &[IdxSize],
_seq_id: u64,
) -> PolarsResult<()> {
// TODO: we should really implement a sum-as-other-type operation instead
// of doing this materialized cast.
Expand Down
Loading
Loading