diff --git a/crates/polars-arrow/src/legacy/kernels/rolling/no_nulls/mod.rs b/crates/polars-arrow/src/legacy/kernels/rolling/no_nulls/mod.rs index 7abe2455e61f..b5ff7f7c144f 100644 --- a/crates/polars-arrow/src/legacy/kernels/rolling/no_nulls/mod.rs +++ b/crates/polars-arrow/src/legacy/kernels/rolling/no_nulls/mod.rs @@ -4,6 +4,7 @@ mod quantile; mod sum; mod variance; use std::fmt::Debug; +use std::hash::{Hash, Hasher}; pub use mean::*; pub use min_max::*; @@ -83,6 +84,22 @@ pub enum QuantileMethod { Equiprobable, } +#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] +#[derive(Debug, Copy, Clone, PartialEq)] +pub struct QuantileOptions { + pub prob: f64, + pub method: QuantileMethod, +} + +impl Eq for QuantileOptions {} + +impl Hash for QuantileOptions { + fn hash(&self, state: &mut H) { + self.prob.to_bits().hash(state); + self.method.hash(state); + } +} + #[deprecated(note = "use QuantileMethod instead")] pub type QuantileInterpolOptions = QuantileMethod; diff --git a/crates/polars-arrow/src/legacy/prelude.rs b/crates/polars-arrow/src/legacy/prelude.rs index 6afeb0c6c9be..4786be70581b 100644 --- a/crates/polars-arrow/src/legacy/prelude.rs +++ b/crates/polars-arrow/src/legacy/prelude.rs @@ -2,7 +2,7 @@ use crate::array::{BinaryArray, ListArray, Utf8Array}; pub use crate::legacy::array::default_arrays::*; pub use crate::legacy::array::*; pub use crate::legacy::index::*; -pub use crate::legacy::kernels::rolling::no_nulls::QuantileMethod; +pub use crate::legacy::kernels::rolling::no_nulls::{QuantileMethod, QuantileOptions}; pub use crate::legacy::kernels::rolling::{ RollingFnParams, RollingQuantileParams, RollingVarParams, }; diff --git a/crates/polars-core/src/series/implementations/decimal.rs b/crates/polars-core/src/series/implementations/decimal.rs index 6e477ccf6c3f..53589ea01a0d 100644 --- a/crates/polars-core/src/series/implementations/decimal.rs +++ b/crates/polars-core/src/series/implementations/decimal.rs @@ -396,6 +396,13 @@ impl SeriesTrait for SeriesWrap { self.0.mean().map(|v| v / self.scale_factor() as f64) } + fn quantile(&self, quantile: f64, method: QuantileMethod) -> PolarsResult> { + polars_ensure!((0.0..=1.0).contains(&quantile), + ComputeError: "quantile should be between 0.0 and 1.0", + ); + Ok(self.0.quantile(quantile, method).unwrap()) + } + fn median(&self) -> Option { self.0.median().map(|v| v / self.scale_factor() as f64) } diff --git a/crates/polars-core/src/series/implementations/duration.rs b/crates/polars-core/src/series/implementations/duration.rs index 51426f1b94e6..296f58d828aa 100644 --- a/crates/polars-core/src/series/implementations/duration.rs +++ b/crates/polars-core/src/series/implementations/duration.rs @@ -307,6 +307,13 @@ impl SeriesTrait for SeriesWrap { self.0.median() } + fn quantile(&self, quantile: f64, method: QuantileMethod) -> PolarsResult> { + polars_ensure!((0.0..=1.0).contains(&quantile), + ComputeError: "quantile should be between 0.0 and 1.0", + ); + Ok(self.0.quantile(quantile, method).unwrap()) + } + fn std(&self, ddof: u8) -> Option { self.0.std(ddof) } diff --git a/crates/polars-core/src/series/implementations/floats.rs b/crates/polars-core/src/series/implementations/floats.rs index 9ccbb1d8d958..645723785113 100644 --- a/crates/polars-core/src/series/implementations/floats.rs +++ b/crates/polars-core/src/series/implementations/floats.rs @@ -206,6 +206,13 @@ macro_rules! impl_dyn_series { self.0.median().map(|v| v as f64) } + fn quantile(&self, quantile: f64, method: QuantileMethod) -> PolarsResult> { + polars_ensure!((0.0..=1.0).contains(&quantile), + ComputeError: "quantile should be between 0.0 and 1.0", + ); + Ok(self.0.quantile(quantile, method).unwrap().map(|v| v as f64)) + } + fn std(&self, ddof: u8) -> Option { self.0.std(ddof) } diff --git a/crates/polars-core/src/series/implementations/mod.rs b/crates/polars-core/src/series/implementations/mod.rs index 9df0e7695127..48c488d0b7f7 100644 --- a/crates/polars-core/src/series/implementations/mod.rs +++ b/crates/polars-core/src/series/implementations/mod.rs @@ -276,6 +276,13 @@ macro_rules! impl_dyn_series { self.0.median() } + fn quantile(&self, quantile: f64, method: QuantileMethod) -> PolarsResult> { + polars_ensure!((0.0..=1.0).contains(&quantile), + ComputeError: "quantile should be between 0.0 and 1.0", + ); + Ok(self.0.quantile(quantile, method).unwrap().map(|v| v as f64)) + } + fn std(&self, ddof: u8) -> Option { self.0.std(ddof) } diff --git a/crates/polars-core/src/series/series_trait.rs b/crates/polars-core/src/series/series_trait.rs index 178dd3729c44..25fc10979399 100644 --- a/crates/polars-core/src/series/series_trait.rs +++ b/crates/polars-core/src/series/series_trait.rs @@ -359,6 +359,12 @@ pub trait SeriesTrait: None } + /// Returns the quantile value in the array + /// Returns a result of option because the array is nullable and quantile can OOB. + fn quantile(&self, _quantile: f64, _method: QuantileMethod) -> PolarsResult> { + Ok(None) + } + /// Create a new Series filled with values from the given index. /// /// # Example diff --git a/crates/polars-ops/src/chunked_array/list/dispersion.rs b/crates/polars-ops/src/chunked_array/list/dispersion.rs index a4521c71c9d0..99be35eb58ab 100644 --- a/crates/polars-ops/src/chunked_array/list/dispersion.rs +++ b/crates/polars-ops/src/chunked_array/list/dispersion.rs @@ -24,6 +24,54 @@ pub(super) fn median_with_nulls(ca: &ListChunked) -> Series { } } +pub(super) fn quantile_with_nulls( + ca: &ListChunked, + quantile: f64, + method: QuantileMethod, +) -> Series { + match ca.inner_dtype() { + DataType::Float32 => { + let out: Float32Chunked = ca + .apply_amortized_generic(|s| { + s.and_then(|s| { + s.as_ref() + .quantile(quantile, method) + .unwrap_or(Some(f64::NAN)) + .map(|v| v as f32) + }) + }) + .with_name(ca.name().clone()); + out.into_series() + }, + #[cfg(feature = "dtype-duration")] + DataType::Duration(tu) => { + let out: Int64Chunked = ca + .apply_amortized_generic(|s| { + s.and_then(|s| { + s.as_ref() + .quantile(quantile, method) + .unwrap_or(Some(f64::NAN)) + .map(|v| v as i64) + }) + }) + .with_name(ca.name().clone()); + out.into_duration(*tu).into_series() + }, + _ => { + let out: Float64Chunked = ca + .apply_amortized_generic(|s| { + s.and_then(|s| { + s.as_ref() + .quantile(quantile, method) + .unwrap_or(Some(f64::NAN)) + }) + }) + .with_name(ca.name().clone()); + out.into_series() + }, + } +} + pub(super) fn std_with_nulls(ca: &ListChunked, ddof: u8) -> Series { match ca.inner_dtype() { DataType::Float32 => { diff --git a/crates/polars-ops/src/chunked_array/list/namespace.rs b/crates/polars-ops/src/chunked_array/list/namespace.rs index 5c21a7e65ac3..bafbd37f7487 100644 --- a/crates/polars-ops/src/chunked_array/list/namespace.rs +++ b/crates/polars-ops/src/chunked_array/list/namespace.rs @@ -223,6 +223,11 @@ pub trait ListNameSpaceImpl: AsList { dispersion::median_with_nulls(ca) } + fn lst_quantile(&self, quantile: f64, method: QuantileMethod) -> Series { + let ca = self.as_list(); + dispersion::quantile_with_nulls(ca, quantile, method) + } + fn lst_std(&self, ddof: u8) -> Series { let ca = self.as_list(); dispersion::std_with_nulls(ca, ddof) diff --git a/crates/polars-plan/src/dsl/function_expr/list.rs b/crates/polars-plan/src/dsl/function_expr/list.rs index e6b1468f4f82..1618916b1709 100644 --- a/crates/polars-plan/src/dsl/function_expr/list.rs +++ b/crates/polars-plan/src/dsl/function_expr/list.rs @@ -34,6 +34,7 @@ pub enum ListFunction { Min, Mean, Median, + Quantile(QuantileOptions), Std(u8), Var(u8), ArgMin, @@ -85,6 +86,7 @@ impl ListFunction { Max => mapper.map_to_list_and_array_inner_dtype(), Mean => mapper.with_dtype(DataType::Float64), Median => mapper.map_to_float_dtype(), + Quantile(_) => mapper.map_to_float_dtype(), Std(_) => mapper.map_to_float_dtype(), // Need to also have this sometimes marked as float32 or duration.. Var(_) => mapper.map_to_float_dtype(), ArgMin => mapper.with_dtype(IDX_DTYPE), @@ -152,6 +154,7 @@ impl Display for ListFunction { Max => "max", Mean => "mean", Median => "median", + Quantile(_) => "quantile", Std(_) => "std", Var(_) => "var", ArgMin => "arg_min", @@ -222,6 +225,7 @@ impl From for SpecialEq> { Min => map!(min), Mean => map!(mean), Median => map!(median), + Quantile(options) => map!(quantile, options.prob, options.method), Std(ddof) => map!(std, ddof), Var(ddof) => map!(var, ddof), ArgMin => map!(arg_min), @@ -571,6 +575,10 @@ pub(super) fn median(s: &Column) -> PolarsResult { Ok(s.list()?.lst_median().into()) } +pub(super) fn quantile(s: &Column, quantile: f64, method: QuantileMethod) -> PolarsResult { + Ok(s.list()?.lst_quantile(quantile, method).into()) +} + pub(super) fn std(s: &Column, ddof: u8) -> PolarsResult { Ok(s.list()?.lst_std(ddof).into()) } diff --git a/crates/polars-plan/src/dsl/list.rs b/crates/polars-plan/src/dsl/list.rs index d5c2622b5afb..8fb4d2b39e84 100644 --- a/crates/polars-plan/src/dsl/list.rs +++ b/crates/polars-plan/src/dsl/list.rs @@ -112,6 +112,13 @@ impl ListNameSpace { .map_private(FunctionExpr::ListExpr(ListFunction::Median)) } + pub fn quantile(self, prob: f64, method: QuantileMethod) -> Expr { + self.0 + .map_private(FunctionExpr::ListExpr(ListFunction::Quantile( + QuantileOptions { prob, method }, + ))) + } + pub fn std(self, ddof: u8) -> Expr { self.0 .map_private(FunctionExpr::ListExpr(ListFunction::Std(ddof))) diff --git a/crates/polars-python/src/expr/list.rs b/crates/polars-python/src/expr/list.rs index b8f10fc60c3e..d1dcac74227e 100644 --- a/crates/polars-python/src/expr/list.rs +++ b/crates/polars-python/src/expr/list.rs @@ -89,6 +89,15 @@ impl PyExpr { .into() } + fn list_quantile(&self, quantile: f64, method: Wrap) -> Self { + self.inner + .clone() + .list() + .quantile(quantile, method.0) + .with_fmt("list.quantile") + .into() + } + fn list_std(&self, ddof: u8) -> Self { self.inner .clone() diff --git a/py-polars/polars/_typing.py b/py-polars/polars/_typing.py index b362e2116502..5bd2abfa99f7 100644 --- a/py-polars/polars/_typing.py +++ b/py-polars/polars/_typing.py @@ -165,6 +165,14 @@ def __arrow_c_stream__(self, requested_schema: object | None = None) -> object: RollingInterpolationMethod: TypeAlias = Literal[ "nearest", "higher", "lower", "midpoint", "linear" ] # QuantileInterpolOptions +QuantileMethod: TypeAlias = Literal[ + "lower", + "higher", + "nearest", + "linear", + "midpoint", + "equiprobable", +] ListToStructWidthStrategy: TypeAlias = Literal["first_non_null", "max_width"] # The following have no equivalent on the Rust side diff --git a/py-polars/polars/expr/list.py b/py-polars/polars/expr/list.py index de98cf5869ce..5d5d35a945fc 100644 --- a/py-polars/polars/expr/list.py +++ b/py-polars/polars/expr/list.py @@ -20,6 +20,7 @@ IntoExprColumn, ListToStructWidthStrategy, NullBehavior, + QuantileMethod, ) @@ -294,6 +295,27 @@ def median(self) -> Expr: """ return wrap_expr(self._pyexpr.list_median()) + def quantile(self, quantile: float, method: QuantileMethod) -> Expr: + """ + Compute the specified quantile value of the lists in the array. + + Examples + -------- + >>> df = pl.DataFrame({"values": [[-1, 0, 1], [1, 10]]}) + >>> expr = pl.col("values").list.quantile(0.1, "linear").alias("10percent") + >>> df.with_columns(expr) + shape: (2, 2) + ┌────────────┬───────────┐ + │ values ┆ 10percent │ + │ --- ┆ --- │ + │ list[i64] ┆ f64 │ + ╞════════════╪═══════════╡ + │ [-1, 0, 1] ┆ -0.8 │ + │ [1, 10] ┆ 1.9 │ + └────────────┴───────────┘ + """ + return wrap_expr(self._pyexpr.list_quantile(quantile, method)) + def std(self, ddof: int = 1) -> Expr: """ Compute the std value of the lists in the array. diff --git a/py-polars/tests/unit/datatypes/test_list.py b/py-polars/tests/unit/datatypes/test_list.py index 53c401ec110e..12f39b2e14b3 100644 --- a/py-polars/tests/unit/datatypes/test_list.py +++ b/py-polars/tests/unit/datatypes/test_list.py @@ -752,6 +752,58 @@ def test_list_median(data_dispersion: pl.DataFrame) -> None: assert_frame_equal(result, expected) +def test_list_quantile(data_dispersion: pl.DataFrame) -> None: + df = data_dispersion + + result = df.select( + pl.col("int").list.quantile(0.3, "linear").name.suffix("_median"), + pl.col("float").list.quantile(0.3, "linear").name.suffix("_median"), + pl.col("duration").list.quantile(0.3, "linear").name.suffix("_median"), + ) + + expected = pl.DataFrame( + [ + pl.Series("int_median", [2.2], dtype=pl.Float64), + pl.Series("float_median", [2.2], dtype=pl.Float64), + pl.Series( + "duration_median", + [timedelta(microseconds=2200)], + dtype=pl.Duration(time_unit="us"), + ), + ] + ) + + assert_frame_equal(result, expected) + + +def test_list_quantile_extremities(data_dispersion: pl.DataFrame) -> None: + df = data_dispersion + assert_frame_equal( + df.select( + pl.col(col).list.quantile(0.0, "linear") + for col in ["int", "float", "duration"] + ), + df.select(pl.col(col).list.min() for col in ["int", "float", "duration"]), + check_dtypes=False, + ) + assert_frame_equal( + df.select( + pl.col(col).list.quantile(1.0, "linear") + for col in ["int", "float", "duration"] + ), + df.select(pl.col(col).list.max() for col in ["int", "float", "duration"]), + check_dtypes=False, + ) + assert_frame_equal( + df.select( + pl.col(col).list.quantile(0.5, "linear") + for col in ["int", "float", "duration"] + ), + df.select(pl.col(col).list.median() for col in ["int", "float", "duration"]), + check_dtypes=False, + ) + + def test_list_gather_null_struct_14927() -> None: df = pl.DataFrame( [