Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(rust,python): Add quantiles method to expression list namespace #20782

Open
wants to merge 11 commits into
base: main
Choose a base branch
from
17 changes: 17 additions & 0 deletions crates/polars-arrow/src/legacy/kernels/rolling/no_nulls/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ mod quantile;
mod sum;
mod variance;
use std::fmt::Debug;
use std::hash::{Hash, Hasher};

pub use mean::*;
pub use min_max::*;
Expand Down Expand Up @@ -83,6 +84,22 @@ pub enum QuantileMethod {
Equiprobable,
}

#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
#[derive(Debug, Copy, Clone, PartialEq)]
pub struct QuantileOptions {
pub prob: f64,
pub method: QuantileMethod,
}

impl Eq for QuantileOptions {}

impl Hash for QuantileOptions {
fn hash<H: Hasher>(&self, state: &mut H) {
self.prob.to_bits().hash(state);
self.method.hash(state);
}
}

#[deprecated(note = "use QuantileMethod instead")]
pub type QuantileInterpolOptions = QuantileMethod;

Expand Down
2 changes: 1 addition & 1 deletion crates/polars-arrow/src/legacy/prelude.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ use crate::array::{BinaryArray, ListArray, Utf8Array};
pub use crate::legacy::array::default_arrays::*;
pub use crate::legacy::array::*;
pub use crate::legacy::index::*;
pub use crate::legacy::kernels::rolling::no_nulls::QuantileMethod;
pub use crate::legacy::kernels::rolling::no_nulls::{QuantileMethod, QuantileOptions};
pub use crate::legacy::kernels::rolling::{
RollingFnParams, RollingQuantileParams, RollingVarParams,
};
Expand Down
7 changes: 7 additions & 0 deletions crates/polars-core/src/series/implementations/decimal.rs
Original file line number Diff line number Diff line change
Expand Up @@ -396,6 +396,13 @@ impl SeriesTrait for SeriesWrap<DecimalChunked> {
self.0.mean().map(|v| v / self.scale_factor() as f64)
}

fn quantile(&self, quantile: f64, method: QuantileMethod) -> PolarsResult<Option<f64>> {
polars_ensure!((0.0..=1.0).contains(&quantile),
ComputeError: "quantile should be between 0.0 and 1.0",
);
Ok(self.0.quantile(quantile, method).unwrap())
}

fn median(&self) -> Option<f64> {
self.0.median().map(|v| v / self.scale_factor() as f64)
}
Expand Down
7 changes: 7 additions & 0 deletions crates/polars-core/src/series/implementations/duration.rs
Original file line number Diff line number Diff line change
Expand Up @@ -307,6 +307,13 @@ impl SeriesTrait for SeriesWrap<DurationChunked> {
self.0.median()
}

fn quantile(&self, quantile: f64, method: QuantileMethod) -> PolarsResult<Option<f64>> {
polars_ensure!((0.0..=1.0).contains(&quantile),
ComputeError: "quantile should be between 0.0 and 1.0",
);
Ok(self.0.quantile(quantile, method).unwrap())
}

fn std(&self, ddof: u8) -> Option<f64> {
self.0.std(ddof)
}
Expand Down
7 changes: 7 additions & 0 deletions crates/polars-core/src/series/implementations/floats.rs
Original file line number Diff line number Diff line change
Expand Up @@ -206,6 +206,13 @@ macro_rules! impl_dyn_series {
self.0.median().map(|v| v as f64)
}

fn quantile(&self, quantile: f64, method: QuantileMethod) -> PolarsResult<Option<f64>> {
polars_ensure!((0.0..=1.0).contains(&quantile),
ComputeError: "quantile should be between 0.0 and 1.0",
);
Ok(self.0.quantile(quantile, method).unwrap().map(|v| v as f64))
}

fn std(&self, ddof: u8) -> Option<f64> {
self.0.std(ddof)
}
Expand Down
7 changes: 7 additions & 0 deletions crates/polars-core/src/series/implementations/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -276,6 +276,13 @@ macro_rules! impl_dyn_series {
self.0.median()
}

fn quantile(&self, quantile: f64, method: QuantileMethod) -> PolarsResult<Option<f64>> {
polars_ensure!((0.0..=1.0).contains(&quantile),
ComputeError: "quantile should be between 0.0 and 1.0",
);
Ok(self.0.quantile(quantile, method).unwrap().map(|v| v as f64))
}

fn std(&self, ddof: u8) -> Option<f64> {
self.0.std(ddof)
}
Expand Down
6 changes: 6 additions & 0 deletions crates/polars-core/src/series/series_trait.rs
Original file line number Diff line number Diff line change
Expand Up @@ -359,6 +359,12 @@ pub trait SeriesTrait:
None
}

