From 30a7e347cbb0c099e3f4e30df5f7ca6743d7816b Mon Sep 17 00:00:00 2001 From: Alexander Beedie Date: Tue, 21 Jan 2025 19:29:58 +0400 Subject: [PATCH] fix: Remove `assert` that panics on `group_by` followed by `head(n)`, where `n` is larger then the frame height (#20819) --- .../src/frame/group_by/position.rs | 8 ++++-- py-polars/polars/series/series.py | 2 +- py-polars/tests/unit/test_queries.py | 25 +++++++++++++++++++ 3 files changed, 32 insertions(+), 3 deletions(-) diff --git a/crates/polars-core/src/frame/group_by/position.rs b/crates/polars-core/src/frame/group_by/position.rs index 58734cd50b74..8c38596faf74 100644 --- a/crates/polars-core/src/frame/group_by/position.rs +++ b/crates/polars-core/src/frame/group_by/position.rs @@ -620,8 +620,12 @@ impl Default for GroupPositions { impl GroupPositions { pub fn slice(&self, offset: i64, len: usize) -> Self { let offset = self.offset + offset; - assert!(len <= self.len); - slice_groups(self.original.clone(), offset, len) + slice_groups( + self.original.clone(), + offset, + // invariant that len should be in bounds, so truncate if not + if len > self.len { self.len } else { len }, + ) } pub fn sort(&mut self) { diff --git a/py-polars/polars/series/series.py b/py-polars/polars/series/series.py index 875aef8be515..51b3c4423aec 100644 --- a/py-polars/polars/series/series.py +++ b/py-polars/polars/series/series.py @@ -168,7 +168,7 @@ @expr_dispatch class Series: """ - A Series represents a single column in a polars DataFrame. + A Series represents a single column in a Polars DataFrame. Parameters ---------- diff --git a/py-polars/tests/unit/test_queries.py b/py-polars/tests/unit/test_queries.py index e7323661978a..d5620f2d2ce8 100644 --- a/py-polars/tests/unit/test_queries.py +++ b/py-polars/tests/unit/test_queries.py @@ -172,6 +172,31 @@ def test_group_by_agg_equals_zero_3535() -> None: } +def test_group_by_followed_by_limit() -> None: + lf = pl.LazyFrame( + { + "key": ["xx", "yy", "zz", "xx", "zz", "zz"], + "val1": [15, 25, 10, 20, 20, 20], + "val2": [-33, 20, 44, -2, 16, 71], + } + ) + grp = lf.group_by("key", maintain_order=True).agg(pl.col("val1", "val2").sum()) + assert sorted(grp.collect().rows()) == [ + ("xx", 35, -35), + ("yy", 25, 20), + ("zz", 50, 131), + ] + assert sorted(grp.head(2).collect().rows()) == [ + ("xx", 35, -35), + ("yy", 25, 20), + ] + assert sorted(grp.head(10).collect().rows()) == [ + ("xx", 35, -35), + ("yy", 25, 20), + ("zz", 50, 131), + ] + + def test_dtype_concat_3735() -> None: for dt in NUMERIC_DTYPES: d1 = pl.DataFrame([pl.Series("val", [1, 2], dtype=dt)])