From 1c5b7663ef53f4ef746e5e6038eef7521873a621 Mon Sep 17 00:00:00 2001 From: ion-elgreco <15728914+ion-elgreco@users.noreply.github.com> Date: Thu, 21 Dec 2023 17:33:34 +0100 Subject: [PATCH 1/3] add tversky --- .../polars_distance/__init__.py | 10 ++++ .../polars_distance/src/expressions.rs | 20 +++++--- polars_distance/polars_distance/src/list.rs | 49 +++++++++++++++++++ 3 files changed, 73 insertions(+), 6 deletions(-) diff --git a/polars_distance/polars_distance/polars_distance/__init__.py b/polars_distance/polars_distance/polars_distance/__init__.py index 86dd101..4f3a60a 100644 --- a/polars_distance/polars_distance/polars_distance/__init__.py +++ b/polars_distance/polars_distance/polars_distance/__init__.py @@ -244,6 +244,16 @@ def jaccard_index(self, other: IntoExpr) -> pl.Expr: is_elementwise=True, ) + def tversky_index(self, other: IntoExpr, alpha: float, beta: float) -> pl.Expr: + """Returns tversky index between two lists. Each list is converted to a set.""" + return self._expr.register_plugin( + lib=lib, + args=[other], + kwargs={"alpha": alpha, "beta": beta}, + symbol="tversky_index_list", + is_elementwise=True, + ) + def sorensen_index(self, other: IntoExpr) -> pl.Expr: """Returns sorensen index between two lists. Each list is converted to a set.""" return self._expr.register_plugin( diff --git a/polars_distance/polars_distance/src/expressions.rs b/polars_distance/polars_distance/src/expressions.rs index eb588e9..1628f41 100644 --- a/polars_distance/polars_distance/src/expressions.rs +++ b/polars_distance/polars_distance/src/expressions.rs @@ -1,5 +1,5 @@ use crate::array::{cosine_dist, distance_calc_float_inp, euclidean_dist}; -use crate::list::{cosine_set_distance, jaccard_index, overlap_coef, sorensen_index}; +use crate::list::{cosine_set_distance, jaccard_index, overlap_coef, sorensen_index, tversky_index}; use crate::string::{ dam_levenshtein_dist, dam_levenshtein_normalized_dist, hamming_dist, hamming_normalized_dist, indel_dist, indel_normalized_dist, jaro_dist, jaro_normalized_dist, jaro_winkler_dist, @@ -10,12 +10,13 @@ use crate::string::{ use distances::vectors::{canberra, chebyshev}; use polars::prelude::*; use pyo3_polars::derive::polars_expr; -// use serde::Deserialize; +use serde::Deserialize; -// #[derive(Deserialize)] -// struct StringDistanceKwargs { -// normalized: bool, -// } +#[derive(Deserialize)] +struct TverskyIndexKwargs { + alpha: f64, + beta: f64, +} #[polars_expr(output_type=UInt32)] fn hamming_str(inputs: &[Series]) -> PolarsResult { @@ -341,3 +342,10 @@ fn cosine_list(inputs: &[Series]) -> PolarsResult { let y: &ChunkedArray = inputs[1].list()?; cosine_set_distance(x, y).map(|ca| ca.into_series()) } + +#[polars_expr(output_type=Float64)] +fn tversky_index_list(inputs: &[Series], kwargs: TverskyIndexKwargs) -> PolarsResult { + let x: &ChunkedArray = inputs[0].list()?; + let y: &ChunkedArray = inputs[1].list()?; + tversky_index(x, y, kwargs.alpha, kwargs.beta).map(|ca| ca.into_series()) +} \ No newline at end of file diff --git a/polars_distance/polars_distance/src/list.rs b/polars_distance/polars_distance/src/list.rs index ccff568..648f09f 100644 --- a/polars_distance/polars_distance/src/list.rs +++ b/polars_distance/polars_distance/src/list.rs @@ -179,3 +179,52 @@ pub fn cosine_set_distance(a: &ListChunked, b: &ListChunked) -> PolarsResult PolarsResult { + polars_ensure!( + a.inner_dtype() == b.inner_dtype(), + ComputeError: "inner data types don't match" + ); + + if a.inner_dtype().is_integer() { + with_match_physical_integer_type!(a.inner_dtype(), |$T| { + Ok(binary_elementwise(a, b, |a, b| match (a, b) { + (Some(a), Some(b)) => { + let a = a.as_any().downcast_ref::>().unwrap(); + let b = b.as_any().downcast_ref::>().unwrap(); + let s1 = a.into_iter().collect::>(); + let s2 = b.into_iter().collect::>(); + let len_intersect = s1.intersection(&s2).count() as f64; + let len_diff1 = s1.difference(&s2).count(); + let len_diff2 = s2.difference(&s1).count(); + + Some(len_intersect / (len_intersect + (alpha * len_diff1 as f64) + (beta * len_diff2 as f64))) + } + _ => None, + })) + + }) + } else { + match a.inner_dtype() { + DataType::Utf8 => { + Ok(binary_elementwise(a, b, |a, b| match (a, b) { + (Some(a), Some(b)) => { + let a = a.as_any().downcast_ref::>().unwrap(); + let b = b.as_any().downcast_ref::>().unwrap(); + let s1 = a.into_iter().collect::>(); + let s2 = b.into_iter().collect::>(); + let len_intersect = s1.intersection(&s2).count() as f64; + let len_diff1 = s1.difference(&s2).count(); + let len_diff2 = s2.difference(&s1).count(); + + Some(len_intersect / (len_intersect + (alpha * len_diff1 as f64) + (beta * len_diff2 as f64))) + } + _ => None, + })) + }, + _ => Err(PolarsError::ComputeError( + format!("tversky index distance only works on inner dtype Utf8 or integer. Use of {} is not supported", a.inner_dtype()).into(), + )) + } + } +} From c192ce4a5710f73d88e93c0558dba50e4f49857d Mon Sep 17 00:00:00 2001 From: ion-elgreco <15728914+ion-elgreco@users.noreply.github.com> Date: Thu, 21 Dec 2023 17:34:12 +0100 Subject: [PATCH 2/3] fmt --- polars_distance/polars_distance/src/expressions.rs | 6 ++++-- polars_distance/polars_distance/src/list.rs | 12 ++++++++---- 2 files changed, 12 insertions(+), 6 deletions(-) diff --git a/polars_distance/polars_distance/src/expressions.rs b/polars_distance/polars_distance/src/expressions.rs index 1628f41..4495821 100644 --- a/polars_distance/polars_distance/src/expressions.rs +++ b/polars_distance/polars_distance/src/expressions.rs @@ -1,5 +1,7 @@ use crate::array::{cosine_dist, distance_calc_float_inp, euclidean_dist}; -use crate::list::{cosine_set_distance, jaccard_index, overlap_coef, sorensen_index, tversky_index}; +use crate::list::{ + cosine_set_distance, jaccard_index, overlap_coef, sorensen_index, tversky_index, +}; use crate::string::{ dam_levenshtein_dist, dam_levenshtein_normalized_dist, hamming_dist, hamming_normalized_dist, indel_dist, indel_normalized_dist, jaro_dist, jaro_normalized_dist, jaro_winkler_dist, @@ -348,4 +350,4 @@ fn tversky_index_list(inputs: &[Series], kwargs: TverskyIndexKwargs) -> PolarsRe let x: &ChunkedArray = inputs[0].list()?; let y: &ChunkedArray = inputs[1].list()?; tversky_index(x, y, kwargs.alpha, kwargs.beta).map(|ca| ca.into_series()) -} \ No newline at end of file +} diff --git a/polars_distance/polars_distance/src/list.rs b/polars_distance/polars_distance/src/list.rs index 648f09f..befdf9f 100644 --- a/polars_distance/polars_distance/src/list.rs +++ b/polars_distance/polars_distance/src/list.rs @@ -180,7 +180,12 @@ pub fn cosine_set_distance(a: &ListChunked, b: &ListChunked) -> PolarsResult PolarsResult { +pub fn tversky_index( + a: &ListChunked, + b: &ListChunked, + alpha: f64, + beta: f64, +) -> PolarsResult { polars_ensure!( a.inner_dtype() == b.inner_dtype(), ComputeError: "inner data types don't match" @@ -197,12 +202,12 @@ pub fn tversky_index(a: &ListChunked, b: &ListChunked, alpha: f64, beta: f64) -> let len_intersect = s1.intersection(&s2).count() as f64; let len_diff1 = s1.difference(&s2).count(); let len_diff2 = s2.difference(&s1).count(); - + Some(len_intersect / (len_intersect + (alpha * len_diff1 as f64) + (beta * len_diff2 as f64))) } _ => None, })) - + }) } else { match a.inner_dtype() { @@ -216,7 +221,6 @@ pub fn tversky_index(a: &ListChunked, b: &ListChunked, alpha: f64, beta: f64) -> let len_intersect = s1.intersection(&s2).count() as f64; let len_diff1 = s1.difference(&s2).count(); let len_diff2 = s2.difference(&s1).count(); - Some(len_intersect / (len_intersect + (alpha * len_diff1 as f64) + (beta * len_diff2 as f64))) } _ => None, From 239929f7092c13a901281315478ef47307d3f5c7 Mon Sep 17 00:00:00 2001 From: ion-elgreco <15728914+ion-elgreco@users.noreply.github.com> Date: Thu, 21 Dec 2023 17:35:40 +0100 Subject: [PATCH 3/3] add test --- polars_distance/tests/test_distance_arr.py | 23 ++++++++++++++++++++++ 1 file changed, 23 insertions(+) diff --git a/polars_distance/tests/test_distance_arr.py b/polars_distance/tests/test_distance_arr.py index 8cd66ae..329b764 100644 --- a/polars_distance/tests/test_distance_arr.py +++ b/polars_distance/tests/test_distance_arr.py @@ -192,3 +192,26 @@ def test_cosine_set_distance(data_sets): assert_frame_equal(result, expected) assert_frame_equal(result_int, expected) + + +def test_cosine_set_distance(data_sets): + result = data_sets.select( + pld.col("x_str") + .dist_list.tversky_index("y_str", alpha=1, beta=1) + .alias("tversky") + ) + + result_int = data_sets.select( + pld.col("x_int") + .dist_list.tversky_index("y_int", alpha=1, beta=1) + .alias("tversky") + ) + + expected = pl.DataFrame( + [ + pl.Series("tversky", [0.3333333333333333], dtype=pl.Float64), + ] + ) + + assert_frame_equal(result, expected) + assert_frame_equal(result_int, expected)