/// Returns the quantile value in the array
/// Returns a result of option because the array is nullable and quantile can OOB.
fn quantile(&self, _quantile: f64, _method: QuantileMethod) -> PolarsResult<Option<f64>> {
Ok(None)
}

/// Create a new Series filled with values from the given index.
///
/// # Example
Expand Down
48 changes: 48 additions & 0 deletions crates/polars-ops/src/chunked_array/list/dispersion.rs
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,54 @@ pub(super) fn median_with_nulls(ca: &ListChunked) -> Series {
}
}

pub(super) fn quantile_with_nulls(
ca: &ListChunked,
quantile: f64,
method: QuantileMethod,
) -> Series {
match ca.inner_dtype() {
DataType::Float32 => {
let out: Float32Chunked = ca
.apply_amortized_generic(|s| {
s.and_then(|s| {
s.as_ref()
.quantile(quantile, method)
.unwrap_or(Some(f64::NAN))
.map(|v| v as f32)
})
})
.with_name(ca.name().clone());
out.into_series()
},
#[cfg(feature = "dtype-duration")]
DataType::Duration(tu) => {
let out: Int64Chunked = ca
.apply_amortized_generic(|s| {
s.and_then(|s| {
s.as_ref()
.quantile(quantile, method)
.unwrap_or(Some(f64::NAN))
.map(|v| v as i64)
})
})
.with_name(ca.name().clone());
out.into_duration(*tu).into_series()
},
_ => {
let out: Float64Chunked = ca
.apply_amortized_generic(|s| {
s.and_then(|s| {
s.as_ref()
.quantile(quantile, method)
.unwrap_or(Some(f64::NAN))
})
})
.with_name(ca.name().clone());
out.into_series()
},
}
}

