Skip to content

Commit

Permalink
add haversine, bump version
Browse files Browse the repository at this point in the history
  • Loading branch information
ion-elgreco committed Jan 1, 2024
1 parent 0f26938 commit 8b0eb10
Show file tree
Hide file tree
Showing 7 changed files with 179 additions and 5 deletions.
4 changes: 2 additions & 2 deletions polars_distance/polars_distance/Cargo.toml
Original file line number Diff line number Diff line change
@@ -1,14 +1,14 @@
[package]
name = "polars_distance"
version = "0.3.0"
version = "0.3.1"
edition = "2021"

[lib]
name = "polars_distance"
crate-type = ["cdylib"]

[dependencies]
polars = { version = "*" , features = ["dtype-array", 'dtype-categorical', 'dtype-u16', 'dtype-u8', 'dtype-i8','dtype-i16']}
polars = { version = "*" , features = ["dtype-struct", "dtype-array", 'dtype-categorical', 'dtype-u16', 'dtype-u8', 'dtype-i8','dtype-i16']}
polars-core = {version = "*"}
polars-arrow = {version = "*"}
pyo3 = { version = "0.20", features = ["extension-module"] }
Expand Down
51 changes: 49 additions & 2 deletions polars_distance/polars_distance/polars_distance/__init__.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,50 @@
import polars as pl
from polars.utils.udfs import _get_shared_lib_location
from typing import Protocol, Iterable, cast
from typing import Protocol, Iterable, cast, Literal
from polars.type_aliases import PolarsDataType, IntoExpr

lib = _get_shared_lib_location(__file__)

__version__ = "0.3.0"
__version__ = "0.3.1"


@pl.api.register_expr_namespace("dist")
class DistancePairWise:
def __init__(self, expr: pl.Expr):
self._expr = expr

def haversine(
self, other: IntoExpr, unit: Literal["km", "miles"] = "km"
) -> pl.Expr:
"""Returns haversine distance between two structs with the keys latitude, longitude.
Example:
```python
df = pl.DataFrame(
{
"x": [{"latitude": 38.898556, "longitude": -77.037852}],
"y": [{"latitude": 38.897147, "longitude": -77.043934}],
}
)
df.select(pld.col('x').dist.haversine('y', 'km').alias('haversine'))
shape: (1, 1)
┌───────────┐
│ haversine │
│ --- │
│ f64 │
╞═══════════╡
│ 0.549156 │
└───────────┘
```
"""
return self._expr.register_plugin(
lib=lib,
args=[other],
kwargs={"unit": unit},
symbol="haversine_struct",
is_elementwise=True,
)


@pl.api.register_expr_namespace("dist_arr")
Expand Down Expand Up @@ -332,6 +371,10 @@ def cosine(self, other: IntoExpr) -> pl.Expr:


class DExpr(pl.Expr):
@property
def dist(self) -> DistancePairWise:
return DistancePairWise(self)

