Skip to content

Commit

Permalink
Sketch of implementation, initial tests pass.
Browse files Browse the repository at this point in the history
  • Loading branch information
pythonspeed committed Jan 15, 2025
1 parent 1a09f7b commit 3083803
Show file tree
Hide file tree
Showing 4 changed files with 56 additions and 6 deletions.
46 changes: 45 additions & 1 deletion crates/polars-ops/src/chunked_array/list/index_of_in.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,49 @@
use super::*;
use crate::series::index_of;

pub fn list_index_of_in(ca: &ListChunked, needles: &Series) -> PolarsResult<Series> {
todo!("Implement me");
let mut builder = PrimitiveChunkedBuilder::<IdxType>::new(ca.name().clone(), ca.len());
if needles.len() == 1 {
// For some reason we need to do casting ourselves.
let needle = needles.get(0).unwrap();
let cast_needle = needle.cast(ca.dtype().inner_dtype().unwrap());
if cast_needle != needle {
todo!("nicer error handling");
}
let needle = Scalar::new(
cast_needle.dtype().clone(),
cast_needle.into_static(),
);
ca.amortized_iter().for_each(|opt_series| {
if let Some(subseries) = opt_series {
// TODO justify why unwrap()s are ok
builder.append_option(
// TODO clone() sucks, maybe need to change the API for index_of?
index_of(subseries.as_ref(), needle.clone())
.unwrap()
.map(|v| v.try_into().unwrap()),
);
} else {
builder.append_null();
}
});
} else {
ca.amortized_iter()
.zip(needles.iter())
.for_each(|(opt_series, needle)| {
match (opt_series, needle) {
(None, _) => builder.append_null(),
(Some(subseries), needle) => {
let needle = Scalar::new(needles.dtype().clone(), needle.into_static());
// TODO justify why unwrap()s are ok
builder.append_option(
index_of(subseries.as_ref(), needle)
.unwrap()
.map(|v| v.try_into().unwrap()),
);
},
}
});
}
Ok(builder.finish().into())
}
3 changes: 1 addition & 2 deletions crates/polars-plan/src/dsl/function_expr/list.rs
Original file line number Diff line number Diff line change
Expand Up @@ -560,8 +560,7 @@ pub(super) fn index_of_in(args: &[Column]) -> PolarsResult<Column> {
let s = &args[0];
let needles = &args[1];
let ca = s.list()?;
todo!("Implement me");
//list_count_matches(ca, needles).map(Column::from)
list_index_of_in(ca, needles.as_materialized_series()).map(Column::from)
}

pub(super) fn sum(s: &Column) -> PolarsResult<Column> {
Expand Down
9 changes: 8 additions & 1 deletion crates/polars-plan/src/dsl/list.rs
Original file line number Diff line number Diff line change
Expand Up @@ -325,7 +325,14 @@ impl ListNameSpace {
#[cfg(feature = "list_index_of_in")]
/// Find the index of needle in the list.
pub fn index_of_in<N: Into<Expr>>(self, needle: N) -> Expr {
todo!("Implement me");
let other = needle.into();

self.0.map_many_private(
FunctionExpr::ListExpr(ListFunction::IndexOfIn),
&[other],
false,
None,
)
}

#[cfg(feature = "list_sets")]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,13 +8,13 @@ def test_index_of_in_from_constant() -> None:
df = pl.DataFrame({"lists": [[3, 1], [2, 4], [5, 3, 1]]})
assert_frame_equal(
df.select(pl.col("lists").list.index_of_in(1)),
pl.DataFrame({"lists": [1, None, 2]}),
pl.DataFrame({"lists": [1, None, 2]}, schema={"lists": pl.get_index_type()}),
)


def test_index_of_in_from_column() -> None:
df = pl.DataFrame({"lists": [[3, 1], [2, 4], [5, 3, 1]], "values": [1, 2, 6]})
assert_frame_equal(
df.select(pl.col("lists").list.index_of_in(pl.col("values"))),
pl.DataFrame({"lists": [1, None, 2]}),
pl.DataFrame({"lists": [1, 0, None]}, schema={"lists": pl.get_index_type()}),
)

0 comments on commit 3083803

Please sign in to comment.