diff --git a/polars_distance/polars_distance/polars_distance/__init__.py b/polars_distance/polars_distance/polars_distance/__init__.py index 86dd101..4f3a60a 100644 --- a/polars_distance/polars_distance/polars_distance/__init__.py +++ b/polars_distance/polars_distance/polars_distance/__init__.py @@ -244,6 +244,16 @@ def jaccard_index(self, other: IntoExpr) -> pl.Expr: is_elementwise=True, ) + def tversky_index(self, other: IntoExpr, alpha: float, beta: float) -> pl.Expr: + """Returns tversky index between two lists. Each list is converted to a set.""" + return self._expr.register_plugin( + lib=lib, + args=[other], + kwargs={"alpha": alpha, "beta": beta}, + symbol="tversky_index_list", + is_elementwise=True, + ) + def sorensen_index(self, other: IntoExpr) -> pl.Expr: """Returns sorensen index between two lists. Each list is converted to a set.""" return self._expr.register_plugin( diff --git a/polars_distance/polars_distance/src/expressions.rs b/polars_distance/polars_distance/src/expressions.rs index eb588e9..4495821 100644 --- a/polars_distance/polars_distance/src/expressions.rs +++ b/polars_distance/polars_distance/src/expressions.rs @@ -1,5 +1,7 @@ use crate::array::{cosine_dist, distance_calc_float_inp, euclidean_dist}; -use crate::list::{cosine_set_distance, jaccard_index, overlap_coef, sorensen_index}; +use crate::list::{ + cosine_set_distance, jaccard_index, overlap_coef, sorensen_index, tversky_index, +}; use crate::string::{ dam_levenshtein_dist, dam_levenshtein_normalized_dist, hamming_dist, hamming_normalized_dist, indel_dist, indel_normalized_dist, jaro_dist, jaro_normalized_dist, jaro_winkler_dist, @@ -10,12 +12,13 @@ use crate::string::{ use distances::vectors::{canberra, chebyshev}; use polars::prelude::*; use pyo3_polars::derive::polars_expr; -// use serde::Deserialize; +use serde::Deserialize; -// #[derive(Deserialize)] -// struct StringDistanceKwargs { -// normalized: bool, -// } +#[derive(Deserialize)] +struct TverskyIndexKwargs { + alpha: f64, + beta: f64, +} #[polars_expr(output_type=UInt32)] fn hamming_str(inputs: &[Series]) -> PolarsResult { @@ -341,3 +344,10 @@ fn cosine_list(inputs: &[Series]) -> PolarsResult { let y: &ChunkedArray = inputs[1].list()?; cosine_set_distance(x, y).map(|ca| ca.into_series()) } + +#[polars_expr(output_type=Float64)] +fn tversky_index_list(inputs: &[Series], kwargs: TverskyIndexKwargs) -> PolarsResult { + let x: &ChunkedArray = inputs[0].list()?; + let y: &ChunkedArray = inputs[1].list()?; + tversky_index(x, y, kwargs.alpha, kwargs.beta).map(|ca| ca.into_series()) +} diff --git a/polars_distance/polars_distance/src/list.rs b/polars_distance/polars_distance/src/list.rs index ccff568..befdf9f 100644 --- a/polars_distance/polars_distance/src/list.rs +++ b/polars_distance/polars_distance/src/list.rs @@ -179,3 +179,56 @@ pub fn cosine_set_distance(a: &ListChunked, b: &ListChunked) -> PolarsResult PolarsResult { + polars_ensure!( + a.inner_dtype() == b.inner_dtype(), + ComputeError: "inner data types don't match" + ); + + if a.inner_dtype().is_integer() { + with_match_physical_integer_type!(a.inner_dtype(), |$T| { + Ok(binary_elementwise(a, b, |a, b| match (a, b) { + (Some(a), Some(b)) => { + let a = a.as_any().downcast_ref::>().unwrap(); + let b = b.as_any().downcast_ref::>().unwrap(); + let s1 = a.into_iter().collect::>(); + let s2 = b.into_iter().collect::>(); + let len_intersect = s1.intersection(&s2).count() as f64; + let len_diff1 = s1.difference(&s2).count(); + let len_diff2 = s2.difference(&s1).count(); + + Some(len_intersect / (len_intersect + (alpha * len_diff1 as f64) + (beta * len_diff2 as f64))) + } + _ => None, + })) + + }) + } else { + match a.inner_dtype() { + DataType::Utf8 => { + Ok(binary_elementwise(a, b, |a, b| match (a, b) { + (Some(a), Some(b)) => { + let a = a.as_any().downcast_ref::>().unwrap(); + let b = b.as_any().downcast_ref::>().unwrap(); + let s1 = a.into_iter().collect::>(); + let s2 = b.into_iter().collect::>(); + let len_intersect = s1.intersection(&s2).count() as f64; + let len_diff1 = s1.difference(&s2).count(); + let len_diff2 = s2.difference(&s1).count(); + Some(len_intersect / (len_intersect + (alpha * len_diff1 as f64) + (beta * len_diff2 as f64))) + } + _ => None, + })) + }, + _ => Err(PolarsError::ComputeError( + format!("tversky index distance only works on inner dtype Utf8 or integer. Use of {} is not supported", a.inner_dtype()).into(), + )) + } + } +} diff --git a/polars_distance/tests/test_distance_arr.py b/polars_distance/tests/test_distance_arr.py index 8cd66ae..329b764 100644 --- a/polars_distance/tests/test_distance_arr.py +++ b/polars_distance/tests/test_distance_arr.py @@ -192,3 +192,26 @@ def test_cosine_set_distance(data_sets): assert_frame_equal(result, expected) assert_frame_equal(result_int, expected) + + +def test_cosine_set_distance(data_sets): + result = data_sets.select( + pld.col("x_str") + .dist_list.tversky_index("y_str", alpha=1, beta=1) + .alias("tversky") + ) + + result_int = data_sets.select( + pld.col("x_int") + .dist_list.tversky_index("y_int", alpha=1, beta=1) + .alias("tversky") + ) + + expected = pl.DataFrame( + [ + pl.Series("tversky", [0.3333333333333333], dtype=pl.Float64), + ] + ) + + assert_frame_equal(result, expected) + assert_frame_equal(result_int, expected)