Skip to content

Commit

Permalink
fix: Remove assert that panics on group_by followed by head(n),…
Browse files Browse the repository at this point in the history
… where `n` is larger then the frame height (#20819)
  • Loading branch information
alexander-beedie authored Jan 21, 2025
1 parent e567c79 commit 30a7e34
Show file tree
Hide file tree
Showing 3 changed files with 32 additions and 3 deletions.
8 changes: 6 additions & 2 deletions crates/polars-core/src/frame/group_by/position.rs
Original file line number Diff line number Diff line change
Expand Up @@ -620,8 +620,12 @@ impl Default for GroupPositions {
impl GroupPositions {
pub fn slice(&self, offset: i64, len: usize) -> Self {
let offset = self.offset + offset;
assert!(len <= self.len);
slice_groups(self.original.clone(), offset, len)
slice_groups(
self.original.clone(),
offset,
// invariant that len should be in bounds, so truncate if not
if len > self.len { self.len } else { len },
)
}

pub fn sort(&mut self) {
Expand Down
2 changes: 1 addition & 1 deletion py-polars/polars/series/series.py
Original file line number Diff line number Diff line change
Expand Up @@ -168,7 +168,7 @@
@expr_dispatch
class Series:
"""
A Series represents a single column in a polars DataFrame.
A Series represents a single column in a Polars DataFrame.
Parameters
----------
Expand Down
25 changes: 25 additions & 0 deletions py-polars/tests/unit/test_queries.py
Original file line number Diff line number Diff line change
Expand Up @@ -172,6 +172,31 @@ def test_group_by_agg_equals_zero_3535() -> None:
}


def test_group_by_followed_by_limit() -> None:
lf = pl.LazyFrame(
{
"key": ["xx", "yy", "zz", "xx", "zz", "zz"],
"val1": [15, 25, 10, 20, 20, 20],
"val2": [-33, 20, 44, -2, 16, 71],
}
)
grp = lf.group_by("key", maintain_order=True).agg(pl.col("val1", "val2").sum())
assert sorted(grp.collect().rows()) == [
("xx", 35, -35),
("yy", 25, 20),
("zz", 50, 131),
]
assert sorted(grp.head(2).collect().rows()) == [
("xx", 35, -35),
("yy", 25, 20),
]
assert sorted(grp.head(10).collect().rows()) == [
("xx", 35, -35),
("yy", 25, 20),
("zz", 50, 131),
]


def test_dtype_concat_3735() -> None:
for dt in NUMERIC_DTYPES:
d1 = pl.DataFrame([pl.Series("val", [1, 2], dtype=dt)])
Expand Down

0 comments on commit 30a7e34

Please sign in to comment.