Skip to content

Commit

Permalink
Support sql federation cte (#28888)
Browse files Browse the repository at this point in the history
* Support mysql cte sql parse

* Add mysql cte sql parse test

* Refactor sql federation WithConverter

* Support sql federation SelectStatement with convert

* Add sql federation cte execution plan test

* Format test sql

* Format parse code

* Change SelectStatementHandler mysql test
  • Loading branch information
zihaoAK47 authored Oct 29, 2023
1 parent 60144f0 commit c0ce962
Show file tree
Hide file tree
Showing 11 changed files with 400 additions and 20 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -26,8 +26,9 @@
import org.apache.calcite.sql.SqlWithItem;
import org.apache.calcite.sql.parser.SqlParserPos;
import org.apache.shardingsphere.sql.parser.sql.common.segment.dml.column.ColumnSegment;
import org.apache.shardingsphere.sql.parser.sql.common.segment.dml.expr.complex.CommonTableExpressionSegment;
import org.apache.shardingsphere.sql.parser.sql.common.segment.generic.WithSegment;
import org.apache.shardingsphere.sqlfederation.optimizer.converter.segment.expression.ExpressionConverter;
import org.apache.shardingsphere.sqlfederation.optimizer.converter.segment.expression.impl.ColumnConverter;
import org.apache.shardingsphere.sqlfederation.optimizer.converter.statement.select.SelectStatementConverter;

import java.util.Collection;
Expand All @@ -48,21 +49,20 @@ public final class WithConverter {
* @return sql node list
*/
public static Optional<SqlNode> convert(final WithSegment withSegment, final SqlNode sqlNode) {
SqlIdentifier name = new SqlIdentifier(withSegment.getCommonTableExpressions().iterator().next().getIdentifier().getValue(), SqlParserPos.ZERO);
SqlNode selectSubquery = new SelectStatementConverter().convert(withSegment.getCommonTableExpressions().iterator().next().getSubquery().getSelect());
Collection<ColumnSegment> collectionColumns = withSegment.getCommonTableExpressions().iterator().next().getColumns();
Collection<SqlNode> convertedColumns;
SqlNodeList columns = null;
if (!collectionColumns.isEmpty()) {
convertedColumns = collectionColumns.stream().map(ExpressionConverter::convert).filter(Optional::isPresent).map(Optional::get).collect(Collectors.toList());
columns = new SqlNodeList(convertedColumns, SqlParserPos.ZERO);
}
SqlWithItem sqlWithItem = new SqlWithItem(SqlParserPos.ZERO, name, columns, selectSubquery);
SqlNodeList sqlWithItems = new SqlNodeList(SqlParserPos.ZERO);
sqlWithItems.add(sqlWithItem);
SqlWith sqlWith = new SqlWith(SqlParserPos.ZERO, sqlWithItems, sqlNode);
return Optional.of(new SqlWith(SqlParserPos.ZERO, convertWithItem(withSegment.getCommonTableExpressions()), sqlNode));
}

private static SqlNodeList convertWithItem(final Collection<CommonTableExpressionSegment> commonTableExpressionSegments) {
SqlNodeList result = new SqlNodeList(SqlParserPos.ZERO);
result.add(sqlWith);
return Optional.of(result);
for (CommonTableExpressionSegment each : commonTableExpressionSegments) {
SqlIdentifier name = new SqlIdentifier(each.getIdentifier().getValue(), SqlParserPos.ZERO);
SqlNodeList columns = each.getColumns().isEmpty() ? null : convertColumns(each.getColumns());
result.add(new SqlWithItem(SqlParserPos.ZERO, name, columns, new SelectStatementConverter().convert(each.getSubquery().getSelect())));
}
return result;
}

private static SqlNodeList convertColumns(final Collection<ColumnSegment> columnSegments) {
return new SqlNodeList(columnSegments.stream().map(each -> ColumnConverter.convert(each).orElseThrow(IllegalStateException::new)).collect(Collectors.toList()), SqlParserPos.ZERO);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@
import org.apache.shardingsphere.sqlfederation.optimizer.converter.segment.projection.ProjectionsConverter;
import org.apache.shardingsphere.sqlfederation.optimizer.converter.segment.where.WhereConverter;
import org.apache.shardingsphere.sqlfederation.optimizer.converter.segment.window.WindowConverter;
import org.apache.shardingsphere.sqlfederation.optimizer.converter.segment.with.WithConverter;
import org.apache.shardingsphere.sqlfederation.optimizer.converter.statement.SQLStatementConverter;
import org.apache.shardingsphere.sqlfederation.optimizer.converter.type.CombineOperatorConverter;

Expand All @@ -50,7 +51,8 @@ public final class SelectStatementConverter implements SQLStatementConverter<Sel
@Override
public SqlNode convert(final SelectStatement selectStatement) {
SqlSelect sqlSelect = convertSelect(selectStatement);
SqlNode sqlCombine = convertCombine(sqlSelect, selectStatement);
SqlNode sqlWith = convertWith(sqlSelect, selectStatement);
SqlNode sqlCombine = convertCombine(null != sqlWith ? sqlWith : sqlSelect, selectStatement);
SqlNodeList orderBy = selectStatement.getOrderBy().flatMap(OrderByConverter::convert).orElse(SqlNodeList.EMPTY);
Optional<LimitSegment> limit = SelectStatementHandler.getLimitSegment(selectStatement);
if (limit.isPresent()) {
Expand All @@ -61,6 +63,10 @@ public SqlNode convert(final SelectStatement selectStatement) {
return orderBy.isEmpty() ? sqlCombine : new SqlOrderBy(SqlParserPos.ZERO, sqlCombine, orderBy, null, null);
}

private SqlNode convertWith(final SqlNode sqlSelect, final SelectStatement selectStatement) {
return SelectStatementHandler.getWithSegment(selectStatement).flatMap(segment -> WithConverter.convert(segment, sqlSelect)).orElse(null);
}

private SqlSelect convertSelect(final SelectStatement selectStatement) {
SqlNodeList distinct = DistinctConverter.convert(selectStatement.getProjections()).orElse(null);
SqlNodeList projection = ProjectionsConverter.convert(selectStatement.getProjections()).orElseThrow(IllegalStateException::new);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -436,4 +436,16 @@
<test-case sql="SELECT * FROM multi_types_first first JOIN multi_types_second second ON first.id = second.id WHERE second.long_varchar_column = '1'">
<assertion expected-result="EnumerableHashJoin(condition=[=($0, $22)], joinType=[inner]) EnumerableScan(table=[[federate_jdbc, multi_types_first]], sql=[SELECT * FROM `federate_jdbc`.`multi_types_first`], dynamicParameters=[null]) EnumerableScan(table=[[federate_jdbc, multi_types_second]], sql=[SELECT * FROM `federate_jdbc`.`multi_types_second` WHERE `long_varchar_column` = '1'], dynamicParameters=[null]) " />
</test-case>

<test-case sql="WITH cte AS (SELECT 1 AS col1, 2 AS col2 UNION ALL SELECT 3, 4) SELECT col1, col2 FROM cte">
<assertion expected-result="EnumerableUnion(all=[true]) EnumerableValues(tuples=[[{ 1, 2 }]]) EnumerableValues(tuples=[[{ 3, 4 }]]) " />
</test-case>

<test-case sql="WITH cte1(col1, col2, col3) AS (SELECT id, bit_column, tiny_int_column FROM multi_types_first), cte2(col1, col2, col3) AS (SELECT id, bit_column, tiny_int_column FROM multi_types_second) SELECT * FROM cte1 inner join cte2 on cte1.col1 = cte2.col1">
<assertion expected-result="EnumerableHashJoin(condition=[=($0, $3)], joinType=[inner]) EnumerableScan(table=[[federate_jdbc, multi_types_first]], sql=[SELECT `id`, `bit_column`, `tiny_int_column` FROM `federate_jdbc`.`multi_types_first`], dynamicParameters=[null]) EnumerableScan(table=[[federate_jdbc, multi_types_second]], sql=[SELECT `id`, `bit_column`, `tiny_int_column` FROM `federate_jdbc`.`multi_types_second`], dynamicParameters=[null]) " />
</test-case>

<test-case sql="WITH cte1(col1, col2, col3) AS (SELECT 1, 2, 3 UNION ALL SELECT 4, 5, 6), cte2(col1, col2, col3) AS (SELECT 1, 2, 3 UNION ALL SELECT 4, 5, 6) SELECT cte1.* FROM cte1 inner join cte2 on cte1.col1 = cte2.col1 WHERE cte1.col1 = 1">
<assertion expected-result="EnumerableCalc(expr#0..3=[{inputs}], proj#0..2=[{exprs}]) EnumerableHashJoin(condition=[=($0, $3)], joinType=[inner]) EnumerableUnion(all=[true]) EnumerableCalc(expr#0..2=[{inputs}], expr#3=[1], expr#4=[=($t0, $t3)], proj#0..2=[{exprs}], $condition=[$t4]) EnumerableValues(tuples=[[{ 1, 2, 3 }]]) EnumerableCalc(expr#0..2=[{inputs}], expr#3=[1], expr#4=[=($t0, $t3)], proj#0..2=[{exprs}], $condition=[$t4]) EnumerableValues(tuples=[[{ 4, 5, 6 }]]) EnumerableUnion(all=[true]) EnumerableCalc(expr#0..2=[{inputs}], expr#3=[1], expr#4=[=($t0, $t3)], EXPR$0=[$t0], $condition=[$t4]) EnumerableValues(tuples=[[{ 1, 2, 3 }]]) EnumerableCalc(expr#0..2=[{inputs}], expr#3=[1], expr#4=[=($t0, $t3)], EXPR$0=[$t0], $condition=[$t4]) EnumerableValues(tuples=[[{ 4, 5, 6 }]]) " />
</test-case>
</test-cases>
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@
import org.apache.shardingsphere.sql.parser.autogen.MySQLStatementParser.ConstraintNameContext;
import org.apache.shardingsphere.sql.parser.autogen.MySQLStatementParser.ConvertFunctionContext;
import org.apache.shardingsphere.sql.parser.autogen.MySQLStatementParser.CurrentUserFunctionContext;
import org.apache.shardingsphere.sql.parser.autogen.MySQLStatementParser.CteClauseContext;
import org.apache.shardingsphere.sql.parser.autogen.MySQLStatementParser.DataTypeContext;
import org.apache.shardingsphere.sql.parser.autogen.MySQLStatementParser.DeleteContext;
import org.apache.shardingsphere.sql.parser.autogen.MySQLStatementParser.DuplicateSpecificationContext;
Expand Down Expand Up @@ -147,6 +148,7 @@
import org.apache.shardingsphere.sql.parser.autogen.MySQLStatementParser.WindowFunctionContext;
import org.apache.shardingsphere.sql.parser.autogen.MySQLStatementParser.WindowItemContext;
import org.apache.shardingsphere.sql.parser.autogen.MySQLStatementParser.WindowSpecificationContext;
import org.apache.shardingsphere.sql.parser.autogen.MySQLStatementParser.WithClauseContext;
import org.apache.shardingsphere.sql.parser.sql.common.enums.AggregationType;
import org.apache.shardingsphere.sql.parser.sql.common.enums.CombineType;
import org.apache.shardingsphere.sql.parser.sql.common.enums.JoinType;
Expand Down Expand Up @@ -181,6 +183,7 @@
import org.apache.shardingsphere.sql.parser.sql.common.segment.dml.expr.UnaryOperationExpression;
import org.apache.shardingsphere.sql.parser.sql.common.segment.dml.expr.ValuesExpression;
import org.apache.shardingsphere.sql.parser.sql.common.segment.dml.expr.complex.CommonExpressionSegment;
import org.apache.shardingsphere.sql.parser.sql.common.segment.dml.expr.complex.CommonTableExpressionSegment;
import org.apache.shardingsphere.sql.parser.sql.common.segment.dml.expr.simple.LiteralExpressionSegment;
import org.apache.shardingsphere.sql.parser.sql.common.segment.dml.expr.simple.ParameterMarkerExpressionSegment;
import org.apache.shardingsphere.sql.parser.sql.common.segment.dml.expr.simple.SimpleExpressionSegment;
Expand Down Expand Up @@ -220,6 +223,7 @@
import org.apache.shardingsphere.sql.parser.sql.common.segment.generic.table.SubqueryTableSegment;
import org.apache.shardingsphere.sql.parser.sql.common.segment.generic.table.TableNameSegment;
import org.apache.shardingsphere.sql.parser.sql.common.segment.generic.table.TableSegment;
import org.apache.shardingsphere.sql.parser.sql.common.segment.generic.WithSegment;
import org.apache.shardingsphere.sql.parser.sql.common.util.SQLUtils;
import org.apache.shardingsphere.sql.parser.sql.common.value.collection.CollectionValue;
import org.apache.shardingsphere.sql.parser.sql.common.value.identifier.IdentifierValue;
Expand Down Expand Up @@ -712,6 +716,9 @@ public ASTNode visitQueryExpression(final QueryExpressionContext ctx) {
if (null != ctx.limitClause()) {
result.setLimit((LimitSegment) visit(ctx.limitClause()));
}
if (null != result && null != ctx.withClause()) {
result.setWithSegment((WithSegment) visit(ctx.withClause()));
}
return result;
}

Expand All @@ -727,6 +734,27 @@ public ASTNode visitSelectWithInto(final SelectWithIntoContext ctx) {
return result;
}

@Override
public ASTNode visitWithClause(final WithClauseContext ctx) {
Collection<CommonTableExpressionSegment> commonTableExpressions = new LinkedList<>();
for (CteClauseContext each : ctx.cteClause()) {
commonTableExpressions.add((CommonTableExpressionSegment) visit(each));
}
return new WithSegment(ctx.start.getStartIndex(), ctx.stop.getStopIndex(), commonTableExpressions);
}

@SuppressWarnings("unchecked")
@Override
public ASTNode visitCteClause(final CteClauseContext ctx) {
CommonTableExpressionSegment result = new CommonTableExpressionSegment(ctx.getStart().getStartIndex(), ctx.getStop().getStopIndex(), (IdentifierValue) visit(ctx.identifier()),
new SubquerySegment(ctx.getStart().getStartIndex(), ctx.getStop().getStopIndex(), (MySQLSelectStatement) visit(ctx.subquery()), getOriginalText(ctx.subquery())));
if (null != ctx.columnNames()) {
CollectionValue<ColumnSegment> columns = (CollectionValue<ColumnSegment>) visit(ctx.columnNames());
result.getColumns().addAll(columns.getValue());
}
return result;
}

@Override
public ASTNode visitQueryExpressionBody(final QueryExpressionBodyContext ctx) {
if (1 == ctx.getChildCount() && ctx.getChild(0) instanceof QueryPrimaryContext) {
Expand Down Expand Up @@ -1592,7 +1620,6 @@ private List<SimpleTableSegment> generateTablesFromTableAliasRefList(final Table

@Override
public ASTNode visitSelect(final SelectContext ctx) {
// TODO :Unsupported for withClause.
MySQLSelectStatement result;
if (null != ctx.queryExpression()) {
result = (MySQLSelectStatement) visit(ctx.queryExpression());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -183,6 +183,9 @@ public static Optional<WithSegment> getWithSegment(final SelectStatement selectS
if (selectStatement instanceof SQLServerSelectStatement) {
return ((SQLServerSelectStatement) selectStatement).getWithSegment();
}
if (selectStatement instanceof MySQLSelectStatement) {
return ((MySQLSelectStatement) selectStatement).getWithSegment();
}
return Optional.empty();
}

Expand All @@ -199,6 +202,9 @@ public static void setWithSegment(final SelectStatement selectStatement, final W
if (selectStatement instanceof SQLServerSelectStatement) {
((SQLServerSelectStatement) selectStatement).setWithSegment(withSegment);
}
if (selectStatement instanceof MySQLSelectStatement) {
((MySQLSelectStatement) selectStatement).setWithSegment(withSegment);
}
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
import org.apache.shardingsphere.sql.parser.sql.common.segment.dml.pagination.limit.LimitSegment;
import org.apache.shardingsphere.sql.parser.sql.common.segment.dml.predicate.LockSegment;
import org.apache.shardingsphere.sql.parser.sql.common.segment.generic.WindowSegment;
import org.apache.shardingsphere.sql.parser.sql.common.segment.generic.WithSegment;
import org.apache.shardingsphere.sql.parser.sql.common.segment.generic.table.SimpleTableSegment;
import org.apache.shardingsphere.sql.parser.sql.common.statement.dml.SelectStatement;
import org.apache.shardingsphere.sql.parser.sql.dialect.statement.mysql.MySQLStatement;
Expand All @@ -41,6 +42,8 @@ public final class MySQLSelectStatement extends SelectStatement implements MySQL

private WindowSegment window;

private WithSegment withSegment;

/**
* Get order by segment.
*
Expand Down Expand Up @@ -76,4 +79,13 @@ public Optional<WindowSegment> getWindow() {
public Optional<SimpleTableSegment> getTable() {
return Optional.ofNullable(table);
}

/**
* Get with segment.
*
* @return with segment.
*/
public Optional<WithSegment> getWithSegment() {
return Optional.ofNullable(withSegment);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -187,6 +187,16 @@ void assertGetWithSegmentForOracle() {
assertFalse(SelectStatementHandler.getWithSegment(new OracleSelectStatement()).isPresent());
}

@Test
void assertGetWithSegmentForMysql() {
MySQLSelectStatement selectStatement = new MySQLSelectStatement();
selectStatement.setWithSegment(new WithSegment(0, 2, new LinkedList<>()));
Optional<WithSegment> withSegment = SelectStatementHandler.getWithSegment(selectStatement);
assertTrue(withSegment.isPresent());
assertThat(withSegment.get(), is(selectStatement.getWithSegment().get()));
assertFalse(SelectStatementHandler.getWithSegment(new MySQLSelectStatement()).isPresent());
}

@Test
void assertGetWithSegmentForSQLServer() {
SQLServerSelectStatement selectStatement = new SQLServerSelectStatement();
Expand Down
4 changes: 2 additions & 2 deletions test/it/optimizer/src/test/resources/converter/delete.xml
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,6 @@
<test-cases sql-case-id="delete_with_alias" expected-sql="DELETE FROM [t_order] AS [o] AS [o] WHERE [status] = 'init'" db-types="SQLServer" sql-case-types="LITERAL" />
<test-cases sql-case-id="delete_with_alias" expected-sql="DELETE FROM `t_order` AS `o` AS `o` WHERE `status` = ?" db-types="MySQL" sql-case-types="PLACEHOLDER" />
<test-cases sql-case-id="delete_with_alias" expected-sql="DELETE FROM [t_order] AS [o] AS [o] WHERE [status] = ?" db-types="SQLServer" sql-case-types="PLACEHOLDER" />
<test-cases sql-case-id="delete_with_with_clause" expected-sql="(WITH [cte] ([order_id], [user_id]) AS (SELECT [order_id], [user_id] FROM [t_order]) DELETE FROM ([cte], [t_order]) WHERE [t_order].[order_id] = [cte].[order_id])" db-types="SQLServer" />
<test-cases sql-case-id="delete_without_columns_with_with_clause" expected-sql="(WITH [cte] AS (SELECT [order_id], [user_id] FROM [t_order]) DELETE FROM ([cte], [t_order]) WHERE [t_order].[order_id] = [cte].[order_id])" db-types="SQLServer" />
<test-cases sql-case-id="delete_with_with_clause" expected-sql="WITH [cte] ([order_id], [user_id]) AS (SELECT [order_id], [user_id] FROM [t_order]) DELETE FROM ([cte], [t_order]) WHERE [t_order].[order_id] = [cte].[order_id]" db-types="SQLServer" />
<test-cases sql-case-id="delete_without_columns_with_with_clause" expected-sql="WITH [cte] AS (SELECT [order_id], [user_id] FROM [t_order]) DELETE FROM ([cte], [t_order]) WHERE [t_order].[order_id] = [cte].[order_id]" db-types="SQLServer" />
</sql-node-converter-test-cases>
24 changes: 24 additions & 0 deletions test/it/optimizer/src/test/resources/converter/select-with.xml
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
<?xml version="1.0" encoding="UTF-8"?>
<!--
~ Licensed to the Apache Software Foundation (ASF) under one or more
~ contributor license agreements. See the NOTICE file distributed with
~ this work for additional information regarding copyright ownership.
~ The ASF licenses this file to You under the Apache License, Version 2.0
~ (the "License"); you may not use this file except in compliance with
~ the License. You may obtain a copy of the License at
~
~ http://www.apache.org/licenses/LICENSE-2.0
~
~ Unless required by applicable law or agreed to in writing, software
~ distributed under the License is distributed on an "AS IS" BASIS,
~ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
~ See the License for the specific language governing permissions and
~ limitations under the License.
-->

<sql-node-converter-test-cases>
<test-cases sql-case-id="select_with_single_subquery" expected-sql="WITH `t` AS (SELECT `a` + 2 AS `c`, `b` FROM `t1`) SELECT `c`, `b` FROM `t`" db-types="MySQL" />
<test-cases sql-case-id="select_with_multiple_subquery" expected-sql="WITH `cte1` (`col1`, `col2`, `col3`) AS (SELECT `emp_no`, `first_name`, `last_name` FROM `employees` WHERE `emp_no` = 10012), `cte2` (`col1`, `col2`, `col3`) AS (SELECT `emp_no`, `first_name`, `last_name` FROM `employees` WHERE `emp_no` = 10012) SELECT `col1`, `col2`, `col3` FROM `cte1`" db-types="MySQL" />
<test-cases sql-case-id="select_with_recursive_union_all1" expected-sql="WITH `DirectoryCTE` AS (SELECT * FROM `table1` WHERE `id` = 1 AND `project_id` = 2 UNION ALL SELECT * FROM `project_file_catalog` AS `t` INNER JOIN `DirectoryCTE` AS `cte` ON `t`.`project_id` = `cte`.`project_id` AND `t`.`parent_id` = `cte`.`id`) SELECT * FROM `DirectoryCTE` ORDER BY `level`" db-types="MySQL" />
<test-cases sql-case-id="select_with_recursive_union_all2" expected-sql="WITH `cte` AS (SELECT 1 AS `col1`, 2 AS `col2` UNION ALL SELECT 3, 4) SELECT `col1`, `col2` FROM `cte`" db-types="MySQL" />
</sql-node-converter-test-cases>
Loading

0 comments on commit c0ce962

Please sign in to comment.