Skip to content

Commit

Permalink
Merge pull request #7 from ion-elgreco/feat/more_arr_dist
Browse files Browse the repository at this point in the history
feat: more array dist
  • Loading branch information
ion-elgreco authored Dec 21, 2023
2 parents 412ea4e + 7ad94a3 commit e0da914
Show file tree
Hide file tree
Showing 7 changed files with 309 additions and 19 deletions.
16 changes: 15 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,14 +7,28 @@ The plugin provides three namespaces:
- dist_str
- hamming
- levenshtein
- damerau_levenshtein
- indel
- jaro
- jaro_winkler
- lcs_seq
- osa
- postfix
- prefix
- dist_arr
- euclidean
- cosine
- chebyshev
- canberra
- bray_curtis
- manhatten
- minkowski
- l3_norm
- l4_norm
- dist_list (these act as set similary metrics)
- jaccard_index
- sorensen_index
- tversky_index
- overlap_coef
- cosine

Expand All @@ -37,7 +51,7 @@ df.select(
---
│ u32 │
╞══════╡
1
7
└──────┘


Expand Down
2 changes: 1 addition & 1 deletion polars_distance/polars_distance/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ name = "polars_distance"
crate-type = ["cdylib"]

[dependencies]
polars = { version = "*" , features = ["dtype-array", 'dtype-categorical']}
polars = { version = "*" , features = ["dtype-array", 'dtype-categorical', 'dtype-u16', 'dtype-u8', 'dtype-i8','dtype-i16']}
polars-core = {version = "*"}
polars-arrow = {version = "*"}
pyo3 = { version = "0.20", features = ["extension-module"] }
Expand Down
16 changes: 15 additions & 1 deletion polars_distance/polars_distance/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,14 +7,28 @@ The plugin provides three namespaces:
- dist_str
- hamming
- levenshtein
- damerau_levenshtein
- indel
- jaro
- jaro_winkler
- lcs_seq
- osa
- postfix
- prefix
- dist_arr
- euclidean
- cosine
- chebyshev
- canberra
- bray_curtis
- manhatten
- minkowski
- l3_norm
- l4_norm
- dist_list (these act as set similary metrics)
- jaccard_index
- sorensen_index
- tversky_index
- overlap_coef
- cosine

Expand All @@ -37,7 +51,7 @@ df.select(
---
│ u32 │
╞══════╡
1
7
└──────┘


Expand Down
51 changes: 50 additions & 1 deletion polars_distance/polars_distance/polars_distance/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,14 +49,63 @@ def canberra(self, other: IntoExpr) -> pl.Expr:
is_elementwise=True,
)

def bray_curtis(self, other: IntoExpr) -> pl.Expr:
"""Returns chebyshev distance between two vectors"""
return self._expr.register_plugin(
lib=lib,
args=[other],
symbol="bray_curtis_arr",
is_elementwise=True,
)

def manhatten(self, other: IntoExpr) -> pl.Expr:
"""Returns manhatten distance between two vectors"""
return self._expr.register_plugin(
lib=lib,
args=[other],
symbol="manhatten_arr",
is_elementwise=True,
)

def minkowski(self, other: IntoExpr, p: int) -> pl.Expr:
"""Returns minkowski distance between two vectors"""
return self._expr.register_plugin(
lib=lib,
args=[other],
kwargs={"p": p},
symbol="minkowski_arr",
is_elementwise=True,
)

def l3_norm(self, other: IntoExpr) -> pl.Expr:
"""Returns l3_norm distance between two vectors"""
return self._expr.register_plugin(
lib=lib,
args=[other],
symbol="l3_norm_arr",
is_elementwise=True,
)

def l4_norm(self, other: IntoExpr) -> pl.Expr:
"""Returns l4_norm distance between two vectors"""
return self._expr.register_plugin(
lib=lib,
args=[other],
symbol="l4_norm_arr",
is_elementwise=True,
)


@pl.api.register_expr_namespace("dist_str")
class DistancePairWiseString:
def __init__(self, expr: pl.Expr):
self._expr = expr

def hamming(self, other: IntoExpr, normalized: bool = False) -> pl.Expr:
"""Returns hamming distance between two expressions"""
"""Returns hamming distance between two expressions.
The length of the shortest string is padded to the length of longest string.
"""
if normalized:
return self._expr.register_plugin(
lib=lib,
Expand Down
119 changes: 108 additions & 11 deletions polars_distance/polars_distance/src/array.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
use distances::vectors::minkowski;
use polars::prelude::arity::try_binary_elementwise;
use polars::prelude::*;
use polars_arrow::array::{Array, PrimitiveArray};
Expand All @@ -7,11 +8,20 @@ fn collect_into_vecf64(arr: Box<dyn Array>) -> Vec<f64> {
.downcast_ref::<PrimitiveArray<f64>>()
.unwrap()
.values_iter()
.map(|v| *v)
.copied()
.collect::<Vec<_>>()
}

pub fn distance_calc_float_inp(
fn collect_into_uint64(arr: Box<dyn Array>) -> Vec<u64> {
arr.as_any()
.downcast_ref::<PrimitiveArray<_>>()
.unwrap()
.values_iter()
.copied()
.collect::<Vec<_>>()
}

pub fn distance_calc_numeric_inp(
a: &ChunkedArray<FixedSizeListType>,
b: &ChunkedArray<FixedSizeListType>,
f: fn(&[f64], &[f64]) -> f64,
Expand All @@ -21,11 +31,17 @@ pub fn distance_calc_float_inp(
ComputeError: "inner data types don't match"
);
polars_ensure!(
a.inner_dtype().is_float(),
ComputeError: "inner data types must be float"
a.inner_dtype().is_numeric(),
ComputeError: "inner data types must be numeric"
);

try_binary_elementwise(a, b, |a: Option<Box<dyn Array>>, b| match (a, b) {
let s1 = a.cast(&DataType::Array(Box::new(DataType::Float64), a.width()))?;
let s2 = b.cast(&DataType::Array(Box::new(DataType::Float64), a.width()))?;

let a: &ArrayChunked = s1.array()?;
let b: &ArrayChunked = s2.array()?;

try_binary_elementwise(a, b, |a, b| match (a, b) {
(Some(a), Some(b)) => {
if a.null_count() > 0 || b.null_count() > 0 {
polars_bail!(ComputeError: "array cannot contain nulls")
Expand All @@ -39,6 +55,40 @@ pub fn distance_calc_float_inp(
})
}

pub fn distance_calc_uint_inp(
a: &ChunkedArray<FixedSizeListType>,
b: &ChunkedArray<FixedSizeListType>,
f: fn(&[u64], &[u64]) -> f64,
) -> PolarsResult<Float64Chunked> {
polars_ensure!(
a.inner_dtype() == b.inner_dtype(),
ComputeError: "inner data types don't match"
);
polars_ensure!(
a.inner_dtype().is_unsigned_integer(),
ComputeError: "inner data types must be unsigned integer"
);

let s1 = a.cast(&DataType::Array(Box::new(DataType::UInt64), a.width()))?;
let s2 = b.cast(&DataType::Array(Box::new(DataType::UInt64), a.width()))?;

let a: &ArrayChunked = s1.array()?;
let b: &ArrayChunked = s2.array()?;

try_binary_elementwise(a, b, |a, b| match (a, b) {
(Some(a), Some(b)) => {
if a.null_count() > 0 || b.null_count() > 0 {
polars_bail!(ComputeError: "array cannot contain nulls")
} else {
let a = &collect_into_uint64(a);
let b = &collect_into_uint64(b);
Ok(Some(f(a, b)))
}
}
_ => Ok(None),
})
}

pub fn euclidean_dist(
a: &ChunkedArray<FixedSizeListType>,
b: &ChunkedArray<FixedSizeListType>,
Expand All @@ -48,11 +98,17 @@ pub fn euclidean_dist(
ComputeError: "inner data types don't match"
);
polars_ensure!(
a.inner_dtype().is_float(),
ComputeError: "inner data types must be float"
a.inner_dtype().is_numeric(),
ComputeError: "inner data types must be numeric"
);

try_binary_elementwise(a, b, |a: Option<Box<dyn Array>>, b| match (a, b) {
let s1 = a.cast(&DataType::Array(Box::new(DataType::Float64), a.width()))?;
let s2 = b.cast(&DataType::Array(Box::new(DataType::Float64), a.width()))?;

let a: &ArrayChunked = s1.array()?;
let b: &ArrayChunked = s2.array()?;

try_binary_elementwise(a, b, |a, b| match (a, b) {
(Some(a), Some(b)) => {
if a.null_count() > 0 || b.null_count() > 0 {
polars_bail!(ComputeError: "array cannot contain nulls")
Expand Down Expand Up @@ -85,11 +141,17 @@ pub fn cosine_dist(
ComputeError: "inner data types don't match"
);
polars_ensure!(
a.inner_dtype().is_float(),
ComputeError: "inner data types must be float"
a.inner_dtype().is_numeric(),
ComputeError: "inner data types must be numeric"
);

try_binary_elementwise(a, b, |a: Option<Box<dyn Array>>, b| match (a, b) {
let s1 = a.cast(&DataType::Array(Box::new(DataType::Float64), a.width()))?;
let s2 = b.cast(&DataType::Array(Box::new(DataType::Float64), a.width()))?;

let a: &ArrayChunked = s1.array()?;
let b: &ArrayChunked = s2.array()?;

try_binary_elementwise(a, b, |a, b| match (a, b) {
(Some(a), Some(b)) => {
if a.null_count() > 0 || b.null_count() > 0 {
polars_bail!(ComputeError: "array cannot contain nulls")
Expand Down Expand Up @@ -120,3 +182,38 @@ pub fn cosine_dist(
_ => Ok(None),
})
}

pub fn minkowski_dist(
a: &ChunkedArray<FixedSizeListType>,
b: &ChunkedArray<FixedSizeListType>,
p: i32,
) -> PolarsResult<Float64Chunked> {
polars_ensure!(
a.inner_dtype() == b.inner_dtype(),
ComputeError: "inner data types don't match"
);
polars_ensure!(
a.inner_dtype().is_numeric(),
ComputeError: "inner data types must be numeric"
);

let s1 = a.cast(&DataType::Array(Box::new(DataType::Float64), a.width()))?;
let s2 = b.cast(&DataType::Array(Box::new(DataType::Float64), a.width()))?;

let a: &ArrayChunked = s1.array()?;
let b: &ArrayChunked = s2.array()?;

try_binary_elementwise(a, b, |a, b| match (a, b) {
(Some(a), Some(b)) => {
if a.null_count() > 0 || b.null_count() > 0 {
polars_bail!(ComputeError: "array cannot contain nulls")
} else {
let a = &collect_into_vecf64(a);
let b = &collect_into_vecf64(b);
let metric = minkowski(p);
Ok(Some(metric(a, b)))
}
}
_ => Ok(None),
})
}
Loading

0 comments on commit e0da914

Please sign in to comment.