pub(super) fn std_with_nulls(ca: &ListChunked, ddof: u8) -> Series {
match ca.inner_dtype() {
DataType::Float32 => {
Expand Down
5 changes: 5 additions & 0 deletions crates/polars-ops/src/chunked_array/list/namespace.rs
Original file line number Diff line number Diff line change
Expand Up @@ -223,6 +223,11 @@ pub trait ListNameSpaceImpl: AsList {
dispersion::median_with_nulls(ca)
}

fn lst_quantile(&self, quantile: f64, method: QuantileMethod) -> Series {
let ca = self.as_list();
dispersion::quantile_with_nulls(ca, quantile, method)
}

fn lst_std(&self, ddof: u8) -> Series {
let ca = self.as_list();
dispersion::std_with_nulls(ca, ddof)
Expand Down
8 changes: 8 additions & 0 deletions crates/polars-plan/src/dsl/function_expr/list.rs
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ pub enum ListFunction {
Min,
Mean,
Median,
Quantile(QuantileOptions),
Std(u8),
Var(u8),
ArgMin,
Expand Down Expand Up @@ -85,6 +86,7 @@ impl ListFunction {
Max => mapper.map_to_list_and_array_inner_dtype(),
Mean => mapper.with_dtype(DataType::Float64),
Median => mapper.map_to_float_dtype(),
Quantile(_) => mapper.map_to_float_dtype(),
Std(_) => mapper.map_to_float_dtype(), // Need to also have this sometimes marked as float32 or duration..
Var(_) => mapper.map_to_float_dtype(),
ArgMin => mapper.with_dtype(IDX_DTYPE),
Expand Down Expand Up @@ -152,6 +154,7 @@ impl Display for ListFunction {
Max => "max",
Mean => "mean",
Median => "median",
Quantile(_) => "quantile",
Std(_) => "std",
Var(_) => "var",
ArgMin => "arg_min",
Expand Down Expand Up @@ -222,6 +225,7 @@ impl From<ListFunction> for SpecialEq<Arc<dyn ColumnsUdf>> {
Min => map!(min),
Mean => map!(mean),
Median => map!(median),
Quantile(options) => map!(quantile, options.prob, options.method),
Std(ddof) => map!(std, ddof),
Var(ddof) => map!(var, ddof),
ArgMin => map!(arg_min),
Expand Down Expand Up @@ -571,6 +575,10 @@ pub(super) fn median(s: &Column) -> PolarsResult<Column> {
Ok(s.list()?.lst_median().into())
}

pub(super) fn quantile(s: &Column, quantile: f64, method: QuantileMethod) -> PolarsResult<Column> {
Ok(s.list()?.lst_quantile(quantile, method).into())
}

pub(super) fn std(s: &Column, ddof: u8) -> PolarsResult<Column> {
Ok(s.list()?.lst_std(ddof).into())
}
Expand Down
7 changes: 7 additions & 0 deletions crates/polars-plan/src/dsl/list.rs
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,13 @@ impl ListNameSpace {
.map_private(FunctionExpr::ListExpr(ListFunction::Median))
}

pub fn quantile(self, prob: f64, method: QuantileMethod) -> Expr {
self.0
.map_private(FunctionExpr::ListExpr(ListFunction::Quantile(
QuantileOptions { prob, method },
)))
}

pub fn std(self, ddof: u8) -> Expr {
self.0
.map_private(FunctionExpr::ListExpr(ListFunction::Std(ddof)))
Expand Down
9 changes: 9 additions & 0 deletions crates/polars-python/src/expr/list.rs
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,15 @@ impl PyExpr {
.into()
}

fn list_quantile(&self, quantile: f64, method: Wrap<QuantileMethod>) -> Self {
self.inner
.clone()
.list()
.quantile(quantile, method.0)
.with_fmt("list.quantile")
.into()
}

fn list_std(&self, ddof: u8) -> Self {
self.inner
.clone()
Expand Down
8 changes: 8 additions & 0 deletions py-polars/polars/_typing.py
Original file line number Diff line number Diff line change
Expand Up @@ -165,6 +165,14 @@ def __arrow_c_stream__(self, requested_schema: object | None = None) -> object:
RollingInterpolationMethod: TypeAlias = Literal[
"nearest", "higher", "lower", "midpoint", "linear"
] # QuantileInterpolOptions
QuantileMethod: TypeAlias = Literal[
"lower",
"higher",
"nearest",
"linear",
"midpoint",
"equiprobable",
]
ListToStructWidthStrategy: TypeAlias = Literal["first_non_null", "max_width"]

# The following have no equivalent on the Rust side
Expand Down
22 changes: 22 additions & 0 deletions py-polars/polars/expr/list.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
IntoExprColumn,
ListToStructWidthStrategy,
NullBehavior,
QuantileMethod,
)


Expand Down Expand Up @@ -294,6 +295,27 @@ def median(self) -> Expr:
"""
return wrap_expr(self._pyexpr.list_median())

def quantile(self, quantile: float, method: QuantileMethod) -> Expr:
"""
Compute the specified quantile value of the lists in the array.

Examples
--------
>>> df = pl.DataFrame({"values": [[-1, 0, 1], [1, 10]]})
>>> expr = pl.col("values").list.quantile(0.1, "linear").alias("10percent")
>>> df.with_columns(expr)
shape: (2, 2)
┌────────────┬───────────┐
│ values ┆ 10percent │
│ --- ┆ --- │
│ list[i64] ┆ f64 │
╞════════════╪═══════════╡
│ [-1, 0, 1] ┆ -0.8 │
│ [1, 10] ┆ 1.9 │
└────────────┴───────────┘
"""
return wrap_expr(self._pyexpr.list_quantile(quantile, method))

def std(self, ddof: int = 1) -> Expr:
"""
Compute the std value of the lists in the array.
Expand Down
52 changes: 52 additions & 0 deletions py-polars/tests/unit/datatypes/test_list.py
Original file line number Diff line number Diff line change
Expand Up @@ -752,6 +752,58 @@ def test_list_median(data_dispersion: pl.DataFrame) -> None:
assert_frame_equal(result, expected)


def test_list_quantile(data_dispersion: pl.DataFrame) -> None:
df = data_dispersion

result = df.select(
pl.col("int").list.quantile(0.3, "linear").name.suffix("_median"),
pl.col("float").list.quantile(0.3, "linear").name.suffix("_median"),
pl.col("duration").list.quantile(0.3, "linear").name.suffix("_median"),
)

expected = pl.DataFrame(
[
pl.Series("int_median", [2.2], dtype=pl.Float64),
pl.Series("float_median", [2.2], dtype=pl.Float64),
pl.Series(
"duration_median",
[timedelta(microseconds=2200)],
dtype=pl.Duration(time_unit="us"),
),
]
)

assert_frame_equal(result, expected)


def test_list_quantile_extremities(data_dispersion: pl.DataFrame) -> None:
df = data_dispersion
assert_frame_equal(
df.select(
pl.col(col).list.quantile(0.0, "linear")
for col in ["int", "float", "duration"]
),
df.select(pl.col(col).list.min() for col in ["int", "float", "duration"]),
check_dtypes=False,
)
assert_frame_equal(
df.select(
pl.col(col).list.quantile(1.0, "linear")
for col in ["int", "float", "duration"]
),
df.select(pl.col(col).list.max() for col in ["int", "float", "duration"]),
check_dtypes=False,
)
assert_frame_equal(
df.select(
pl.col(col).list.quantile(0.5, "linear")
for col in ["int", "float", "duration"]
),
df.select(pl.col(col).list.median() for col in ["int", "float", "duration"]),
check_dtypes=False,
)


def test_list_gather_null_struct_14927() -> None:
df = pl.DataFrame(
[
Expand Down
Loading