Skip to content

Commit

Permalink
add categorical dtype in set distances
Browse files Browse the repository at this point in the history
  • Loading branch information
ion-elgreco committed Dec 30, 2023
1 parent 8f6e7e8 commit c4c97c2
Showing 1 changed file with 24 additions and 5 deletions.
29 changes: 24 additions & 5 deletions polars_distance/polars_distance/src/list.rs
Original file line number Diff line number Diff line change
Expand Up @@ -119,8 +119,9 @@ pub fn jaccard_index(a: &ListChunked, b: &ListChunked) -> PolarsResult<Float64Ch
} else {
match a.inner_dtype() {
DataType::Utf8 => elementwise_string_inp(a,b, jacc_str_array),
DataType::Categorical(_) => elementwise_int_inp(a,b, jacc_int_array::<u32>),
_ => Err(PolarsError::ComputeError(
format!("jaccard index only works on inner dtype Utf8 or integer. Use of {} is not supported", a.inner_dtype()).into(),
format!("jaccard index only works on inner dtype Utf8, Categorical and integer. Use of {} is not supported", a.inner_dtype()).into(),
))
}
}
Expand All @@ -137,8 +138,9 @@ pub fn sorensen_index(a: &ListChunked, b: &ListChunked) -> PolarsResult<Float64C
} else {
match a.inner_dtype() {
DataType::Utf8 => elementwise_string_inp(a,b, sorensen_str_array),
DataType::Categorical(_) => elementwise_int_inp(a,b, sorensen_int_array::<u32>),
_ => Err(PolarsError::ComputeError(
format!("sorensen index only works on inner dtype Utf8 or integer. Use of {} is not supported", a.inner_dtype()).into(),
format!("sorensen index only works on inner dtype Utf8, Categorical and integer. Use of {} is not supported", a.inner_dtype()).into(),
))
}
}
Expand All @@ -155,8 +157,9 @@ pub fn overlap_coef(a: &ListChunked, b: &ListChunked) -> PolarsResult<Float64Chu
} else {
match a.inner_dtype() {
DataType::Utf8 => elementwise_string_inp(a,b, overlap_str_array),
DataType::Categorical(_) => elementwise_int_inp(a,b, overlap_int_array::<u32>),
_ => Err(PolarsError::ComputeError(
format!("overlap coefficient only works on inner dtype Utf8 or integer. Use of {} is not supported", a.inner_dtype()).into(),
format!("overlap coefficient only works on inner dtype Utf8, Categorical and integer. Use of {} is not supported", a.inner_dtype()).into(),
))
}
}
Expand All @@ -173,8 +176,9 @@ pub fn cosine_set_distance(a: &ListChunked, b: &ListChunked) -> PolarsResult<Flo
} else {
match a.inner_dtype() {
DataType::Utf8 => elementwise_string_inp(a,b, cosine_str_array),
DataType::Categorical(_) => elementwise_int_inp(a,b, cosine_int_array::<u32>),
_ => Err(PolarsError::ComputeError(
format!("cosine set distance only works on inner dtype Utf8 or integer. Use of {} is not supported", a.inner_dtype()).into(),
format!("cosine set distance only works on inner dtype Utf8, Categorical and integer. Use of {} is not supported", a.inner_dtype()).into(),
))
}
}
Expand Down Expand Up @@ -226,8 +230,23 @@ pub fn tversky_index(
_ => None,
}))
},
DataType::Categorical(_) => {
Ok(binary_elementwise(a, b, |a, b| match (a, b) {
(Some(a), Some(b)) => {
let a = a.as_any().downcast_ref::<PrimitiveArray<u32>>().unwrap();
let b = b.as_any().downcast_ref::<PrimitiveArray<u32>>().unwrap();
let s1 = a.into_iter().collect::<PlHashSet<_>>();
let s2 = b.into_iter().collect::<PlHashSet<_>>();
let len_intersect = s1.intersection(&s2).count() as f64;
let len_diff1 = s1.difference(&s2).count();
let len_diff2 = s2.difference(&s1).count();
Some(len_intersect / (len_intersect + (alpha * len_diff1 as f64) + (beta * len_diff2 as f64)))
}
_ => None,
}))
},
_ => Err(PolarsError::ComputeError(
format!("tversky index distance only works on inner dtype Utf8 or integer. Use of {} is not supported", a.inner_dtype()).into(),
format!("tversky index distance only works on inner dtype Utf8, Categorical and integer. Use of {} is not supported", a.inner_dtype()).into(),
))
}
}
Expand Down

0 comments on commit c4c97c2

Please sign in to comment.