Skip to content

Commit

Permalink
feat: Add 'allow_exact_matches' join_asof' (#20723)
Browse files Browse the repository at this point in the history
  • Loading branch information
ritchie46 authored Jan 15, 2025
1 parent 417dd44 commit 46b2714
Show file tree
Hide file tree
Showing 9 changed files with 285 additions and 73 deletions.
84 changes: 56 additions & 28 deletions crates/polars-ops/src/frame/join/asof/default.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,12 @@ use super::{
AsofJoinBackwardState, AsofJoinForwardState, AsofJoinNearestState, AsofJoinState, AsofStrategy,
};

fn join_asof_impl<'a, T, S, F>(left: &'a T::Array, right: &'a T::Array, mut filter: F) -> IdxCa
fn join_asof_impl<'a, T, S, F>(
left: &'a T::Array,
right: &'a T::Array,
mut filter: F,
allow_eq: bool,
) -> IdxCa
where
T: PolarsDataType,
S: AsofJoinState<T::Physical<'a>>,
Expand All @@ -20,7 +25,7 @@ where

let mut out = vec![0; left.len()];
let mut mask = vec![0; (left.len() + 7) / 8];
let mut state = S::default();
let mut state = S::new(allow_eq);

if left.null_count() == 0 && right.null_count() == 0 {
for (i, val_l) in left.values_iter().enumerate() {
Expand All @@ -31,9 +36,11 @@ where
right.len() as IdxSize,
) {
// SAFETY: r_idx is non-null and valid.
let val_r = unsafe { right.value_unchecked(r_idx as usize) };
out[i] = r_idx;
mask[i / 8] |= (filter(val_l, val_r) as u8) << (i % 8);
unsafe {
let val_r = right.value_unchecked(r_idx as usize);
*out.get_unchecked_mut(i) = r_idx;
*mask.get_unchecked_mut(i / 8) |= (filter(val_l, val_r) as u8) << (i % 8);
}
}
}
} else {
Expand All @@ -46,9 +53,11 @@ where
right.len() as IdxSize,
) {
// SAFETY: r_idx is non-null and valid.
let val_r = unsafe { right.value_unchecked(r_idx as usize) };
out[i] = r_idx;
mask[i / 8] |= (filter(val_l, val_r) as u8) << (i % 8);
unsafe {
let val_r = right.value_unchecked(r_idx as usize);
*out.get_unchecked_mut(i) = r_idx;
*mask.get_unchecked_mut(i / 8) |= (filter(val_l, val_r) as u8) << (i % 8);
}
}
}
}
Expand All @@ -58,38 +67,54 @@ where
IdxCa::from_vec_validity(PlSmallStr::EMPTY, out, Some(bitmap))
}

fn join_asof_forward<'a, T, F>(left: &'a T::Array, right: &'a T::Array, filter: F) -> IdxCa
fn join_asof_forward<'a, T, F>(
left: &'a T::Array,
right: &'a T::Array,
filter: F,
allow_eq: bool,
) -> IdxCa
where
T: PolarsDataType,
T::Physical<'a>: PartialOrd,
F: FnMut(T::Physical<'a>, T::Physical<'a>) -> bool,
{
join_asof_impl::<'a, T, AsofJoinForwardState, _>(left, right, filter)
join_asof_impl::<'a, T, AsofJoinForwardState, _>(left, right, filter, allow_eq)
}

fn join_asof_backward<'a, T, F>(left: &'a T::Array, right: &'a T::Array, filter: F) -> IdxCa
fn join_asof_backward<'a, T, F>(
left: &'a T::Array,
right: &'a T::Array,
filter: F,
allow_eq: bool,
) -> IdxCa
where
T: PolarsDataType,
T::Physical<'a>: PartialOrd,
F: FnMut(T::Physical<'a>, T::Physical<'a>) -> bool,
{
join_asof_impl::<'a, T, AsofJoinBackwardState, _>(left, right, filter)
join_asof_impl::<'a, T, AsofJoinBackwardState, _>(left, right, filter, allow_eq)
}

fn join_asof_nearest<'a, T, F>(left: &'a T::Array, right: &'a T::Array, filter: F) -> IdxCa
fn join_asof_nearest<'a, T, F>(
left: &'a T::Array,
right: &'a T::Array,
filter: F,
allow_eq: bool,
) -> IdxCa
where
T: PolarsDataType,
T::Physical<'a>: NumericNative,
F: FnMut(T::Physical<'a>, T::Physical<'a>) -> bool,
{
join_asof_impl::<'a, T, AsofJoinNearestState, _>(left, right, filter)
join_asof_impl::<'a, T, AsofJoinNearestState, _>(left, right, filter, allow_eq)
}

