Skip to content

Commit

Permalink
Merge pull request #26 from ion-elgreco/fix--literal-support
Browse files Browse the repository at this point in the history
fix: support broadcasting
  • Loading branch information
ion-elgreco authored Nov 17, 2024
2 parents 57eacef + 65f98d5 commit 40cd9bd
Show file tree
Hide file tree
Showing 6 changed files with 601 additions and 187 deletions.
2 changes: 1 addition & 1 deletion polars_distance/Cargo.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[package]
name = "polars_distance"
version = "0.5.0"
version = "0.5.1"
edition = "2021"

[lib]
Expand Down
334 changes: 250 additions & 84 deletions polars_distance/src/array.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
use distances::vectors::minkowski;
use polars::prelude::arity::try_binary_elementwise;
use polars::prelude::arity::{try_binary_elementwise, try_unary_elementwise};
use polars::prelude::*;
use polars_arrow::array::{Array, PrimitiveArray};
use polars_arrow::array::{new_null_array, Array, PrimitiveArray};

fn collect_into_vecf64(arr: Box<dyn Array>) -> Vec<f64> {
arr.as_any()
Expand Down Expand Up @@ -41,18 +41,46 @@ pub fn distance_calc_numeric_inp(
let a: &ArrayChunked = s1.array()?;
let b: &ArrayChunked = s2.array()?;

try_binary_elementwise(a, b, |a, b| match (a, b) {
(Some(a), Some(b)) => {
if a.null_count() > 0 || b.null_count() > 0 {
polars_bail!(ComputeError: "array cannot contain nulls")
} else {
let a = &collect_into_vecf64(a);
let b = &collect_into_vecf64(b);
Ok(Some(f(a, b)))
// If one side is a literal it will be shorter but is moved to RHS so we can use unsafe access
let (a, b) = if a.len() < b.len() { (b, a) } else { (a, b) };
match b.len() {
1 => match unsafe { b.get_unchecked(0) } {
Some(b_value) => {
if b_value.null_count() > 0 {
polars_bail!(ComputeError: "array cannot contain nulls")
}
try_unary_elementwise(a, |a| match a {
Some(a) => {
if a.null_count() > 0 {
polars_bail!(ComputeError: "array cannot contain nulls")
}
let a = &collect_into_vecf64(a);
let b = &collect_into_vecf64(b_value.clone());
Ok(Some(f(a, b)))
}
_ => Ok(None),
})
}
}
_ => Ok(None),
})
None => unsafe {
Ok(ChunkedArray::from_chunks(
a.name().clone(),
vec![new_null_array(ArrowDataType::Float64, a.len())],
))
},
},
_ => try_binary_elementwise(a, b, |a, b| match (a, b) {
(Some(a), Some(b)) => {
if a.null_count() > 0 || b.null_count() > 0 {
polars_bail!(ComputeError: "array cannot contain nulls")
} else {
let a = &collect_into_vecf64(a);
let b = &collect_into_vecf64(b);
Ok(Some(f(a, b)))
}
}
_ => Ok(None),
}),
}
}

pub fn distance_calc_uint_inp(
Expand All @@ -75,18 +103,45 @@ pub fn distance_calc_uint_inp(
let a: &ArrayChunked = s1.array()?;
let b: &ArrayChunked = s2.array()?;

try_binary_elementwise(a, b, |a, b| match (a, b) {
(Some(a), Some(b)) => {
if a.null_count() > 0 || b.null_count() > 0 {
polars_bail!(ComputeError: "array cannot contain nulls")
} else {
let a = &collect_into_uint64(a);
let b = &collect_into_uint64(b);
Ok(Some(f(a, b)))
let (a, b) = if a.len() < b.len() { (b, a) } else { (a, b) };
match b.len() {
1 => match unsafe { b.get_unchecked(0) } {
Some(b_value) => {
if b_value.null_count() > 0 {
polars_bail!(ComputeError: "array cannot contain nulls")
}
try_unary_elementwise(a, |a| match a {
Some(a) => {
if a.null_count() > 0 {
polars_bail!(ComputeError: "array cannot contain nulls")
}
let a = &collect_into_uint64(a);
let b = &collect_into_uint64(b_value.clone());
Ok(Some(f(a, b)))
}
_ => Ok(None),
})
}
}
_ => Ok(None),
})
None => unsafe {
Ok(ChunkedArray::from_chunks(
a.name().clone(),
vec![new_null_array(ArrowDataType::Float64, a.len())],
))
},
},
_ => try_binary_elementwise(a, b, |a, b| match (a, b) {
(Some(a), Some(b)) => {
if a.null_count() > 0 || b.null_count() > 0 {
polars_bail!(ComputeError: "array cannot contain nulls")
} else {
let a = &collect_into_uint64(a);
let b = &collect_into_uint64(b);
Ok(Some(f(a, b)))
}
}
_ => Ok(None),
}),
}
}

pub fn euclidean_dist(
Expand All @@ -108,28 +163,65 @@ pub fn euclidean_dist(
let a: &ArrayChunked = s1.array()?;
let b: &ArrayChunked = s2.array()?;

try_binary_elementwise(a, b, |a, b| match (a, b) {
(Some(a), Some(b)) => {
if a.null_count() > 0 || b.null_count() > 0 {
polars_bail!(ComputeError: "array cannot contain nulls")
} else {
let a = a
.as_any()
.downcast_ref::<PrimitiveArray<f64>>()
.unwrap()
.values_iter();
let b = b
.as_any()
.downcast_ref::<PrimitiveArray<f64>>()
.unwrap()
.values_iter();
Ok(Some(
a.zip(b).map(|(x, y)| (x - y).powi(2)).sum::<f64>().sqrt(),
let (a, b) = if a.len() < b.len() { (b, a) } else { (a, b) };
match b.len() {
1 => match unsafe { b.get_unchecked(0) } {
Some(b_value) => {
if b_value.null_count() > 0 {
polars_bail!(ComputeError: "array cannot contain nulls")
}
try_unary_elementwise(a, |a| match a {
Some(a) => {
if a.null_count() > 0 {
polars_bail!(ComputeError: "array cannot contain nulls")
}
let a = a
.as_any()
.downcast_ref::<PrimitiveArray<f64>>()
.unwrap()
.values_iter();
let b = b_value
.as_any()
.downcast_ref::<PrimitiveArray<f64>>()
.unwrap()
.values_iter();
Ok(Some(
a.zip(b).map(|(x, y)| (x - y).powi(2)).sum::<f64>().sqrt(),
))
}
_ => Ok(None),
})
}
None => unsafe {
Ok(ChunkedArray::from_chunks(
a.name().clone(),
vec![new_null_array(ArrowDataType::Float64, a.len())],
))
},
},
_ => try_binary_elementwise(a, b, |a, b| match (a, b) {
(Some(a), Some(b)) => {
if a.null_count() > 0 || b.null_count() > 0 {
polars_bail!(ComputeError: "array cannot contain nulls")
} else {
let a = a
.as_any()
.downcast_ref::<PrimitiveArray<f64>>()
.unwrap()
.values_iter();
let b = b
.as_any()
.downcast_ref::<PrimitiveArray<f64>>()
.unwrap()
.values_iter();
Ok(Some(
a.zip(b).map(|(x, y)| (x - y).powi(2)).sum::<f64>().sqrt(),
))
}
}
}
_ => Ok(None),
})
_ => Ok(None),
}),
}
}

pub fn cosine_dist(
Expand All @@ -151,36 +243,81 @@ pub fn cosine_dist(
let a: &ArrayChunked = s1.array()?;
let b: &ArrayChunked = s2.array()?;

try_binary_elementwise(a, b, |a, b| match (a, b) {
(Some(a), Some(b)) => {
if a.null_count() > 0 || b.null_count() > 0 {
polars_bail!(ComputeError: "array cannot contain nulls")
} else {
let a = a
.as_any()
.downcast_ref::<PrimitiveArray<f64>>()
.unwrap()
.values_iter();
let b = b
.as_any()
.downcast_ref::<PrimitiveArray<f64>>()
.unwrap()
.values_iter();

let dot_prod: f64 = a.clone().zip(b.clone()).map(|(x, y)| x * y).sum();
let mag1: f64 = a.map(|x| x.powi(2)).sum::<f64>().sqrt();
let mag2: f64 = b.map(|y| y.powi(2)).sum::<f64>().sqrt();

let res = if mag1 == 0.0 || mag2 == 0.0 {
0.0
let (a, b) = if a.len() < b.len() { (b, a) } else { (a, b) };
match b.len() {
1 => match unsafe { b.get_unchecked(0) } {
Some(b_value) => {
if b_value.null_count() > 0 {
polars_bail!(ComputeError: "array cannot contain nulls")
}
try_unary_elementwise(a, |a| match a {
Some(a) => {
if a.null_count() > 0 {
polars_bail!(ComputeError: "array cannot contain nulls")
}
let a = a
.as_any()
.downcast_ref::<PrimitiveArray<f64>>()
.unwrap()
.values_iter();
let b = b_value
.as_any()
.downcast_ref::<PrimitiveArray<f64>>()
.unwrap()
.values_iter();

let dot_prod: f64 = a.clone().zip(b.clone()).map(|(x, y)| x * y).sum();
let mag1: f64 = a.map(|x| x.powi(2)).sum::<f64>().sqrt();
let mag2: f64 = b.map(|y| y.powi(2)).sum::<f64>().sqrt();

let res = if mag1 == 0.0 || mag2 == 0.0 {
0.0
} else {
1.0 - (dot_prod / (mag1 * mag2))
};
Ok(Some(res))
}
_ => Ok(None),
})
}
None => unsafe {
Ok(ChunkedArray::from_chunks(
a.name().clone(),
vec![new_null_array(ArrowDataType::Float64, a.len())],
))
},
},
_ => try_binary_elementwise(a, b, |a, b| match (a, b) {
(Some(a), Some(b)) => {
if a.null_count() > 0 || b.null_count() > 0 {
polars_bail!(ComputeError: "array cannot contain nulls")
} else {
1.0 - (dot_prod / (mag1 * mag2))
};
Ok(Some(res))
let a = a
.as_any()
.downcast_ref::<PrimitiveArray<f64>>()
.unwrap()
.values_iter();
let b = b
.as_any()
.downcast_ref::<PrimitiveArray<f64>>()
.unwrap()
.values_iter();

let dot_prod: f64 = a.clone().zip(b.clone()).map(|(x, y)| x * y).sum();
let mag1: f64 = a.map(|x| x.powi(2)).sum::<f64>().sqrt();
let mag2: f64 = b.map(|y| y.powi(2)).sum::<f64>().sqrt();

let res = if mag1 == 0.0 || mag2 == 0.0 {
0.0
} else {
1.0 - (dot_prod / (mag1 * mag2))
};
Ok(Some(res))
}
}
}
_ => Ok(None),
})
_ => Ok(None),
}),
}
}

pub fn minkowski_dist(
Expand All @@ -203,17 +340,46 @@ pub fn minkowski_dist(
let a: &ArrayChunked = s1.array()?;
let b: &ArrayChunked = s2.array()?;

try_binary_elementwise(a, b, |a, b| match (a, b) {
(Some(a), Some(b)) => {
if a.null_count() > 0 || b.null_count() > 0 {
polars_bail!(ComputeError: "array cannot contain nulls")
} else {
let a = &collect_into_vecf64(a);
let b = &collect_into_vecf64(b);
let metric = minkowski(p);
Ok(Some(metric(a, b)))
// If one side is a literal it will be shorter but is moved to RHS so we can use unsafe access
let (a, b) = if a.len() < b.len() { (b, a) } else { (a, b) };
match b.len() {
1 => match unsafe { b.get_unchecked(0) } {
Some(b_value) => {
if b_value.null_count() > 0 {
polars_bail!(ComputeError: "array cannot contain nulls")
}
try_unary_elementwise(a, |a| match a {
Some(a) => {
if a.null_count() > 0 {
polars_bail!(ComputeError: "array cannot contain nulls")
}
let a = &collect_into_vecf64(a);
let b = &collect_into_vecf64(b_value.clone());
let metric = minkowski(p);
Ok(Some(metric(a, b)))
}
_ => Ok(None),
})
}
None => unsafe {
Ok(ChunkedArray::from_chunks(
a.name().clone(),
vec![new_null_array(ArrowDataType::Float64, a.len())],
))
},
},
_ => try_binary_elementwise(a, b, |a, b| match (a, b) {
(Some(a), Some(b)) => {
if a.null_count() > 0 || b.null_count() > 0 {
polars_bail!(ComputeError: "array cannot contain nulls")
} else {
let a = &collect_into_vecf64(a);
let b = &collect_into_vecf64(b);
let metric = minkowski(p);
Ok(Some(metric(a, b)))
}
}
}
_ => Ok(None),
})
_ => Ok(None),
}),
}
}
Loading

0 comments on commit 40cd9bd

Please sign in to comment.