Skip to content

Commit

Permalink
draft right mark
Browse files Browse the repository at this point in the history
  • Loading branch information
jonathanc-n committed Nov 4, 2024
1 parent b7f4db4 commit a6c2e06
Show file tree
Hide file tree
Showing 28 changed files with 256 additions and 27 deletions.
2 changes: 1 addition & 1 deletion datafusion/common/src/functional_dependencies.rs
Original file line number Diff line number Diff line change
Expand Up @@ -338,7 +338,7 @@ impl FunctionalDependencies {
// These joins preserve functional dependencies of the left side:
left_func_dependencies
}
JoinType::RightSemi | JoinType::RightAnti => {
JoinType::RightSemi | JoinType::RightAnti | JoinType::RightMark => {
// These joins preserve functional dependencies of the right side:
right_func_dependencies
}
Expand Down
9 changes: 8 additions & 1 deletion datafusion/common/src/join_type.rs
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ pub enum JoinType {
LeftAnti,
/// Right Anti Join
RightAnti,
/// Left Mark join
/// Left Mark Join
///
/// Returns one record for each record from the left input. The output contains an additional
/// column "mark" which is true if there is at least one match in the right input where the
Expand All @@ -58,6 +58,11 @@ pub enum JoinType {
///
/// [1]: http://btw2017.informatik.uni-stuttgart.de/slidesandpapers/F1-10-37/paper_web.pdf
LeftMark,
/// Righ Mark Join
///
/// Same logic as the LeftMark Join above, however it returns a record for each record from the
/// right input.
RightMark,
}

impl JoinType {
Expand All @@ -78,6 +83,7 @@ impl Display for JoinType {
JoinType::LeftAnti => "LeftAnti",
JoinType::RightAnti => "RightAnti",
JoinType::LeftMark => "LeftMark",
JoinType::RightMark => "RightMark",
};
write!(f, "{join_type}")
}
Expand All @@ -98,6 +104,7 @@ impl FromStr for JoinType {
"LEFTANTI" => Ok(JoinType::LeftAnti),
"RIGHTANTI" => Ok(JoinType::RightAnti),
"LEFTMARK" => Ok(JoinType::LeftMark),
"RIGHTMARK" => Ok(JoinType::RightMark),
_ => _not_impl_err!("The join type {s} does not exist or is not implemented"),
}
}
Expand Down
4 changes: 3 additions & 1 deletion datafusion/core/src/dataframe/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3865,6 +3865,7 @@ mod tests {
JoinType::LeftAnti,
JoinType::RightAnti,
JoinType::LeftMark,
JoinType::RightMark,
];

let default_partition_count = SessionConfig::new().target_partitions();
Expand Down Expand Up @@ -3898,7 +3899,8 @@ mod tests {
JoinType::Inner
| JoinType::Right
| JoinType::RightSemi
| JoinType::RightAnti => {
| JoinType::RightAnti
| JoinType::RightMark => {
let right_exprs: Vec<Arc<dyn PhysicalExpr>> = vec![
Arc::new(Column::new_with_schema("c2_c1", &join_schema)?),
Arc::new(Column::new_with_schema("c2_c2", &join_schema)?),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -322,7 +322,7 @@ fn adjust_input_keys_ordering(
left.schema().fields().len(),
)
.unwrap_or_default(),
JoinType::RightSemi | JoinType::RightAnti => {
JoinType::RightSemi | JoinType::RightAnti | JoinType::RightMark => {
requirements.data.clone()
}
JoinType::Left
Expand Down Expand Up @@ -1963,6 +1963,7 @@ pub(crate) mod tests {
JoinType::LeftMark,
JoinType::RightSemi,
JoinType::RightAnti,
JoinType::RightMark,
];

// Join on (a == b1)
Expand Down Expand Up @@ -2036,7 +2037,7 @@ pub(crate) mod tests {
assert_optimized!(expected, top_join.clone(), true);
assert_optimized!(expected, top_join, false);
}
JoinType::RightSemi | JoinType::RightAnti => {}
JoinType::RightSemi | JoinType::RightAnti | JoinType::RightMark => {}
}

match join_type {
Expand Down Expand Up @@ -2101,7 +2102,7 @@ pub(crate) mod tests {
assert_optimized!(expected, top_join.clone(), true);
assert_optimized!(expected, top_join, false);
}
JoinType::LeftSemi | JoinType::LeftAnti | JoinType::LeftMark => {}
JoinType::LeftSemi | JoinType::LeftAnti | JoinType::LeftMark | JoinType::RightMark => {}
}
}

Expand Down
3 changes: 3 additions & 0 deletions datafusion/core/src/physical_optimizer/join_selection.rs
Original file line number Diff line number Diff line change
Expand Up @@ -135,6 +135,9 @@ fn swap_join_type(join_type: JoinType) -> JoinType {
JoinType::LeftMark => {
unreachable!("LeftMark join type does not support swapping")
}
JoinType::RightMark => {
unreachable!("RightMark join type does not support swapping")
}
}
}

Expand Down
3 changes: 2 additions & 1 deletion datafusion/core/src/physical_optimizer/sort_pushdown.rs
Original file line number Diff line number Diff line change
Expand Up @@ -427,7 +427,8 @@ fn expr_source_side(
| JoinType::Left
| JoinType::Right
| JoinType::Full
| JoinType::LeftMark => {
| JoinType::LeftMark
| JoinType::RightMark => {
let all_column_sides = required_exprs
.iter()
.filter_map(|r| {
Expand Down
24 changes: 24 additions & 0 deletions datafusion/core/tests/fuzz_cases/join_fuzz.rs
Original file line number Diff line number Diff line change
Expand Up @@ -258,6 +258,30 @@ async fn test_left_mark_join_1k_filtered() {
.await
}

#[tokio::test]
async fn test_right_mark_join_1k() {
JoinFuzzTestCase::new(
make_staggered_batches(1000),
make_staggered_batches(1000),
JoinType::RightMark,
None,
)
.run_test(&[JoinTestType::HjSmj, JoinTestType::NljHj], false)
.await
}

#[tokio::test]
async fn test_right_mark_join_1k_filtered() {
JoinFuzzTestCase::new(
make_staggered_batches(1000),
make_staggered_batches(1000),
JoinType::RightMark,
Some(Box::new(col_lt_col_filter)),
)
.run_test(&[JoinTestType::HjSmj, JoinTestType::NljHj], false)
.await
}

type JoinFilterBuilder = Box<dyn Fn(Arc<Schema>, Arc<Schema>) -> JoinFilter>;

struct JoinFuzzTestCase {
Expand Down
4 changes: 4 additions & 0 deletions datafusion/expr/src/logical_plan/builder.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1422,6 +1422,10 @@ pub fn build_join_schema(
.map(|(q, f)| (q.cloned(), Arc::clone(f)))
.collect()
}
JoinType::RightMark => right_fields
.map(|(q, f)| (q.cloned(), Arc::clone(f)))
.chain(once(mark_field(left)))
.collect(),
};
let func_dependencies = left.functional_dependencies().join(
right.functional_dependencies(),
Expand Down
8 changes: 6 additions & 2 deletions datafusion/expr/src/logical_plan/plan.rs
Original file line number Diff line number Diff line change
Expand Up @@ -539,7 +539,9 @@ impl LogicalPlan {
JoinType::LeftSemi | JoinType::LeftAnti | JoinType::LeftMark => {
left.head_output_expr()
}
JoinType::RightSemi | JoinType::RightAnti => right.head_output_expr(),
JoinType::RightSemi | JoinType::RightAnti | JoinType::RightMark => {
right.head_output_expr()
}
},
LogicalPlan::RecursiveQuery(RecursiveQuery { static_term, .. }) => {
static_term.head_output_expr()
Expand Down Expand Up @@ -1309,7 +1311,9 @@ impl LogicalPlan {
JoinType::LeftSemi | JoinType::LeftAnti | JoinType::LeftMark => {
left.max_rows()
}
JoinType::RightSemi | JoinType::RightAnti => right.max_rows(),
JoinType::RightSemi | JoinType::RightAnti | JoinType::RightMark => {
right.max_rows()
}
},
LogicalPlan::Repartition(Repartition { input, .. }) => input.max_rows(),
LogicalPlan::Union(Union { inputs, .. }) => inputs
Expand Down
2 changes: 1 addition & 1 deletion datafusion/optimizer/src/analyzer/subquery.rs
Original file line number Diff line number Diff line change
Expand Up @@ -188,7 +188,7 @@ fn check_inner_plan(inner_plan: &LogicalPlan, can_contain_outer_ref: bool) -> Re
check_inner_plan(left, can_contain_outer_ref)?;
check_inner_plan(right, false)
}
JoinType::Right | JoinType::RightSemi | JoinType::RightAnti => {
JoinType::Right | JoinType::RightSemi | JoinType::RightAnti | JoinType::RightMark => {
check_inner_plan(left, false)?;
check_inner_plan(right, can_contain_outer_ref)
}
Expand Down
3 changes: 2 additions & 1 deletion datafusion/optimizer/src/optimize_projections/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -682,7 +682,8 @@ fn split_join_requirements(
| JoinType::Left
| JoinType::Right
| JoinType::Full
| JoinType::LeftMark => {
| JoinType::LeftMark
| JoinType::RightMark => {
// Decrease right side indices by `left_len` so that they point to valid
// positions within the right child:
indices.split_off(left_len)
Expand Down
9 changes: 5 additions & 4 deletions datafusion/optimizer/src/push_down_filter.rs
Original file line number Diff line number Diff line change
Expand Up @@ -162,11 +162,11 @@ pub(crate) fn lr_is_preserved(join_type: JoinType) -> (bool, bool) {
JoinType::Right => (false, true),
JoinType::Full => (false, false),
// No columns from the right side of the join can be referenced in output
// predicates for semi/anti joins, so whether we specify t/f doesn't matter.
// predicates for semi/anti/mark joins, so whether we specify t/f doesn't matter.
JoinType::LeftSemi | JoinType::LeftAnti | JoinType::LeftMark => (true, false),
// No columns from the left side of the join can be referenced in output
// predicates for semi/anti joins, so whether we specify t/f doesn't matter.
JoinType::RightSemi | JoinType::RightAnti => (false, true),
// predicates for semi/anti/mark joins, so whether we specify t/f doesn't matter.
JoinType::RightSemi | JoinType::RightAnti | JoinType::RightMark => (false, true),
}
}

Expand All @@ -189,6 +189,7 @@ pub(crate) fn on_lr_is_preserved(join_type: JoinType) -> (bool, bool) {
JoinType::LeftAnti => (false, true),
JoinType::RightAnti => (true, false),
JoinType::LeftMark => (false, true),
JoinType::RightMark => (true, false),
}
}

Expand Down Expand Up @@ -740,7 +741,7 @@ fn infer_join_predicates_from_on_filters(
inferred_predicates,
)
}
JoinType::Right | JoinType::RightSemi => {
JoinType::Right | JoinType::RightSemi | JoinType::RightMark => {
infer_join_predicates_impl::<false, true>(
join_col_keys,
on_filters,
Expand Down
2 changes: 1 addition & 1 deletion datafusion/optimizer/src/push_down_limit.rs
Original file line number Diff line number Diff line change
Expand Up @@ -249,7 +249,7 @@ fn push_down_join(mut join: Join, limit: usize) -> Transformed<Join> {
match join.join_type {
Left | Right | Full | Inner => (Some(limit), Some(limit)),
LeftAnti | LeftSemi | LeftMark => (Some(limit), None),
RightAnti | RightSemi => (None, Some(limit)),
RightAnti | RightSemi | RightMark => (None, Some(limit)),
}
} else {
match join.join_type {
Expand Down
2 changes: 1 addition & 1 deletion datafusion/physical-expr/src/equivalence/class.rs
Original file line number Diff line number Diff line change
Expand Up @@ -633,7 +633,7 @@ impl EquivalenceGroup {
result
}
JoinType::LeftSemi | JoinType::LeftAnti | JoinType::LeftMark => self.clone(),
JoinType::RightSemi | JoinType::RightAnti => right_equivalences.clone(),
JoinType::RightSemi | JoinType::RightAnti | JoinType::RightMark => right_equivalences.clone(),
}
}
}
Expand Down
101 changes: 101 additions & 0 deletions datafusion/physical-plan/src/joins/hash_join.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3180,6 +3180,96 @@ mod tests {
Ok(())
}

#[apply(batch_sizes)]
#[tokio::test]
async fn join_right_mark(batch_size: usize) -> Result<()> {
let task_ctx = prepare_task_ctx(batch_size);
let left = build_table(
("a1", &vec![1, 2, 3]),
("b1", &vec![4, 5, 7]), // 7 does not exist on the right
("c1", &vec![7, 8, 9]),
);
let right = build_table(
("a2", &vec![10, 20, 30]),
("b1", &vec![4, 5, 6]), // 6 does not exist on the left
("c2", &vec![70, 80, 90]),
);
let on = vec![(
Arc::new(Column::new_with_schema("b1", &left.schema())?) as _,
Arc::new(Column::new_with_schema("b1", &right.schema())?) as _,
)];

let (columns, batches) = join_collect(
Arc::clone(&left),
Arc::clone(&right),
on.clone(),
&JoinType::RightMark,
false,
task_ctx,
)
.await?;
assert_eq!(columns, vec!["a2", "b1", "c2", "mark"]);

let expected = [
"+----+----+----+-------+",
"| a2 | b1 | c2 | mark |",
"+----+----+----+-------+",
"| 10 | 4 | 70 | true |",
"| 20 | 5 | 80 | true |",
"| 30 | 6 | 90 | false |",
"+----+----+----+-------+",
];
assert_batches_sorted_eq!(expected, &batches);

Ok(())
}

#[apply(batch_sizes)]
#[tokio::test]
async fn partitioned_join_right_mark(batch_size: usize) -> Result<()> {
let task_ctx = prepare_task_ctx(batch_size);
let left = build_table(
("a1", &vec![1, 2, 3]),
("b1", &vec![4, 5, 7]), // 7 does not exist on the right
("c1", &vec![7, 8, 9]),
);
let right = build_table(
("a2", &vec![10, 20, 30, 40]),
("b1", &vec![4, 4, 5, 6]), // 6 does not exist on the left
("c2", &vec![60, 70, 80, 90]),
);
let on = vec![(
Arc::new(Column::new_with_schema("b1", &left.schema())?) as _,
Arc::new(Column::new_with_schema("b1", &right.schema())?) as _,
)];

let (columns, batches) = partitioned_join_collect(
Arc::clone(&left),
Arc::clone(&right),
on.clone(),
&JoinType::RightMark,
false,
task_ctx,
)
.await?;
assert_eq!(columns, vec!["a2", "b1", "c2", "mark"]);

let expected = [
"+----+----+----+-------+",
"| a2 | b1 | c2 | mark |",
"+----+----+----+-------+",
"| 10 | 4 | 60 | true |",
"| 20 | 4 | 70 | true |",
"| 30 | 5 | 80 | true |",
"| 40 | 6 | 90 | false |",
"+----+----+----+-------+",
];
assert_batches_sorted_eq!(expected, &batches);

Ok(())
}


#[test]
fn join_with_hash_collision() -> Result<()> {
let mut hashmap_left = RawTable::with_capacity(2);
Expand Down Expand Up @@ -3574,6 +3664,15 @@ mod tests {
"| 3 | 7 | 9 | false |",
"+----+----+----+-------+",
];
let expected_right_mark = vec![
"+----+----+----+-------+",
"| a2 | b2 | c2 | mark |",
"+----+----+----+-------+",
"| 10 | 4 | 70 | true |",
"| 20 | 5 | 80 | true |",
"| 30 | 6 | 90 | false |",
"+----+----+----+-------+",
];

let test_cases = vec![
(JoinType::Inner, expected_inner),
Expand All @@ -3585,6 +3684,7 @@ mod tests {
(JoinType::RightSemi, expected_right_semi),
(JoinType::RightAnti, expected_right_anti),
(JoinType::LeftMark, expected_left_mark),
(JoinType::RightMark, expected_right_mark),
];

for (join_type, expected) in test_cases {
Expand Down Expand Up @@ -3868,6 +3968,7 @@ mod tests {
JoinType::RightSemi,
JoinType::RightAnti,
JoinType::LeftMark,
JoinType::RightMark,
];

for join_type in join_types {
Expand Down
Loading

0 comments on commit a6c2e06

Please sign in to comment.