diff --git a/partiql-ast/src/main/java/org/partiql/ast/v1/QueryBody.java b/partiql-ast/src/main/java/org/partiql/ast/v1/QueryBody.java index c75b89f629..d9a48ecc63 100644 --- a/partiql-ast/src/main/java/org/partiql/ast/v1/QueryBody.java +++ b/partiql-ast/src/main/java/org/partiql/ast/v1/QueryBody.java @@ -11,6 +11,17 @@ import java.util.List; public abstract class QueryBody extends AstNode { + @Override + public R accept(@NotNull AstVisitor visitor, C ctx) { + if (this instanceof QueryBody.SFW) { + return visitor.visitQueryBodySFW((QueryBody.SFW) this, ctx); + } else if (this instanceof QueryBody.SetOp) { + return visitor.visitQueryBodySetOp((QueryBody.SetOp) this, ctx); + } else { + throw new IllegalStateException("Unknown QueryBody type: " + this.getClass().getName()); + } + } + @Builder(builderClassName = "Builder") @EqualsAndHashCode(callSuper = false) public static class SFW extends QueryBody { diff --git a/partiql-planner/src/main/kotlin/org/partiql/planner/internal/normalize/AstRewriter.kt b/partiql-planner/src/main/kotlin/org/partiql/planner/internal/normalize/AstRewriter.kt new file mode 100644 index 0000000000..8a567e20f4 --- /dev/null +++ b/partiql-planner/src/main/kotlin/org/partiql/planner/internal/normalize/AstRewriter.kt @@ -0,0 +1,778 @@ +package org.partiql.planner.internal.normalize + +import org.partiql.ast.v1.Ast.explain +import org.partiql.ast.v1.Ast.exprQuerySet +import org.partiql.ast.v1.Ast.identifier +import org.partiql.ast.v1.Ast.query +import org.partiql.ast.v1.AstNode +import org.partiql.ast.v1.AstVisitor +import org.partiql.ast.v1.Exclude +import org.partiql.ast.v1.ExcludePath +import org.partiql.ast.v1.ExcludeStep +import org.partiql.ast.v1.Explain +import org.partiql.ast.v1.From +import org.partiql.ast.v1.FromExpr +import org.partiql.ast.v1.FromJoin +import org.partiql.ast.v1.FromTableRef +import org.partiql.ast.v1.GroupBy +import org.partiql.ast.v1.Identifier +import org.partiql.ast.v1.IdentifierChain +import org.partiql.ast.v1.Let +import org.partiql.ast.v1.OrderBy +import org.partiql.ast.v1.Query +import org.partiql.ast.v1.QueryBody +import org.partiql.ast.v1.Select +import org.partiql.ast.v1.SelectItem +import org.partiql.ast.v1.SelectList +import org.partiql.ast.v1.SelectPivot +import org.partiql.ast.v1.SelectStar +import org.partiql.ast.v1.SelectValue +import org.partiql.ast.v1.SetOp +import org.partiql.ast.v1.Statement +import org.partiql.ast.v1.expr.Expr +import org.partiql.ast.v1.expr.ExprAnd +import org.partiql.ast.v1.expr.ExprArray +import org.partiql.ast.v1.expr.ExprBag +import org.partiql.ast.v1.expr.ExprBetween +import org.partiql.ast.v1.expr.ExprCall +import org.partiql.ast.v1.expr.ExprCase +import org.partiql.ast.v1.expr.ExprCast +import org.partiql.ast.v1.expr.ExprCoalesce +import org.partiql.ast.v1.expr.ExprExtract +import org.partiql.ast.v1.expr.ExprInCollection +import org.partiql.ast.v1.expr.ExprIsType +import org.partiql.ast.v1.expr.ExprLike +import org.partiql.ast.v1.expr.ExprLit +import org.partiql.ast.v1.expr.ExprMatch +import org.partiql.ast.v1.expr.ExprNot +import org.partiql.ast.v1.expr.ExprNullIf +import org.partiql.ast.v1.expr.ExprOperator +import org.partiql.ast.v1.expr.ExprOr +import org.partiql.ast.v1.expr.ExprOverlay +import org.partiql.ast.v1.expr.ExprParameter +import org.partiql.ast.v1.expr.ExprPath +import org.partiql.ast.v1.expr.ExprPosition +import org.partiql.ast.v1.expr.ExprQuerySet +import org.partiql.ast.v1.expr.ExprSessionAttribute +import org.partiql.ast.v1.expr.ExprStruct +import org.partiql.ast.v1.expr.ExprSubstring +import org.partiql.ast.v1.expr.ExprTrim +import org.partiql.ast.v1.expr.ExprValues +import org.partiql.ast.v1.expr.ExprVarRef +import org.partiql.ast.v1.expr.ExprVariant +import org.partiql.ast.v1.expr.ExprWindow +import org.partiql.ast.v1.expr.PathStep +import org.partiql.ast.v1.graph.GraphMatch +import org.partiql.ast.v1.graph.GraphPattern +import org.partiql.ast.v1.graph.GraphQuantifier +import org.partiql.ast.v1.graph.GraphRestrictor +import org.partiql.ast.v1.graph.GraphSelector +import org.partiql.value.PartiQLValueExperimental + +internal abstract class AstRewriter : AstVisitor { + public override fun defaultReturn(node: AstNode, context: C): AstNode = node + + private inline fun _visitList( + nodes: List, + ctx: C, + method: (node: T, ctx: C) -> AstNode, + ): List { + if (nodes.isEmpty()) return nodes + var diff = false + val transformed = ArrayList(nodes.size) + nodes.forEach { + val n = method(it, ctx) as T + if (it !== n) diff = true + transformed.add(n) + } + return if (diff) transformed else nodes + } + + // expr + override fun visitExprAnd(node: ExprAnd, ctx: C): AstNode { + val lhs = node.lhs.accept(this, ctx) as Expr + val rhs = node.rhs.accept(this, ctx) as Expr + return if (lhs !== node.lhs || rhs !== node.rhs) { + ExprAnd(lhs, rhs) + } else { + node + } + } + + override fun visitExprArray(node: ExprArray, ctx: C): AstNode { + val values = _visitList(node.values, ctx, ::visitExpr) + return if (values !== node.values) { + ExprArray(values) + } else { + node + } + } + + override fun visitExprBag(node: ExprBag, ctx: C): AstNode { + val values = _visitList(node.values, ctx, ::visitExpr) + return if (values !== node.values) { + ExprBag(values) + } else { + node + } + } + + override fun visitExprBetween(node: ExprBetween, ctx: C): AstNode { + val value = visitExpr(node.value, ctx) as Expr + val from = visitExpr(node.from, ctx) as Expr + val to = visitExpr(node.to, ctx) as Expr + val not = node.not + return if (value !== node.value || from !== node.from || to !== node.to || not != node.not) { + ExprBetween(value, from, to, not) + } else { + node + } + } + + override fun visitExprCall(node: ExprCall, ctx: C): AstNode { + val function = node.function.accept(this, ctx) as IdentifierChain + val args = _visitList(node.args, ctx, ::visitExpr) + val setq = node.setq + return if (function !== node.function || args !== node.args || setq !== node.setq) { + ExprCall(function, args, setq) + } else { + node + } + } + + override fun visitExprCase(node: ExprCase, ctx: C): AstNode { + val expr = node.expr?.accept(this, ctx) as Expr? + val branches = _visitList(node.branches, ctx, ::visitExprCaseBranch) + val defaultExpr = node.defaultExpr?.accept(this, ctx) as Expr? + return if (expr !== node.expr || branches !== node.branches || defaultExpr !== node.defaultExpr) { + ExprCase(expr, branches, defaultExpr) + } else { + node + } + } + + override fun visitExprCaseBranch(node: ExprCase.Branch, ctx: C): AstNode { + val condition = node.condition.accept(this, ctx) as Expr + val expr = node.expr.accept(this, ctx) as Expr + return if (condition !== node.condition || expr !== node.expr) { + ExprCase.Branch(condition, expr) + } else { + node + } + } + + override fun visitExprCast(node: ExprCast, ctx: C): AstNode { + val value = node.value.accept(this, ctx) as Expr + val type = node.asType + return if (value !== node.value || type !== node.asType) { + ExprCast(value, type) + } else { + node + } + } + + override fun visitExprCoalesce(node: ExprCoalesce, ctx: C): AstNode { + val args = _visitList(node.args, ctx, ::visitExpr) + return if (args !== node.args) { + ExprCoalesce(args) + } else { + node + } + } + + override fun visitExprExtract(node: ExprExtract, ctx: C): AstNode { + val field = node.field + val source = node.source.accept(this, ctx) as Expr + return if (field !== node.field || source !== node.source) { + ExprExtract(field, source) + } else { + node + } + } + + override fun visitExprInCollection(node: ExprInCollection, ctx: C): AstNode { + val lhs = node.lhs.accept(this, ctx) as Expr + val rhs = node.rhs.accept(this, ctx) as Expr + val not = node.not + return if (lhs !== node.lhs || rhs !== node.rhs || not != node.not) { + ExprInCollection(lhs, rhs, not) + } else { + node + } + } + + override fun visitExprIsType(node: ExprIsType, ctx: C): AstNode { + val value = node.value.accept(this, ctx) as Expr + val type = node.type + val not = node.not + return if (value !== node.value || type !== node.type || not != node.not) { + ExprIsType(value, type, not) + } else { + node + } + } + + override fun visitExprLike(node: ExprLike, ctx: C): AstNode { + val value = node.value.accept(this, ctx) as Expr + val pattern = node.pattern.accept(this, ctx) as Expr + val escape = node.escape?.accept(this, ctx) as Expr? + val not = node.not + return if (value !== node.value || pattern !== node.pattern || escape !== node.escape || not != node.not) { + ExprLike(value, pattern, escape, not) + } else { + node + } + } + + @OptIn(PartiQLValueExperimental::class) + override fun visitExprLit(node: ExprLit, ctx: C): AstNode { + val value = node.value + return node + } + + override fun visitExprMatch(node: ExprMatch, ctx: C): AstNode { + val expr = node.expr.accept(this, ctx) as Expr + val pattern = visitGraphMatch(node.pattern, ctx) as GraphMatch + return if (expr !== node.expr || pattern !== node.pattern) { + ExprMatch(expr, pattern) + } else { + node + } + } + + override fun visitExprNot(node: ExprNot, ctx: C): AstNode { + val expr = node.value.accept(this, ctx) as Expr + return if (expr !== node.value) { + ExprNot(expr) + } else { + node + } + } + + override fun visitExprNullIf(node: ExprNullIf, ctx: C): AstNode { + val v1 = node.v1.accept(this, ctx) as Expr + val v2 = node.v2.accept(this, ctx) as Expr + return if (v1 !== node.v1 || v2 !== node.v2) { + ExprNullIf(v1, v2) + } else { + node + } + } + + override fun visitExprOperator(node: ExprOperator, ctx: C): AstNode { + val symbol = node.symbol + val lhs = node.lhs?.accept(this, ctx) as Expr? + val rhs = node.rhs.accept(this, ctx) as Expr + return if (symbol !== node.symbol || lhs !== node.lhs || rhs !== node.rhs) { + ExprOperator(symbol, lhs, rhs) + } else { + node + } + } + + override fun visitExprOr(node: ExprOr, ctx: C): AstNode { + val lhs = node.lhs.accept(this, ctx) as Expr + val rhs = node.rhs.accept(this, ctx) as Expr + return if (lhs !== node.lhs || rhs !== node.rhs) { + ExprOr(lhs, rhs) + } else { + node + } + } + + override fun visitExprOverlay(node: ExprOverlay, ctx: C): AstNode { + val value = node.value.accept(this, ctx) as Expr + val overlay = node.placing.accept(this, ctx) as Expr + val from = node.from.accept(this, ctx) as Expr + val forLength = node.forLength?.accept(this, ctx) as Expr? + return if (value !== node.value || overlay !== node.placing || from !== node.from || forLength !== node.forLength) { + ExprOverlay(value, overlay, from, forLength) + } else { + node + } + } + + override fun visitExprParameter(node: ExprParameter, ctx: C): AstNode { + val index = node.index + return node + } + + override fun visitExprPath(node: ExprPath, ctx: C): AstNode { + val root = node.root.accept(this, ctx) as Expr + val next = node.next?.accept(this, ctx) as PathStep? + return if (root !== node.root || next !== node.next) { + ExprPath(root, next) + } else { + node + } + } + + override fun visitExprPosition(node: ExprPosition, ctx: C): AstNode { + val lhs = node.lhs.accept(this, ctx) as Expr + val rhs = node.rhs.accept(this, ctx) as Expr + return if (lhs !== node.lhs || rhs !== node.rhs) { + ExprPosition(lhs, rhs) + } else { + node + } + } + + override fun visitExprQuerySet(node: ExprQuerySet, ctx: C): AstNode { + val body = node.body.accept(this, ctx) as QueryBody + val orderBy = node.orderBy?.let { it.accept(this, ctx) as OrderBy? } + val limit = node.limit?.let { it.accept(this, ctx) as Expr? } + val offset = node.offset?.let { it.accept(this, ctx) as Expr? } + return if (body !== node.body || orderBy !== node.orderBy || limit !== node.limit || offset !== + node.offset + ) { + exprQuerySet(body, orderBy, limit, offset) + } else { + node + } + } + + override fun visitExprSessionAttribute(node: ExprSessionAttribute, ctx: C): AstNode { + val sessionAttribute = node.sessionAttribute + return node + } + + override fun visitExprStruct(node: ExprStruct, ctx: C): AstNode { + val fields = _visitList(node.fields, ctx, ::visitExprStructField) + return if (fields !== node.fields) { + ExprStruct(fields) + } else { + node + } + } + + override fun visitExprStructField(node: ExprStruct.Field, ctx: C): AstNode { + val name = node.name.accept(this, ctx) as Expr + val value = node.value.accept(this, ctx) as Expr + return if (name !== node.name || value !== node.value) { + ExprStruct.Field(name, value) + } else { + node + } + } + + override fun visitExprSubstring(node: ExprSubstring, ctx: C): AstNode { + val value = node.value.accept(this, ctx) as Expr + val start = node.start?.accept(this, ctx) as Expr? + val length = node.length?.accept(this, ctx) as Expr? + return if (value !== node.value || start !== node.start || length !== node.length) { + ExprSubstring(value, start, length) + } else { + node + } + } + + override fun visitExprTrim(node: ExprTrim, ctx: C): AstNode { + val value = node.value.accept(this, ctx) as Expr + val chars = node.chars?.accept(this, ctx) as Expr? + val trimSpec = node.trimSpec + return if (value !== node.value || chars !== node.chars || trimSpec !== node.trimSpec) { + ExprTrim(value, chars, trimSpec) + } else { + node + } + } + + override fun visitExprValues(node: ExprValues, ctx: C): AstNode { + val values = _visitList(node.rows, ctx, ::visitExprValuesRow) + return if (values !== node.rows) { + ExprValues(values) + } else { + node + } + } + + override fun visitExprValuesRow(node: ExprValues.Row, ctx: C): AstNode { + val values = _visitList(node.values, ctx, ::visitExpr) + return if (values !== node.values) { + ExprValues.Row(values) + } else { + node + } + } + + override fun visitExprVariant(node: ExprVariant, ctx: C): AstNode { + val value = node.value + val encoding = node.encoding + return node + } + + override fun visitExprVarRef(node: ExprVarRef, ctx: C): AstNode { + val identifierChain = node.identifierChain.accept(this, ctx) as IdentifierChain + val scope = node.scope + return if (identifierChain !== node.identifierChain || scope !== node.scope) { + ExprVarRef(identifierChain, scope) + } else { + node + } + } + + override fun visitExprWindow(node: ExprWindow, ctx: C): AstNode { + val windowFunction = node.windowFunction + val expression = node.expression.accept(this, ctx) as Expr + val offset = node.offset?.accept(this, ctx) as Expr? + val defaultValue = node.defaultValue?.accept(this, ctx) as Expr? + val over = node.over.accept(this, ctx) as ExprWindow.Over + return if (windowFunction !== node.windowFunction || expression !== node.expression || offset !== + node.offset || defaultValue !== node.defaultValue || over !== node.over + ) { + ExprWindow(windowFunction, expression, offset, defaultValue, over) + } else { + node + } + } + + override fun visitExprWindowOver(node: ExprWindow.Over, ctx: C): AstNode { + val partitions = node.partitions?.let { _visitList(it, ctx, ::visitExpr) } + val sorts = node.sorts?.let { _visitList(it, ctx, ::visitSort) } + return if (partitions !== node.partitions || sorts !== node.sorts) { + ExprWindow.Over(partitions, sorts) + } else { + node + } + } + + override fun visitPathStepField(node: PathStep.Field, ctx: C): AstNode { + val field = node.field.accept(this, ctx) as Identifier + val next = node.next?.accept(this, ctx) as PathStep? + return if (field !== node.field || next !== node.next) { + PathStep.Field(field, next) + } else { + node + } + } + + override fun visitPathStepElement(node: PathStep.Element, ctx: C): AstNode { + val element = node.element.accept(this, ctx) as Expr + val next = node.next?.accept(this, ctx) as PathStep? + return if (element !== node.element || next !== node.next) { + PathStep.Element(element, next) + } else { + node + } + } + + override fun visitPathStepAllFields(node: PathStep.AllFields, ctx: C): AstNode { + val next = node.next?.accept(this, ctx) as PathStep? + return if (next !== node.next) { + PathStep.AllFields(next) + } else { + node + } + } + + override fun visitPathStepAllElements(node: PathStep.AllElements, ctx: C): AstNode { + val next = node.next?.accept(this, ctx) as PathStep? + return if (next !== node.next) { + PathStep.AllElements(next) + } else { + node + } + } + + // graph + override fun visitGraphMatch(node: GraphMatch, ctx: C): AstNode { + val patterns = _visitList(node.patterns, ctx, ::visitGraphMatchPattern) + val selector = node.selector?.accept(this, ctx) as GraphSelector? + return if (patterns !== node.patterns || selector !== node.selector) { + GraphMatch(patterns, selector) + } else { + node + } + } + + // TODO rename the visitor + override fun visitGraphMatchPattern(node: GraphPattern, ctx: C): AstNode { + val restrictor = node.restrictor?.accept(this, ctx) as GraphRestrictor? + val prefilter = node.prefilter?.accept(this, ctx) as Expr? + val variable = node.variable + val quantifier = node.quantifier?.accept(this, ctx) as GraphQuantifier? + val parts = _visitList(node.parts, ctx, ::visitGraphPart) + return if (restrictor !== node.restrictor || prefilter !== node.prefilter || variable !== + node.variable || quantifier !== node.quantifier || parts !== node.parts + ) { + GraphPattern(restrictor, prefilter, variable, quantifier, parts) + } else { + node + } + } + + override fun visitGraphQuantifier(node: GraphQuantifier, ctx: C): AstNode { + val lower = node.lower + val upper = node.upper + return node + } + + override fun visitGraphSelectorAny(node: GraphSelector.Any, ctx: C): AstNode { + return node + } + + override fun visitGraphSelectorAnyK(node: GraphSelector.AnyK, ctx: C): AstNode { + val k = node.k + return node + } + + override fun visitGraphSelectorAllShortest(node: GraphSelector.AllShortest, ctx: C): AstNode { + return node + } + + override fun visitGraphSelectorAnyShortest(node: GraphSelector.AnyShortest, ctx: C): AstNode { + return node + } + + override fun visitGraphSelectorShortestK(node: GraphSelector.ShortestK, ctx: C): AstNode { + val k = node.k + return node + } + + override fun visitGraphSelectorShortestKGroup(node: GraphSelector.ShortestKGroup, ctx: C): AstNode { + val k = node.k + return node + } + + override fun visitExclude(node: Exclude, ctx: C): AstNode { + val excludePaths = _visitList(node.excludePaths, ctx, ::visitExcludePath) + return if (excludePaths !== node.excludePaths) { + Exclude(excludePaths) + } else { + node + } + } + + override fun visitExcludePath(node: ExcludePath, ctx: C): AstNode { + val root = node.root.accept(this, ctx) as ExprVarRef + val excludeSteps = _visitList(node.excludeSteps, ctx, ::visitExcludeStep) + return if (root !== node.root || excludeSteps !== node.excludeSteps) { + ExcludePath(root, excludeSteps) + } else { + node + } + } + + override fun visitExcludeStepCollIndex(node: ExcludeStep.CollIndex, ctx: C): AstNode { + val index = node.index + return node + } + + override fun visitExcludeStepStructField(node: ExcludeStep.StructField, ctx: C): AstNode { + val symbol = visitIdentifier(node.symbol, ctx) as Identifier + return if (symbol !== node.symbol) { + ExcludeStep.StructField(symbol) + } else { + node + } + } + + override fun visitExcludeStepCollWildcard(node: ExcludeStep.CollWildcard, ctx: C): AstNode { + return node + } + + override fun visitExcludeStepStructWildcard(node: ExcludeStep.StructWildcard, ctx: C): AstNode { + return node + } + + @OptIn(PartiQLValueExperimental::class) + override fun visitExplain(node: Explain, ctx: C): AstNode { + val statement = node.statement.accept(this, ctx) as Statement + return if (statement !== node.statement) { + explain(node.options, statement) + } else { + node + } + } + + override fun visitFrom(node: From, ctx: C): AstNode { + val tableRefs = _visitList(node.tableRefs, ctx, ::visitTableRef) + return if (tableRefs !== node.tableRefs) { + From(tableRefs) + } else { + node + } + } + + override fun visitFromExpr(node: FromExpr, ctx: C): AstNode { + val expr = node.expr.accept(this, ctx) as Expr + val fromType = node.fromType + val asAlias = node.asAlias?.accept(this, ctx) as Identifier + val atAlias = node.atAlias?.accept(this, ctx) as Identifier + return if (expr !== node.expr || fromType !== node.fromType || asAlias !== node.asAlias || + atAlias !== node.atAlias + ) { + FromExpr(expr, fromType, asAlias, atAlias) + } else { + node + } + } + + override fun visitFromJoin(node: FromJoin, ctx: C): AstNode { + val lhs = node.lhs.accept(this, ctx) as FromTableRef + val rhs = node.rhs.accept(this, ctx) as FromTableRef + val joinType = node.joinType + val condition = node.condition?.accept(this, ctx) as Expr? + return if (lhs !== node.lhs || rhs !== node.rhs || joinType !== node.joinType || + condition !== node.condition + ) { + FromJoin(lhs, rhs, joinType, condition) + } else { + node + } + } + + override fun visitGroupBy(node: GroupBy, ctx: C): AstNode { + val strategy = node.strategy + val keys = _visitList(node.keys, ctx, ::visitGroupByKey) + val asAlias = node.asAlias?.accept(this, ctx) as Identifier? + return if (strategy !== node.strategy || keys !== node.keys || asAlias !== node.asAlias) { + GroupBy(strategy, keys, asAlias) + } else { + node + } + } + + override fun visitGroupByKey(node: GroupBy.Key, ctx: C): AstNode { + val expr = node.expr.accept(this, ctx) as Expr + val asAlias = node.asAlias?.accept(this, ctx) as Identifier? + return if (expr !== node.expr || asAlias !== node.asAlias) { + GroupBy.Key(expr, asAlias) + } else { + node + } + } + + override fun visitIdentifier(node: Identifier, ctx: C): AstNode { + val symbol = node.symbol + val isDelimited = node.isDelimited + return identifier(symbol, isDelimited) + } + + override fun visitIdentifierChain(node: IdentifierChain, ctx: C): AstNode { + val root = node.root.accept(this, ctx) as Identifier + val next = node.next?.accept(this, ctx) as IdentifierChain? + return if (root !== node.root || next !== node.next) { + IdentifierChain(root, next) + } else { + node + } + } + + override fun visitLet(node: Let, ctx: C): AstNode { + val bindings = _visitList(node.bindings, ctx, ::visitLetBinding) + return if (bindings !== node.bindings) { + Let(bindings) + } else { + node + } + } + + override fun visitLetBinding(node: Let.Binding, ctx: C): AstNode { + val expr = node.expr.accept(this, ctx) as Expr + val asAlias = node.asAlias.accept(this, ctx) as Identifier + return if (expr !== node.expr || asAlias !== node.asAlias) { + Let.Binding(expr, asAlias) + } else { + node + } + } + + override fun visitQuery(node: Query, ctx: C): AstNode { + val expr = node.expr.accept(this, ctx) as Expr + return if (expr !== node.expr) { + query(expr) + } else { + node + } + } + + public override fun visitQueryBodySFW(node: QueryBody.SFW, ctx: C): AstNode { + val select = node.select.accept(this, ctx) as Select + val exclude = node.exclude?.let { it.accept(this, ctx) as Exclude? } + val from = node.from.accept(this, ctx) as From + val let = node.let?.let { it.accept(this, ctx) as Let? } + val where = node.where?.let { it.accept(this, ctx) as Expr? } + val groupBy = node.groupBy?.let { it.accept(this, ctx) as GroupBy? } + val having = node.having?.let { it.accept(this, ctx) as Expr? } + return if (select !== node.select || exclude !== node.exclude || from !== node.from || let !== + node.let || where !== node.where || groupBy !== node.groupBy || having !== node.having + ) { + QueryBody.SFW(select, exclude, from, let, where, groupBy, having) + } else { + node + } + } + + public override fun visitQueryBodySetOp(node: QueryBody.SetOp, ctx: C): AstNode { + val type = visitSetOp(node.type, ctx) as SetOp + val isOuter = node.isOuter + val lhs = node.lhs.accept(this, ctx) as Expr + val rhs = node.rhs.accept(this, ctx) as Expr + return if (type !== node.type || isOuter != node.isOuter || lhs !== node.lhs || rhs !== node.rhs) { + QueryBody.SetOp(type, isOuter, lhs, rhs) + } else { + node + } + } + + public override fun visitSelectItemStar(node: SelectItem.Star, ctx: C): AstNode { + val expr = node.expr.accept(this, ctx) as Expr + return if (expr !== node.expr) { + SelectItem.Star(expr) + } else { + node + } + } + + public override fun visitSelectItemExpr(node: SelectItem.Expr, ctx: C): AstNode { + val expr = node.expr.accept(this, ctx) as Expr + val asAlias = node.asAlias?.accept(this, ctx) as Identifier? + return if (expr !== node.expr || asAlias !== node.asAlias) { + SelectItem.Expr(expr, asAlias) + } else { + node + } + } + + override fun visitSelectList(node: SelectList, ctx: C): AstNode { + val items = _visitList(node.items, ctx, ::visitSelectItem) + val setq = node.setq + return if (items !== node.items || setq !== node.setq) { + SelectList(items, setq) + } else { + node + } + } + + override fun visitSelectPivot(node: SelectPivot, ctx: C): AstNode { + val key = node.key.accept(this, ctx) as Expr + val value = node.value.accept(this, ctx) as Expr + return if (key !== node.key || value !== node.value) { + SelectPivot(key, value) + } else { + node + } + } + + override fun visitSelectStar(node: SelectStar, ctx: C): AstNode { + val setq = node.setq + return node + } + + override fun visitSelectValue(node: SelectValue, ctx: C): AstNode { + val constructor = node.constructor.accept(this, ctx) as Expr + val setq = node.setq + return if (constructor !== node.constructor || setq !== node.setq) { + SelectValue(constructor, setq) + } else { + node + } + } + + override fun visitSetOp(node: SetOp, ctx: C): AstNode { + val setOpType = node.setOpType + val setq = node.setq + return node + } +} diff --git a/partiql-planner/src/main/kotlin/org/partiql/planner/internal/normalize/NormalizeFromSource.kt b/partiql-planner/src/main/kotlin/org/partiql/planner/internal/normalize/NormalizeFromSource.kt index 41c9e24b97..572c1d6ed9 100644 --- a/partiql-planner/src/main/kotlin/org/partiql/planner/internal/normalize/NormalizeFromSource.kt +++ b/partiql-planner/src/main/kotlin/org/partiql/planner/internal/normalize/NormalizeFromSource.kt @@ -17,7 +17,6 @@ package org.partiql.planner.internal.normalize import org.partiql.ast.v1.Ast.fromExpr import org.partiql.ast.v1.Ast.fromJoin import org.partiql.ast.v1.AstNode -import org.partiql.ast.v1.AstVisitor import org.partiql.ast.v1.From import org.partiql.ast.v1.FromExpr import org.partiql.ast.v1.FromJoin @@ -33,9 +32,9 @@ import org.partiql.planner.internal.helpers.toBinder */ internal object NormalizeFromSource : AstPass { - override fun apply(statement: Statement): Statement = Visitor.visitStatement(statement, 0) as Statement + override fun apply(statement: Statement): Statement = statement.accept(Visitor, 0) as Statement - private object Visitor : AstVisitor { + private object Visitor : AstRewriter() { // Each SFW starts the ctx count again. override fun visitQueryBodySFW(node: QueryBody.SFW, ctx: Int): AstNode = super.visitQueryBodySFW(node, 0) @@ -43,9 +42,9 @@ internal object NormalizeFromSource : AstPass { override fun visitFrom(node: From, ctx: Int) = super.visitFrom(node, ctx) as From override fun visitFromJoin(node: FromJoin, ctx: Int): FromJoin { - val lhs = visitTableRef(node.lhs, ctx) as FromTableRef - val rhs = visitTableRef(node.rhs, ctx + 1) as FromTableRef - val condition = node.condition?.let { visitExpr(it, ctx) as Expr } + val lhs = node.lhs.accept(this, ctx) as FromTableRef + val rhs = node.rhs.accept(this, ctx + 1) as FromTableRef + val condition = node.condition?.accept(this, ctx) as Expr? return if (lhs !== node.lhs || rhs !== node.rhs || condition !== node.condition) { fromJoin(lhs, rhs, node.joinType, condition) } else { @@ -54,7 +53,7 @@ internal object NormalizeFromSource : AstPass { } override fun visitFromExpr(node: FromExpr, ctx: Int): FromExpr { - val expr = visitExpr(node.expr, ctx) as Expr + val expr = node.expr.accept(this, ctx) as Expr var i = ctx var asAlias = node.asAlias var atAlias = node.atAlias @@ -72,7 +71,5 @@ internal object NormalizeFromSource : AstPass { node } } - - override fun defaultReturn(node: AstNode, ctx: Int) = node } } diff --git a/partiql-planner/src/main/kotlin/org/partiql/planner/internal/normalize/NormalizeGroupBy.kt b/partiql-planner/src/main/kotlin/org/partiql/planner/internal/normalize/NormalizeGroupBy.kt index 5931272531..1741a3a014 100644 --- a/partiql-planner/src/main/kotlin/org/partiql/planner/internal/normalize/NormalizeGroupBy.kt +++ b/partiql-planner/src/main/kotlin/org/partiql/planner/internal/normalize/NormalizeGroupBy.kt @@ -17,7 +17,6 @@ package org.partiql.planner.internal.normalize import org.partiql.ast.v1.Ast.groupBy import org.partiql.ast.v1.Ast.groupByKey import org.partiql.ast.v1.AstNode -import org.partiql.ast.v1.AstVisitor import org.partiql.ast.v1.GroupBy import org.partiql.ast.v1.Statement import org.partiql.ast.v1.expr.Expr @@ -30,7 +29,7 @@ internal object NormalizeGroupBy : AstPass { override fun apply(statement: Statement) = Visitor.visitStatement(statement, 0) as Statement - private object Visitor : AstVisitor { + private object Visitor : AstRewriter() { override fun visitGroupBy(node: GroupBy, ctx: Int): AstNode { val keys = node.keys.mapIndexed { index, key -> @@ -51,7 +50,5 @@ internal object NormalizeGroupBy : AstPass { node } } - - override fun defaultReturn(node: AstNode, ctx: Int) = node } } diff --git a/partiql-planner/src/main/kotlin/org/partiql/planner/internal/transforms/V1NormalizeSelect.kt b/partiql-planner/src/main/kotlin/org/partiql/planner/internal/transforms/V1NormalizeSelect.kt index 4a95a4174b..b7f833c6a9 100644 --- a/partiql-planner/src/main/kotlin/org/partiql/planner/internal/transforms/V1NormalizeSelect.kt +++ b/partiql-planner/src/main/kotlin/org/partiql/planner/internal/transforms/V1NormalizeSelect.kt @@ -30,8 +30,6 @@ import org.partiql.ast.v1.Ast.queryBodySetOp import org.partiql.ast.v1.Ast.selectItemExpr import org.partiql.ast.v1.Ast.selectList import org.partiql.ast.v1.Ast.selectValue -import org.partiql.ast.v1.AstNode -import org.partiql.ast.v1.AstVisitor import org.partiql.ast.v1.DataType import org.partiql.ast.v1.From import org.partiql.ast.v1.FromExpr @@ -51,6 +49,7 @@ import org.partiql.ast.v1.expr.ExprStruct import org.partiql.ast.v1.expr.ExprVarRef import org.partiql.ast.v1.expr.Scope import org.partiql.planner.internal.helpers.toBinder +import org.partiql.planner.internal.normalize.AstRewriter import org.partiql.value.PartiQLValueExperimental import org.partiql.value.stringValue @@ -162,7 +161,7 @@ internal object V1NormalizeSelect { /** * The type parameter () -> Int */ - private object Visitor : AstVisitor Int> { + private object Visitor : AstRewriter<() -> Int>() { /** * This is used to give projections a name. For example: @@ -218,7 +217,7 @@ internal object V1NormalizeSelect { var diff = false val visitedItems = ArrayList(node.items.size) node.items.forEach { n -> - val item = visitSelectItem(n, ctx) as SelectItem + val item = n.accept(this, ctx) as SelectItem if (item !== n) diff = true visitedItems.add(item) } @@ -390,7 +389,5 @@ internal object V1NormalizeSelect { } private fun id(symbol: String) = identifier(symbol, isDelimited = false) - - override fun defaultReturn(node: AstNode, ctx: () -> Int) = node } } diff --git a/partiql-planner/src/main/kotlin/org/partiql/planner/internal/transforms/V1RelConverter.kt b/partiql-planner/src/main/kotlin/org/partiql/planner/internal/transforms/V1RelConverter.kt index ef79254280..4551359a60 100644 --- a/partiql-planner/src/main/kotlin/org/partiql/planner/internal/transforms/V1RelConverter.kt +++ b/partiql-planner/src/main/kotlin/org/partiql/planner/internal/transforms/V1RelConverter.kt @@ -83,6 +83,7 @@ import org.partiql.planner.internal.ir.rexOpSelect import org.partiql.planner.internal.ir.rexOpStruct import org.partiql.planner.internal.ir.rexOpStructField import org.partiql.planner.internal.ir.rexOpVarLocal +import org.partiql.planner.internal.normalize.AstRewriter import org.partiql.planner.internal.typer.CompilerType import org.partiql.types.PType import org.partiql.value.PartiQLValueExperimental @@ -675,7 +676,7 @@ internal object V1RelConverter { /** * Rewrites a SELECT node replacing (and extracting) each aggregation `i` with a synthetic field name `$agg_i`. */ - private object AggregationTransform : AstVisitor { + private object AggregationTransform : AstRewriter() { // currently hard-coded @JvmStatic private val aggregates = setOf("count", "avg", "sum", "min", "max", "any", "some", "every") diff --git a/partiql-planner/src/main/kotlin/org/partiql/planner/internal/transforms/V1RexConverter.kt b/partiql-planner/src/main/kotlin/org/partiql/planner/internal/transforms/V1RexConverter.kt index 1188f136c6..5e22fea269 100644 --- a/partiql-planner/src/main/kotlin/org/partiql/planner/internal/transforms/V1RexConverter.kt +++ b/partiql-planner/src/main/kotlin/org/partiql/planner/internal/transforms/V1RexConverter.kt @@ -165,7 +165,7 @@ internal object V1RexConverter { * @return */ internal fun visitExprCoerce(node: Expr, ctx: Env, coercion: Rex.Op.Subquery.Coercion = Rex.Op.Subquery.Coercion.SCALAR): Rex { - val rex = super.visitExpr(node, ctx) + val rex = node.accept(this, ctx) return when (isSqlSelect(node)) { true -> { val select = rex.op as Rex.Op.Select @@ -715,7 +715,7 @@ internal object V1RexConverter { val type = BOOL // Args val arg0 = visitExprCoerce(node.lhs, ctx) - val arg1 = visitExpr(node.rhs, ctx) // !! don't insert scalar subquery coercions + val arg1 = node.rhs.accept(this, ctx) // !! don't insert scalar subquery coercions // Call var call = call("in_collection", arg0, arg1) diff --git a/partiql-planner/src/main/kotlin/org/partiql/planner/internal/transforms/V1SubstitutionVisitor.kt b/partiql-planner/src/main/kotlin/org/partiql/planner/internal/transforms/V1SubstitutionVisitor.kt index 1f11a82784..5bd3626038 100644 --- a/partiql-planner/src/main/kotlin/org/partiql/planner/internal/transforms/V1SubstitutionVisitor.kt +++ b/partiql-planner/src/main/kotlin/org/partiql/planner/internal/transforms/V1SubstitutionVisitor.kt @@ -1,12 +1,10 @@ package org.partiql.planner.internal.transforms import org.partiql.ast.v1.AstNode -import org.partiql.ast.v1.AstVisitor import org.partiql.ast.v1.expr.Expr +import org.partiql.planner.internal.normalize.AstRewriter -internal object V1SubstitutionVisitor : AstVisitor> { - override fun defaultReturn(node: AstNode, ctx: Map<*, AstNode>) = node - +internal object V1SubstitutionVisitor : AstRewriter>() { override fun visitExpr(node: Expr, ctx: Map<*, AstNode>): AstNode { val visited = super.visitExpr(node, ctx) if (ctx.containsKey(visited)) { diff --git a/partiql-planner/src/test/kotlin/org/partiql/planner/internal/transforms/NormalizeSelectTest.kt b/partiql-planner/src/test/kotlin/org/partiql/planner/internal/transforms/NormalizeSelectTest.kt index 0ca74b17c7..ce81b193da 100644 --- a/partiql-planner/src/test/kotlin/org/partiql/planner/internal/transforms/NormalizeSelectTest.kt +++ b/partiql-planner/src/test/kotlin/org/partiql/planner/internal/transforms/NormalizeSelectTest.kt @@ -1,15 +1,23 @@ package org.partiql.planner.internal.transforms import org.junit.jupiter.api.Test -import org.partiql.ast.Expr -import org.partiql.ast.From -import org.partiql.ast.Identifier -import org.partiql.ast.Select -import org.partiql.ast.builder.ast -import org.partiql.ast.exprLit -import org.partiql.ast.exprVar -import org.partiql.ast.identifierSymbol -import org.partiql.ast.selectProjectItemExpression +import org.partiql.ast.v1.Ast.exprLit +import org.partiql.ast.v1.Ast.exprQuerySet +import org.partiql.ast.v1.Ast.exprStruct +import org.partiql.ast.v1.Ast.exprStructField +import org.partiql.ast.v1.Ast.exprVarRef +import org.partiql.ast.v1.Ast.from +import org.partiql.ast.v1.Ast.fromExpr +import org.partiql.ast.v1.Ast.identifier +import org.partiql.ast.v1.Ast.identifierChain +import org.partiql.ast.v1.Ast.queryBodySFW +import org.partiql.ast.v1.Ast.selectItemExpr +import org.partiql.ast.v1.Ast.selectList +import org.partiql.ast.v1.Ast.selectValue +import org.partiql.ast.v1.FromType +import org.partiql.ast.v1.SelectItem +import org.partiql.ast.v1.expr.Expr +import org.partiql.ast.v1.expr.Scope import org.partiql.value.PartiQLValueExperimental import org.partiql.value.int32Value import org.partiql.value.stringValue @@ -38,7 +46,7 @@ class NormalizeSelectTest { "b" to variable("b"), "c" to variable("c"), ) - val actual = NormalizeSelect.normalize(input) + val actual = V1NormalizeSelect.normalize(input) assertEquals(expected, actual) } @@ -63,7 +71,7 @@ class NormalizeSelectTest { "_2" to lit(2), "_3" to lit(3), ) - val actual = NormalizeSelect.normalize(input) + val actual = V1NormalizeSelect.normalize(input) assertEquals(expected, actual) } @@ -88,7 +96,7 @@ class NormalizeSelectTest { "_1" to lit(2), "_2" to lit(3), ) - val actual = NormalizeSelect.normalize(input) + val actual = V1NormalizeSelect.normalize(input) assertEquals(expected, actual) } @@ -113,70 +121,104 @@ class NormalizeSelectTest { "b" to lit(2), "c" to lit(3), ) - val actual = NormalizeSelect.normalize(input) + val actual = V1NormalizeSelect.normalize(input) assertEquals(expected, actual) } // ----- HELPERS ------------------------- - private fun variable(name: String) = exprVar( - identifier = identifierSymbol( - symbol = name, - caseSensitivity = Identifier.CaseSensitivity.INSENSITIVE, + private fun variable(name: String) = exprVarRef( + identifierChain = identifierChain( + identifier( + symbol = name, + isDelimited = false, + ), + next = null ), - scope = Expr.Var.Scope.DEFAULT, + scope = Scope.DEFAULT(), ) - private fun select(vararg items: Select.Project.Item) = ast { - exprQuerySet { - body = queryBodySFW { - select = selectProject { - this.items += items - } - from = fromValue { - expr = variable("T") - type = From.Value.Type.SCAN - } - } - } - } + private fun select(vararg items: SelectItem) = + exprQuerySet( + body = queryBodySFW( + select = selectList( + items = items.toList(), + setq = null + ), + exclude = null, + from = from( + listOf( + fromExpr( + expr = variable("T"), + fromType = FromType.SCAN(), + asAlias = null, + atAlias = null + ) + ) + ), + let = null, + where = null, + groupBy = null, + having = null, + ), + limit = null, + offset = null, + orderBy = null + ) @OptIn(PartiQLValueExperimental::class) - private fun selectValue(vararg items: Pair) = ast { - exprQuerySet { - body = queryBodySFW { - select = selectValue { - constructor = exprStruct { - for ((k, v) in items) { - fields += exprStructField { - name = exprLit(stringValue(k)) - value = v - } - } - } - } - from = fromValue { - expr = exprVar { - identifier = identifierSymbol { - symbol = "T" - caseSensitivity = Identifier.CaseSensitivity.INSENSITIVE + private fun selectValue(vararg items: Pair) = + exprQuerySet( + body = queryBodySFW( + select = selectValue( + constructor = exprStruct( + items.map { + exprStructField( + name = exprLit(stringValue(it.first)), + value = it.second + ) } - scope = Expr.Var.Scope.DEFAULT - } - type = From.Value.Type.SCAN - } - } - } - } + ), + setq = null + ), + exclude = null, + from = from( + listOf( + fromExpr( + expr = exprVarRef( + identifierChain = identifierChain( + identifier( + symbol = "T", + isDelimited = false + ), + next = null + ), + scope = Scope.DEFAULT() + ), + fromType = FromType.SCAN(), + asAlias = null, + atAlias = null + ), + ) + ), + let = null, + where = null, + groupBy = null, + having = null, + ), + limit = null, + offset = null, + orderBy = null + ) - private fun varItem(symbol: String, asAlias: String? = null) = selectProjectItemExpression( + private fun varItem(symbol: String, asAlias: String? = null) = selectItemExpr( expr = variable(symbol), - asAlias = asAlias?.let { identifierSymbol(asAlias, Identifier.CaseSensitivity.INSENSITIVE) } + asAlias = asAlias?.let { identifier(asAlias, isDelimited = false) } ) - private fun litItem(value: Int, asAlias: String? = null) = selectProjectItemExpression( + private fun litItem(value: Int, asAlias: String? = null) = selectItemExpr( expr = lit(value), - asAlias = asAlias?.let { identifierSymbol(asAlias, Identifier.CaseSensitivity.INSENSITIVE) } + asAlias = asAlias?.let { identifier(asAlias, isDelimited = false) } ) @OptIn(PartiQLValueExperimental::class)