diff --git a/polars_distance/src/expressions.rs b/polars_distance/src/expressions.rs index 33b2098..9a85680 100644 --- a/polars_distance/src/expressions.rs +++ b/polars_distance/src/expressions.rs @@ -39,7 +39,7 @@ fn elementwise_str_u32( y: &ChunkedArray, f: fn(&str, &str) -> u32, ) -> UInt32Chunked { - let (x, y) = if x.len() < y.len() { (x, y) } else { (y, x) }; + let (x, y) = if x.len() < y.len() { (y, x) } else { (x, y) }; match y.len() { 1 => match unsafe { y.get_unchecked(0) } { Some(y_value) => arity::unary_elementwise(x, |x| x.map(|x| f(x, y_value))), @@ -59,7 +59,7 @@ fn elementwise_str_f64( y: &ChunkedArray, f: fn(&str, &str) -> f64, ) -> Float64Chunked { - let (x, y) = if x.len() < y.len() { (x, y) } else { (y, x) }; + let (x, y) = if x.len() < y.len() { (y, x) } else { (x, y) }; match y.len() { 1 => match unsafe { y.get_unchecked(0) } { Some(y_value) => arity::unary_elementwise(x, |x| x.map(|x| f(x, y_value))), diff --git a/polars_distance/tests/test_distance_arr.py b/polars_distance/tests/test_distance_arr.py index 178e2f3..157f119 100644 --- a/polars_distance/tests/test_distance_arr.py +++ b/polars_distance/tests/test_distance_arr.py @@ -359,3 +359,16 @@ def test_broadcast(): ] ) assert_frame_equal(result, expected) + + df = pl.DataFrame( + { + "a1": ["test1", "hello", "test1", "hello", "test1", "hello"], + } + ) + result = df.select(d=pld.col("a1").dist_str.levenshtein(pl.lit("testaa"))) + expected = pl.DataFrame( + [ + pl.Series("d", [2, 5, 2, 5, 2, 5], dtype=pl.UInt32), + ] + ) + assert_frame_equal(result, expected)