@property
def dist_arr(self) -> DistancePairWiseArray:
return DistancePairWiseArray(self)
Expand All @@ -356,6 +399,10 @@ def __call__(
def __getattr__(self, name: str) -> pl.Expr:
...

@property
def dist(self) -> DistancePairWise:
...

@property
def dist_arr(self) -> DistancePairWiseArray:
...
Expand Down
2 changes: 1 addition & 1 deletion polars_distance/polars_distance/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ classifiers = [
"Programming Language :: Python :: Implementation :: CPython",
"Programming Language :: Python :: Implementation :: PyPy",
]
version = "0.3.0"
version = "0.3.1"
authors = [
{ name="Ion Koutsours", email="[email protected]"},
]
Expand Down
51 changes: 51 additions & 0 deletions polars_distance/polars_distance/src/expressions.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ use crate::array::{
use crate::list::{
cosine_set_distance, jaccard_index, overlap_coef, sorensen_index, tversky_index,
};
use crate::other_dist::haversine_dist;
use crate::string::{
dam_levenshtein_dist, dam_levenshtein_normalized_dist, hamming_dist, hamming_normalized_dist,
indel_dist, indel_normalized_dist, jaro_dist, jaro_normalized_dist, jaro_winkler_dist,
Expand All @@ -27,6 +28,11 @@ struct MinkowskiKwargs {
p: i32,
}

#[derive(Deserialize)]
struct HaversineKwargs {
unit: String,
}

// STR EXPRESSIONS
#[polars_expr(output_type=UInt32)]
fn hamming_str(inputs: &[Series]) -> PolarsResult<Series> {
Expand Down Expand Up @@ -436,3 +442,48 @@ fn tversky_index_list(inputs: &[Series], kwargs: TverskyIndexKwargs) -> PolarsRe
let y: &ChunkedArray<ListType> = inputs[1].list()?;
tversky_index(x, y, kwargs.alpha, kwargs.beta).map(|ca| ca.into_series())
}

#[polars_expr(output_type=Float64)]
fn haversine_struct(inputs: &[Series], kwargs: HaversineKwargs) -> PolarsResult<Series> {
let ca_x: &StructChunked = inputs[0].struct_()?;
let ca_y: &StructChunked = inputs[1].struct_()?;

let x_lat = ca_x.field_by_name("latitude")?;
let x_long = ca_x.field_by_name("longitude")?;

let y_lat = ca_y.field_by_name("latitude")?;
let y_long = ca_y.field_by_name("longitude")?;

polars_ensure!(
x_lat.dtype() == x_long.dtype() && x_lat.dtype().is_float(),
ComputeError: "x data types should match"
);

polars_ensure!(
y_lat.dtype() == y_long.dtype() && y_lat.dtype().is_float(),
ComputeError: "y data types should match"
);

polars_ensure!(
x_lat.dtype() == y_lat.dtype(),
ComputeError: "x and y data types should match"
);

Ok(match *x_lat.dtype() {
DataType::Float32 => {
let x_lat = x_lat.f32().unwrap();
let x_long = x_long.f32().unwrap();
let y_lat = y_lat.f32().unwrap();
let y_long = y_long.f32().unwrap();
haversine_dist(x_lat, x_long, y_lat, y_long, kwargs.unit)?.into_series()
}
DataType::Float64 => {
let x_lat = x_lat.f64().unwrap();
let x_long = x_long.f64().unwrap();
let y_lat = y_lat.f64().unwrap();
let y_long = y_long.f64().unwrap();
haversine_dist(x_lat, x_long, y_lat, y_long, kwargs.unit)?.into_series()
}
_ => unimplemented!(),
})
}
1 change: 1 addition & 0 deletions polars_distance/polars_distance/src/lib.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
mod array;
mod expressions;
mod list;
mod other_dist;
mod string;

#[cfg(target_os = "linux")]
Expand Down
54 changes: 54 additions & 0 deletions polars_distance/polars_distance/src/other_dist.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
use polars::export::num::Float;
use polars::prelude::*;

fn haversine_elementwise<T: Float>(x_lat: T, x_long: T, y_lat: T, y_long: T, radius: f64) -> T {
let radius = T::from(radius).unwrap();
let two = T::from(2.0).unwrap();
let one = T::one();

let d_lat = (y_lat - x_lat).to_radians();
let d_lon = (y_long - x_long).to_radians();
let lat1 = (x_lat).to_radians();
let lat2 = (y_lat).to_radians();

let a = ((d_lat / two).sin()) * ((d_lat / two).sin())
+ ((d_lon / two).sin()) * ((d_lon / two).sin()) * (lat1.cos()) * (lat2.cos());
let c = two * ((a.sqrt()).atan2((one - a).sqrt()));
radius * c
}

pub fn haversine_dist<T>(
x_lat: &ChunkedArray<T>,
x_long: &ChunkedArray<T>,
y_lat: &ChunkedArray<T>,
y_long: &ChunkedArray<T>,
unit: String,
) -> PolarsResult<ChunkedArray<T>>
where
T: PolarsFloatType,
T::Native: Float,
{
let radius = match unit.to_ascii_lowercase().as_str() {
"km" => 6371.0,
"miles" => 3960.0,
_ => {
polars_bail!(InvalidOperation: "Incorrect unit passed to haversine distance. Only 'km' or 'miles' are supported.")
}
};

let out: ChunkedArray<T> = x_lat
.into_iter()
.zip(x_long.into_iter())
.zip(y_lat.into_iter())
.zip(y_long.into_iter())
.map(|(((x_lat, x_long), y_lat), y_long)| {
let x_lat = x_lat?;
let x_long = x_long?;
let y_lat = y_lat?;
let y_long = y_long?;
Some(haversine_elementwise(x_lat, x_long, y_lat, y_long, radius))
})
.collect();

Ok(out.with_name("haversine"))
}
21 changes: 21 additions & 0 deletions polars_distance/tests/test_distance_arr.py
Original file line number Diff line number Diff line change
Expand Up @@ -246,3 +246,24 @@ def test_tversky_set_distance(data_sets):

assert_frame_equal(result, expected)
assert_frame_equal(result_int, expected)


@pytest.mark.parametrize(
"unit,value", [("km", 0.5491557912038084), ("miles", 0.341336828310639)]
)
def test_haversine(unit, value):
df = pl.DataFrame(
{
"x": [{"latitude": 38.898556, "longitude": -77.037852}],
"y": [{"latitude": 38.897147, "longitude": -77.043934}],
}
)

result = df.select(pld.col("x").dist.haversine("y", unit=unit).alias("haversine"))
expected = pl.DataFrame(
[
pl.Series("haversine", [value], dtype=pl.Float64),
]
)

assert_frame_equal(result, expected)

0 comments on commit 8b0eb10

Please sign in to comment.