Skip to content

Commit

Permalink
Merge pull request #4 from ion-elgreco/feat/set_metrics
Browse files Browse the repository at this point in the history
feat: add more set metrics
  • Loading branch information
ion-elgreco authored Dec 19, 2023
2 parents f2e793f + bf7d011 commit 43b722e
Show file tree
Hide file tree
Showing 6 changed files with 266 additions and 45 deletions.
9 changes: 7 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
Hellooo :)

This plugin is a work-in progress, main goal is to provide distance metrics on list, arrays and string.
This plugin is a work-in progress, main goal is to provide distance metrics on list, arrays and string datatypes.

The plugin provides two namespaces:
The plugin provides three namespaces:

- dist_str
- hamming
Expand All @@ -12,6 +12,11 @@ The plugin provides two namespaces:
- cosine
- chebyshev
- canberra
- dist_list (these act as set similary metrics)
- jaccard_index
- sorensen_index
- overlap_coef
- cosine

## Examples

Expand Down
9 changes: 7 additions & 2 deletions polars_distance/polars_distance/README.md
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
Hellooo :)

This plugin is a work-in progress, main goal is to provide distance metrics on list, arrays and string.
This plugin is a work-in progress, main goal is to provide distance metrics on list, arrays and string datatypes.

The plugin provides two namespaces:
The plugin provides three namespaces:

- dist_str
- hamming
Expand All @@ -12,6 +12,11 @@ The plugin provides two namespaces:
- cosine
- chebyshev
- canberra
- dist_list (these act as set similary metrics)
- jaccard_index
- sorensen_index
- overlap_coef
- cosine

## Examples

Expand Down
27 changes: 27 additions & 0 deletions polars_distance/polars_distance/polars_distance/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,33 @@ def jaccard_index(self, other: IntoExpr) -> pl.Expr:
is_elementwise=True,
)

def sorensen_index(self, other: IntoExpr) -> pl.Expr:
"""Returns sorensen index between two lists. Each list is converted to a set."""
return self._expr.register_plugin(
lib=lib,
args=[other],
symbol="sorensen_index_list",
is_elementwise=True,
)

def overlap_coef(self, other: IntoExpr) -> pl.Expr:
"""Returns overlap coef between two lists. Each list is converted to a set."""
return self._expr.register_plugin(
lib=lib,
args=[other],
symbol="overlap_coef_list",
is_elementwise=True,
)

def cosine(self, other: IntoExpr) -> pl.Expr:
"""Returns cosine distance between two lists. Each list is converted to a set."""
return self._expr.register_plugin(
lib=lib,
args=[other],
symbol="cosine_list",
is_elementwise=True,
)


class DExpr(pl.Expr):
@property
Expand Down
23 changes: 22 additions & 1 deletion polars_distance/polars_distance/src/expressions.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
use crate::array::{cosine_dist, distance_calc_float_inp, euclidean_dist};
use crate::list::jaccard_index;
use crate::list::{cosine_set_distance, jaccard_index, overlap_coef, sorensen_index};
use crate::string::{hamming_distance_string, levenshtein_distance_string};
use distances::vectors::{canberra, chebyshev};
use polars::prelude::*;
Expand Down Expand Up @@ -91,3 +91,24 @@ fn jaccard_index_list(inputs: &[Series]) -> PolarsResult<Series> {
let y: &ChunkedArray<ListType> = inputs[1].list()?;
jaccard_index(x, y).map(|ca| ca.into_series())
}

#[polars_expr(output_type=Float64)]
fn sorensen_index_list(inputs: &[Series]) -> PolarsResult<Series> {
let x: &ChunkedArray<ListType> = inputs[0].list()?;
let y: &ChunkedArray<ListType> = inputs[1].list()?;
sorensen_index(x, y).map(|ca| ca.into_series())
}

