Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(optimizer): Implement LIKE expression rule for query optimization #96

Open
wants to merge 10 commits into
base: main
Choose a base branch
from
22 changes: 18 additions & 4 deletions src/db.rs
Original file line number Diff line number Diff line change
Expand Up @@ -54,10 +54,10 @@ impl<S: Storage> Database<S> {
/// Limit(1)
/// Project(a,b)
let source_plan = binder.bind(&stmts[0])?;
// println!("source_plan plan: {:#?}", source_plan);
//println!("source_plan plan: {:#?}", source_plan);

let best_plan = Self::default_optimizer(source_plan).find_best()?;
// println!("best_plan plan: {:#?}", best_plan);
//println!("best_plan plan: {:#?}", best_plan);

let transaction = RefCell::new(transaction);
let mut stream = build(best_plan, &transaction);
Expand All @@ -78,10 +78,14 @@ impl<S: Storage> Database<S> {
.batch(
"Simplify Filter".to_string(),
HepBatchStrategy::fix_point_topdown(10),
vec![RuleImpl::SimplifyFilter, RuleImpl::ConstantCalculation],
vec![
RuleImpl::LikeRewrite,
RuleImpl::SimplifyFilter,
RuleImpl::ConstantCalculation,
],
)
.batch(
"Predicate Pushdown".to_string(),
"Predicate Pushown".to_string(),
loloxwg marked this conversation as resolved.
Show resolved Hide resolved
HepBatchStrategy::fix_point_topdown(10),
vec![
RuleImpl::PushPredicateThroughJoin,
Expand Down Expand Up @@ -206,6 +210,12 @@ mod test {
let _ = kipsql
.run("insert into t3 (a, b) values (4, 4444), (5, 5222), (6, 1.00)")
.await?;
let _ = kipsql
.run("create table t4 (a int primary key, b varchar(100))")
.await?;
let _ = kipsql
.run("insert into t4 (a, b) values (1, 'abc'), (2, 'abdc'), (3, 'abcd'), (4, 'ddabc')")
.await?;

println!("show tables:");
let tuples_show_tables = kipsql.run("show tables").await?;
Expand Down Expand Up @@ -371,6 +381,10 @@ mod test {
let tuples_decimal = kipsql.run("select * from t3").await?;
println!("{}", create_table(&tuples_decimal));

println!("like rewrite:");
let tuples_like_rewrite = kipsql.run("select * from t4 where b like 'abc%'").await?;
println!("{}", create_table(&tuples_like_rewrite));

Ok(())
}
}
5 changes: 4 additions & 1 deletion src/optimizer/rule/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,8 @@ use crate::optimizer::rule::pushdown_limit::{
};
use crate::optimizer::rule::pushdown_predicates::PushPredicateIntoScan;
use crate::optimizer::rule::pushdown_predicates::PushPredicateThroughJoin;
use crate::optimizer::rule::simplification::ConstantCalculation;
use crate::optimizer::rule::simplification::SimplifyFilter;
use crate::optimizer::rule::simplification::{ConstantCalculation, LikeRewrite};
use crate::optimizer::OptimizerError;

mod column_pruning;
Expand All @@ -37,6 +37,7 @@ pub enum RuleImpl {
// Simplification
SimplifyFilter,
ConstantCalculation,
LikeRewrite,
}

impl Rule for RuleImpl {
Expand All @@ -53,6 +54,7 @@ impl Rule for RuleImpl {
RuleImpl::PushPredicateIntoScan => PushPredicateIntoScan.pattern(),
RuleImpl::SimplifyFilter => SimplifyFilter.pattern(),
RuleImpl::ConstantCalculation => ConstantCalculation.pattern(),
RuleImpl::LikeRewrite => LikeRewrite.pattern(),
}
}

Expand All @@ -69,6 +71,7 @@ impl Rule for RuleImpl {
RuleImpl::SimplifyFilter => SimplifyFilter.apply(node_id, graph),
RuleImpl::PushPredicateIntoScan => PushPredicateIntoScan.apply(node_id, graph),
RuleImpl::ConstantCalculation => ConstantCalculation.apply(node_id, graph),
RuleImpl::LikeRewrite => LikeRewrite.apply(node_id, graph),
}
}
}
Expand Down
117 changes: 109 additions & 8 deletions src/optimizer/rule/simplification.rs
Original file line number Diff line number Diff line change
@@ -1,11 +1,19 @@
use crate::expression::{BinaryOperator, ScalarExpression};
use crate::optimizer::core::pattern::{Pattern, PatternChildrenPredicate};
use crate::optimizer::core::rule::Rule;
use crate::optimizer::heuristic::graph::{HepGraph, HepNodeId};
use crate::optimizer::OptimizerError;
use crate::planner::operator::join::JoinCondition;
use crate::planner::operator::Operator;
use crate::types::value::{DataValue, ValueRef};
use lazy_static::lazy_static;
lazy_static! {
static ref LIKE_REWRITE_RULE: Pattern = {
Pattern {
predicate: |op| matches!(op, Operator::Filter(_)),
children: PatternChildrenPredicate::None,
}
};
static ref CONSTANT_CALCULATION_RULE: Pattern = {
Pattern {
predicate: |_| true,
Expand Down Expand Up @@ -109,6 +117,91 @@ impl Rule for SimplifyFilter {
}
}

pub struct LikeRewrite;

impl Rule for LikeRewrite {
fn pattern(&self) -> &Pattern {
&LIKE_REWRITE_RULE
}

fn apply(&self, node_id: HepNodeId, graph: &mut HepGraph) -> Result<(), OptimizerError> {
if let Operator::Filter(mut filter_op) = graph.operator(node_id).clone() {
Copy link
Member

@KKould KKould Dec 25, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Use the operator_mut method to modify directly instead of replace

// if is like expression
if let ScalarExpression::Binary {
op: BinaryOperator::Like,
left_expr,
right_expr,
ty,
} = &mut filter_op.predicate
{
// if left is column and right is constant
if let ScalarExpression::ColumnRef(_) = left_expr.as_ref() {
loloxwg marked this conversation as resolved.
Show resolved Hide resolved
if let ScalarExpression::Constant(value) = right_expr.as_ref() {
Copy link
Member

@KKould KKould Nov 12, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

reduce unnecessary nesting and matching

if let ScalarExpression::Constant(DataValue::Utf8(mut val)) = right_expr.as_ref() {

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

knock knock

match value.as_ref() {
DataValue::Utf8(val_str) => {
let mut value = val_str.clone().unwrap_or_else(|| "".to_string());

if value.ends_with('%') {
value.pop(); // remove '%'
loloxwg marked this conversation as resolved.
Show resolved Hide resolved
if let Some(last_char) = value.clone().pop() {
if let Some(next_char) = increment_char(last_char) {
let mut new_value = value.clone();
new_value.pop();
new_value.push(next_char);

let new_expr = ScalarExpression::Binary {
op: BinaryOperator::And,
left_expr: Box::new(ScalarExpression::Binary {
op: BinaryOperator::GtEq,
left_expr: left_expr.clone(),
right_expr: Box::new(
ScalarExpression::Constant(ValueRef::from(
DataValue::Utf8(Some(value)),
)),
),
ty: ty.clone(),
}),
right_expr: Box::new(ScalarExpression::Binary {
op: BinaryOperator::Lt,
left_expr: left_expr.clone(),
right_expr: Box::new(
ScalarExpression::Constant(ValueRef::from(
DataValue::Utf8(Some(new_value)),
)),
),
ty: ty.clone(),
}),
ty: ty.clone(),
};
filter_op.predicate = new_expr;
}
}
}
}
_ => {
graph.version += 1;
loloxwg marked this conversation as resolved.
Show resolved Hide resolved
return Ok(());
}
}
}
}
}
graph.replace_node(node_id, Operator::Filter(filter_op))
}
// mark changed to skip this rule batch
graph.version += 1;
Ok(())
}
}

fn increment_char(v: char) -> Option<char> {
match v {
'z' => None,
'Z' => None,
_ => std::char::from_u32(v as u32 + 1),
}
}

#[cfg(test)]
mod test {
use crate::binder::test::select_sql_run;
Expand All @@ -118,6 +211,7 @@ mod test {
use crate::expression::{BinaryOperator, ScalarExpression, UnaryOperator};
use crate::optimizer::heuristic::batch::HepBatchStrategy;
use crate::optimizer::heuristic::optimizer::HepOptimizer;
use crate::optimizer::rule::simplification::increment_char;
use crate::optimizer::rule::RuleImpl;
use crate::planner::operator::filter::FilterOperator;
use crate::planner::operator::Operator;
Expand All @@ -127,6 +221,13 @@ mod test {
use std::collections::Bound;
use std::sync::Arc;

#[test]
fn test_increment_char() {
assert_eq!(increment_char('a'), Some('b'));
assert_eq!(increment_char('z'), None);
assert_eq!(increment_char('A'), Some('B'));
}

#[tokio::test]
async fn test_constant_calculation_omitted() -> Result<(), DatabaseError> {
// (2 + (-1)) < -(c1 + 1)
Expand Down Expand Up @@ -343,7 +444,7 @@ mod test {
cb_1_c1,
Some(ConstantBinary::Scope {
min: Bound::Unbounded,
max: Bound::Excluded(Arc::new(DataValue::Int32(Some(-2))))
max: Bound::Excluded(Arc::new(DataValue::Int32(Some(-2)))),
})
);

Expand All @@ -353,7 +454,7 @@ mod test {
cb_1_c2,
Some(ConstantBinary::Scope {
min: Bound::Excluded(Arc::new(DataValue::Int32(Some(2)))),
max: Bound::Unbounded
max: Bound::Unbounded,
})
);

Expand All @@ -363,7 +464,7 @@ mod test {
cb_2_c1,
Some(ConstantBinary::Scope {
min: Bound::Excluded(Arc::new(DataValue::Int32(Some(2)))),
max: Bound::Unbounded
max: Bound::Unbounded,
})
);

Expand All @@ -373,7 +474,7 @@ mod test {
cb_1_c1,
Some(ConstantBinary::Scope {
min: Bound::Unbounded,
max: Bound::Excluded(Arc::new(DataValue::Int32(Some(-2))))
max: Bound::Excluded(Arc::new(DataValue::Int32(Some(-2)))),
})
);

Expand All @@ -383,7 +484,7 @@ mod test {
cb_3_c1,
Some(ConstantBinary::Scope {
min: Bound::Unbounded,
max: Bound::Excluded(Arc::new(DataValue::Int32(Some(-1))))
max: Bound::Excluded(Arc::new(DataValue::Int32(Some(-1)))),
})
);

Expand All @@ -393,7 +494,7 @@ mod test {
cb_3_c2,
Some(ConstantBinary::Scope {
min: Bound::Excluded(Arc::new(DataValue::Int32(Some(0)))),
max: Bound::Unbounded
max: Bound::Unbounded,
})
);

Expand All @@ -403,7 +504,7 @@ mod test {
cb_4_c1,
Some(ConstantBinary::Scope {
min: Bound::Excluded(Arc::new(DataValue::Int32(Some(0)))),
max: Bound::Unbounded
max: Bound::Unbounded,
})
);

Expand All @@ -413,7 +514,7 @@ mod test {
cb_4_c2,
Some(ConstantBinary::Scope {
min: Bound::Unbounded,
max: Bound::Excluded(Arc::new(DataValue::Int32(Some(-1))))
max: Bound::Excluded(Arc::new(DataValue::Int32(Some(-1)))),
})
);

Expand Down
Loading