pub(crate) fn join_asof_numeric<T: PolarsNumericType>(
input_ca: &ChunkedArray<T>,
other: &Series,
strategy: AsofStrategy,
tolerance: Option<AnyValue<'static>>,
allow_eq: bool,
) -> PolarsResult<IdxCa> {
let other = input_ca.unpack_series_matching_type(other)?;

Expand All @@ -103,16 +128,16 @@ pub(crate) fn join_asof_numeric<T: PolarsNumericType>(
let abs_tolerance = native_tolerance.abs_diff(T::Native::zero());
let filter = |l: T::Native, r: T::Native| l.abs_diff(r) <= abs_tolerance;
match strategy {
AsofStrategy::Forward => join_asof_forward::<T, _>(left, right, filter),
AsofStrategy::Backward => join_asof_backward::<T, _>(left, right, filter),
AsofStrategy::Nearest => join_asof_nearest::<T, _>(left, right, filter),
AsofStrategy::Forward => join_asof_forward::<T, _>(left, right, filter, allow_eq),
AsofStrategy::Backward => join_asof_backward::<T, _>(left, right, filter, allow_eq),
AsofStrategy::Nearest => join_asof_nearest::<T, _>(left, right, filter, allow_eq),
}
} else {
let filter = |_l: T::Native, _r: T::Native| true;
match strategy {
AsofStrategy::Forward => join_asof_forward::<T, _>(left, right, filter),
AsofStrategy::Backward => join_asof_backward::<T, _>(left, right, filter),
AsofStrategy::Nearest => join_asof_nearest::<T, _>(left, right, filter),
AsofStrategy::Forward => join_asof_forward::<T, _>(left, right, filter, allow_eq),
AsofStrategy::Backward => join_asof_backward::<T, _>(left, right, filter, allow_eq),
AsofStrategy::Nearest => join_asof_nearest::<T, _>(left, right, filter, allow_eq),
}
};
Ok(out)
Expand All @@ -122,6 +147,7 @@ pub(crate) fn join_asof<T>(
input_ca: &ChunkedArray<T>,
other: &Series,
strategy: AsofStrategy,
allow_eq: bool,
) -> PolarsResult<IdxCa>
where
T: PolarsDataType,
Expand All @@ -136,9 +162,11 @@ where

let filter = |_l: T::Physical<'_>, _r: T::Physical<'_>| true;
Ok(match strategy {
AsofStrategy::Forward => join_asof_impl::<T, AsofJoinForwardState, _>(left, right, filter),
AsofStrategy::Forward => {
join_asof_impl::<T, AsofJoinForwardState, _>(left, right, filter, allow_eq)
},
AsofStrategy::Backward => {
join_asof_impl::<T, AsofJoinBackwardState, _>(left, right, filter)
join_asof_impl::<T, AsofJoinBackwardState, _>(left, right, filter, allow_eq)
},
AsofStrategy::Nearest => unimplemented!(),
})
Expand All @@ -155,31 +183,31 @@ mod test {
let a = PrimitiveArray::from_slice([-1, 2, 3, 3, 3, 4]);
let b = PrimitiveArray::from_slice([1, 2, 3, 3]);

let tuples = join_asof_backward::<Int32Type, _>(&a, &b, |_, _| true);
let tuples = join_asof_backward::<Int32Type, _>(&a, &b, |_, _| true, true);
assert_eq!(tuples.len(), a.len());
assert_eq!(
tuples.to_vec(),
&[None, Some(1), Some(3), Some(3), Some(3), Some(3)]
);

let b = PrimitiveArray::from_slice([1, 2, 4, 5]);
let tuples = join_asof_backward::<Int32Type, _>(&a, &b, |_, _| true);
let tuples = join_asof_backward::<Int32Type, _>(&a, &b, |_, _| true, true);
assert_eq!(
tuples.to_vec(),
&[None, Some(1), Some(1), Some(1), Some(1), Some(2)]
);

let a = PrimitiveArray::from_slice([2, 4, 4, 4]);
let b = PrimitiveArray::from_slice([1, 2, 3, 3]);
let tuples = join_asof_backward::<Int32Type, _>(&a, &b, |_, _| true);
let tuples = join_asof_backward::<Int32Type, _>(&a, &b, |_, _| true, true);
assert_eq!(tuples.to_vec(), &[Some(1), Some(3), Some(3), Some(3)]);
}

#[test]
fn test_asof_backward_tolerance() {
let a = PrimitiveArray::from_slice([-1, 20, 25, 30, 30, 40]);
let b = PrimitiveArray::from_slice([10, 20, 30, 30]);
let tuples = join_asof_backward::<Int32Type, _>(&a, &b, |l, r| l.abs_diff(r) <= 4u32);
let tuples = join_asof_backward::<Int32Type, _>(&a, &b, |l, r| l.abs_diff(r) <= 4u32, true);
assert_eq!(
tuples.to_vec(),
&[None, Some(1), None, Some(3), Some(3), None]
Expand All @@ -190,7 +218,7 @@ mod test {
fn test_asof_forward_tolerance() {
let a = PrimitiveArray::from_slice([-1, 20, 25, 30, 30, 40, 52]);
let b = PrimitiveArray::from_slice([10, 20, 33, 55]);
let tuples = join_asof_forward::<Int32Type, _>(&a, &b, |l, r| l.abs_diff(r) <= 4u32);
let tuples = join_asof_forward::<Int32Type, _>(&a, &b, |l, r| l.abs_diff(r) <= 4u32, true);
assert_eq!(
tuples.to_vec(),
&[None, Some(1), None, Some(2), Some(2), None, Some(3)]
Expand All @@ -202,7 +230,7 @@ mod test {
let a = PrimitiveArray::from_slice([-1, 1, 2, 4, 6]);
let b = PrimitiveArray::from_slice([1, 2, 4, 5]);

let tuples = join_asof_forward::<Int32Type, _>(&a, &b, |_, _| true);
let tuples = join_asof_forward::<Int32Type, _>(&a, &b, |_, _| true, true);
assert_eq!(tuples.len(), a.len());
assert_eq!(tuples.to_vec(), &[Some(0), Some(0), Some(1), Some(2), None]);
}
Expand Down
Loading

0 comments on commit 46b2714

Please sign in to comment.