#[polars_expr(output_type=Float64)]
fn overlap_coef_list(inputs: &[Series]) -> PolarsResult<Series> {
let x: &ChunkedArray<ListType> = inputs[0].list()?;
let y: &ChunkedArray<ListType> = inputs[1].list()?;
overlap_coef(x, y).map(|ca| ca.into_series())
}

#[polars_expr(output_type=Float64)]
fn cosine_list(inputs: &[Series]) -> PolarsResult<Series> {
let x: &ChunkedArray<ListType> = inputs[0].list()?;
let y: &ChunkedArray<ListType> = inputs[1].list()?;
cosine_set_distance(x, y).map(|ca| ca.into_series())
}
186 changes: 146 additions & 40 deletions polars_distance/polars_distance/src/list.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
use core::hash::Hash;
use distances::Number;
use polars::prelude::arity::binary_elementwise;
use polars::prelude::*;
use polars_arrow::array::{PrimitiveArray, Utf8Array};
Expand All @@ -8,69 +9,174 @@ use polars_core::with_match_physical_integer_type;
fn jacc_int_array<T: NativeType + Hash + Eq>(a: &PrimitiveArray<T>, b: &PrimitiveArray<T>) -> f64 {
let s1 = a.into_iter().collect::<PlHashSet<_>>();
let s2 = b.into_iter().collect::<PlHashSet<_>>();

let len_intersect = s1.intersection(&s2).count();

len_intersect as f64 / (s1.len() + s2.len() - len_intersect) as f64
}

fn jacc_str_array(a: &Utf8Array<i64>, b: &Utf8Array<i64>) -> f64 {
let s1 = a.into_iter().collect::<PlHashSet<_>>();
let s2 = b.into_iter().collect::<PlHashSet<_>>();
let len_intersect = s1.intersection(&s2).count();

len_intersect as f64 / (s1.len() + s2.len() - len_intersect) as f64
}

fn sorensen_int_array<T: NativeType + Hash + Eq>(
a: &PrimitiveArray<T>,
b: &PrimitiveArray<T>,
) -> f64 {
let s1 = a.into_iter().collect::<PlHashSet<_>>();
let s2 = b.into_iter().collect::<PlHashSet<_>>();
let len_intersect = s1.intersection(&s2).count();

(2 * len_intersect) as f64 / (s1.len() + s2.len()) as f64
}

fn sorensen_str_array(a: &Utf8Array<i64>, b: &Utf8Array<i64>) -> f64 {
let s1 = a.into_iter().collect::<PlHashSet<_>>();
let s2 = b.into_iter().collect::<PlHashSet<_>>();
let len_intersect = s1.intersection(&s2).count();

(2 * len_intersect) as f64 / (s1.len() + s2.len()) as f64
}

fn overlap_int_array<T: NativeType + Hash + Eq>(
a: &PrimitiveArray<T>,
b: &PrimitiveArray<T>,
) -> f64 {
let s1 = a.into_iter().collect::<PlHashSet<_>>();
let s2 = b.into_iter().collect::<PlHashSet<_>>();
let len_intersect = s1.intersection(&s2).count();

len_intersect as f64 / std::cmp::min(s1.len(), s2.len()) as f64
}

fn overlap_str_array(a: &Utf8Array<i64>, b: &Utf8Array<i64>) -> f64 {
let s1 = a.into_iter().collect::<PlHashSet<_>>();
let s2 = b.into_iter().collect::<PlHashSet<_>>();
let len_intersect = s1.intersection(&s2).count();

len_intersect as f64 / std::cmp::min(s1.len(), s2.len()) as f64
}

fn cosine_int_array<T: NativeType + Hash + Eq>(
a: &PrimitiveArray<T>,
b: &PrimitiveArray<T>,
) -> f64 {
let s1 = a.into_iter().collect::<PlHashSet<_>>();
let s2 = b.into_iter().collect::<PlHashSet<_>>();
let len_intersect = s1.intersection(&s2).count();

len_intersect as f64 / (s1.len() as f64).sqrt() * (s2.len() as f64).sqrt()
}

