diff --git a/crates/polars-ops/src/chunked_array/list/index_of_in.rs b/crates/polars-ops/src/chunked_array/list/index_of_in.rs index 902fb48bfd1b..a51e7764ed94 100644 --- a/crates/polars-ops/src/chunked_array/list/index_of_in.rs +++ b/crates/polars-ops/src/chunked_array/list/index_of_in.rs @@ -1,5 +1,49 @@ use super::*; +use crate::series::index_of; pub fn list_index_of_in(ca: &ListChunked, needles: &Series) -> PolarsResult { - todo!("Implement me"); + let mut builder = PrimitiveChunkedBuilder::::new(ca.name().clone(), ca.len()); + if needles.len() == 1 { + // For some reason we need to do casting ourselves. + let needle = needles.get(0).unwrap(); + let cast_needle = needle.cast(ca.dtype().inner_dtype().unwrap()); + if cast_needle != needle { + todo!("nicer error handling"); + } + let needle = Scalar::new( + cast_needle.dtype().clone(), + cast_needle.into_static(), + ); + ca.amortized_iter().for_each(|opt_series| { + if let Some(subseries) = opt_series { + // TODO justify why unwrap()s are ok + builder.append_option( + // TODO clone() sucks, maybe need to change the API for index_of? + index_of(subseries.as_ref(), needle.clone()) + .unwrap() + .map(|v| v.try_into().unwrap()), + ); + } else { + builder.append_null(); + } + }); + } else { + ca.amortized_iter() + .zip(needles.iter()) + .for_each(|(opt_series, needle)| { + match (opt_series, needle) { + (None, _) => builder.append_null(), + (Some(subseries), needle) => { + let needle = Scalar::new(needles.dtype().clone(), needle.into_static()); + // TODO justify why unwrap()s are ok + builder.append_option( + index_of(subseries.as_ref(), needle) + .unwrap() + .map(|v| v.try_into().unwrap()), + ); + }, + } + }); + } + Ok(builder.finish().into()) } diff --git a/crates/polars-plan/src/dsl/function_expr/list.rs b/crates/polars-plan/src/dsl/function_expr/list.rs index 49e96d4cc6c9..7e374031b905 100644 --- a/crates/polars-plan/src/dsl/function_expr/list.rs +++ b/crates/polars-plan/src/dsl/function_expr/list.rs @@ -560,8 +560,7 @@ pub(super) fn index_of_in(args: &[Column]) -> PolarsResult { let s = &args[0]; let needles = &args[1]; let ca = s.list()?; - todo!("Implement me"); - //list_count_matches(ca, needles).map(Column::from) + list_index_of_in(ca, needles.as_materialized_series()).map(Column::from) } pub(super) fn sum(s: &Column) -> PolarsResult { diff --git a/crates/polars-plan/src/dsl/list.rs b/crates/polars-plan/src/dsl/list.rs index 087609c6c299..9f75089a6478 100644 --- a/crates/polars-plan/src/dsl/list.rs +++ b/crates/polars-plan/src/dsl/list.rs @@ -325,7 +325,14 @@ impl ListNameSpace { #[cfg(feature = "list_index_of_in")] /// Find the index of needle in the list. pub fn index_of_in>(self, needle: N) -> Expr { - todo!("Implement me"); + let other = needle.into(); + + self.0.map_many_private( + FunctionExpr::ListExpr(ListFunction::IndexOfIn), + &[other], + false, + None, + ) } #[cfg(feature = "list_sets")] diff --git a/py-polars/tests/unit/operations/namespaces/list/test_index_of_in.py b/py-polars/tests/unit/operations/namespaces/list/test_index_of_in.py index 0cb406507ebf..2ff5db4e25ed 100644 --- a/py-polars/tests/unit/operations/namespaces/list/test_index_of_in.py +++ b/py-polars/tests/unit/operations/namespaces/list/test_index_of_in.py @@ -8,7 +8,7 @@ def test_index_of_in_from_constant() -> None: df = pl.DataFrame({"lists": [[3, 1], [2, 4], [5, 3, 1]]}) assert_frame_equal( df.select(pl.col("lists").list.index_of_in(1)), - pl.DataFrame({"lists": [1, None, 2]}), + pl.DataFrame({"lists": [1, None, 2]}, schema={"lists": pl.get_index_type()}), ) @@ -16,5 +16,5 @@ def test_index_of_in_from_column() -> None: df = pl.DataFrame({"lists": [[3, 1], [2, 4], [5, 3, 1]], "values": [1, 2, 6]}) assert_frame_equal( df.select(pl.col("lists").list.index_of_in(pl.col("values"))), - pl.DataFrame({"lists": [1, None, 2]}), + pl.DataFrame({"lists": [1, 0, None]}, schema={"lists": pl.get_index_type()}), )