Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Remove TablesContext#findTableNames method and implement select order by, group by bind logic #34123

Merged
merged 5 commits into from
Dec 24, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions RELEASE-NOTES.md
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@
1. SQL Binder: Add sql bind logic for create table statement - [#34074](https://github.com/apache/shardingsphere/pull/34074)
1. SQL Binder: Support create index statement sql bind - [#34112](https://github.com/apache/shardingsphere/pull/34112)
1. SQL Parser: Support MySQL update with statement parse - [#34126](https://github.com/apache/shardingsphere/pull/34126)
1. SQL Binder: Remove TablesContext#findTableNames method and implement select order by, group by bind logic - [#34123](https://github.com/apache/shardingsphere/pull/34123)

### Bug Fixes

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,9 +32,7 @@
import org.apache.shardingsphere.sql.parser.statement.core.segment.dml.order.item.ColumnOrderByItemSegment;

import java.util.Collection;
import java.util.Collections;
import java.util.LinkedList;
import java.util.Map;
import java.util.Optional;

/**
Expand Down Expand Up @@ -64,7 +62,7 @@ private boolean containsOrderByItem(final SelectStatementContext sqlStatementCon
public void check(final EncryptRule rule, final ShardingSphereDatabase database, final ShardingSphereSchema currentSchema, final SelectStatementContext sqlStatementContext) {
for (OrderByItem each : getOrderByItems(sqlStatementContext)) {
if (each.getSegment() instanceof ColumnOrderByItemSegment) {
checkColumnOrderByItem(rule, currentSchema, sqlStatementContext, ((ColumnOrderByItemSegment) each.getSegment()).getColumn());
checkColumnOrderByItem(rule, ((ColumnOrderByItemSegment) each.getSegment()).getColumn());
}
}
}
Expand All @@ -80,10 +78,8 @@ private Collection<OrderByItem> getOrderByItems(final SelectStatementContext sql
return result;
}

private void checkColumnOrderByItem(final EncryptRule rule, final ShardingSphereSchema schema, final SelectStatementContext sqlStatementContext, final ColumnSegment columnSegment) {
Map<String, String> columnTableNames = sqlStatementContext.getTablesContext().findTableNames(Collections.singleton(columnSegment), schema);
String tableName = columnTableNames.getOrDefault(columnSegment.getExpression(), "");
Optional<EncryptTable> encryptTable = rule.findEncryptTable(tableName);
private void checkColumnOrderByItem(final EncryptRule rule, final ColumnSegment columnSegment) {
Optional<EncryptTable> encryptTable = rule.findEncryptTable(columnSegment.getColumnBoundInfo().getOriginalTable().getValue());
String columnName = columnSegment.getIdentifier().getValue();
ShardingSpherePreconditions.checkState(!encryptTable.isPresent() || !encryptTable.get().isEncryptColumn(columnName), () -> new UnsupportedEncryptSQLException("ORDER BY"));
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,6 @@
import org.apache.shardingsphere.infra.binder.context.extractor.SQLStatementContextExtractor;
import org.apache.shardingsphere.infra.binder.context.statement.SQLStatementContext;
import org.apache.shardingsphere.infra.binder.context.statement.dml.SelectStatementContext;
import org.apache.shardingsphere.infra.binder.context.type.TableAvailable;
import org.apache.shardingsphere.infra.binder.context.type.WhereAvailable;
import org.apache.shardingsphere.infra.checker.SupportedSQLChecker;
import org.apache.shardingsphere.infra.exception.core.ShardingSpherePreconditions;
Expand All @@ -40,7 +39,6 @@
import org.apache.shardingsphere.sql.parser.statement.core.segment.dml.predicate.WhereSegment;

import java.util.Collection;
import java.util.Map;
import java.util.Optional;

/**
Expand All @@ -60,13 +58,12 @@ public void check(final EncryptRule rule, final ShardingSphereDatabase database,
Collection<BinaryOperationExpression> joinConditions = SQLStatementContextExtractor.getJoinConditions((WhereAvailable) sqlStatementContext, allSubqueryContexts);
ShardingSpherePreconditions.checkState(JoinConditionsEncryptorComparator.isSame(joinConditions, rule),
() -> new UnsupportedSQLOperationException("Can not use different encryptor in join condition"));
check(rule, currentSchema, (WhereAvailable) sqlStatementContext);
check(rule, (WhereAvailable) sqlStatementContext);
}

private void check(final EncryptRule rule, final ShardingSphereSchema schema, final WhereAvailable sqlStatementContext) {
Map<String, String> columnExpressionTableNames = ((TableAvailable) sqlStatementContext).getTablesContext().findTableNames(sqlStatementContext.getColumnSegments(), schema);
private void check(final EncryptRule rule, final WhereAvailable sqlStatementContext) {
for (ColumnSegment each : sqlStatementContext.getColumnSegments()) {
Optional<EncryptTable> encryptTable = rule.findEncryptTable(columnExpressionTableNames.getOrDefault(each.getExpression(), ""));
Optional<EncryptTable> encryptTable = rule.findEncryptTable(each.getColumnBoundInfo().getOriginalTable().getValue());
String columnName = each.getIdentifier().getValue();
if (encryptTable.isPresent() && encryptTable.get().isEncryptColumn(columnName) && includesLike(sqlStatementContext.getWhereSegments(), each)) {
String tableName = encryptTable.get().getTable();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,12 +24,8 @@
import org.apache.shardingsphere.encrypt.rule.EncryptRule;
import org.apache.shardingsphere.encrypt.rule.table.EncryptTable;
import org.apache.shardingsphere.infra.annotation.HighFrequencyInvocation;
import org.apache.shardingsphere.infra.binder.context.statement.SQLStatementContext;
import org.apache.shardingsphere.infra.binder.context.type.TableAvailable;
import org.apache.shardingsphere.infra.database.core.type.DatabaseTypeRegistry;
import org.apache.shardingsphere.infra.exception.core.ShardingSpherePreconditions;
import org.apache.shardingsphere.infra.metadata.database.ShardingSphereDatabase;
import org.apache.shardingsphere.infra.metadata.database.schema.model.ShardingSphereSchema;
import org.apache.shardingsphere.sql.parser.statement.core.extractor.ColumnExtractor;
import org.apache.shardingsphere.sql.parser.statement.core.extractor.ExpressionExtractor;
import org.apache.shardingsphere.sql.parser.statement.core.segment.dml.column.ColumnSegment;
Expand All @@ -43,13 +39,11 @@
import org.apache.shardingsphere.sql.parser.statement.core.segment.dml.expr.subquery.SubqueryExpressionSegment;
import org.apache.shardingsphere.sql.parser.statement.core.segment.dml.predicate.AndPredicate;
import org.apache.shardingsphere.sql.parser.statement.core.segment.dml.predicate.WhereSegment;
import org.apache.shardingsphere.sql.parser.statement.core.segment.generic.bound.ColumnSegmentBoundInfo;

import java.util.Collection;
import java.util.HashSet;
import java.util.LinkedList;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.Set;
import java.util.TreeSet;
Expand Down Expand Up @@ -89,44 +83,36 @@ public final class EncryptConditionEngine {
* Create encrypt conditions.
*
* @param whereSegments where segments
* @param columnSegments column segments
* @param sqlStatementContext SQL statement context
* @param databaseName database name
* @return encrypt conditions
*/
public Collection<EncryptCondition> createEncryptConditions(final Collection<WhereSegment> whereSegments, final Collection<ColumnSegment> columnSegments,
final SQLStatementContext sqlStatementContext, final String databaseName) {
public Collection<EncryptCondition> createEncryptConditions(final Collection<WhereSegment> whereSegments) {
Collection<EncryptCondition> result = new LinkedList<>();
String defaultSchema = new DatabaseTypeRegistry(sqlStatementContext.getDatabaseType()).getDefaultSchemaName(databaseName);
ShardingSphereSchema schema = ((TableAvailable) sqlStatementContext).getTablesContext().getSchemaName().map(database::getSchema).orElseGet(() -> database.getSchema(defaultSchema));
Map<String, String> expressionTableNames = ((TableAvailable) sqlStatementContext).getTablesContext().findTableNames(columnSegments, schema);
for (WhereSegment each : whereSegments) {
Collection<AndPredicate> andPredicates = ExpressionExtractor.extractAndPredicates(each.getExpr());
for (AndPredicate predicate : andPredicates) {
addEncryptConditions(result, predicate.getPredicates(), expressionTableNames);
addEncryptConditions(result, predicate.getPredicates());
}
}
return result;
}

private void addEncryptConditions(final Collection<EncryptCondition> encryptConditions, final Collection<ExpressionSegment> predicates, final Map<String, String> expressionTableNames) {
private void addEncryptConditions(final Collection<EncryptCondition> encryptConditions, final Collection<ExpressionSegment> predicates) {
Collection<Integer> stopIndexes = new HashSet<>(predicates.size(), 1F);
for (ExpressionSegment each : predicates) {
if (stopIndexes.add(each.getStopIndex())) {
addEncryptConditions(encryptConditions, each, expressionTableNames);
addEncryptConditions(encryptConditions, each);
}
}
}

private void addEncryptConditions(final Collection<EncryptCondition> encryptConditions, final ExpressionSegment expression, final Map<String, String> expressionTableNames) {
private void addEncryptConditions(final Collection<EncryptCondition> encryptConditions, final ExpressionSegment expression) {
if (!findNotContainsNullLiteralsExpression(expression).isPresent()) {
return;
}
for (ColumnSegment each : ColumnExtractor.extract(expression)) {
ColumnSegmentBoundInfo columnBoundInfo = each.getColumnBoundInfo();
String tableName = columnBoundInfo.getOriginalTable().getValue().isEmpty() ? expressionTableNames.getOrDefault(each.getExpression(), "") : columnBoundInfo.getOriginalTable().getValue();
String tableName = each.getColumnBoundInfo().getOriginalTable().getValue();
Optional<EncryptTable> encryptTable = rule.findEncryptTable(tableName);
if (encryptTable.isPresent() && encryptTable.get().isEncryptColumn(each.getIdentifier().getValue())) {
if (encryptTable.isPresent() && encryptTable.get().isEncryptColumn(each.getColumnBoundInfo().getOriginalColumn().getValue())) {
createEncryptCondition(expression, tableName).ifPresent(encryptConditions::add);
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,6 @@
import org.apache.shardingsphere.infra.rewrite.parameter.rewriter.ParameterRewritersBuilder;
import org.apache.shardingsphere.infra.rewrite.sql.token.common.generator.builder.SQLTokenGeneratorBuilder;
import org.apache.shardingsphere.infra.route.context.RouteContext;
import org.apache.shardingsphere.sql.parser.statement.core.segment.dml.column.ColumnSegment;
import org.apache.shardingsphere.sql.parser.statement.core.segment.dml.predicate.WhereSegment;

import java.util.Collection;
Expand Down Expand Up @@ -93,9 +92,7 @@ private Collection<EncryptCondition> createEncryptConditions(final EncryptRule r
}
Collection<SelectStatementContext> allSubqueryContexts = SQLStatementContextExtractor.getAllSubqueryContexts(sqlStatementContext);
Collection<WhereSegment> whereSegments = SQLStatementContextExtractor.getWhereSegments((WhereAvailable) sqlStatementContext, allSubqueryContexts);
Collection<ColumnSegment> columnSegments = SQLStatementContextExtractor.getColumnSegments((WhereAvailable) sqlStatementContext, allSubqueryContexts);
return new EncryptConditionEngine(rule, sqlRewriteContext.getDatabase()).createEncryptConditions(whereSegments, columnSegments, sqlStatementContext,
sqlRewriteContext.getDatabase().getName());
return new EncryptConditionEngine(rule, sqlRewriteContext.getDatabase()).createEncryptConditions(whereSegments);
}

private void rewriteParameters(final SQLRewriteContext sqlRewriteContext, final Collection<ParameterRewriter> parameterRewriters) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,9 +30,7 @@
import org.apache.shardingsphere.infra.binder.context.statement.dml.SelectStatementContext;
import org.apache.shardingsphere.infra.database.core.metadata.database.enums.QuoteCharacter;
import org.apache.shardingsphere.infra.database.core.type.DatabaseType;
import org.apache.shardingsphere.infra.metadata.database.schema.model.ShardingSphereSchema;
import org.apache.shardingsphere.infra.rewrite.sql.token.common.generator.CollectionSQLTokenGenerator;
import org.apache.shardingsphere.infra.rewrite.sql.token.common.generator.aware.SchemaMetaDataAware;
import org.apache.shardingsphere.infra.rewrite.sql.token.common.pojo.SQLToken;
import org.apache.shardingsphere.infra.rewrite.sql.token.common.pojo.generic.SubstitutableColumnNameToken;
import org.apache.shardingsphere.sql.parser.statement.core.segment.dml.column.ColumnSegment;
Expand All @@ -42,7 +40,6 @@
import java.util.Collection;
import java.util.Collections;
import java.util.LinkedList;
import java.util.Map;
import java.util.Optional;

/**
Expand All @@ -51,14 +48,10 @@
@HighFrequencyInvocation
@RequiredArgsConstructor
@Setter
public final class EncryptGroupByItemTokenGenerator implements CollectionSQLTokenGenerator<SelectStatementContext>, SchemaMetaDataAware {
public final class EncryptGroupByItemTokenGenerator implements CollectionSQLTokenGenerator<SelectStatementContext> {

private final EncryptRule rule;

private Map<String, ShardingSphereSchema> schemas;

private ShardingSphereSchema defaultSchema;

@Override
public boolean isGenerateSQLToken(final SQLStatementContext sqlStatementContext) {
return sqlStatementContext instanceof SelectStatementContext && containsGroupByItem((SelectStatementContext) sqlStatementContext);
Expand All @@ -79,21 +72,18 @@ private boolean containsGroupByItem(final SelectStatementContext sqlStatementCon
@Override
public Collection<SQLToken> generateSQLTokens(final SelectStatementContext sqlStatementContext) {
Collection<SQLToken> result = new LinkedList<>();
ShardingSphereSchema schema = sqlStatementContext.getTablesContext().getSchemaName().map(schemas::get).orElseGet(() -> defaultSchema);
for (OrderByItem each : getGroupByItems(sqlStatementContext)) {
if (each.getSegment() instanceof ColumnOrderByItemSegment) {
ColumnSegment columnSegment = ((ColumnOrderByItemSegment) each.getSegment()).getColumn();
Map<String, String> columnTableNames = sqlStatementContext.getTablesContext().findTableNames(Collections.singleton(columnSegment), schema);
generateSQLToken(columnSegment, columnTableNames, sqlStatementContext.getDatabaseType()).ifPresent(result::add);
generateSQLToken(columnSegment, sqlStatementContext.getDatabaseType()).ifPresent(result::add);
}
}
return result;
}

private Optional<SubstitutableColumnNameToken> generateSQLToken(final ColumnSegment columnSegment, final Map<String, String> columnTableNames, final DatabaseType databaseType) {
String tableName = columnTableNames.getOrDefault(columnSegment.getExpression(), "");
Optional<EncryptTable> encryptTable = rule.findEncryptTable(tableName);
String columnName = columnSegment.getIdentifier().getValue();
private Optional<SubstitutableColumnNameToken> generateSQLToken(final ColumnSegment columnSegment, final DatabaseType databaseType) {
Optional<EncryptTable> encryptTable = rule.findEncryptTable(columnSegment.getColumnBoundInfo().getOriginalTable().getValue());
String columnName = columnSegment.getColumnBoundInfo().getOriginalColumn().getValue();
if (!encryptTable.isPresent() || !encryptTable.get().isEncryptColumn(columnName)) {
return Optional.empty();
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,8 @@
import org.apache.shardingsphere.sql.parser.statement.core.segment.dml.order.item.ColumnOrderByItemSegment;
import org.apache.shardingsphere.sql.parser.statement.core.segment.generic.AliasSegment;
import org.apache.shardingsphere.sql.parser.statement.core.segment.generic.OwnerSegment;
import org.apache.shardingsphere.sql.parser.statement.core.segment.generic.bound.ColumnSegmentBoundInfo;
import org.apache.shardingsphere.sql.parser.statement.core.segment.generic.bound.TableSegmentBoundInfo;
import org.apache.shardingsphere.sql.parser.statement.core.segment.generic.table.SimpleTableSegment;
import org.apache.shardingsphere.sql.parser.statement.core.segment.generic.table.TableNameSegment;
import org.apache.shardingsphere.sql.parser.statement.core.value.identifier.IdentifierValue;
Expand Down Expand Up @@ -116,6 +118,8 @@ private SelectStatementContext mockSelectStatementContext(final String tableName
simpleTableSegment.setAlias(new AliasSegment(0, 0, new IdentifierValue("a")));
ColumnSegment columnSegment = new ColumnSegment(0, 0, new IdentifierValue("foo_col"));
columnSegment.setOwner(new OwnerSegment(0, 0, new IdentifierValue("a")));
columnSegment.setColumnBoundInfo(
new ColumnSegmentBoundInfo(new TableSegmentBoundInfo(new IdentifierValue("foo_db"), new IdentifierValue("foo_db")), new IdentifierValue(tableName), new IdentifierValue("foo_col")));
SelectStatementContext result = mock(SelectStatementContext.class, RETURNS_DEEP_STUBS);
when(result.getDatabaseType()).thenReturn(databaseType);
ColumnOrderByItemSegment columnOrderByItemSegment = new ColumnOrderByItemSegment(columnSegment, OrderDirection.ASC, NullsOrderType.FIRST);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,6 @@ private SQLStatementContext mockSelectStatementContextWithLike() {
columnSegment.setColumnBoundInfo(new ColumnSegmentBoundInfo(new TableSegmentBoundInfo(new IdentifierValue("foo_db"), new IdentifierValue("foo_schema")), new IdentifierValue("t_user"),
new IdentifierValue("user_name")));
SelectStatementContext result = mock(SelectStatementContext.class, RETURNS_DEEP_STUBS);
when(result.getTablesContext().findTableNames(Collections.singleton(columnSegment), null)).thenReturn(Collections.singletonMap("user_name", "t_user"));
when(result.getColumnSegments()).thenReturn(Collections.singleton(columnSegment));
when(result.getWhereSegments()).thenReturn(Collections.singleton(new WhereSegment(0, 0, new BinaryOperationExpression(0, 0, columnSegment, columnSegment, "LIKE", ""))));
when(result.getSubqueryContexts()).thenReturn(Collections.emptyMap());
Expand All @@ -106,7 +105,6 @@ private SQLStatementContext mockSelectStatementContextWithEqual() {
columnSegment.setColumnBoundInfo(new ColumnSegmentBoundInfo(new TableSegmentBoundInfo(new IdentifierValue("foo_db"), new IdentifierValue("foo_schema")), new IdentifierValue("t_user"),
new IdentifierValue("user_name")));
SelectStatementContext result = mock(SelectStatementContext.class, RETURNS_DEEP_STUBS);
when(result.getTablesContext().findTableNames(Collections.singleton(columnSegment), null)).thenReturn(Collections.singletonMap("user_name", "t_user"));
when(result.getColumnSegments()).thenReturn(Collections.singleton(columnSegment));
when(result.getWhereSegments()).thenReturn(Collections.singleton(new WhereSegment(0, 0, new BinaryOperationExpression(0, 0, columnSegment, columnSegment, "=", ""))));
when(result.getSubqueryContexts()).thenReturn(Collections.emptyMap());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,6 @@ void assertGenerateSQLTokenFromGenerateNewSQLToken() {
private Collection<EncryptCondition> getEncryptConditions(final UpdateStatementContext updateStatementContext) {
ShardingSphereDatabase database = new ShardingSphereDatabase("foo_db", mock(), mock(), mock(), Collections.singleton(new ShardingSphereSchema("foo_db")));
return new EncryptConditionEngine(EncryptGeneratorFixtureBuilder.createEncryptRule(), database)
.createEncryptConditions(updateStatementContext.getWhereSegments(), updateStatementContext.getColumnSegments(), updateStatementContext, "foo_db");
.createEncryptConditions(updateStatementContext.getWhereSegments());
}
}
Loading
Loading