fn cosine_str_array(a: &Utf8Array<i64>, b: &Utf8Array<i64>) -> f64 {
let s1 = a.into_iter().collect::<PlHashSet<_>>();
let s2 = b.into_iter().collect::<PlHashSet<_>>();
let len_intersect = s1.intersection(&s2).count();

len_intersect as f64 / (s1.len() as f64).sqrt() * (s2.len() as f64).sqrt()
}

pub fn elementwise_int_inp<T: NativeType + Hash + Eq>(
a: &ListChunked,
b: &ListChunked,
f: fn(&PrimitiveArray<T>, &PrimitiveArray<T>) -> f64,
) -> PolarsResult<Float64Chunked> {
Ok(binary_elementwise(a, b, |a, b| match (a, b) {
(Some(a), Some(b)) => {
let a = a.as_any().downcast_ref::<PrimitiveArray<T>>().unwrap();
let b = b.as_any().downcast_ref::<PrimitiveArray<T>>().unwrap();
Some(f(a, b))
}
_ => None,
}))
}

pub fn elementwise_string_inp(
a: &ListChunked,
b: &ListChunked,
f: fn(&Utf8Array<i64>, &Utf8Array<i64>) -> f64,
) -> PolarsResult<Float64Chunked> {
Ok(binary_elementwise(a, b, |a, b| match (a, b) {
(Some(a), Some(b)) => {
let a = a.as_any().downcast_ref::<Utf8Array<i64>>().unwrap();
let b = b.as_any().downcast_ref::<Utf8Array<i64>>().unwrap();
Some(f(a, b))
}
_ => None,
}))
}

pub fn jaccard_index(a: &ListChunked, b: &ListChunked) -> PolarsResult<Float64Chunked> {
polars_ensure!(
a.inner_dtype() == b.inner_dtype(),
ComputeError: "inner data types don't match"
);

if a.inner_dtype().is_integer() {
Ok(with_match_physical_integer_type!(a.inner_dtype(), |$T| {
binary_elementwise(a, b, |a, b| {
match (a, b) {
(Some(a), Some(b)) => {
let a = a.as_any().downcast_ref::<PrimitiveArray<$T>>().unwrap();
let b = b.as_any().downcast_ref::<PrimitiveArray<$T>>().unwrap();
Some(jacc_int_array(a, b))
},
_ => None
}
})
}
))
with_match_physical_integer_type!(a.inner_dtype(), |$T| {elementwise_int_inp(a, b, jacc_int_array::<$T>)})
} else {
match a.inner_dtype() {
DataType::Utf8 => {
Ok(binary_elementwise(a, b, |a, b| {
match (a, b) {
(Some(a), Some(b)) => {
let a = a.as_any().downcast_ref::<Utf8Array<i64>>().unwrap();
let b = b.as_any().downcast_ref::<Utf8Array<i64>>().unwrap();
Some(jacc_str_array(a, b))
},
_ => None
}
}))
},
// DataType::Categorical(_) => {
// let a = a.cast(&DataType::List(Box::new(DataType::Utf8)))?;
// let b = b.cast(&DataType::List(Box::new(DataType::Utf8)))?;
// Ok(binary_elementwise(a.list()?, b.list()?, |a, b| {
// match (a, b) {
// (Some(a), Some(b)) => {
// let a = a.as_any().downcast_ref::<Utf8Array<i64>>().unwrap();
// let b = b.as_any().downcast_ref::<Utf8Array<i64>>().unwrap();
// Some(5.0)
// },
// _ => None
// }
// }))
// },
DataType::Utf8 => elementwise_string_inp(a,b, jacc_str_array),
_ => Err(PolarsError::ComputeError(
format!("jaccard index only works on inner dtype Utf8 or integer. Use of {} is not supported", a.inner_dtype()).into(),
))
}
}
}

pub fn sorensen_index(a: &ListChunked, b: &ListChunked) -> PolarsResult<Float64Chunked> {
polars_ensure!(
a.inner_dtype() == b.inner_dtype(),
ComputeError: "inner data types don't match"
);

if a.inner_dtype().is_integer() {
with_match_physical_integer_type!(a.inner_dtype(), |$T| {elementwise_int_inp(a, b, sorensen_int_array::<$T>)})
} else {
match a.inner_dtype() {
DataType::Utf8 => elementwise_string_inp(a,b, sorensen_str_array),
_ => Err(PolarsError::ComputeError(
format!("sorensen index only works on inner dtype Utf8 or integer. Use of {} is not supported", a.inner_dtype()).into(),
))
}
}
}

pub fn overlap_coef(a: &ListChunked, b: &ListChunked) -> PolarsResult<Float64Chunked> {
polars_ensure!(
a.inner_dtype() == b.inner_dtype(),
ComputeError: "inner data types don't match"
);

if a.inner_dtype().is_integer() {
with_match_physical_integer_type!(a.inner_dtype(), |$T| {elementwise_int_inp(a, b, overlap_int_array::<$T>)})
} else {
match a.inner_dtype() {
DataType::Utf8 => elementwise_string_inp(a,b, overlap_str_array),
_ => Err(PolarsError::ComputeError(
format!("overlap coefficient only works on inner dtype Utf8 or integer. Use of {} is not supported", a.inner_dtype()).into(),
))
}
}
}

pub fn cosine_set_distance(a: &ListChunked, b: &ListChunked) -> PolarsResult<Float64Chunked> {
polars_ensure!(
a.inner_dtype() == b.inner_dtype(),
ComputeError: "inner data types don't match"
);

if a.inner_dtype().is_integer() {
with_match_physical_integer_type!(a.inner_dtype(), |$T| {elementwise_int_inp(a, b, cosine_int_array::<$T>)})
} else {
match a.inner_dtype() {
DataType::Utf8 => elementwise_string_inp(a,b, cosine_str_array),
_ => Err(PolarsError::ComputeError(
format!("cosine set distance only works on inner dtype Utf8 or integer. Use of {} is not supported", a.inner_dtype()).into(),
))
}
}
}
57 changes: 57 additions & 0 deletions polars_distance/tests/test_distance_arr.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,3 +135,60 @@ def test_jaccard_index(data_sets):

assert_frame_equal(result, expected)
assert_frame_equal(result_int, expected)


def test_sorensen_index(data_sets):
result = data_sets.select(
pld.col("x_str").dist_list.sorensen_index("y_str").alias("sorensen_index")
)

result_int = data_sets.select(
pld.col("x_int").dist_list.sorensen_index("y_int").alias("sorensen_index")
)

expected = pl.DataFrame(
[
pl.Series("sorensen_index", [0.5], dtype=pl.Float64),
]
)

assert_frame_equal(result, expected)
assert_frame_equal(result_int, expected)


def test_overlap_coef(data_sets):
result = data_sets.select(
pld.col("x_str").dist_list.overlap_coef("y_str").alias("overlap")
)

result_int = data_sets.select(
pld.col("x_int").dist_list.overlap_coef("y_int").alias("overlap")
)

expected = pl.DataFrame(
[
pl.Series("overlap", [1.0], dtype=pl.Float64),
]
)

assert_frame_equal(result, expected)
assert_frame_equal(result_int, expected)


def test_cosine_set_distance(data_sets):
result = data_sets.select(
pld.col("x_str").dist_list.cosine("y_str").alias("cosine_set")
)

result_int = data_sets.select(
pld.col("x_int").dist_list.cosine("y_int").alias("cosine_set")
)

expected = pl.DataFrame(
[
pl.Series("cosine_set", [1.7320508075688772], dtype=pl.Float64),
]
)

assert_frame_equal(result, expected)
assert_frame_equal(result_int, expected)

0 comments on commit 43b722e

Please sign in to comment.