From 16b46890d28e7c389c43070a5872a7306112b08c Mon Sep 17 00:00:00 2001 From: Alan Cai Date: Fri, 1 Nov 2024 13:06:53 -0700 Subject: [PATCH] [v1] Rename V1Parser to Parser; move parser builder within the interface --- partiql-ast/api/partiql-ast.api | 9 - .../org/partiql/ast/helpers/ToBinder.kt | 67 - .../org/partiql/ast/normalize/AstPass.kt | 26 - .../org/partiql/ast/normalize/Normalize.kt | 31 - .../ast/normalize/NormalizeFromSource.kt | 73 - .../partiql/ast/normalize/NormalizeGroupBy.kt | 56 - .../org/partiql/cli/pipeline/Pipeline.kt | 7 +- .../eval/internal/PartiQLEvaluatorTest.kt | 6 +- partiql-parser/api/partiql-parser.api | 24 +- .../org/partiql/parser/PartiQLParser.java | 18 +- .../partiql/parser/PartiQLParserBuilder.java | 33 - .../parser/PartiQLParserBuilderV1.java | 30 - .../org/partiql/parser/PartiQLParserV1.java | 97 - .../parser/internal/PartiQLParserDefault.kt | 1199 ++++----- .../parser/internal/PartiQLParserDefaultV1.kt | 2199 ----------------- .../parser/internal/ParserTestCaseSimple.kt | 4 +- .../internal/PartiQLParserBagOpTests.kt | 2 +- .../parser/internal/PartiQLParserDDLTests.kt | 2 +- .../PartiQLParserFunctionCallTests.kt | 2 +- .../internal/PartiQLParserOperatorTests.kt | 2 +- .../PartiQLParserSessionAttributeTests.kt | 2 +- .../partiql/planner/internal/SqlPlanner.kt | 4 +- .../planner/internal/transforms/AstToPlan.kt | 39 +- .../internal/transforms/NormalizeSelect.kt | 176 +- .../internal/transforms/RelConverter.kt | 250 +- .../internal/transforms/RexConverter.kt | 499 ++-- .../transforms/SubstitutionVisitor.kt | 6 +- .../internal/transforms/V1AstToPlan.kt | 75 - .../internal/transforms/V1NormalizeSelect.kt | 393 --- .../internal/transforms/V1RelConverter.kt | 755 ------ .../internal/transforms/V1RexConverter.kt | 1077 -------- .../transforms/V1SubstitutionVisitor.kt | 15 - .../kotlin/org/partiql/planner/PlanTest.kt | 4 +- .../planner/PlannerPErrorReportingTests.kt | 4 +- .../internal/exclude/SubsumptionTest.kt | 4 +- .../transforms/NormalizeSelectTest.kt | 8 +- .../internal/typer/PartiQLTyperTestBase.kt | 4 +- .../internal/typer/PlanTyperTestsPorted.kt | 4 +- .../org/partiql/lang/randomized/eval/Utils.kt | 4 +- .../partiql/runner/executor/EvalExecutor.kt | 4 +- 40 files changed, 1218 insertions(+), 5996 deletions(-) delete mode 100644 partiql-ast/src/main/kotlin/org/partiql/ast/helpers/ToBinder.kt delete mode 100644 partiql-ast/src/main/kotlin/org/partiql/ast/normalize/AstPass.kt delete mode 100644 partiql-ast/src/main/kotlin/org/partiql/ast/normalize/Normalize.kt delete mode 100644 partiql-ast/src/main/kotlin/org/partiql/ast/normalize/NormalizeFromSource.kt delete mode 100644 partiql-ast/src/main/kotlin/org/partiql/ast/normalize/NormalizeGroupBy.kt delete mode 100644 partiql-parser/src/main/java/org/partiql/parser/PartiQLParserBuilder.java delete mode 100644 partiql-parser/src/main/java/org/partiql/parser/PartiQLParserBuilderV1.java delete mode 100644 partiql-parser/src/main/java/org/partiql/parser/PartiQLParserV1.java delete mode 100644 partiql-parser/src/main/kotlin/org/partiql/parser/internal/PartiQLParserDefaultV1.kt delete mode 100644 partiql-planner/src/main/kotlin/org/partiql/planner/internal/transforms/V1AstToPlan.kt delete mode 100644 partiql-planner/src/main/kotlin/org/partiql/planner/internal/transforms/V1NormalizeSelect.kt delete mode 100644 partiql-planner/src/main/kotlin/org/partiql/planner/internal/transforms/V1RelConverter.kt delete mode 100644 partiql-planner/src/main/kotlin/org/partiql/planner/internal/transforms/V1RexConverter.kt delete mode 100644 partiql-planner/src/main/kotlin/org/partiql/planner/internal/transforms/V1SubstitutionVisitor.kt diff --git a/partiql-ast/api/partiql-ast.api b/partiql-ast/api/partiql-ast.api index 2d141cfa97..d6f0ac4302 100644 --- a/partiql-ast/api/partiql-ast.api +++ b/partiql-ast/api/partiql-ast.api @@ -4983,15 +4983,6 @@ public final class org/partiql/ast/builder/TypeVarcharBuilder { public final fun setLength (Ljava/lang/Integer;)V } -public final class org/partiql/ast/helpers/ToBinderKt { - public static final fun toBinder (Lorg/partiql/ast/Expr;I)Lorg/partiql/ast/Identifier$Symbol; - public static final fun toBinder (Lorg/partiql/ast/Expr;Lkotlin/jvm/functions/Function0;)Lorg/partiql/ast/Identifier$Symbol; -} - -public final class org/partiql/ast/normalize/Normalize { - public static final fun normalize (Lorg/partiql/ast/Statement;)Lorg/partiql/ast/Statement; -} - public abstract class org/partiql/ast/sql/SqlBlock { public static final field Companion Lorg/partiql/ast/sql/SqlBlock$Companion; public field next Lorg/partiql/ast/sql/SqlBlock; diff --git a/partiql-ast/src/main/kotlin/org/partiql/ast/helpers/ToBinder.kt b/partiql-ast/src/main/kotlin/org/partiql/ast/helpers/ToBinder.kt deleted file mode 100644 index fbc2a78bc4..0000000000 --- a/partiql-ast/src/main/kotlin/org/partiql/ast/helpers/ToBinder.kt +++ /dev/null @@ -1,67 +0,0 @@ -package org.partiql.ast.helpers - -import org.partiql.ast.Expr -import org.partiql.ast.Identifier -import org.partiql.ast.builder.ast -import org.partiql.value.PartiQLValueExperimental -import org.partiql.value.StringValue - -// TODO DELETE FILE - -private val col = { index: () -> Int -> "_${index()}" } - -/** - * Produces a "binder" (AS alias) for an expression following the given rules: - * - * 1. If item is an id, use the last symbol - * 2. If item is a path with a final symbol step, use the symbol — else 4 - * 3. If item is a cast, use the value name - * 4. Else, use item index with prefix _ - * - * See https://github.com/partiql/partiql-lang-kotlin/issues/1122 - */ -public fun Expr.toBinder(index: () -> Int): Identifier.Symbol = when (this) { - is Expr.Var -> this.identifier.toBinder() - is Expr.Path -> this.toBinder(index) - is Expr.Cast -> this.value.toBinder(index) - is Expr.SessionAttribute -> this.attribute.name.uppercase().toBinder() - else -> col(index).toBinder() -} - -/** - * Simple toBinder that uses an int literal rather than a closure. - * - * @param index - * @return - */ -public fun Expr.toBinder(index: Int): Identifier.Symbol = toBinder { index } - -private fun String.toBinder(): Identifier.Symbol = ast { - // Every binder preserves case - identifierSymbol(this@toBinder, Identifier.CaseSensitivity.SENSITIVE) -} - -private fun Identifier.toBinder(): Identifier.Symbol = when (this@toBinder) { - is Identifier.Qualified -> when (steps.isEmpty()) { - true -> root.symbol.toBinder() - else -> steps.last().symbol.toBinder() - } - is Identifier.Symbol -> symbol.toBinder() -} - -@OptIn(PartiQLValueExperimental::class) -private fun Expr.Path.toBinder(index: () -> Int): Identifier.Symbol { - if (steps.isEmpty()) return root.toBinder(index) - return when (val last = steps.last()) { - is Expr.Path.Step.Symbol -> last.symbol.toBinder() - is Expr.Path.Step.Index -> { - val k = last.key - if (k is Expr.Lit && k.value is StringValue) { - k.value.value!!.toBinder() - } else { - col(index).toBinder() - } - } - else -> col(index).toBinder() - } -} diff --git a/partiql-ast/src/main/kotlin/org/partiql/ast/normalize/AstPass.kt b/partiql-ast/src/main/kotlin/org/partiql/ast/normalize/AstPass.kt deleted file mode 100644 index 57faa804b5..0000000000 --- a/partiql-ast/src/main/kotlin/org/partiql/ast/normalize/AstPass.kt +++ /dev/null @@ -1,26 +0,0 @@ -/* - * Copyright 2022 Amazon.com, Inc. or its affiliates. All rights reserved. - * - * Licensed under the Apache License, Version 2.0 (the "License"). - * You may not use this file except in compliance with the License. - * A copy of the License is located at: - * - * http://aws.amazon.com/apache2.0/ - * - * or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific - * language governing permissions and limitations under the License. - */ - -package org.partiql.ast.normalize - -import org.partiql.ast.Statement - -// TODO DELETE FILE - -/** - * Wraps a rewriter with a default entry point. - */ -internal interface AstPass { - public fun apply(statement: Statement): Statement -} diff --git a/partiql-ast/src/main/kotlin/org/partiql/ast/normalize/Normalize.kt b/partiql-ast/src/main/kotlin/org/partiql/ast/normalize/Normalize.kt deleted file mode 100644 index d4bce86b0b..0000000000 --- a/partiql-ast/src/main/kotlin/org/partiql/ast/normalize/Normalize.kt +++ /dev/null @@ -1,31 +0,0 @@ -@file:JvmName("Normalize") -/* - * Copyright 2022 Amazon.com, Inc. or its affiliates. All rights reserved. - * - * Licensed under the Apache License, Version 2.0 (the "License"). - * You may not use this file except in compliance with the License. - * A copy of the License is located at: - * - * http://aws.amazon.com/apache2.0/ - * - * or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific - * language governing permissions and limitations under the License. - */ - -package org.partiql.ast.normalize - -import org.partiql.ast.Statement - -// TODO DELETE FILE - -/** - * AST normalization - */ -public fun Statement.normalize(): Statement { // TODO: Make this Java friendly and consider moving to planner package. - // could be a fold, but this is nice for setting breakpoints - var ast = this - ast = NormalizeFromSource.apply(ast) - ast = NormalizeGroupBy.apply(ast) - return ast -} diff --git a/partiql-ast/src/main/kotlin/org/partiql/ast/normalize/NormalizeFromSource.kt b/partiql-ast/src/main/kotlin/org/partiql/ast/normalize/NormalizeFromSource.kt deleted file mode 100644 index 29531daa6e..0000000000 --- a/partiql-ast/src/main/kotlin/org/partiql/ast/normalize/NormalizeFromSource.kt +++ /dev/null @@ -1,73 +0,0 @@ -/* - * Copyright 2022 Amazon.com, Inc. or its affiliates. All rights reserved. - * - * Licensed under the Apache License, Version 2.0 (the "License"). - * You may not use this file except in compliance with the License. - * A copy of the License is located at: - * - * http://aws.amazon.com/apache2.0/ - * - * or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific - * language governing permissions and limitations under the License. - */ - -package org.partiql.ast.normalize - -import org.partiql.ast.AstNode -import org.partiql.ast.Expr -import org.partiql.ast.From -import org.partiql.ast.QueryBody -import org.partiql.ast.Statement -import org.partiql.ast.fromJoin -import org.partiql.ast.helpers.toBinder -import org.partiql.ast.util.AstRewriter - -// TODO DELETE FILE - -/** - * Assign aliases to any FROM source which does not have one. - */ -internal object NormalizeFromSource : AstPass { - - override fun apply(statement: Statement): Statement = Visitor.visitStatement(statement, 0) as Statement - - private object Visitor : AstRewriter() { - - // Each SFW starts the ctx count again. - override fun visitQueryBodySFW(node: QueryBody.SFW, ctx: Int): AstNode = super.visitQueryBodySFW(node, 0) - - override fun visitFrom(node: From, ctx: Int) = super.visitFrom(node, ctx) as From - - override fun visitFromJoin(node: From.Join, ctx: Int): From { - val lhs = visitFrom(node.lhs, ctx) - val rhs = visitFrom(node.rhs, ctx + 1) - val condition = node.condition?.let { visitExpr(it, ctx) as Expr } - return if (lhs !== node.lhs || rhs !== node.rhs || condition !== node.condition) { - fromJoin(lhs, rhs, node.type, condition) - } else { - node - } - } - - override fun visitFromValue(node: From.Value, ctx: Int): From { - val expr = visitExpr(node.expr, ctx) as Expr - var i = ctx - var asAlias = node.asAlias - var atAlias = node.atAlias - // derive AS alias - if (asAlias == null) { - asAlias = expr.toBinder(i++) - } - // derive AT binder - if (atAlias == null && node.type == From.Value.Type.UNPIVOT) { - atAlias = expr.toBinder(i++) - } - return if (expr !== node.expr || asAlias !== node.asAlias || atAlias !== node.atAlias) { - node.copy(expr = expr, asAlias = asAlias, atAlias = atAlias) - } else { - node - } - } - } -} diff --git a/partiql-ast/src/main/kotlin/org/partiql/ast/normalize/NormalizeGroupBy.kt b/partiql-ast/src/main/kotlin/org/partiql/ast/normalize/NormalizeGroupBy.kt deleted file mode 100644 index 66228fe237..0000000000 --- a/partiql-ast/src/main/kotlin/org/partiql/ast/normalize/NormalizeGroupBy.kt +++ /dev/null @@ -1,56 +0,0 @@ -/* - * Copyright 2022 Amazon.com, Inc. or its affiliates. All rights reserved. - * - * Licensed under the Apache License, Version 2.0 (the "License"). - * You may not use this file except in compliance with the License. - * A copy of the License is located at: - * - * http://aws.amazon.com/apache2.0/ - * - * or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific - * language governing permissions and limitations under the License. - */ - -package org.partiql.ast.normalize - -import org.partiql.ast.AstNode -import org.partiql.ast.Expr -import org.partiql.ast.GroupBy -import org.partiql.ast.Statement -import org.partiql.ast.groupByKey -import org.partiql.ast.helpers.toBinder -import org.partiql.ast.util.AstRewriter - -// TODO DELETE FILE - -/** - * Adds a unique binder to each group key. - */ -internal object NormalizeGroupBy : AstPass { - - override fun apply(statement: Statement) = Visitor.visitStatement(statement, 0) as Statement - - private object Visitor : AstRewriter() { - - override fun visitGroupBy(node: GroupBy, ctx: Int): AstNode { - val keys = node.keys.mapIndexed { index, key -> - visitGroupByKey(key, index + 1) - } - return node.copy(keys = keys) - } - - override fun visitGroupByKey(node: GroupBy.Key, ctx: Int): GroupBy.Key { - val expr = visitExpr(node.expr, 0) as Expr - val alias = when (node.asAlias) { - null -> expr.toBinder(ctx) - else -> node.asAlias - } - return if (expr !== node.expr || alias !== node.asAlias) { - groupByKey(expr, alias) - } else { - node - } - } - } -} diff --git a/partiql-cli/src/main/kotlin/org/partiql/cli/pipeline/Pipeline.kt b/partiql-cli/src/main/kotlin/org/partiql/cli/pipeline/Pipeline.kt index 59b7ff4ba5..0fac6dd660 100644 --- a/partiql-cli/src/main/kotlin/org/partiql/cli/pipeline/Pipeline.kt +++ b/partiql-cli/src/main/kotlin/org/partiql/cli/pipeline/Pipeline.kt @@ -4,8 +4,7 @@ import org.partiql.ast.v1.Statement import org.partiql.cli.ErrorCodeString import org.partiql.eval.Mode import org.partiql.eval.compiler.PartiQLCompiler -import org.partiql.parser.PartiQLParserBuilderV1 -import org.partiql.parser.PartiQLParserV1 +import org.partiql.parser.PartiQLParser import org.partiql.plan.Plan import org.partiql.planner.PartiQLPlanner import org.partiql.spi.Context @@ -15,7 +14,7 @@ import org.partiql.spi.value.Datum import java.io.PrintStream internal class Pipeline private constructor( - private val parser: PartiQLParserV1, + private val parser: PartiQLParser, private val planner: PartiQLPlanner, private val compiler: PartiQLCompiler, private val ctx: Context, @@ -83,7 +82,7 @@ internal class Pipeline private constructor( private fun create(mode: Mode, out: PrintStream, config: Config): Pipeline { val listener = config.getErrorListener(out) val ctx = Context.of(listener) - val parser = PartiQLParserBuilderV1().build() + val parser = PartiQLParser.Builder().build() val planner = PartiQLPlanner.builder().build() val compiler = PartiQLCompiler.builder().build() return Pipeline(parser, planner, compiler, ctx, mode) diff --git a/partiql-eval/src/test/kotlin/org/partiql/eval/internal/PartiQLEvaluatorTest.kt b/partiql-eval/src/test/kotlin/org/partiql/eval/internal/PartiQLEvaluatorTest.kt index cedeaa173f..dddd3e95c2 100644 --- a/partiql-eval/src/test/kotlin/org/partiql/eval/internal/PartiQLEvaluatorTest.kt +++ b/partiql-eval/src/test/kotlin/org/partiql/eval/internal/PartiQLEvaluatorTest.kt @@ -9,7 +9,7 @@ import org.junit.jupiter.params.ParameterizedTest import org.junit.jupiter.params.provider.MethodSource import org.partiql.eval.Mode import org.partiql.eval.compiler.PartiQLCompiler -import org.partiql.parser.PartiQLParserV1 +import org.partiql.parser.PartiQLParser import org.partiql.plan.Plan import org.partiql.planner.PartiQLPlanner import org.partiql.spi.catalog.Catalog @@ -1308,7 +1308,7 @@ class PartiQLEvaluatorTest { ) { private val compiler = PartiQLCompiler.standard() - private val parser = PartiQLParserV1.standard() + private val parser = PartiQLParser.standard() private val planner = PartiQLPlanner.standard() /** @@ -1376,7 +1376,7 @@ class PartiQLEvaluatorTest { ) { private val compiler = PartiQLCompiler.standard() - private val parser = PartiQLParserV1.standard() + private val parser = PartiQLParser.standard() private val planner = PartiQLPlanner.standard() internal fun assert() { diff --git a/partiql-parser/api/partiql-parser.api b/partiql-parser/api/partiql-parser.api index bd84b15a30..f6cc5b551b 100644 --- a/partiql-parser/api/partiql-parser.api +++ b/partiql-parser/api/partiql-parser.api @@ -1,34 +1,16 @@ public abstract interface class org/partiql/parser/PartiQLParser { - public static fun builder ()Lorg/partiql/parser/PartiQLParserBuilder; + public static fun builder ()Lorg/partiql/parser/PartiQLParser$Builder; public fun parse (Ljava/lang/String;)Lorg/partiql/parser/PartiQLParser$Result; public abstract fun parse (Ljava/lang/String;Lorg/partiql/spi/Context;)Lorg/partiql/parser/PartiQLParser$Result; public static fun standard ()Lorg/partiql/parser/PartiQLParser; } -public final class org/partiql/parser/PartiQLParser$Result { - public field locations Lorg/partiql/spi/SourceLocations; - public field statements Ljava/util/List; - public fun (Ljava/util/List;Lorg/partiql/spi/SourceLocations;)V -} - -public class org/partiql/parser/PartiQLParserBuilder { +public class org/partiql/parser/PartiQLParser$Builder { public fun ()V public fun build ()Lorg/partiql/parser/PartiQLParser; } -public class org/partiql/parser/PartiQLParserBuilderV1 { - public fun ()V - public fun build ()Lorg/partiql/parser/PartiQLParserV1; -} - -public abstract interface class org/partiql/parser/PartiQLParserV1 { - public static fun builder ()Lorg/partiql/parser/PartiQLParserBuilderV1; - public fun parse (Ljava/lang/String;)Lorg/partiql/parser/PartiQLParserV1$Result; - public abstract fun parse (Ljava/lang/String;Lorg/partiql/spi/Context;)Lorg/partiql/parser/PartiQLParserV1$Result; - public static fun standard ()Lorg/partiql/parser/PartiQLParserV1; -} - -public final class org/partiql/parser/PartiQLParserV1$Result { +public final class org/partiql/parser/PartiQLParser$Result { public field locations Lorg/partiql/spi/SourceLocations; public field statements Ljava/util/List; public fun (Ljava/util/List;Lorg/partiql/spi/SourceLocations;)V diff --git a/partiql-parser/src/main/java/org/partiql/parser/PartiQLParser.java b/partiql-parser/src/main/java/org/partiql/parser/PartiQLParser.java index ef85820992..4c0c2eca19 100644 --- a/partiql-parser/src/main/java/org/partiql/parser/PartiQLParser.java +++ b/partiql-parser/src/main/java/org/partiql/parser/PartiQLParser.java @@ -15,7 +15,7 @@ package org.partiql.parser; import org.jetbrains.annotations.NotNull; -import org.partiql.ast.Statement; +import org.partiql.ast.v1.Statement; import org.partiql.parser.internal.PartiQLParserDefault; import org.partiql.spi.Context; import org.partiql.spi.SourceLocations; @@ -24,7 +24,7 @@ import java.util.List; /** - * TODO + * TODO docs */ public interface PartiQLParser { @@ -82,8 +82,8 @@ public Result(@NotNull List statements, @NotNull SourceLocations loca * @return TODO */ @NotNull - public static PartiQLParserBuilder builder() { - return new PartiQLParserBuilder(); + public static Builder builder() { + return new Builder(); } /** @@ -94,4 +94,14 @@ public static PartiQLParserBuilder builder() { public static PartiQLParser standard() { return new PartiQLParserDefault(); } + + /** + * A builder class to instantiate a [PartiQLParser]. + */ + public class Builder { + @NotNull + public PartiQLParser build() { + return new PartiQLParserDefault(); + } + } } diff --git a/partiql-parser/src/main/java/org/partiql/parser/PartiQLParserBuilder.java b/partiql-parser/src/main/java/org/partiql/parser/PartiQLParserBuilder.java deleted file mode 100644 index 447424d76e..0000000000 --- a/partiql-parser/src/main/java/org/partiql/parser/PartiQLParserBuilder.java +++ /dev/null @@ -1,33 +0,0 @@ -/* - * Copyright Amazon.com, Inc. or its affiliates. All rights reserved. - * - * Licensed under the Apache License, Version 2.0 (the "License"). - * You may not use this file except in compliance with the License. - * A copy of the License is located at: - * - * http://aws.amazon.com/apache2.0/ - * - * or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific - * language governing permissions and limitations under the License. - */ - -package org.partiql.parser; - -import org.jetbrains.annotations.NotNull; -import org.partiql.parser.internal.PartiQLParserDefault; - -/** - * A builder class to instantiate a [PartiQLParser]. - */ -public class PartiQLParserBuilder { - - /** - * TODO - * @return TODO - */ - @NotNull - public PartiQLParser build() { - return new PartiQLParserDefault(); - } -} diff --git a/partiql-parser/src/main/java/org/partiql/parser/PartiQLParserBuilderV1.java b/partiql-parser/src/main/java/org/partiql/parser/PartiQLParserBuilderV1.java deleted file mode 100644 index 24b8e405cb..0000000000 --- a/partiql-parser/src/main/java/org/partiql/parser/PartiQLParserBuilderV1.java +++ /dev/null @@ -1,30 +0,0 @@ -/* - * Copyright Amazon.com, Inc. or its affiliates. All rights reserved. - * - * Licensed under the Apache License, Version 2.0 (the "License"). - * You may not use this file except in compliance with the License. - * A copy of the License is located at: - * - * http://aws.amazon.com/apache2.0/ - * - * or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific - * language governing permissions and limitations under the License. - */ - -package org.partiql.parser; - -import org.jetbrains.annotations.NotNull; -import org.partiql.parser.internal.PartiQLParserDefaultV1; - -/** - * A builder class to instantiate a [PartiQLParserV1]. https://github.com/partiql/partiql-lang-kotlin/issues/1632 - * TODO replace with Lombok builder once [PartiQLParserV1] is migrated to Java. - */ -public class PartiQLParserBuilderV1 { - - @NotNull - public PartiQLParserV1 build() { - return new PartiQLParserDefaultV1(); - } -} diff --git a/partiql-parser/src/main/java/org/partiql/parser/PartiQLParserV1.java b/partiql-parser/src/main/java/org/partiql/parser/PartiQLParserV1.java deleted file mode 100644 index d16acc2677..0000000000 --- a/partiql-parser/src/main/java/org/partiql/parser/PartiQLParserV1.java +++ /dev/null @@ -1,97 +0,0 @@ -/* - * Copyright Amazon.com, Inc. or its affiliates. All rights reserved. - * - * Licensed under the Apache License, Version 2.0 (the "License"). - * You may not use this file except in compliance with the License. - * A copy of the License is located at: - * - * http://aws.amazon.com/apache2.0/ - * - * or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific - * language governing permissions and limitations under the License. - */ - -package org.partiql.parser; - -import org.jetbrains.annotations.NotNull; -import org.partiql.ast.v1.Statement; -import org.partiql.parser.internal.PartiQLParserDefaultV1; -import org.partiql.spi.Context; -import org.partiql.spi.SourceLocations; -import org.partiql.spi.errors.PErrorListenerException; - -import java.util.List; - -/** - * TODO: Rename to PartiQLParser - */ -public interface PartiQLParserV1 { - - /** - * Parses the [source] into an AST. - * @param source the user's input - * @param ctx a configuration object for the parser - * @throws PErrorListenerException when the [org.partiql.spi.errors.PErrorListener] defined in the [ctx] throws an - * [PErrorListenerException], this method halts execution and propagates the exception. - */ - @NotNull - Result parse(@NotNull String source, @NotNull Context ctx) throws PErrorListenerException; - - /** - * Parses the [source] into an AST. - * @param source the user's input - * @throws PErrorListenerException when the [org.partiql.spi.errors.PErrorListener] defined in the context throws an - * [PErrorListenerException], this method halts execution and propagates the exception. - */ - @NotNull - default Result parse(@NotNull String source) throws PErrorListenerException { - return parse(source, Context.standard()); - } - - /** - * TODO - */ - final class Result { - - /** - * TODO - */ - @NotNull - public List statements; - - /** - * TODO - */ - @NotNull - public SourceLocations locations; - - /** - * TODO - * @param statements TODO - * @param locations TODO - */ - public Result(@NotNull List statements, @NotNull SourceLocations locations) { - this.statements = statements; - this.locations = locations; - } - } - - /** - * TODO - * @return TODO - */ - @NotNull - public static PartiQLParserBuilderV1 builder() { - return new PartiQLParserBuilderV1(); - } - - /** - * TODO - * @return TODO - */ - @NotNull - public static PartiQLParserV1 standard() { - return new PartiQLParserDefaultV1(); - } -} diff --git a/partiql-parser/src/main/kotlin/org/partiql/parser/internal/PartiQLParserDefault.kt b/partiql-parser/src/main/kotlin/org/partiql/parser/internal/PartiQLParserDefault.kt index 04cdcf5f79..b2b0e74834 100644 --- a/partiql-parser/src/main/kotlin/org/partiql/parser/internal/PartiQLParserDefault.kt +++ b/partiql-parser/src/main/kotlin/org/partiql/parser/internal/PartiQLParserDefault.kt @@ -30,147 +30,128 @@ import org.antlr.v4.runtime.TokenStream import org.antlr.v4.runtime.atn.PredictionMode import org.antlr.v4.runtime.misc.ParseCancellationException import org.antlr.v4.runtime.tree.TerminalNode -import org.partiql.ast.AstNode -import org.partiql.ast.DatetimeField -import org.partiql.ast.Exclude -import org.partiql.ast.Expr -import org.partiql.ast.From -import org.partiql.ast.GraphMatch -import org.partiql.ast.GroupBy -import org.partiql.ast.Identifier -import org.partiql.ast.Let -import org.partiql.ast.Path -import org.partiql.ast.Select -import org.partiql.ast.SetOp -import org.partiql.ast.SetQuantifier -import org.partiql.ast.Sort -import org.partiql.ast.Statement -import org.partiql.ast.TableDefinition -import org.partiql.ast.Type -import org.partiql.ast.exclude -import org.partiql.ast.excludeItem -import org.partiql.ast.excludeStepCollIndex -import org.partiql.ast.excludeStepCollWildcard -import org.partiql.ast.excludeStepStructField -import org.partiql.ast.excludeStepStructWildcard -import org.partiql.ast.exprAnd -import org.partiql.ast.exprBetween -import org.partiql.ast.exprCall -import org.partiql.ast.exprCase -import org.partiql.ast.exprCaseBranch -import org.partiql.ast.exprCast -import org.partiql.ast.exprCoalesce -import org.partiql.ast.exprCollection -import org.partiql.ast.exprDateAdd -import org.partiql.ast.exprDateDiff -import org.partiql.ast.exprExtract -import org.partiql.ast.exprInCollection -import org.partiql.ast.exprIsType -import org.partiql.ast.exprLike -import org.partiql.ast.exprLit -import org.partiql.ast.exprMatch -import org.partiql.ast.exprNot -import org.partiql.ast.exprNullIf -import org.partiql.ast.exprOperator -import org.partiql.ast.exprOr -import org.partiql.ast.exprOverlay -import org.partiql.ast.exprParameter -import org.partiql.ast.exprPath -import org.partiql.ast.exprPathStepIndex -import org.partiql.ast.exprPathStepSymbol -import org.partiql.ast.exprPathStepUnpivot -import org.partiql.ast.exprPathStepWildcard -import org.partiql.ast.exprPosition -import org.partiql.ast.exprQuerySet -import org.partiql.ast.exprSessionAttribute -import org.partiql.ast.exprStruct -import org.partiql.ast.exprStructField -import org.partiql.ast.exprSubstring -import org.partiql.ast.exprTrim -import org.partiql.ast.exprVar -import org.partiql.ast.exprVariant -import org.partiql.ast.exprWindow -import org.partiql.ast.exprWindowOver -import org.partiql.ast.fromJoin -import org.partiql.ast.fromValue -import org.partiql.ast.graphMatch -import org.partiql.ast.graphMatchLabelConj -import org.partiql.ast.graphMatchLabelDisj -import org.partiql.ast.graphMatchLabelName -import org.partiql.ast.graphMatchLabelNegation -import org.partiql.ast.graphMatchLabelWildcard -import org.partiql.ast.graphMatchPattern -import org.partiql.ast.graphMatchPatternPartEdge -import org.partiql.ast.graphMatchPatternPartNode -import org.partiql.ast.graphMatchPatternPartPattern -import org.partiql.ast.graphMatchQuantifier -import org.partiql.ast.graphMatchSelectorAllShortest -import org.partiql.ast.graphMatchSelectorAny -import org.partiql.ast.graphMatchSelectorAnyK -import org.partiql.ast.graphMatchSelectorAnyShortest -import org.partiql.ast.graphMatchSelectorShortestK -import org.partiql.ast.graphMatchSelectorShortestKGroup -import org.partiql.ast.groupBy -import org.partiql.ast.groupByKey -import org.partiql.ast.identifierQualified -import org.partiql.ast.identifierSymbol -import org.partiql.ast.let -import org.partiql.ast.letBinding -import org.partiql.ast.orderBy -import org.partiql.ast.path -import org.partiql.ast.pathStepIndex -import org.partiql.ast.pathStepSymbol -import org.partiql.ast.queryBodySFW -import org.partiql.ast.queryBodySetOp -import org.partiql.ast.selectPivot -import org.partiql.ast.selectProject -import org.partiql.ast.selectProjectItemAll -import org.partiql.ast.selectProjectItemExpression -import org.partiql.ast.selectStar -import org.partiql.ast.selectValue -import org.partiql.ast.setOp -import org.partiql.ast.sort -import org.partiql.ast.statementDDLCreateIndex -import org.partiql.ast.statementDDLCreateTable -import org.partiql.ast.statementDDLDropIndex -import org.partiql.ast.statementDDLDropTable -import org.partiql.ast.statementExplain -import org.partiql.ast.statementExplainTargetDomain -import org.partiql.ast.statementQuery -import org.partiql.ast.tableDefinition -import org.partiql.ast.tableDefinitionColumn -import org.partiql.ast.tableDefinitionColumnConstraint -import org.partiql.ast.tableDefinitionColumnConstraintBodyNotNull -import org.partiql.ast.tableDefinitionColumnConstraintBodyNullable -import org.partiql.ast.typeAny -import org.partiql.ast.typeBag -import org.partiql.ast.typeBlob -import org.partiql.ast.typeBool -import org.partiql.ast.typeChar -import org.partiql.ast.typeClob -import org.partiql.ast.typeCustom -import org.partiql.ast.typeDate -import org.partiql.ast.typeDecimal -import org.partiql.ast.typeFloat32 -import org.partiql.ast.typeFloat64 -import org.partiql.ast.typeInt2 -import org.partiql.ast.typeInt4 -import org.partiql.ast.typeInt8 -import org.partiql.ast.typeList -import org.partiql.ast.typeMissing -import org.partiql.ast.typeNullType -import org.partiql.ast.typeNumeric -import org.partiql.ast.typeReal -import org.partiql.ast.typeSexp -import org.partiql.ast.typeString -import org.partiql.ast.typeStruct -import org.partiql.ast.typeSymbol -import org.partiql.ast.typeTime -import org.partiql.ast.typeTimeWithTz -import org.partiql.ast.typeTimestamp -import org.partiql.ast.typeTimestampWithTz -import org.partiql.ast.typeTuple -import org.partiql.ast.typeVarchar +import org.partiql.ast.v1.Ast +import org.partiql.ast.v1.Ast.exclude +import org.partiql.ast.v1.Ast.excludePath +import org.partiql.ast.v1.Ast.excludeStepCollIndex +import org.partiql.ast.v1.Ast.excludeStepCollWildcard +import org.partiql.ast.v1.Ast.excludeStepStructField +import org.partiql.ast.v1.Ast.excludeStepStructWildcard +import org.partiql.ast.v1.Ast.explain +import org.partiql.ast.v1.Ast.exprAnd +import org.partiql.ast.v1.Ast.exprArray +import org.partiql.ast.v1.Ast.exprBag +import org.partiql.ast.v1.Ast.exprBetween +import org.partiql.ast.v1.Ast.exprCall +import org.partiql.ast.v1.Ast.exprCase +import org.partiql.ast.v1.Ast.exprCaseBranch +import org.partiql.ast.v1.Ast.exprCast +import org.partiql.ast.v1.Ast.exprCoalesce +import org.partiql.ast.v1.Ast.exprExtract +import org.partiql.ast.v1.Ast.exprInCollection +import org.partiql.ast.v1.Ast.exprIsType +import org.partiql.ast.v1.Ast.exprLike +import org.partiql.ast.v1.Ast.exprLit +import org.partiql.ast.v1.Ast.exprMatch +import org.partiql.ast.v1.Ast.exprNot +import org.partiql.ast.v1.Ast.exprNullIf +import org.partiql.ast.v1.Ast.exprOperator +import org.partiql.ast.v1.Ast.exprOr +import org.partiql.ast.v1.Ast.exprOverlay +import org.partiql.ast.v1.Ast.exprParameter +import org.partiql.ast.v1.Ast.exprPath +import org.partiql.ast.v1.Ast.exprPathStepAllElements +import org.partiql.ast.v1.Ast.exprPathStepAllFields +import org.partiql.ast.v1.Ast.exprPathStepElement +import org.partiql.ast.v1.Ast.exprPathStepField +import org.partiql.ast.v1.Ast.exprPosition +import org.partiql.ast.v1.Ast.exprQuerySet +import org.partiql.ast.v1.Ast.exprSessionAttribute +import org.partiql.ast.v1.Ast.exprStruct +import org.partiql.ast.v1.Ast.exprStructField +import org.partiql.ast.v1.Ast.exprSubstring +import org.partiql.ast.v1.Ast.exprTrim +import org.partiql.ast.v1.Ast.exprVarRef +import org.partiql.ast.v1.Ast.exprVariant +import org.partiql.ast.v1.Ast.exprWindow +import org.partiql.ast.v1.Ast.exprWindowOver +import org.partiql.ast.v1.Ast.from +import org.partiql.ast.v1.Ast.fromExpr +import org.partiql.ast.v1.Ast.fromJoin +import org.partiql.ast.v1.Ast.graphLabelConj +import org.partiql.ast.v1.Ast.graphLabelDisj +import org.partiql.ast.v1.Ast.graphLabelName +import org.partiql.ast.v1.Ast.graphLabelNegation +import org.partiql.ast.v1.Ast.graphLabelWildcard +import org.partiql.ast.v1.Ast.graphMatch +import org.partiql.ast.v1.Ast.graphMatchEdge +import org.partiql.ast.v1.Ast.graphMatchNode +import org.partiql.ast.v1.Ast.graphMatchPattern +import org.partiql.ast.v1.Ast.graphPattern +import org.partiql.ast.v1.Ast.graphQuantifier +import org.partiql.ast.v1.Ast.graphSelectorAllShortest +import org.partiql.ast.v1.Ast.graphSelectorAny +import org.partiql.ast.v1.Ast.graphSelectorAnyK +import org.partiql.ast.v1.Ast.graphSelectorAnyShortest +import org.partiql.ast.v1.Ast.graphSelectorShortestK +import org.partiql.ast.v1.Ast.graphSelectorShortestKGroup +import org.partiql.ast.v1.Ast.groupBy +import org.partiql.ast.v1.Ast.groupByKey +import org.partiql.ast.v1.Ast.identifier +import org.partiql.ast.v1.Ast.identifierChain +import org.partiql.ast.v1.Ast.letBinding +import org.partiql.ast.v1.Ast.orderBy +import org.partiql.ast.v1.Ast.query +import org.partiql.ast.v1.Ast.queryBodySFW +import org.partiql.ast.v1.Ast.queryBodySetOp +import org.partiql.ast.v1.Ast.selectItemExpr +import org.partiql.ast.v1.Ast.selectItemStar +import org.partiql.ast.v1.Ast.selectList +import org.partiql.ast.v1.Ast.selectPivot +import org.partiql.ast.v1.Ast.selectStar +import org.partiql.ast.v1.Ast.selectValue +import org.partiql.ast.v1.Ast.setOp +import org.partiql.ast.v1.Ast.sort +import org.partiql.ast.v1.AstNode +import org.partiql.ast.v1.DataType +import org.partiql.ast.v1.DatetimeField +import org.partiql.ast.v1.Exclude +import org.partiql.ast.v1.ExcludeStep +import org.partiql.ast.v1.From +import org.partiql.ast.v1.FromTableRef +import org.partiql.ast.v1.FromType +import org.partiql.ast.v1.GroupBy +import org.partiql.ast.v1.GroupByStrategy +import org.partiql.ast.v1.Identifier +import org.partiql.ast.v1.IdentifierChain +import org.partiql.ast.v1.JoinType +import org.partiql.ast.v1.Let +import org.partiql.ast.v1.Nulls +import org.partiql.ast.v1.Order +import org.partiql.ast.v1.Select +import org.partiql.ast.v1.SelectItem +import org.partiql.ast.v1.SetOpType +import org.partiql.ast.v1.SetQuantifier +import org.partiql.ast.v1.Sort +import org.partiql.ast.v1.Statement +import org.partiql.ast.v1.expr.Expr +import org.partiql.ast.v1.expr.ExprArray +import org.partiql.ast.v1.expr.ExprBag +import org.partiql.ast.v1.expr.ExprCall +import org.partiql.ast.v1.expr.ExprPath +import org.partiql.ast.v1.expr.ExprQuerySet +import org.partiql.ast.v1.expr.PathStep +import org.partiql.ast.v1.expr.Scope +import org.partiql.ast.v1.expr.SessionAttribute +import org.partiql.ast.v1.expr.TrimSpec +import org.partiql.ast.v1.expr.WindowFunction +import org.partiql.ast.v1.graph.GraphDirection +import org.partiql.ast.v1.graph.GraphLabel +import org.partiql.ast.v1.graph.GraphPart +import org.partiql.ast.v1.graph.GraphPattern +import org.partiql.ast.v1.graph.GraphQuantifier +import org.partiql.ast.v1.graph.GraphRestrictor +import org.partiql.ast.v1.graph.GraphSelector import org.partiql.parser.PartiQLLexerException import org.partiql.parser.PartiQLParser import org.partiql.parser.PartiQLParserException @@ -183,9 +164,7 @@ import org.partiql.spi.errors.PError import org.partiql.spi.errors.PErrorKind import org.partiql.spi.errors.PErrorListener import org.partiql.spi.errors.PErrorListenerException -import org.partiql.value.NumericValue import org.partiql.value.PartiQLValueExperimental -import org.partiql.value.StringValue import org.partiql.value.boolValue import org.partiql.value.dateValue import org.partiql.value.datetime.DateTimeException @@ -240,53 +219,53 @@ internal class PartiQLParserDefault : PartiQLParser { val error = PError.INTERNAL_ERROR(PErrorKind.SYNTAX(), null, throwable) ctx.errorListener.report(error) val locations = SourceLocations() - return PartiQLParser.Result(listOf(Statement.Query(Expr.Lit(nullValue()))), locations) + return PartiQLParser.Result( + mutableListOf(org.partiql.ast.v1.Query(org.partiql.ast.v1.expr.ExprLit(nullValue()))) as List, + locations + ) } } - companion object { + /** + * To reduce latency costs, the [PartiQLParserDefault] attempts to use [PredictionMode.SLL] and falls back to + * [PredictionMode.LL] if a [ParseCancellationException] is thrown by the [BailErrorStrategy]. + */ + private fun parse(source: String, listener: PErrorListener): PartiQLParser.Result = try { + parse(source, PredictionMode.SLL, listener) + } catch (ex: ParseCancellationException) { + parse(source, PredictionMode.LL, listener) + } - /** - * To reduce latency costs, the [PartiQLParserDefault] attempts to use [PredictionMode.SLL] and falls back to - * [PredictionMode.LL] if a [ParseCancellationException] is thrown by the [BailErrorStrategy]. - */ - private fun parse(source: String, listener: PErrorListener): PartiQLParser.Result = try { - parse(source, PredictionMode.SLL, listener) - } catch (ex: ParseCancellationException) { - parse(source, PredictionMode.LL, listener) - } + /** + * Parses an input string [source] using the given prediction mode. + */ + private fun parse(source: String, mode: PredictionMode, listener: PErrorListener): PartiQLParser.Result { + val tokens = createTokenStream(source, listener) + val parser = InterruptibleParser(tokens) + parser.reset() + parser.removeErrorListeners() + parser.interpreter.predictionMode = mode + when (mode) { + PredictionMode.SLL -> parser.errorHandler = BailErrorStrategy() + PredictionMode.LL -> parser.addErrorListener(ParseErrorListener(listener)) + else -> throw IllegalArgumentException("Unsupported parser mode: $mode") + } + val tree = parser.statements() + return Visitor.translate(tokens, tree) + } - /** - * Parses an input string [source] using the given prediction mode. - */ - private fun parse(source: String, mode: PredictionMode, listener: PErrorListener): PartiQLParser.Result { - val tokens = createTokenStream(source, listener) - val parser = InterruptibleParser(tokens) - parser.reset() - parser.removeErrorListeners() - parser.interpreter.predictionMode = mode - when (mode) { - PredictionMode.SLL -> parser.errorHandler = BailErrorStrategy() - PredictionMode.LL -> parser.addErrorListener(ParseErrorListener(listener)) - else -> throw IllegalArgumentException("Unsupported parser mode: $mode") - } - val tree = parser.statements() - return Visitor.translate(tokens, tree) - } - - private fun createTokenStream(source: String, listener: PErrorListener): CountingTokenStream { - val queryStream = source.byteInputStream(StandardCharsets.UTF_8) - val inputStream = try { - CharStreams.fromStream(queryStream) - } catch (ex: ClosedByInterruptException) { - throw InterruptedException() - } - val handler = TokenizeErrorListener(listener) - val lexer = GeneratedLexer(inputStream) - lexer.removeErrorListeners() - lexer.addErrorListener(handler) - return CountingTokenStream(lexer) - } + private fun createTokenStream(source: String, listener: PErrorListener): CountingTokenStream { + val queryStream = source.byteInputStream(StandardCharsets.UTF_8) + val inputStream = try { + CharStreams.fromStream(queryStream) + } catch (ex: ClosedByInterruptException) { + throw InterruptedException() + } + val handler = TokenizeErrorListener(listener) + val lexer = GeneratedLexer(inputStream) + lexer.removeErrorListeners() + lexer.addErrorListener(handler) + return CountingTokenStream(lexer) } /** @@ -302,14 +281,11 @@ internal class PartiQLParserDefault : PartiQLParser { msg: String, e: RecognitionException?, ) { - if (offendingSymbol is Token) { - val token = offendingSymbol.text - val location = org.partiql.spi.SourceLocation(line.toLong(), charPositionInLine + 1L, token.length.toLong()) - val error = PErrors.unrecognizedToken(location, token) - listener.report(error) - } else { - throw IllegalArgumentException("Offending symbol is not a Token.") - } + offendingSymbol as Token + val token = offendingSymbol.text + val location = org.partiql.spi.SourceLocation(line.toLong(), charPositionInLine + 1L, token.length.toLong()) + val error = PErrors.unrecognizedToken(location, token) + listener.report(error) } } @@ -329,16 +305,13 @@ internal class PartiQLParserDefault : PartiQLParser { msg: String, e: RecognitionException?, ) { - if (offendingSymbol is Token) { - val rule = e?.ctx?.toString(rules) ?: "UNKNOWN" // TODO: Do we want to display the offending rule? - val token = offendingSymbol.text - val tokenType = GeneratedParser.VOCABULARY.getSymbolicName(offendingSymbol.type) - val location = org.partiql.spi.SourceLocation(line.toLong(), charPositionInLine + 1L, token.length.toLong()) - val error = PErrors.unexpectedToken(location, tokenType, null) - listener.report(error) - } else { - throw IllegalArgumentException("Offending symbol is not a Token.") - } + offendingSymbol as Token + val rule = e?.ctx?.toString(rules) ?: "UNKNOWN" // TODO: Do we want to display the offending rule? + val token = offendingSymbol.text + val tokenType = GeneratedParser.VOCABULARY.getSymbolicName(offendingSymbol.type) + val location = org.partiql.spi.SourceLocation(line.toLong(), charPositionInLine + 1L, token.length.toLong()) + val error = PErrors.unexpectedToken(location, tokenType, null) + listener.report(error) } } @@ -375,7 +348,7 @@ internal class PartiQLParserDefault : PartiQLParser { } /** - * Translate an ANTLR ParseTree to a PartiQL + * Translate an ANTLR ParseTree to a PartiQL AST */ @OptIn(PartiQLValueExperimental::class) private class Visitor( @@ -400,7 +373,10 @@ internal class PartiQLParserDefault : PartiQLParser { val statements = tree.statement().map { statementCtx -> visitor.visit(statementCtx) as Statement } - return PartiQLParser.Result(statements, SourceLocations(locations)) + return PartiQLParser.Result( + statements, + SourceLocations(locations), + ) } fun error( @@ -475,11 +451,7 @@ internal class PartiQLParserDefault : PartiQLParser { val parameter = try { ExplainParameters.valueOf(option.param.text.uppercase()) } catch (ex: java.lang.IllegalArgumentException) { - throw error( - option.param, - "Unknown EXPLAIN parameter.", - ex - ) + throw error(option.param, "Unknown EXPLAIN parameter.", ex) } when (parameter) { ExplainParameters.TYPE -> { @@ -490,12 +462,13 @@ internal class PartiQLParserDefault : PartiQLParser { } } } - statementExplain( - target = statementExplainTargetDomain( - statement = visit(ctx.statement()) as Statement, - type = type, - format = format - ) + explain( + // TODO get rid of usage of PartiQLValue https://github.com/partiql/partiql-lang-kotlin/issues/1589 + options = mapOf( + "type" to stringValue(type), + "format" to stringValue(format) + ), + statement = visit(ctx.statement()) as Statement, ) } @@ -511,102 +484,102 @@ internal class PartiQLParserDefault : PartiQLParser { override fun visitByIdent(ctx: GeneratedParser.ByIdentContext) = visitSymbolPrimitive(ctx.symbolPrimitive()) - private fun visitSymbolPrimitive(ctx: GeneratedParser.SymbolPrimitiveContext): Identifier.Symbol = + private fun visitSymbolPrimitive(ctx: GeneratedParser.SymbolPrimitiveContext): Identifier = when (ctx) { is GeneratedParser.IdentifierQuotedContext -> visitIdentifierQuoted(ctx) is GeneratedParser.IdentifierUnquotedContext -> visitIdentifierUnquoted(ctx) else -> throw error(ctx, "Invalid symbol reference.") } - override fun visitIdentifierQuoted(ctx: GeneratedParser.IdentifierQuotedContext): Identifier.Symbol = translate(ctx) { - identifierSymbol( + override fun visitIdentifierQuoted(ctx: GeneratedParser.IdentifierQuotedContext): Identifier = translate(ctx) { + identifier( ctx.IDENTIFIER_QUOTED().getStringValue(), - Identifier.CaseSensitivity.SENSITIVE + true ) } - override fun visitIdentifierUnquoted(ctx: GeneratedParser.IdentifierUnquotedContext): Identifier.Symbol = translate(ctx) { - identifierSymbol( + override fun visitIdentifierUnquoted(ctx: GeneratedParser.IdentifierUnquotedContext): Identifier = translate(ctx) { + identifier( ctx.text, - Identifier.CaseSensitivity.INSENSITIVE + false ) } override fun visitQualifiedName(ctx: GeneratedParser.QualifiedNameContext) = translate(ctx) { val qualifier = ctx.qualifier.map { visitSymbolPrimitive(it) } - val name = visitSymbolPrimitive(ctx.name) + val name = identifierChain(visitSymbolPrimitive(ctx.name), null) if (qualifier.isEmpty()) { name } else { - val root = qualifier.first() - val steps = qualifier.drop(1) + listOf(name) - identifierQualified(root, steps) + qualifier.reversed().fold(name) { acc, id -> + identifierChain(root = id, next = acc) + } } } /** * - * DATA DEFINITION LANGUAGE (DDL) + * DATA DEFINITION LANGUAGE (DDL) -- deleted in v1; will be added before final v1 release * */ - override fun visitQueryDdl(ctx: GeneratedParser.QueryDdlContext): AstNode = visitDdl(ctx.ddl()) - - override fun visitDropTable(ctx: GeneratedParser.DropTableContext) = translate(ctx) { - val table = visitQualifiedName(ctx.qualifiedName()) - statementDDLDropTable(table) - } - - override fun visitDropIndex(ctx: GeneratedParser.DropIndexContext) = translate(ctx) { - val table = visitSymbolPrimitive(ctx.on) - val index = visitSymbolPrimitive(ctx.target) - statementDDLDropIndex(index, table) - } - - override fun visitCreateTable(ctx: GeneratedParser.CreateTableContext) = translate(ctx) { - val table = visitQualifiedName(ctx.qualifiedName()) - val definition = ctx.tableDef()?.let { visitTableDef(it) } - statementDDLCreateTable(table, definition) - } - - override fun visitCreateIndex(ctx: GeneratedParser.CreateIndexContext) = translate(ctx) { - // TODO add index name to ANTLR grammar - val name: Identifier? = null - val table = visitSymbolPrimitive(ctx.symbolPrimitive()) - val fields = ctx.pathSimple().map { path -> visitPathSimple(path) } - statementDDLCreateIndex(name, table, fields) - } - - override fun visitTableDef(ctx: GeneratedParser.TableDefContext) = translate(ctx) { - // Column Definitions are the only thing we currently allow as table definition parts - val columns = ctx.tableDefPart().filterIsInstance().map { - visitColumnDeclaration(it) - } - tableDefinition(columns) - } - - override fun visitColumnDeclaration(ctx: GeneratedParser.ColumnDeclarationContext) = translate(ctx) { - val name = visitSymbolPrimitive(ctx.columnName().symbolPrimitive()).symbol - val type = visit(ctx.type()) as Type - val constraints = ctx.columnConstraint().map { - visitColumnConstraint(it) - } - tableDefinitionColumn(name, type, constraints) - } - - override fun visitColumnConstraint(ctx: GeneratedParser.ColumnConstraintContext) = translate(ctx) { - val identifier = ctx.columnConstraintName()?.let { symbolToString(it.symbolPrimitive()) } - val body = visit(ctx.columnConstraintDef()) as TableDefinition.Column.Constraint.Body - tableDefinitionColumnConstraint(identifier, body) - } - - override fun visitColConstrNotNull(ctx: GeneratedParser.ColConstrNotNullContext) = translate(ctx) { - tableDefinitionColumnConstraintBodyNotNull() - } - - override fun visitColConstrNull(ctx: GeneratedParser.ColConstrNullContext) = translate(ctx) { - tableDefinitionColumnConstraintBodyNullable() - } +// override fun visitQueryDdl(ctx: GeneratedParser.QueryDdlContext): AstNode = visitDdl(ctx.ddl()) +// +// override fun visitDropTable(ctx: GeneratedParser.DropTableContext) = translate(ctx) { +// val table = visitQualifiedName(ctx.qualifiedName()) +// statementDDLDropTable(table) +// } +// +// override fun visitDropIndex(ctx: GeneratedParser.DropIndexContext) = translate(ctx) { +// val table = visitSymbolPrimitive(ctx.on) +// val index = visitSymbolPrimitive(ctx.target) +// statementDDLDropIndex(index, table) +// } +// +// override fun visitCreateTable(ctx: GeneratedParser.CreateTableContext) = translate(ctx) { +// val table = visitQualifiedName(ctx.qualifiedName()) +// val definition = ctx.tableDef()?.let { visitTableDef(it) } +// statementDDLCreateTable(table, definition) +// } +// +// override fun visitCreateIndex(ctx: GeneratedParser.CreateIndexContext) = translate(ctx) { +// // TODO add index name to ANTLR grammar +// val name: Identifier? = null +// val table = visitSymbolPrimitive(ctx.symbolPrimitive()) +// val fields = ctx.pathSimple().map { path -> visitPathSimple(path) } +// statementDDLCreateIndex(name, table, fields) +// } +// +// override fun visitTableDef(ctx: GeneratedParser.TableDefContext) = translate(ctx) { +// // Column Definitions are the only thing we currently allow as table definition parts +// val columns = ctx.tableDefPart().filterIsInstance().map { +// visitColumnDeclaration(it) +// } +// tableDefinition(columns) +// } +// +// override fun visitColumnDeclaration(ctx: GeneratedParser.ColumnDeclarationContext) = translate(ctx) { +// val name = visitSymbolPrimitive(ctx.columnName().symbolPrimitive()).symbol +// val type = visit(ctx.type()) as Type +// val constraints = ctx.columnConstraint().map { +// visitColumnConstraint(it) +// } +// tableDefinitionColumn(name, type, constraints) +// } +// +// override fun visitColumnConstraint(ctx: GeneratedParser.ColumnConstraintContext) = translate(ctx) { +// val identifier = ctx.columnConstraintName()?.let { symbolToString(it.symbolPrimitive()) } +// val body = visit(ctx.columnConstraintDef()) as TableDefinition.Column.Constraint.Body +// tableDefinitionColumnConstraint(identifier, body) +// } +// +// override fun visitColConstrNotNull(ctx: GeneratedParser.ColConstrNotNullContext) = translate(ctx) { +// tableDefinitionColumnConstraintBodyNotNull() +// } +// +// override fun visitColConstrNull(ctx: GeneratedParser.ColConstrNullContext) = translate(ctx) { +// tableDefinitionColumnConstraintBodyNullable() +// } /** * @@ -727,36 +700,24 @@ internal class PartiQLParserDefault : PartiQLParser { throw error(ctx, "DML no longer supported in the default PartiQLParser.") } + // "simple paths" used by previous DDL's CREATE INDEX override fun visitPathSimple(ctx: GeneratedParser.PathSimpleContext) = translate(ctx) { - val root = visitSymbolPrimitive(ctx.symbolPrimitive()) - val steps = visitOrEmpty(ctx.pathSimpleSteps()) - path(root, steps) + throw error(ctx, "DDL no longer supported in the default PartiQLParser.") } + // "simple paths" used by previous DDL's CREATE INDEX override fun visitPathSimpleLiteral(ctx: GeneratedParser.PathSimpleLiteralContext) = translate(ctx) { - val v = visit(ctx.literal()) - if (v !is Expr.Lit) { - throw error(ctx, "Expected a path element literal") - } - when (val i = v.value) { - is NumericValue<*> -> pathStepIndex(i.toInt32().value!!) - is StringValue -> pathStepSymbol( - identifierSymbol( - i.value!!, Identifier.CaseSensitivity.SENSITIVE - ) - ) - else -> throw error(ctx, "Expected an integer or string literal, found literal ${i.type}") - } + throw error(ctx, "DDL no longer supported in the default PartiQLParser.") } + // "simple paths" used by previous DDL's CREATE INDEX override fun visitPathSimpleSymbol(ctx: GeneratedParser.PathSimpleSymbolContext) = translate(ctx) { - val identifier = visitSymbolPrimitive(ctx.symbolPrimitive()) - pathStepSymbol(identifier) + throw error(ctx, "DDL no longer supported in the default PartiQLParser.") } + // "simple paths" used by previous DDL's CREATE INDEX override fun visitPathSimpleDotSymbol(ctx: GeneratedParser.PathSimpleDotSymbolContext) = translate(ctx) { - val identifier = visitSymbolPrimitive(ctx.symbolPrimitive()) - pathStepSymbol(identifier) + throw error(ctx, "DDL no longer supported in the default PartiQLParser.") } /** @@ -778,7 +739,7 @@ internal class PartiQLParserDefault : PartiQLParser { override fun visitDql(ctx: GeneratedParser.DqlContext) = translate(ctx) { val expr = visitAs(ctx.expr()) - statementQuery(expr) + query(expr) } override fun visitQueryBase(ctx: GeneratedParser.QueryBaseContext): AstNode = visit(ctx.exprSelect()) @@ -816,9 +777,9 @@ internal class PartiQLParserDefault : PartiQLParser { } override fun visitSelectItems(ctx: GeneratedParser.SelectItemsContext) = translate(ctx) { - val items = visitOrEmpty(ctx.projectionItems().projectionItem()) + val items = visitOrEmpty(ctx.projectionItems().projectionItem()) val setq = convertSetQuantifier(ctx.setQuantifierStrategy()) - selectProject(items, setq) + selectList(items, setq) } override fun visitSelectPivot(ctx: GeneratedParser.SelectPivotContext) = translate(ctx) { @@ -836,10 +797,10 @@ internal class PartiQLParserDefault : PartiQLParser { override fun visitProjectionItem(ctx: GeneratedParser.ProjectionItemContext) = translate(ctx) { val expr = visitExpr(ctx.expr()) val alias = ctx.symbolPrimitive()?.let { visitSymbolPrimitive(it) } - if (expr is Expr.Path) { + if (expr is ExprPath) { convertPathToProjectionItem(ctx, expr, alias) } else { - selectProjectItemExpression(expr, alias) + selectItemExpr(expr, alias) } } @@ -874,7 +835,7 @@ internal class PartiQLParserDefault : PartiQLParser { override fun visitLetClause(ctx: GeneratedParser.LetClauseContext) = translate(ctx) { val bindings = visitOrEmpty(ctx.letBinding()) - let(bindings) + Ast.let(bindings) } override fun visitLetBinding(ctx: GeneratedParser.LetBindingContext) = translate(ctx) { @@ -898,14 +859,14 @@ internal class PartiQLParserDefault : PartiQLParser { val expr = visitAs(ctx.expr()) val dir = when { ctx.dir == null -> null - ctx.dir.type == GeneratedParser.ASC -> Sort.Dir.ASC - ctx.dir.type == GeneratedParser.DESC -> Sort.Dir.DESC + ctx.dir.type == GeneratedParser.ASC -> Order.ASC() + ctx.dir.type == GeneratedParser.DESC -> Order.DESC() else -> throw error(ctx.dir, "Invalid ORDER BY direction; expected ASC or DESC") } val nulls = when { ctx.nulls == null -> null - ctx.nulls.type == GeneratedParser.FIRST -> Sort.Nulls.FIRST - ctx.nulls.type == GeneratedParser.LAST -> Sort.Nulls.LAST + ctx.nulls.type == GeneratedParser.FIRST -> Nulls.FIRST() + ctx.nulls.type == GeneratedParser.LAST -> Nulls.LAST() else -> throw error(ctx.nulls, "Invalid ORDER BY null ordering; expected FIRST or LAST") } sort(expr, dir, nulls) @@ -918,7 +879,7 @@ internal class PartiQLParserDefault : PartiQLParser { */ override fun visitGroupClause(ctx: GeneratedParser.GroupClauseContext) = translate(ctx) { - val strategy = if (ctx.PARTIAL() != null) GroupBy.Strategy.PARTIAL else GroupBy.Strategy.FULL + val strategy = if (ctx.PARTIAL() != null) GroupByStrategy.PARTIAL() else GroupByStrategy.FULL() val keys = visitOrEmpty(ctx.groupKey()) val alias = ctx.groupAlias()?.symbolPrimitive()?.let { visitSymbolPrimitive(it) } groupBy(strategy, keys, alias) @@ -942,9 +903,9 @@ internal class PartiQLParserDefault : PartiQLParser { override fun visitExcludeExpr(ctx: GeneratedParser.ExcludeExprContext) = translate(ctx) { val rootId = visitSymbolPrimitive(ctx.symbolPrimitive()) - val root = exprVar(rootId, Expr.Var.Scope.DEFAULT) - val steps = visitOrEmpty(ctx.excludeExprSteps()) - excludeItem(root, steps) + val root = exprVarRef(identifierChain(rootId, null), Scope.DEFAULT()) + val steps = visitOrEmpty(ctx.excludeExprSteps()) + excludePath(root, steps) } override fun visitExcludeExprTupleAttr(ctx: GeneratedParser.ExcludeExprTupleAttrContext) = translate(ctx) { @@ -961,10 +922,7 @@ internal class PartiQLParserDefault : PartiQLParser { override fun visitExcludeExprCollectionAttr(ctx: GeneratedParser.ExcludeExprCollectionAttrContext) = translate(ctx) { val attr = ctx.attr.getStringValue() - val identifier = identifierSymbol( - attr, - Identifier.CaseSensitivity.SENSITIVE, - ) + val identifier = identifier(attr, true) excludeStepStructField(identifier) } @@ -985,14 +943,14 @@ internal class PartiQLParserDefault : PartiQLParser { */ override fun visitBagOp(ctx: GeneratedParser.BagOpContext) = translate(ctx) { val setq = when { - ctx.ALL() != null -> SetQuantifier.ALL - ctx.DISTINCT() != null -> SetQuantifier.DISTINCT + ctx.ALL() != null -> SetQuantifier.ALL() + ctx.DISTINCT() != null -> SetQuantifier.DISTINCT() else -> null } val op = when (ctx.op.type) { - GeneratedParser.UNION -> setOp(SetOp.Type.UNION, setq) - GeneratedParser.INTERSECT -> setOp(SetOp.Type.INTERSECT, setq) - GeneratedParser.EXCEPT -> setOp(SetOp.Type.EXCEPT, setq) + GeneratedParser.UNION -> setOp(SetOpType.UNION(), setq) + GeneratedParser.INTERSECT -> setOp(SetOpType.INTERSECT(), setq) + GeneratedParser.EXCEPT -> setOp(SetOpType.EXCEPT(), setq) else -> error("Unsupported bag op token ${ctx.op}") } val lhs = visitAs(ctx.lhs) @@ -1002,15 +960,15 @@ internal class PartiQLParserDefault : PartiQLParser { val limit = ctx.limit?.let { visitAs(it) } val offset = ctx.offset?.let { visitAs(it) } exprQuerySet( - body = queryBodySetOp( - type = op, - isOuter = outer, - lhs = lhs, - rhs = rhs + queryBodySetOp( + op, + outer, + lhs, + rhs ), - orderBy = orderBy, - limit = limit, - offset = offset, + orderBy, + limit, + offset, ) } @@ -1022,28 +980,28 @@ internal class PartiQLParserDefault : PartiQLParser { override fun visitGpmlPattern(ctx: GeneratedParser.GpmlPatternContext) = translate(ctx) { val pattern = visitMatchPattern(ctx.matchPattern()) - val selector = visitOrNull(ctx.matchSelector()) + val selector = visitOrNull(ctx.matchSelector()) graphMatch(listOf(pattern), selector) } override fun visitGpmlPatternList(ctx: GeneratedParser.GpmlPatternListContext) = translate(ctx) { val patterns = ctx.matchPattern().map { pattern -> visitMatchPattern(pattern) } - val selector = visitOrNull(ctx.matchSelector()) + val selector = visitOrNull(ctx.matchSelector()) graphMatch(patterns, selector) } override fun visitMatchPattern(ctx: GeneratedParser.MatchPatternContext) = translate(ctx) { - val parts = visitOrEmpty(ctx.graphPart()) + val parts = visitOrEmpty(ctx.graphPart()) val restrictor = ctx.restrictor?.let { when (ctx.restrictor.text.lowercase()) { - "trail" -> GraphMatch.Restrictor.TRAIL - "acyclic" -> GraphMatch.Restrictor.ACYCLIC - "simple" -> GraphMatch.Restrictor.SIMPLE + "trail" -> GraphRestrictor.TRAIL() + "acyclic" -> GraphRestrictor.ACYCLIC() + "simple" -> GraphRestrictor.SIMPLE() else -> throw error(ctx.restrictor, "Unrecognized pattern restrictor") } } - val variable = visitOrNull(ctx.variable)?.symbol - graphMatchPattern(restrictor, null, variable, null, parts) + val variable = visitOrNull(ctx.variable)?.symbol + graphPattern(restrictor, null, variable, null, parts) } override fun visitPatternPathVariable(ctx: GeneratedParser.PatternPathVariableContext) = @@ -1051,161 +1009,175 @@ internal class PartiQLParserDefault : PartiQLParser { override fun visitSelectorBasic(ctx: GeneratedParser.SelectorBasicContext) = translate(ctx) { when (ctx.mod.type) { - GeneratedParser.ANY -> graphMatchSelectorAnyShortest() - GeneratedParser.ALL -> graphMatchSelectorAllShortest() + GeneratedParser.ANY -> graphSelectorAnyShortest() + GeneratedParser.ALL -> graphSelectorAllShortest() else -> throw error(ctx, "Unsupported match selector.") } } override fun visitSelectorAny(ctx: GeneratedParser.SelectorAnyContext) = translate(ctx) { when (ctx.k) { - null -> graphMatchSelectorAny() - else -> graphMatchSelectorAnyK(ctx.k.text.toLong()) + null -> graphSelectorAny() + else -> graphSelectorAnyK(ctx.k.text.toLong()) } } override fun visitSelectorShortest(ctx: GeneratedParser.SelectorShortestContext) = translate(ctx) { val k = ctx.k.text.toLong() when (ctx.GROUP()) { - null -> graphMatchSelectorShortestK(k) - else -> graphMatchSelectorShortestKGroup(k) + null -> graphSelectorShortestK(k) + else -> graphSelectorShortestKGroup(k) } } override fun visitLabelSpecOr(ctx: GeneratedParser.LabelSpecOrContext) = translate(ctx) { - val lhs = visit(ctx.labelSpec()) as GraphMatch.Label - val rhs = visit(ctx.labelTerm()) as GraphMatch.Label - graphMatchLabelDisj(lhs, rhs) + val lhs = visit(ctx.labelSpec()) as GraphLabel + val rhs = visit(ctx.labelTerm()) as GraphLabel + graphLabelDisj(lhs, rhs) } override fun visitLabelTermAnd(ctx: GeneratedParser.LabelTermAndContext) = translate(ctx) { - val lhs = visit(ctx.labelTerm()) as GraphMatch.Label - val rhs = visit(ctx.labelFactor()) as GraphMatch.Label - graphMatchLabelConj(lhs, rhs) + val lhs = visit(ctx.labelTerm()) as GraphLabel + val rhs = visit(ctx.labelFactor()) as GraphLabel + graphLabelConj(lhs, rhs) } override fun visitLabelFactorNot(ctx: GeneratedParser.LabelFactorNotContext) = translate(ctx) { - val arg = visit(ctx.labelPrimary()) as GraphMatch.Label - graphMatchLabelNegation(arg) + val arg = visit(ctx.labelPrimary()) as GraphLabel + graphLabelNegation(arg) } override fun visitLabelPrimaryName(ctx: GeneratedParser.LabelPrimaryNameContext) = translate(ctx) { val x = visitSymbolPrimitive(ctx.symbolPrimitive()) - graphMatchLabelName(x.symbol) + graphLabelName(x.symbol) } override fun visitLabelPrimaryWild(ctx: GeneratedParser.LabelPrimaryWildContext) = translate(ctx) { - graphMatchLabelWildcard() + graphLabelWildcard() } override fun visitLabelPrimaryParen(ctx: GeneratedParser.LabelPrimaryParenContext) = - visit(ctx.labelSpec()) as GraphMatch.Label + visit(ctx.labelSpec()) as GraphLabel override fun visitPattern(ctx: GeneratedParser.PatternContext) = translate(ctx) { val restrictor = visitRestrictor(ctx.restrictor) - val variable = visitOrNull(ctx.variable)?.symbol + val variable = visitOrNull(ctx.variable)?.symbol val prefilter = ctx.where?.let { visitExpr(it.expr()) } val quantifier = ctx.quantifier?.let { visitPatternQuantifier(it) } - val parts = visitOrEmpty(ctx.graphPart()) - graphMatchPattern(restrictor, prefilter, variable, quantifier, parts) + val parts = visitOrEmpty(ctx.graphPart()) + graphPattern(restrictor, prefilter, variable, quantifier, parts) } override fun visitEdgeAbbreviated(ctx: GeneratedParser.EdgeAbbreviatedContext) = translate(ctx) { val direction = visitEdge(ctx.edgeAbbrev()) - val quantifier = visitOrNull(ctx.quantifier) - graphMatchPatternPartEdge(direction, quantifier, null, null, null) - } + val quantifier = visitOrNull(ctx.quantifier) + graphMatchEdge(direction, quantifier, null, null, null) + } + + private fun GraphPart.Edge.copy( + direction: GraphDirection? = null, + quantifier: GraphQuantifier? = null, + prefilter: Expr? = null, + variable: String? = null, + label: GraphLabel? = null, + ) = graphMatchEdge( + direction = direction ?: this.direction, + quantifier = quantifier ?: this.quantifier, + prefilter = prefilter ?: this.prefilter, + variable = variable ?: this.variable, + label = label ?: this.label, + ) override fun visitEdgeWithSpec(ctx: GeneratedParser.EdgeWithSpecContext) = translate(ctx) { - val quantifier = visitOrNull(ctx.quantifier) - val edge = visitOrNull(ctx.edgeWSpec()) + val quantifier = visitOrNull(ctx.quantifier) + val edge = visitOrNull(ctx.edgeWSpec()) edge!!.copy(quantifier = quantifier) } override fun visitEdgeSpec(ctx: GeneratedParser.EdgeSpecContext) = translate(ctx) { - val placeholderDirection = GraphMatch.Direction.RIGHT - val variable = visitOrNull(ctx.symbolPrimitive())?.symbol + val placeholderDirection = GraphDirection.RIGHT() + val variable = visitOrNull(ctx.symbolPrimitive())?.symbol val prefilter = ctx.whereClause()?.let { visitExpr(it.expr()) } - val label = visitOrNull(ctx.labelSpec()) - graphMatchPatternPartEdge(placeholderDirection, null, prefilter, variable, label) + val label = visitOrNull(ctx.labelSpec()) + graphMatchEdge(placeholderDirection, null, prefilter, variable, label) } override fun visitEdgeSpecLeft(ctx: GeneratedParser.EdgeSpecLeftContext): AstNode { val edge = visitEdgeSpec(ctx.edgeSpec()) - return edge.copy(direction = GraphMatch.Direction.LEFT) + return edge.copy(direction = GraphDirection.LEFT()) } override fun visitEdgeSpecRight(ctx: GeneratedParser.EdgeSpecRightContext): AstNode { val edge = visitEdgeSpec(ctx.edgeSpec()) - return edge.copy(direction = GraphMatch.Direction.RIGHT) + return edge.copy(direction = GraphDirection.RIGHT()) } override fun visitEdgeSpecBidirectional(ctx: GeneratedParser.EdgeSpecBidirectionalContext): AstNode { val edge = visitEdgeSpec(ctx.edgeSpec()) - return edge.copy(direction = GraphMatch.Direction.LEFT_OR_RIGHT) + return edge.copy(direction = GraphDirection.LEFT_OR_RIGHT()) } override fun visitEdgeSpecUndirectedBidirectional(ctx: GeneratedParser.EdgeSpecUndirectedBidirectionalContext): AstNode { val edge = visitEdgeSpec(ctx.edgeSpec()) - return edge.copy(direction = GraphMatch.Direction.LEFT_UNDIRECTED_OR_RIGHT) + return edge.copy(direction = GraphDirection.LEFT_UNDIRECTED_OR_RIGHT()) } override fun visitEdgeSpecUndirected(ctx: GeneratedParser.EdgeSpecUndirectedContext): AstNode { val edge = visitEdgeSpec(ctx.edgeSpec()) - return edge.copy(direction = GraphMatch.Direction.UNDIRECTED) + return edge.copy(direction = GraphDirection.UNDIRECTED()) } override fun visitEdgeSpecUndirectedLeft(ctx: GeneratedParser.EdgeSpecUndirectedLeftContext): AstNode { val edge = visitEdgeSpec(ctx.edgeSpec()) - return edge.copy(direction = GraphMatch.Direction.LEFT_OR_UNDIRECTED) + return edge.copy(direction = GraphDirection.LEFT_OR_UNDIRECTED()) } override fun visitEdgeSpecUndirectedRight(ctx: GeneratedParser.EdgeSpecUndirectedRightContext): AstNode { val edge = visitEdgeSpec(ctx.edgeSpec()) - return edge.copy(direction = GraphMatch.Direction.UNDIRECTED_OR_RIGHT) + return edge.copy(direction = GraphDirection.UNDIRECTED_OR_RIGHT()) } - private fun visitEdge(ctx: GeneratedParser.EdgeAbbrevContext): GraphMatch.Direction = when { - ctx.TILDE() != null && ctx.ANGLE_RIGHT() != null -> GraphMatch.Direction.UNDIRECTED_OR_RIGHT - ctx.TILDE() != null && ctx.ANGLE_LEFT() != null -> GraphMatch.Direction.LEFT_OR_UNDIRECTED - ctx.TILDE() != null -> GraphMatch.Direction.UNDIRECTED - ctx.MINUS() != null && ctx.ANGLE_LEFT() != null && ctx.ANGLE_RIGHT() != null -> GraphMatch.Direction.LEFT_OR_RIGHT - ctx.MINUS() != null && ctx.ANGLE_LEFT() != null -> GraphMatch.Direction.LEFT - ctx.MINUS() != null && ctx.ANGLE_RIGHT() != null -> GraphMatch.Direction.RIGHT - ctx.MINUS() != null -> GraphMatch.Direction.LEFT_UNDIRECTED_OR_RIGHT + private fun visitEdge(ctx: GeneratedParser.EdgeAbbrevContext): GraphDirection = when { + ctx.TILDE() != null && ctx.ANGLE_RIGHT() != null -> GraphDirection.UNDIRECTED_OR_RIGHT() + ctx.TILDE() != null && ctx.ANGLE_LEFT() != null -> GraphDirection.LEFT_OR_UNDIRECTED() + ctx.TILDE() != null -> GraphDirection.UNDIRECTED() + ctx.MINUS() != null && ctx.ANGLE_LEFT() != null && ctx.ANGLE_RIGHT() != null -> GraphDirection.LEFT_OR_RIGHT() + ctx.MINUS() != null && ctx.ANGLE_LEFT() != null -> GraphDirection.LEFT() + ctx.MINUS() != null && ctx.ANGLE_RIGHT() != null -> GraphDirection.RIGHT() + ctx.MINUS() != null -> GraphDirection.LEFT_UNDIRECTED_OR_RIGHT() else -> throw error(ctx, "Unsupported edge type") } - override fun visitGraphPart(ctx: GeneratedParser.GraphPartContext): GraphMatch.Pattern.Part { + override fun visitGraphPart(ctx: GeneratedParser.GraphPartContext): GraphPart { val part = super.visitGraphPart(ctx) - if (part is GraphMatch.Pattern) { - return translate(ctx) { graphMatchPatternPartPattern(part) } + if (part is GraphPattern) { + return translate(ctx) { graphMatchPattern(part) } } - return part as GraphMatch.Pattern.Part + return part as GraphPart } override fun visitPatternQuantifier(ctx: GeneratedParser.PatternQuantifierContext) = translate(ctx) { when { - ctx.quant == null -> graphMatchQuantifier(ctx.lower.text.toLong(), ctx.upper?.text?.toLong()) - ctx.quant.type == GeneratedParser.PLUS -> graphMatchQuantifier(1L, null) - ctx.quant.type == GeneratedParser.ASTERISK -> graphMatchQuantifier(0L, null) + ctx.quant == null -> graphQuantifier(ctx.lower.text.toLong(), ctx.upper?.text?.toLong()) + ctx.quant.type == GeneratedParser.PLUS -> graphQuantifier(1L, null) + ctx.quant.type == GeneratedParser.ASTERISK -> graphQuantifier(0L, null) else -> throw error(ctx, "Unsupported quantifier") } } override fun visitNode(ctx: GeneratedParser.NodeContext) = translate(ctx) { - val variable = visitOrNull(ctx.symbolPrimitive())?.symbol + val variable = visitOrNull(ctx.symbolPrimitive())?.symbol val prefilter = ctx.whereClause()?.let { visitExpr(it.expr()) } - val label = visitOrNull(ctx.labelSpec()) - graphMatchPatternPartNode(prefilter, variable, label) + val label = visitOrNull(ctx.labelSpec()) + graphMatchNode(prefilter, variable, label) } - private fun visitRestrictor(ctx: GeneratedParser.PatternRestrictorContext?): GraphMatch.Restrictor? { + private fun visitRestrictor(ctx: GeneratedParser.PatternRestrictorContext?): GraphRestrictor? { if (ctx == null) return null return when (ctx.restrictor.text.lowercase()) { - "trail" -> GraphMatch.Restrictor.TRAIL - "acyclic" -> GraphMatch.Restrictor.ACYCLIC - "simple" -> GraphMatch.Restrictor.SIMPLE + "trail" -> GraphRestrictor.TRAIL() + "acyclic" -> GraphRestrictor.ACYCLIC() + "simple" -> GraphRestrictor.SIMPLE() else -> throw error(ctx, "Unrecognized pattern restrictor") } } @@ -1215,110 +1187,104 @@ internal class PartiQLParserDefault : PartiQLParser { * TABLE REFERENCES & JOINS & FROM CLAUSE * */ + override fun visitFromClause(ctx: GeneratedParser.FromClauseContext): From = translate(ctx) { + val tableRefs = visitOrEmpty(ctx.tableReference()) + from(tableRefs) + } - override fun visitFromClause(ctx: GeneratedParser.FromClauseContext) = translate(ctx) { - val tableRefs = visitOrEmpty(ctx.tableReference()) - tableRefs.drop(1).fold(tableRefs.first()) { acc, tableRef -> - fromJoin(acc, tableRef, From.Join.Type.CROSS, null) - } + override fun visitTableBaseRefSymbol(ctx: GeneratedParser.TableBaseRefSymbolContext): FromTableRef = translate(ctx) { + val expr = visitAs(ctx.source) + val asAlias = visitSymbolPrimitive(ctx.symbolPrimitive()) + fromExpr(expr, FromType.SCAN(), asAlias, null) } - override fun visitTableBaseRefClauses(ctx: GeneratedParser.TableBaseRefClausesContext) = translate(ctx) { + override fun visitTableBaseRefClauses(ctx: GeneratedParser.TableBaseRefClausesContext): FromTableRef = translate(ctx) { val expr = visitAs(ctx.source) val asAlias = ctx.asIdent()?.let { visitSymbolPrimitive(it.symbolPrimitive()) } val atAlias = ctx.atIdent()?.let { visitSymbolPrimitive(it.symbolPrimitive()) } - val byAlias = ctx.byIdent()?.let { visitSymbolPrimitive(it.symbolPrimitive()) } - fromValue(expr, From.Value.Type.SCAN, asAlias, atAlias, byAlias) + fromExpr(expr, FromType.SCAN(), asAlias, atAlias) } - override fun visitTableBaseRefMatch(ctx: GeneratedParser.TableBaseRefMatchContext) = translate(ctx) { + override fun visitTableBaseRefMatch(ctx: GeneratedParser.TableBaseRefMatchContext): FromTableRef = translate(ctx) { val expr = visitAs(ctx.source) val asAlias = ctx.asIdent()?.let { visitSymbolPrimitive(it.symbolPrimitive()) } val atAlias = ctx.atIdent()?.let { visitSymbolPrimitive(it.symbolPrimitive()) } - val byAlias = ctx.byIdent()?.let { visitSymbolPrimitive(it.symbolPrimitive()) } - fromValue(expr, From.Value.Type.SCAN, asAlias, atAlias, byAlias) + fromExpr(expr, FromType.SCAN(), asAlias, atAlias) } - /** - * TODO Remove as/at/by aliases from DELETE command grammar in PartiQL.g4 - */ - override fun visitFromClauseSimpleExplicit(ctx: GeneratedParser.FromClauseSimpleExplicitContext) = translate(ctx) { - throw error(ctx, "DML no longer supported in the default PartiQLParser.") - } - - /** - * TODO Remove fromClauseSimple rule from DELETE command grammar in PartiQL.g4 - */ - override fun visitFromClauseSimpleImplicit(ctx: GeneratedParser.FromClauseSimpleImplicitContext) = translate(ctx) { - throw error(ctx, "DML no longer supported in the default PartiQLParser.") - } - - override fun visitTableUnpivot(ctx: GeneratedParser.TableUnpivotContext) = translate(ctx) { + override fun visitTableUnpivot(ctx: GeneratedParser.TableUnpivotContext): FromTableRef = translate(ctx) { val expr = visitAs(ctx.expr()) val asAlias = ctx.asIdent()?.let { visitSymbolPrimitive(it.symbolPrimitive()) } val atAlias = ctx.atIdent()?.let { visitSymbolPrimitive(it.symbolPrimitive()) } - val byAlias = ctx.byIdent()?.let { visitSymbolPrimitive(it.symbolPrimitive()) } - fromValue(expr, From.Value.Type.UNPIVOT, asAlias, atAlias, byAlias) + fromExpr(expr, FromType.UNPIVOT(), asAlias, atAlias) } - override fun visitTableLeftCrossJoin(ctx: org.partiql.parser.internal.antlr.PartiQLParser.TableLeftCrossJoinContext) = translate(ctx) { - val lhs = visitAs(ctx.lhs) - val rhs = visitAs(ctx.rhs) + override fun visitTableWrapped(ctx: GeneratedParser.TableWrappedContext): FromTableRef = translate(ctx) { + visitAs(ctx.tableReference()) + } + + override fun visitTableLeftCrossJoin(ctx: GeneratedParser.TableLeftCrossJoinContext): FromTableRef = translate(ctx) { + val lhs = visitAs(ctx.lhs) + val rhs = visitAs(ctx.rhs) // PartiQL spec defines equivalence of // l LEFT CROSS JOIN r <=> l LEFT JOIN r ON TRUE // The other join types combined w/ CROSS JOIN are unspecified -- https://github.com/partiql/partiql-lang-kotlin/issues/1013 - fromJoin(lhs, rhs, From.Join.Type.LEFT, null) + fromJoin(lhs, rhs, JoinType.LEFT_CROSS(), null) + } + + override fun visitTableCrossJoin(ctx: GeneratedParser.TableCrossJoinContext): FromTableRef = translate(ctx) { + val lhs = visitAs(ctx.lhs) + val rhs = visitAs(ctx.rhs) + fromJoin(lhs, rhs, JoinType.CROSS(), null) } - override fun visitTableCrossJoin(ctx: GeneratedParser.TableCrossJoinContext) = translate(ctx) { - val lhs = visitAs(ctx.lhs) - val rhs = visitAs(ctx.rhs) - fromJoin(lhs, rhs, From.Join.Type.CROSS, null) + override fun visitTableQualifiedJoin(ctx: GeneratedParser.TableQualifiedJoinContext): FromTableRef = translate(ctx) { + val lhs = visitAs(ctx.lhs) + val rhs = visitAs(ctx.rhs) + val type = convertJoinType(ctx.joinType()) + val condition = ctx.joinSpec()?.let { visitExpr(it.expr()) } + fromJoin(lhs, rhs, type, condition) } - private fun convertJoinType(ctx: GeneratedParser.JoinTypeContext?): From.Join.Type? { + private fun convertJoinType(ctx: GeneratedParser.JoinTypeContext?): JoinType? { if (ctx == null) return null return when (ctx.mod.type) { - GeneratedParser.INNER -> From.Join.Type.INNER + GeneratedParser.INNER -> JoinType.INNER() GeneratedParser.LEFT -> when (ctx.OUTER()) { - null -> From.Join.Type.LEFT - else -> From.Join.Type.LEFT_OUTER + null -> JoinType.LEFT() + else -> JoinType.LEFT_OUTER() } GeneratedParser.RIGHT -> when (ctx.OUTER()) { - null -> From.Join.Type.RIGHT - else -> From.Join.Type.RIGHT_OUTER + null -> JoinType.RIGHT() + else -> JoinType.RIGHT_OUTER() } GeneratedParser.FULL -> when (ctx.OUTER()) { - null -> From.Join.Type.FULL - else -> From.Join.Type.FULL_OUTER + null -> JoinType.FULL() + else -> JoinType.FULL_OUTER() } GeneratedParser.OUTER -> { // TODO https://github.com/partiql/partiql-spec/issues/41 // TODO https://github.com/partiql/partiql-lang-kotlin/issues/1013 - From.Join.Type.FULL_OUTER + JoinType.FULL_OUTER() } else -> null } } - override fun visitTableQualifiedJoin(ctx: GeneratedParser.TableQualifiedJoinContext) = translate(ctx) { - val lhs = visitAs(ctx.lhs) - val rhs = visitAs(ctx.rhs) - val type = convertJoinType(ctx.joinType()) - val condition = ctx.joinSpec()?.let { visitExpr(it.expr()) } - fromJoin(lhs, rhs, type, condition) + /** + * TODO Remove as/at/by aliases from DELETE command grammar in PartiQL.g4 + */ + override fun visitFromClauseSimpleExplicit(ctx: GeneratedParser.FromClauseSimpleExplicitContext) = translate(ctx) { + throw error(ctx, "DML no longer supported in the default PartiQLParser.") } - override fun visitTableBaseRefSymbol(ctx: GeneratedParser.TableBaseRefSymbolContext) = translate(ctx) { - val expr = visitAs(ctx.source) - val asAlias = visitSymbolPrimitive(ctx.symbolPrimitive()) - fromValue(expr, From.Value.Type.SCAN, asAlias, null, null) + /** + * TODO Remove fromClauseSimple rule from DELETE command grammar in PartiQL.g4 + */ + override fun visitFromClauseSimpleImplicit(ctx: GeneratedParser.FromClauseSimpleImplicitContext) = translate(ctx) { + throw error(ctx, "DML no longer supported in the default PartiQLParser.") } - override fun visitTableWrapped(ctx: GeneratedParser.TableWrappedContext): AstNode = visit(ctx.tableReference()) - - override fun visitJoinSpec(ctx: GeneratedParser.JoinSpecContext) = visitExpr(ctx.expr()) - /** * SIMPLE EXPRESSIONS */ @@ -1414,11 +1380,11 @@ internal class PartiQLParserDefault : PartiQLParser { val lhs = visitAs(ctx.lhs) val rhs = visitAs(ctx.rhs ?: ctx.expr()).let { // Wrap rhs in an array unless it's a query or already a collection - if (it is Expr.QuerySet || it is Expr.Collection || ctx.PAREN_LEFT() == null) { + if (it is ExprQuerySet || it is ExprArray || it is ExprBag || ctx.PAREN_LEFT() == null) { it } else { // IN ( expr ) - exprCollection(Expr.Collection.Type.LIST, listOf(it)) + exprArray(listOf(it)) } } val not = ctx.NOT() != null @@ -1427,7 +1393,7 @@ internal class PartiQLParserDefault : PartiQLParser { override fun visitPredicateIs(ctx: GeneratedParser.PredicateIsContext) = translate(ctx) { val value = visitAs(ctx.lhs) - val type = visitAs(ctx.type()) + val type = visitAs(ctx.type()) val not = ctx.NOT() != null exprIsType(value, type, not) } @@ -1459,25 +1425,37 @@ internal class PartiQLParserDefault : PartiQLParser { override fun visitVariableIdentifier(ctx: GeneratedParser.VariableIdentifierContext) = translate(ctx) { val symbol = ctx.ident.getStringValue() - val case = when (ctx.ident.type) { - GeneratedParser.IDENTIFIER -> Identifier.CaseSensitivity.INSENSITIVE - else -> Identifier.CaseSensitivity.SENSITIVE + val isDelimited = when (ctx.ident.type) { + GeneratedParser.IDENTIFIER -> false + else -> true } val scope = when (ctx.qualifier) { - null -> Expr.Var.Scope.DEFAULT - else -> Expr.Var.Scope.LOCAL + null -> Scope.DEFAULT() + else -> Scope.LOCAL() } - exprVar(identifierSymbol(symbol, case), scope) + exprVarRef( + identifierChain( + root = identifier(symbol, isDelimited), + next = null + ), + scope + ) } override fun visitVariableKeyword(ctx: GeneratedParser.VariableKeywordContext) = translate(ctx) { val symbol = ctx.key.text - val case = Identifier.CaseSensitivity.INSENSITIVE + val isDelimited = false val scope = when (ctx.qualifier) { - null -> Expr.Var.Scope.DEFAULT - else -> Expr.Var.Scope.LOCAL + null -> Scope.DEFAULT() + else -> Scope.LOCAL() } - exprVar(identifierSymbol(symbol, case), scope) + exprVarRef( + identifierChain( + root = identifier(symbol, isDelimited), + next = null + ), + scope + ) } override fun visitParameter(ctx: GeneratedParser.ParameterContext) = translate(ctx) { @@ -1488,52 +1466,58 @@ internal class PartiQLParserDefault : PartiQLParser { } override fun visitSequenceConstructor(ctx: GeneratedParser.SequenceConstructorContext) = translate(ctx) { - val expressions = visitOrEmpty(ctx.expr()) - val type = when (ctx.datatype.type) { - GeneratedParser.LIST -> Expr.Collection.Type.LIST - GeneratedParser.SEXP -> Expr.Collection.Type.SEXP - else -> throw error(ctx.datatype, "Invalid sequence type") - } - exprCollection(type, expressions) + error("Sequence constructor not supported") + } + + private fun PathStep.copy(next: PathStep?) = when (this) { + is PathStep.Element -> exprPathStepElement(this.element, next) + is PathStep.Field -> exprPathStepField(this.field, next) + is PathStep.AllElements -> exprPathStepAllElements(next) + is PathStep.AllFields -> exprPathStepAllFields(next) + else -> error("Unsupported PathStep: $this") } override fun visitExprPrimaryPath(ctx: GeneratedParser.ExprPrimaryPathContext) = translate(ctx) { val base = visitAs(ctx.exprPrimary()) - val steps = ctx.pathStep().map { visit(it) as Expr.Path.Step } + val init: PathStep? = null + val steps = ctx.pathStep().reversed().fold(init) { acc, step -> + val stepExpr = visit(step) as PathStep + stepExpr.copy(acc) + } exprPath(base, steps) } override fun visitPathStepIndexExpr(ctx: GeneratedParser.PathStepIndexExprContext) = translate(ctx) { val key = visitAs(ctx.key) - exprPathStepIndex(key) + exprPathStepElement(key, null) } override fun visitPathStepDotExpr(ctx: GeneratedParser.PathStepDotExprContext) = translate(ctx) { val symbol = visitSymbolPrimitive(ctx.symbolPrimitive()) - exprPathStepSymbol(symbol) + exprPathStepField(symbol, null) } override fun visitPathStepIndexAll(ctx: GeneratedParser.PathStepIndexAllContext) = translate(ctx) { - exprPathStepWildcard() + exprPathStepAllElements(null) } override fun visitPathStepDotAll(ctx: GeneratedParser.PathStepDotAllContext) = translate(ctx) { - exprPathStepUnpivot() + exprPathStepAllFields(null) } override fun visitValues(ctx: GeneratedParser.ValuesContext) = translate(ctx) { - val rows = visitOrEmpty(ctx.valueRow()) - exprCollection(Expr.Collection.Type.BAG, rows) + val rows = visitOrEmpty(ctx.valueRow()) + exprBag(rows) } override fun visitValueRow(ctx: GeneratedParser.ValueRowContext) = translate(ctx) { val expressions = visitOrEmpty(ctx.expr()) - exprCollection(Expr.Collection.Type.LIST, expressions) + exprArray(expressions) } override fun visitValueList(ctx: GeneratedParser.ValueListContext) = translate(ctx) { val expressions = visitOrEmpty(ctx.expr()) - exprCollection(Expr.Collection.Type.LIST, expressions) + exprArray(expressions) } override fun visitExprGraphMatchMany(ctx: GeneratedParser.ExprGraphMatchManyContext) = translate(ctx) { @@ -1549,12 +1533,12 @@ internal class PartiQLParserDefault : PartiQLParser { } override fun visitExprTermCurrentUser(ctx: GeneratedParser.ExprTermCurrentUserContext) = translate(ctx) { - exprSessionAttribute(Expr.SessionAttribute.Attribute.CURRENT_USER) + exprSessionAttribute(SessionAttribute.CURRENT_USER()) } override fun visitExprTermCurrentDate(ctx: GeneratedParser.ExprTermCurrentDateContext) = translate(ctx) { - exprSessionAttribute(Expr.SessionAttribute.Attribute.CURRENT_DATE) + exprSessionAttribute(SessionAttribute.CURRENT_DATE()) } /** @@ -1588,7 +1572,7 @@ internal class PartiQLParserDefault : PartiQLParser { override fun visitCast(ctx: GeneratedParser.CastContext) = translate(ctx) { val expr = visitExpr(ctx.expr()) - val type = visitAs(ctx.type()) + val type = visitAs(ctx.type()) exprCast(expr, type) } @@ -1608,13 +1592,14 @@ internal class PartiQLParserDefault : PartiQLParser { GeneratedParser.MOD -> exprOperator("%", args[0], args[1]) GeneratedParser.CHARACTER_LENGTH, GeneratedParser.CHAR_LENGTH -> { val path = ctx.qualifiedName().qualifier.map { visitSymbolPrimitive(it) } - val name = identifierSymbol("char_length", Identifier.CaseSensitivity.INSENSITIVE) + val name = identifierChain(identifier("char_length", false), null) if (path.isEmpty()) { - exprCall(name, args, setq = null) // setq = null for scalar fn + exprCall(name, args, null) // setq = null for scalar fn } else { - val root = path.first() - val steps = path.drop(1) + listOf(name) - exprCall(identifierQualified(root, steps), args, setq = null) + val function = path.reversed().fold(name) { acc, id -> + identifierChain(root = id, next = acc) + } + exprCall(function, args, setq = null) } } else -> visitNonReservedFunctionCall(ctx, args) @@ -1623,7 +1608,7 @@ internal class PartiQLParserDefault : PartiQLParser { else -> visitNonReservedFunctionCall(ctx, args) } } - private fun visitNonReservedFunctionCall(ctx: GeneratedParser.FunctionCallContext, args: List): Expr.Call { + private fun visitNonReservedFunctionCall(ctx: GeneratedParser.FunctionCallContext, args: List): ExprCall { val function = visitQualifiedName(ctx.qualifiedName()) return exprCall(function, args, convertSetQuantifier(ctx.setQuantifierStrategy())) } @@ -1635,16 +1620,19 @@ internal class PartiQLParserDefault : PartiQLParser { */ override fun visitDateFunction(ctx: GeneratedParser.DateFunctionContext) = translate(ctx) { - val field = try { - DatetimeField.valueOf(ctx.dt.text.uppercase()) + try { + DatetimeField.parse(ctx.dt.text) } catch (ex: IllegalArgumentException) { - throw error(ctx.dt, "Expected one of: ${DatetimeField.values().joinToString()}", ex) + throw error(ctx.dt, "Expected one of: ${DatetimeField.codes().joinToString()}", ex) } val lhs = visitExpr(ctx.expr(0)) val rhs = visitExpr(ctx.expr(1)) + // TODO change to not use PartiQLValue -- https://github.com/partiql/partiql-lang-kotlin/issues/1589 + val fieldLit = ctx.dt.text.lowercase() + // TODO error on invalid datetime fields like TIMEZONE_HOUR and TIMEZONE_MINUTE when { - ctx.DATE_ADD() != null -> exprDateAdd(field, lhs, rhs) - ctx.DATE_DIFF() != null -> exprDateDiff(field, lhs, rhs) + ctx.DATE_ADD() != null -> exprCall(identifierChain(identifier("date_add_$fieldLit", false), null), listOf(lhs, rhs), null) + ctx.DATE_DIFF() != null -> exprCall(identifierChain(identifier("date_diff_$fieldLit", false), null), listOf(lhs, rhs), null) else -> throw error(ctx, "Expected DATE_ADD or DATE_DIFF") } } @@ -1655,7 +1643,7 @@ internal class PartiQLParserDefault : PartiQLParser { override fun visitSubstring(ctx: GeneratedParser.SubstringContext) = translate(ctx) { if (ctx.FROM() == null) { // normal form - val function = "SUBSTRING".toIdentifier() + val function = "SUBSTRING".toIdentifierChain() val args = visitOrEmpty(ctx.expr()) exprCall(function, args, setq = null) // setq = null for scalar fn } else { @@ -1673,7 +1661,7 @@ internal class PartiQLParserDefault : PartiQLParser { override fun visitPosition(ctx: GeneratedParser.PositionContext) = translate(ctx) { if (ctx.IN() == null) { // normal form - val function = "POSITION".toIdentifier() + val function = "POSITION".toIdentifierChain() val args = visitOrEmpty(ctx.expr()) exprCall(function, args, setq = null) // setq = null for scalar fn } else { @@ -1691,7 +1679,7 @@ internal class PartiQLParserDefault : PartiQLParser { // TODO: figure out why do we have a normalized form for overlay? if (ctx.PLACING() == null) { // normal form - val function = "OVERLAY".toIdentifier() + val function = "OVERLAY".toIdentifierChain() val args = arrayOfNulls(4).also { visitOrEmpty(ctx.expr()).forEachIndexed { index, expr -> it[index] = expr @@ -1712,9 +1700,11 @@ internal class PartiQLParserDefault : PartiQLParser { override fun visitExtract(ctx: GeneratedParser.ExtractContext) = translate(ctx) { val field = try { - DatetimeField.valueOf(ctx.IDENTIFIER().text.uppercase()) + DatetimeField.parse(ctx.IDENTIFIER().text.uppercase()) } catch (ex: IllegalArgumentException) { - throw error(ctx.IDENTIFIER().symbol, "Expected one of: ${DatetimeField.values().joinToString()}", ex) + // TODO decide if we want int codes here or actual text. If we want text here, then there should be a + // method to convert the code into text. + throw error(ctx.IDENTIFIER().symbol, "Expected one of: ${DatetimeField.codes().joinToString()}", ex) } val source = visitExpr(ctx.expr()) exprExtract(field, source) @@ -1723,9 +1713,9 @@ internal class PartiQLParserDefault : PartiQLParser { override fun visitTrimFunction(ctx: GeneratedParser.TrimFunctionContext) = translate(ctx) { val spec = ctx.mod?.let { try { - Expr.Trim.Spec.valueOf(it.text.uppercase()) + TrimSpec.parse(it.text.uppercase()) } catch (ex: IllegalArgumentException) { - throw error(it, "Expected on of: ${Expr.Trim.Spec.values().joinToString()}", ex) + throw error(it, "Expected on of: ${TrimSpec.codes().joinToString()}", ex) } } val (chars, value) = when (ctx.expr().size) { @@ -1742,8 +1732,8 @@ internal class PartiQLParserDefault : PartiQLParser { override fun visitLagLeadFunction(ctx: GeneratedParser.LagLeadFunctionContext) = translate(ctx) { val function = when { - ctx.LAG() != null -> Expr.Window.Function.LAG - ctx.LEAD() != null -> Expr.Window.Function.LEAD + ctx.LAG() != null -> WindowFunction.LAG() + ctx.LEAD() != null -> WindowFunction.LEAD() else -> throw error(ctx, "Expected LAG or LEAD") } val expression = visitExpr(ctx.expr(0)) @@ -1776,7 +1766,7 @@ internal class PartiQLParserDefault : PartiQLParser { throw error(ctx, "Invalid bag expression") } val expressions = visitOrEmpty(ctx.expr()) - exprCollection(Expr.Collection.Type.BAG, expressions) + exprBag(expressions) } override fun visitLiteralDecimal(ctx: GeneratedParser.LiteralDecimalContext) = translate(ctx) { @@ -1791,7 +1781,7 @@ internal class PartiQLParserDefault : PartiQLParser { override fun visitArray(ctx: GeneratedParser.ArrayContext) = translate(ctx) { val expressions = visitOrEmpty(ctx.expr()) - exprCollection(Expr.Collection.Type.ARRAY, expressions) + exprArray(expressions) } override fun visitLiteralNull(ctx: GeneratedParser.LiteralNullContext) = translate(ctx) { @@ -1905,51 +1895,67 @@ internal class PartiQLParserDefault : PartiQLParser { override fun visitTypeAtomic(ctx: GeneratedParser.TypeAtomicContext) = translate(ctx) { when (ctx.datatype.type) { - GeneratedParser.NULL -> typeNullType() - GeneratedParser.BOOL, GeneratedParser.BOOLEAN -> typeBool() - GeneratedParser.SMALLINT, GeneratedParser.INT2, GeneratedParser.INTEGER2 -> typeInt2() + GeneratedParser.NULL -> DataType.NULL() + GeneratedParser.BOOL -> DataType.BOOLEAN() + GeneratedParser.BOOLEAN -> DataType.BOOL() + GeneratedParser.SMALLINT -> DataType.SMALLINT() + GeneratedParser.INT2 -> DataType.INT2() + GeneratedParser.INTEGER2 -> DataType.INTEGER2() // TODO, we have INT aliased to INT4 when it should be visa-versa. - GeneratedParser.INT4, GeneratedParser.INTEGER4 -> typeInt4() - GeneratedParser.INT, GeneratedParser.INTEGER -> typeInt4() - GeneratedParser.BIGINT, GeneratedParser.INT8, GeneratedParser.INTEGER8 -> typeInt8() - GeneratedParser.FLOAT -> typeFloat32() - GeneratedParser.DOUBLE -> typeFloat64() - GeneratedParser.REAL -> typeReal() - GeneratedParser.TIMESTAMP -> typeTimestamp(null) - GeneratedParser.CHAR, GeneratedParser.CHARACTER -> typeChar(null) - GeneratedParser.MISSING -> typeMissing() - GeneratedParser.STRING -> typeString(null) - GeneratedParser.SYMBOL -> typeSymbol() + GeneratedParser.INT4 -> DataType.INT4() + GeneratedParser.INTEGER4 -> DataType.INTEGER4() + GeneratedParser.INT -> DataType.INT() + GeneratedParser.INTEGER -> DataType.INTEGER() + GeneratedParser.BIGINT -> DataType.BIGINT() + GeneratedParser.INT8 -> DataType.INT8() + GeneratedParser.INTEGER8 -> DataType.INTEGER8() + GeneratedParser.FLOAT -> DataType.FLOAT() + GeneratedParser.DOUBLE -> TODO() // not sure if DOUBLE is to be supported + GeneratedParser.REAL -> DataType.REAL() + GeneratedParser.TIMESTAMP -> DataType.TIMESTAMP() + GeneratedParser.CHAR -> DataType.CHAR() + GeneratedParser.CHARACTER -> DataType.CHARACTER() + GeneratedParser.MISSING -> DataType.MISSING() + GeneratedParser.STRING -> DataType.STRING() + GeneratedParser.SYMBOL -> DataType.SYMBOL() // TODO https://github.com/partiql/partiql-lang-kotlin/issues/1125 - GeneratedParser.BLOB -> typeBlob(null) - GeneratedParser.CLOB -> typeClob(null) - GeneratedParser.DATE -> typeDate() - GeneratedParser.STRUCT -> typeStruct() - GeneratedParser.TUPLE -> typeTuple() - GeneratedParser.LIST -> typeList() - GeneratedParser.SEXP -> typeSexp() - GeneratedParser.BAG -> typeBag() - GeneratedParser.ANY -> typeAny() - else -> throw error(ctx, "Unknown atomic type.") + GeneratedParser.BLOB -> DataType.BLOB() + GeneratedParser.CLOB -> DataType.CLOB() + GeneratedParser.DATE -> DataType.DATE() + GeneratedParser.STRUCT -> DataType.STRUCT() + GeneratedParser.TUPLE -> DataType.TUPLE() + GeneratedParser.LIST -> DataType.LIST() + GeneratedParser.SEXP -> DataType.SEXP() + GeneratedParser.BAG -> DataType.BAG() + GeneratedParser.ANY -> TODO() // not sure if ANY is to be supported + else -> throw error(ctx, "Unknown atomic type.") // TODO other types included in parser } } - override fun visitTypeVarChar(ctx: GeneratedParser.TypeVarCharContext) = translate(ctx) { - val n = ctx.arg0?.text?.toInt() - typeVarchar(n) + override fun visitTypeVarChar(ctx: GeneratedParser.TypeVarCharContext): DataType = translate(ctx) { + when (val n = ctx.arg0?.text?.toInt()) { + null -> DataType.VARCHAR() + else -> DataType.VARCHAR(n) + } } override fun visitTypeArgSingle(ctx: GeneratedParser.TypeArgSingleContext) = translate(ctx) { val n = ctx.arg0?.text?.toInt() when (ctx.datatype.type) { GeneratedParser.FLOAT -> when (n) { - null -> typeFloat64() - 32 -> typeFloat32() - 64 -> typeFloat64() + null -> DataType.FLOAT(64) + 32 -> DataType.FLOAT(32) + 64 -> DataType.FLOAT(64) else -> throw error(ctx.datatype, "Invalid FLOAT precision. Expected 32 or 64") } - GeneratedParser.CHAR, GeneratedParser.CHARACTER -> typeChar(n) - GeneratedParser.VARCHAR -> typeVarchar(n) + GeneratedParser.CHAR, GeneratedParser.CHARACTER -> when (n) { + null -> DataType.CHAR() + else -> DataType.CHAR(n) + } + GeneratedParser.VARCHAR -> when (n) { + null -> DataType.VARCHAR() + else -> DataType.VARCHAR(n) + } else -> throw error(ctx.datatype, "Invalid datatype") } } @@ -1958,8 +1964,24 @@ internal class PartiQLParserDefault : PartiQLParser { val arg0 = ctx.arg0?.text?.toInt() val arg1 = ctx.arg1?.text?.toInt() when (ctx.datatype.type) { - GeneratedParser.DECIMAL, GeneratedParser.DEC -> typeDecimal(arg0, arg1) - GeneratedParser.NUMERIC -> typeNumeric(arg0, arg1) + GeneratedParser.DECIMAL -> when { + arg0 == null && arg1 == null -> DataType.DECIMAL() + arg0 != null && arg1 == null -> DataType.DECIMAL(arg0) + arg0 != null && arg1 != null -> DataType.DECIMAL(arg0, arg1) + else -> error("Invalid parameters for decimal") + } + GeneratedParser.DEC -> when { + arg0 == null && arg1 == null -> DataType.DEC() + arg0 != null && arg1 == null -> DataType.DEC(arg0) + arg0 != null && arg1 != null -> DataType.DEC(arg0, arg1) + else -> error("Invalid parameters for dec") + } + GeneratedParser.NUMERIC -> when { + arg0 == null && arg1 == null -> DataType.NUMERIC() + arg0 != null && arg1 == null -> DataType.NUMERIC(arg0) + arg0 != null && arg1 != null -> DataType.NUMERIC(arg0, arg1) + else -> error("Invalid parameters for decimal") + } else -> throw error(ctx.datatype, "Invalid datatype") } } @@ -1973,19 +1995,31 @@ internal class PartiQLParserDefault : PartiQLParser { when (ctx.datatype.type) { GeneratedParser.TIME -> when (ctx.ZONE()) { - null -> typeTime(precision) - else -> typeTimeWithTz(precision) + null -> when (precision) { + null -> DataType.TIME() + else -> DataType.TIME(precision) + } + else -> when (precision) { + null -> DataType.TIME_WITH_TIME_ZONE() + else -> DataType.TIME_WITH_TIME_ZONE(precision) + } } GeneratedParser.TIMESTAMP -> when (ctx.ZONE()) { - null -> typeTimestamp(precision) - else -> typeTimestampWithTz(precision) + null -> when (precision) { + null -> DataType.TIMESTAMP() + else -> DataType.TIMESTAMP(precision) + } + else -> when (precision) { + null -> DataType.TIMESTAMP_WITH_TIME_ZONE() + else -> DataType.TIMESTAMP_WITH_TIME_ZONE(precision) + } } else -> throw error(ctx.datatype, "Invalid datatype") } } override fun visitTypeCustom(ctx: GeneratedParser.TypeCustomContext) = translate(ctx) { - typeCustom(ctx.text.uppercase()) + DataType.USER_DEFINED(ctx.text.uppercase().toIdentifierChain()) } private inline fun visitOrEmpty(ctx: List?): List = when { @@ -2012,8 +2046,8 @@ internal class PartiQLParserDefault : PartiQLParser { */ private fun convertSetQuantifier(ctx: GeneratedParser.SetQuantifierStrategyContext?): SetQuantifier? = when { ctx == null -> null - ctx.ALL() != null -> SetQuantifier.ALL - ctx.DISTINCT() != null -> SetQuantifier.DISTINCT + ctx.ALL() != null -> SetQuantifier.ALL() + ctx.DISTINCT() != null -> SetQuantifier.DISTINCT() else -> throw error(ctx, "Expected set quantifier ALL or DISTINCT") } @@ -2075,39 +2109,55 @@ internal class PartiQLParserDefault : PartiQLParser { * SELECT foo.*.bar FROM foo * ``` */ - protected fun convertPathToProjectionItem(ctx: ParserRuleContext, path: Expr.Path, alias: Identifier.Symbol?) = + protected fun convertPathToProjectionItem(ctx: ParserRuleContext, path: ExprPath, alias: Identifier?) = translate(ctx) { - val steps = mutableListOf() + val steps = mutableListOf() var containsIndex = false - path.steps.forEachIndexed { index, step -> + var curStep = path.next + var last = curStep + while (curStep != null) { + val isLastStep = curStep.next == null // Only last step can have a '.*' - if (step is Expr.Path.Step.Unpivot && index != path.steps.lastIndex) { + if (curStep is PathStep.AllFields && !isLastStep) { throw error(ctx, "Projection item cannot unpivot unless at end.") } // No step can have an indexed wildcard: '[*]' - if (step is Expr.Path.Step.Wildcard) { + if (curStep is PathStep.AllElements) { throw error(ctx, "Projection item cannot index using wildcard.") } // TODO If the last step is '.*', no indexing is allowed // if (step.metas.containsKey(IsPathIndexMeta.TAG)) { // containsIndex = true // } - if (step !is Expr.Path.Step.Unpivot) { - steps.add(step) + if (curStep !is PathStep.AllFields) { + steps.add(curStep) } - } - if (path.steps.last() is Expr.Path.Step.Unpivot && containsIndex) { - throw error(ctx, "Projection item use wildcard with any indexing.") + + if (isLastStep && curStep is PathStep.AllFields && containsIndex) { + throw error(ctx, "Projection item use wildcard with any indexing.") + } + last = curStep + curStep = curStep.next } when { - path.steps.last() is Expr.Path.Step.Unpivot && steps.isEmpty() -> { - selectProjectItemAll(path.root) + last is PathStep.AllFields && steps.isEmpty() -> { + selectItemStar(path.root) } - path.steps.last() is Expr.Path.Step.Unpivot -> { - selectProjectItemAll(exprPath(path.root, steps)) + last is PathStep.AllFields -> { + val init: PathStep? = null + val newSteps = steps.reversed().fold(init) { acc, step -> + when (step) { + is PathStep.Element -> PathStep.Element(step.element, acc) + is PathStep.Field -> PathStep.Field(step.field, acc) + is PathStep.AllElements -> PathStep.AllElements(acc) + is PathStep.AllFields -> PathStep.AllFields(acc) + else -> error("Unexpected path step") + } + } + selectItemStar(exprPath(path.root, newSteps)) } else -> { - selectProjectItemExpression(path, alias) + selectItemExpr(path, alias) } } } @@ -2122,10 +2172,9 @@ internal class PartiQLParserDefault : PartiQLParser { else -> throw error(this, "Unsupported token for grabbing string value.") } - private fun String.toIdentifier(): Identifier.Symbol = identifierSymbol( - symbol = this, - caseSensitivity = Identifier.CaseSensitivity.INSENSITIVE, - ) + private fun String.toIdentifier(): Identifier = identifier(this, false) + + private fun String.toIdentifierChain(): IdentifierChain = identifierChain(root = this.toIdentifier(), next = null) private fun String.toBigInteger() = BigInteger(this, 10) diff --git a/partiql-parser/src/main/kotlin/org/partiql/parser/internal/PartiQLParserDefaultV1.kt b/partiql-parser/src/main/kotlin/org/partiql/parser/internal/PartiQLParserDefaultV1.kt deleted file mode 100644 index e861073bc1..0000000000 --- a/partiql-parser/src/main/kotlin/org/partiql/parser/internal/PartiQLParserDefaultV1.kt +++ /dev/null @@ -1,2199 +0,0 @@ -/* - * Copyright Amazon.com, Inc. or its affiliates. All rights reserved. - * - * Licensed under the Apache License, Version 2.0 (the "License"). - * You may not use this file except in compliance with the License. - * A copy of the License is located at: - * - * http://aws.amazon.com/apache2.0/ - * - * or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific - * language governing permissions and limitations under the License. - */ - -package org.partiql.parser.internal - -import com.amazon.ionelement.api.IntElement -import com.amazon.ionelement.api.IntElementSize -import com.amazon.ionelement.api.IonElement -import org.antlr.v4.runtime.BailErrorStrategy -import org.antlr.v4.runtime.BaseErrorListener -import org.antlr.v4.runtime.CharStreams -import org.antlr.v4.runtime.CommonTokenStream -import org.antlr.v4.runtime.ParserRuleContext -import org.antlr.v4.runtime.RecognitionException -import org.antlr.v4.runtime.Recognizer -import org.antlr.v4.runtime.Token -import org.antlr.v4.runtime.TokenSource -import org.antlr.v4.runtime.TokenStream -import org.antlr.v4.runtime.atn.PredictionMode -import org.antlr.v4.runtime.misc.ParseCancellationException -import org.antlr.v4.runtime.tree.TerminalNode -import org.partiql.ast.v1.Ast -import org.partiql.ast.v1.Ast.exclude -import org.partiql.ast.v1.Ast.excludePath -import org.partiql.ast.v1.Ast.excludeStepCollIndex -import org.partiql.ast.v1.Ast.excludeStepCollWildcard -import org.partiql.ast.v1.Ast.excludeStepStructField -import org.partiql.ast.v1.Ast.excludeStepStructWildcard -import org.partiql.ast.v1.Ast.explain -import org.partiql.ast.v1.Ast.exprAnd -import org.partiql.ast.v1.Ast.exprArray -import org.partiql.ast.v1.Ast.exprBag -import org.partiql.ast.v1.Ast.exprBetween -import org.partiql.ast.v1.Ast.exprCall -import org.partiql.ast.v1.Ast.exprCase -import org.partiql.ast.v1.Ast.exprCaseBranch -import org.partiql.ast.v1.Ast.exprCast -import org.partiql.ast.v1.Ast.exprCoalesce -import org.partiql.ast.v1.Ast.exprExtract -import org.partiql.ast.v1.Ast.exprInCollection -import org.partiql.ast.v1.Ast.exprIsType -import org.partiql.ast.v1.Ast.exprLike -import org.partiql.ast.v1.Ast.exprLit -import org.partiql.ast.v1.Ast.exprMatch -import org.partiql.ast.v1.Ast.exprNot -import org.partiql.ast.v1.Ast.exprNullIf -import org.partiql.ast.v1.Ast.exprOperator -import org.partiql.ast.v1.Ast.exprOr -import org.partiql.ast.v1.Ast.exprOverlay -import org.partiql.ast.v1.Ast.exprParameter -import org.partiql.ast.v1.Ast.exprPath -import org.partiql.ast.v1.Ast.exprPathStepAllElements -import org.partiql.ast.v1.Ast.exprPathStepAllFields -import org.partiql.ast.v1.Ast.exprPathStepElement -import org.partiql.ast.v1.Ast.exprPathStepField -import org.partiql.ast.v1.Ast.exprPosition -import org.partiql.ast.v1.Ast.exprQuerySet -import org.partiql.ast.v1.Ast.exprSessionAttribute -import org.partiql.ast.v1.Ast.exprStruct -import org.partiql.ast.v1.Ast.exprStructField -import org.partiql.ast.v1.Ast.exprSubstring -import org.partiql.ast.v1.Ast.exprTrim -import org.partiql.ast.v1.Ast.exprVarRef -import org.partiql.ast.v1.Ast.exprVariant -import org.partiql.ast.v1.Ast.exprWindow -import org.partiql.ast.v1.Ast.exprWindowOver -import org.partiql.ast.v1.Ast.from -import org.partiql.ast.v1.Ast.fromExpr -import org.partiql.ast.v1.Ast.fromJoin -import org.partiql.ast.v1.Ast.graphLabelConj -import org.partiql.ast.v1.Ast.graphLabelDisj -import org.partiql.ast.v1.Ast.graphLabelName -import org.partiql.ast.v1.Ast.graphLabelNegation -import org.partiql.ast.v1.Ast.graphLabelWildcard -import org.partiql.ast.v1.Ast.graphMatch -import org.partiql.ast.v1.Ast.graphMatchEdge -import org.partiql.ast.v1.Ast.graphMatchNode -import org.partiql.ast.v1.Ast.graphMatchPattern -import org.partiql.ast.v1.Ast.graphPattern -import org.partiql.ast.v1.Ast.graphQuantifier -import org.partiql.ast.v1.Ast.graphSelectorAllShortest -import org.partiql.ast.v1.Ast.graphSelectorAny -import org.partiql.ast.v1.Ast.graphSelectorAnyK -import org.partiql.ast.v1.Ast.graphSelectorAnyShortest -import org.partiql.ast.v1.Ast.graphSelectorShortestK -import org.partiql.ast.v1.Ast.graphSelectorShortestKGroup -import org.partiql.ast.v1.Ast.groupBy -import org.partiql.ast.v1.Ast.groupByKey -import org.partiql.ast.v1.Ast.identifier -import org.partiql.ast.v1.Ast.identifierChain -import org.partiql.ast.v1.Ast.letBinding -import org.partiql.ast.v1.Ast.orderBy -import org.partiql.ast.v1.Ast.query -import org.partiql.ast.v1.Ast.queryBodySFW -import org.partiql.ast.v1.Ast.queryBodySetOp -import org.partiql.ast.v1.Ast.selectItemExpr -import org.partiql.ast.v1.Ast.selectItemStar -import org.partiql.ast.v1.Ast.selectList -import org.partiql.ast.v1.Ast.selectPivot -import org.partiql.ast.v1.Ast.selectStar -import org.partiql.ast.v1.Ast.selectValue -import org.partiql.ast.v1.Ast.setOp -import org.partiql.ast.v1.Ast.sort -import org.partiql.ast.v1.AstNode -import org.partiql.ast.v1.DataType -import org.partiql.ast.v1.DatetimeField -import org.partiql.ast.v1.Exclude -import org.partiql.ast.v1.ExcludeStep -import org.partiql.ast.v1.From -import org.partiql.ast.v1.FromTableRef -import org.partiql.ast.v1.FromType -import org.partiql.ast.v1.GroupBy -import org.partiql.ast.v1.GroupByStrategy -import org.partiql.ast.v1.Identifier -import org.partiql.ast.v1.IdentifierChain -import org.partiql.ast.v1.JoinType -import org.partiql.ast.v1.Let -import org.partiql.ast.v1.Nulls -import org.partiql.ast.v1.Order -import org.partiql.ast.v1.Select -import org.partiql.ast.v1.SelectItem -import org.partiql.ast.v1.SetOpType -import org.partiql.ast.v1.SetQuantifier -import org.partiql.ast.v1.Sort -import org.partiql.ast.v1.Statement -import org.partiql.ast.v1.expr.Expr -import org.partiql.ast.v1.expr.ExprArray -import org.partiql.ast.v1.expr.ExprBag -import org.partiql.ast.v1.expr.ExprCall -import org.partiql.ast.v1.expr.ExprPath -import org.partiql.ast.v1.expr.ExprQuerySet -import org.partiql.ast.v1.expr.PathStep -import org.partiql.ast.v1.expr.Scope -import org.partiql.ast.v1.expr.SessionAttribute -import org.partiql.ast.v1.expr.TrimSpec -import org.partiql.ast.v1.expr.WindowFunction -import org.partiql.ast.v1.graph.GraphDirection -import org.partiql.ast.v1.graph.GraphLabel -import org.partiql.ast.v1.graph.GraphPart -import org.partiql.ast.v1.graph.GraphPattern -import org.partiql.ast.v1.graph.GraphQuantifier -import org.partiql.ast.v1.graph.GraphRestrictor -import org.partiql.ast.v1.graph.GraphSelector -import org.partiql.parser.PartiQLLexerException -import org.partiql.parser.PartiQLParserException -import org.partiql.parser.PartiQLParserV1 -import org.partiql.parser.internal.antlr.PartiQLParser -import org.partiql.parser.internal.antlr.PartiQLParserBaseVisitor -import org.partiql.parser.internal.util.DateTimeUtils -import org.partiql.spi.Context -import org.partiql.spi.SourceLocation -import org.partiql.spi.SourceLocations -import org.partiql.spi.errors.PError -import org.partiql.spi.errors.PErrorKind -import org.partiql.spi.errors.PErrorListener -import org.partiql.spi.errors.PErrorListenerException -import org.partiql.value.PartiQLValueExperimental -import org.partiql.value.boolValue -import org.partiql.value.dateValue -import org.partiql.value.datetime.DateTimeException -import org.partiql.value.datetime.DateTimeValue -import org.partiql.value.decimalValue -import org.partiql.value.int32Value -import org.partiql.value.int64Value -import org.partiql.value.intValue -import org.partiql.value.missingValue -import org.partiql.value.nullValue -import org.partiql.value.stringValue -import org.partiql.value.timeValue -import org.partiql.value.timestampValue -import java.math.BigDecimal -import java.math.BigInteger -import java.math.MathContext -import java.math.RoundingMode -import java.nio.channels.ClosedByInterruptException -import java.nio.charset.StandardCharsets -import java.time.LocalDate -import java.time.format.DateTimeFormatter -import java.time.format.DateTimeParseException -import org.partiql.parser.internal.antlr.PartiQLParser as GeneratedParser -import org.partiql.parser.internal.antlr.PartiQLTokens as GeneratedLexer - -/** - * ANTLR Based Implementation of a PartiQLParser - * - * SLL Prediction Mode - * ------------------- - * The [PredictionMode.SLL] mode uses the [BailErrorStrategy]. The [GeneratedParser], upon seeing a syntax error, - * will throw a [ParseCancellationException] due to the [GeneratedParser.getErrorHandler] - * being a [BailErrorStrategy]. The purpose of this is to throw syntax errors as quickly as possible once encountered. - * As noted by the [PredictionMode.SLL] documentation, to guarantee results, it is useful to follow up a failed parse - * by parsing with [PredictionMode.LL]. See the JavaDocs for [PredictionMode.SLL] and [BailErrorStrategy] for more. - * - * LL Prediction Mode - * ------------------ - * The [PredictionMode.LL] mode is capable of parsing all valid inputs for a grammar, - * but is slower than [PredictionMode.SLL]. Upon seeing a syntax error, this parser throws a [PartiQLParserException]. - */ -internal class PartiQLParserDefaultV1 : PartiQLParserV1 { - - @OptIn(PartiQLValueExperimental::class) - @Throws(PErrorListenerException::class) - override fun parse(source: String, ctx: Context): PartiQLParserV1.Result { - try { - return parse(source, ctx.errorListener) - } catch (e: PErrorListenerException) { - throw e - } catch (throwable: Throwable) { - val error = PError.INTERNAL_ERROR(PErrorKind.SYNTAX(), null, throwable) - ctx.errorListener.report(error) - val locations = SourceLocations() - return PartiQLParserV1.Result( - mutableListOf(org.partiql.ast.v1.Query(org.partiql.ast.v1.expr.ExprLit(nullValue()))) as List, - locations - ) - } - } - - /** - * To reduce latency costs, the [PartiQLParserDefaultV1] attempts to use [PredictionMode.SLL] and falls back to - * [PredictionMode.LL] if a [ParseCancellationException] is thrown by the [BailErrorStrategy]. - */ - private fun parse(source: String, listener: PErrorListener): PartiQLParserV1.Result = try { - parse(source, PredictionMode.SLL, listener) - } catch (ex: ParseCancellationException) { - parse(source, PredictionMode.LL, listener) - } - - /** - * Parses an input string [source] using the given prediction mode. - */ - private fun parse(source: String, mode: PredictionMode, listener: PErrorListener): PartiQLParserV1.Result { - val tokens = createTokenStream(source, listener) - val parser = InterruptibleParser(tokens) - parser.reset() - parser.removeErrorListeners() - parser.interpreter.predictionMode = mode - when (mode) { - PredictionMode.SLL -> parser.errorHandler = BailErrorStrategy() - PredictionMode.LL -> parser.addErrorListener(ParseErrorListener(listener)) - else -> throw IllegalArgumentException("Unsupported parser mode: $mode") - } - val tree = parser.statements() - return Visitor.translate(tokens, tree) - } - - private fun createTokenStream(source: String, listener: PErrorListener): CountingTokenStream { - val queryStream = source.byteInputStream(StandardCharsets.UTF_8) - val inputStream = try { - CharStreams.fromStream(queryStream) - } catch (ex: ClosedByInterruptException) { - throw InterruptedException() - } - val handler = TokenizeErrorListener(listener) - val lexer = GeneratedLexer(inputStream) - lexer.removeErrorListeners() - lexer.addErrorListener(handler) - return CountingTokenStream(lexer) - } - - /** - * Catches Lexical errors (unidentified tokens) and throws a [PartiQLParserException] - */ - private class TokenizeErrorListener(private val listener: PErrorListener) : BaseErrorListener() { - @Throws(PartiQLParserException::class) - override fun syntaxError( - recognizer: Recognizer<*, *>?, - offendingSymbol: Any?, - line: Int, - charPositionInLine: Int, - msg: String, - e: RecognitionException?, - ) { - offendingSymbol as Token - val token = offendingSymbol.text - val location = org.partiql.spi.SourceLocation(line.toLong(), charPositionInLine + 1L, token.length.toLong()) - val error = PErrors.unrecognizedToken(location, token) - listener.report(error) - } - } - - /** - * Catches Parser errors (malformed syntax) and throws a [PartiQLParserException] - */ - private class ParseErrorListener(private val listener: PErrorListener) : BaseErrorListener() { - - private val rules = GeneratedParser.ruleNames.asList() - - @Throws(PartiQLParserException::class) - override fun syntaxError( - recognizer: Recognizer<*, *>?, - offendingSymbol: Any, - line: Int, - charPositionInLine: Int, - msg: String, - e: RecognitionException?, - ) { - offendingSymbol as Token - val rule = e?.ctx?.toString(rules) ?: "UNKNOWN" // TODO: Do we want to display the offending rule? - val token = offendingSymbol.text - val tokenType = GeneratedParser.VOCABULARY.getSymbolicName(offendingSymbol.type) - val location = org.partiql.spi.SourceLocation(line.toLong(), charPositionInLine + 1L, token.length.toLong()) - val error = PErrors.unexpectedToken(location, tokenType, null) - listener.report(error) - } - } - - /** - * A wrapped [GeneratedParser] to allow thread interruption during parse. - */ - internal class InterruptibleParser(input: TokenStream) : GeneratedParser(input) { - override fun enterRule(localctx: ParserRuleContext?, state: Int, ruleIndex: Int) { - if (Thread.interrupted()) { - throw InterruptedException() - } - super.enterRule(localctx, state, ruleIndex) - } - } - - /** - * This token stream creates [parameterIndexes], which is a map, where the keys represent the - * indexes of all [GeneratedLexer.QUESTION_MARK]'s and the values represent their relative index amongst all other - * [GeneratedLexer.QUESTION_MARK]'s. - */ - internal open class CountingTokenStream(tokenSource: TokenSource) : CommonTokenStream(tokenSource) { - // TODO: Research use-case of parameters and implementation -- see https://github.com/partiql/partiql-docs/issues/23 - val parameterIndexes = mutableMapOf() - private var parametersFound = 0 - override fun LT(k: Int): Token? { - val token = super.LT(k) - token?.let { - if (it.type == GeneratedLexer.QUESTION_MARK && parameterIndexes.containsKey(token.tokenIndex).not()) { - parameterIndexes[token.tokenIndex] = ++parametersFound - } - } - return token - } - } - - /** - * Translate an ANTLR ParseTree to a PartiQL AST - */ - @OptIn(PartiQLValueExperimental::class) - private class Visitor( - private val tokens: CommonTokenStream, - private val locations: MutableMap, - private val parameters: Map = mapOf(), - ) : PartiQLParserBaseVisitor() { - - companion object { - - private val rules = GeneratedParser.ruleNames.asList() - - /** - * Expose an (internal) friendly entry point into the traversal; mostly for keeping mutable state contained. - */ - fun translate( - tokens: CountingTokenStream, - tree: PartiQLParser.StatementsContext, - ): PartiQLParserV1.Result { - val locations = mutableMapOf() - val visitor = Visitor(tokens, locations, tokens.parameterIndexes) - val statements = tree.statement().map { statementCtx -> - visitor.visit(statementCtx) as Statement - } - return PartiQLParserV1.Result( - statements, - SourceLocations(locations), - ) - } - - fun error( - ctx: ParserRuleContext, - message: String, - cause: Throwable? = null, - ) = PartiQLParserException( - rule = ctx.toStringTree(rules), - token = ctx.start.text, - tokenType = GeneratedParser.VOCABULARY.getSymbolicName(ctx.start.type), - message = message, - cause = cause, - location = SourceLocation( - ctx.start.line, - ctx.start.charPositionInLine + 1, - ctx.stop.stopIndex - ctx.start.startIndex, - ), - ) - - fun error( - token: Token, - message: String, - cause: Throwable? = null, - ) = PartiQLLexerException( - token = token.text, - tokenType = GeneratedParser.VOCABULARY.getSymbolicName(token.type), - message = message, - cause = cause, - location = SourceLocation( - token.line, - token.charPositionInLine + 1, - token.stopIndex - token.startIndex, - ), - ) - - internal val DATE_PATTERN_REGEX = Regex("\\d\\d\\d\\d-\\d\\d-\\d\\d") - - internal val GENERIC_TIME_REGEX = Regex("\\d\\d:\\d\\d:\\d\\d(\\.\\d*)?([+|-]\\d\\d:\\d\\d)?") - } - - /** - * Each visit attaches source locations from the given parse tree node; constructs nodes via the factory. - */ - private inline fun translate(ctx: ParserRuleContext, block: () -> T): T { - val node = block() - if (ctx.start != null) { - locations[node.tag] = SourceLocation( - ctx.start.line, - ctx.start.charPositionInLine + 1, - (ctx.stop?.stopIndex ?: ctx.start.stopIndex) - ctx.start.startIndex + 1, - ) - } - return node - } - - /** - * - * TOP LEVEL - * - */ - - override fun visitQueryDql(ctx: GeneratedParser.QueryDqlContext): AstNode = visitDql(ctx.dql()) - - override fun visitQueryDml(ctx: GeneratedParser.QueryDmlContext): AstNode = translate(ctx) { - throw error(ctx, "DML no longer supported in the default PartiQLParser.") - } - - override fun visitExplain(ctx: PartiQLParser.ExplainContext) = translate(ctx) { - var type: String? = null - var format: String? = null - ctx.explainOption().forEach { option -> - val parameter = try { - ExplainParameters.valueOf(option.param.text.uppercase()) - } catch (ex: java.lang.IllegalArgumentException) { - throw error(option.param, "Unknown EXPLAIN parameter.", ex) - } - when (parameter) { - ExplainParameters.TYPE -> { - type = parameter.getCompliantString(type, option.value) - } - ExplainParameters.FORMAT -> { - format = parameter.getCompliantString(format, option.value) - } - } - } - explain( - // TODO get rid of usage of PartiQLValue https://github.com/partiql/partiql-lang-kotlin/issues/1589 - options = mapOf( - "type" to stringValue(type), - "format" to stringValue(format) - ), - statement = visit(ctx.statement()) as Statement, - ) - } - - /** - * - * COMMON USAGES - * - */ - - override fun visitAsIdent(ctx: GeneratedParser.AsIdentContext) = visitSymbolPrimitive(ctx.symbolPrimitive()) - - override fun visitAtIdent(ctx: GeneratedParser.AtIdentContext) = visitSymbolPrimitive(ctx.symbolPrimitive()) - - override fun visitByIdent(ctx: GeneratedParser.ByIdentContext) = visitSymbolPrimitive(ctx.symbolPrimitive()) - - private fun visitSymbolPrimitive(ctx: GeneratedParser.SymbolPrimitiveContext): Identifier = - when (ctx) { - is GeneratedParser.IdentifierQuotedContext -> visitIdentifierQuoted(ctx) - is GeneratedParser.IdentifierUnquotedContext -> visitIdentifierUnquoted(ctx) - else -> throw error(ctx, "Invalid symbol reference.") - } - - override fun visitIdentifierQuoted(ctx: GeneratedParser.IdentifierQuotedContext): Identifier = translate(ctx) { - identifier( - ctx.IDENTIFIER_QUOTED().getStringValue(), - true - ) - } - - override fun visitIdentifierUnquoted(ctx: GeneratedParser.IdentifierUnquotedContext): Identifier = translate(ctx) { - identifier( - ctx.text, - false - ) - } - - override fun visitQualifiedName(ctx: GeneratedParser.QualifiedNameContext) = translate(ctx) { - val qualifier = ctx.qualifier.map { visitSymbolPrimitive(it) } - val name = identifierChain(visitSymbolPrimitive(ctx.name), null) - if (qualifier.isEmpty()) { - name - } else { - qualifier.reversed().fold(name) { acc, id -> - identifierChain(root = id, next = acc) - } - } - } - - /** - * - * DATA DEFINITION LANGUAGE (DDL) -- deleted in v1; will be added before final v1 release - * - */ - -// override fun visitQueryDdl(ctx: GeneratedParser.QueryDdlContext): AstNode = visitDdl(ctx.ddl()) -// -// override fun visitDropTable(ctx: GeneratedParser.DropTableContext) = translate(ctx) { -// val table = visitQualifiedName(ctx.qualifiedName()) -// statementDDLDropTable(table) -// } -// -// override fun visitDropIndex(ctx: GeneratedParser.DropIndexContext) = translate(ctx) { -// val table = visitSymbolPrimitive(ctx.on) -// val index = visitSymbolPrimitive(ctx.target) -// statementDDLDropIndex(index, table) -// } -// -// override fun visitCreateTable(ctx: GeneratedParser.CreateTableContext) = translate(ctx) { -// val table = visitQualifiedName(ctx.qualifiedName()) -// val definition = ctx.tableDef()?.let { visitTableDef(it) } -// statementDDLCreateTable(table, definition) -// } -// -// override fun visitCreateIndex(ctx: GeneratedParser.CreateIndexContext) = translate(ctx) { -// // TODO add index name to ANTLR grammar -// val name: Identifier? = null -// val table = visitSymbolPrimitive(ctx.symbolPrimitive()) -// val fields = ctx.pathSimple().map { path -> visitPathSimple(path) } -// statementDDLCreateIndex(name, table, fields) -// } -// -// override fun visitTableDef(ctx: GeneratedParser.TableDefContext) = translate(ctx) { -// // Column Definitions are the only thing we currently allow as table definition parts -// val columns = ctx.tableDefPart().filterIsInstance().map { -// visitColumnDeclaration(it) -// } -// tableDefinition(columns) -// } -// -// override fun visitColumnDeclaration(ctx: GeneratedParser.ColumnDeclarationContext) = translate(ctx) { -// val name = visitSymbolPrimitive(ctx.columnName().symbolPrimitive()).symbol -// val type = visit(ctx.type()) as Type -// val constraints = ctx.columnConstraint().map { -// visitColumnConstraint(it) -// } -// tableDefinitionColumn(name, type, constraints) -// } -// -// override fun visitColumnConstraint(ctx: GeneratedParser.ColumnConstraintContext) = translate(ctx) { -// val identifier = ctx.columnConstraintName()?.let { symbolToString(it.symbolPrimitive()) } -// val body = visit(ctx.columnConstraintDef()) as TableDefinition.Column.Constraint.Body -// tableDefinitionColumnConstraint(identifier, body) -// } -// -// override fun visitColConstrNotNull(ctx: GeneratedParser.ColConstrNotNullContext) = translate(ctx) { -// tableDefinitionColumnConstraintBodyNotNull() -// } -// -// override fun visitColConstrNull(ctx: GeneratedParser.ColConstrNullContext) = translate(ctx) { -// tableDefinitionColumnConstraintBodyNullable() -// } - - /** - * - * EXECUTE - * - */ - - override fun visitQueryExec(ctx: GeneratedParser.QueryExecContext) = translate(ctx) { - throw error(ctx, "EXEC no longer supported in the default PartiQLParser.") - } - - /** - * TODO EXEC accepts an `expr` as the procedure name so we have to unpack the string. - * - https://github.com/partiql/partiql-lang-kotlin/issues/707 - */ - override fun visitExecCommand(ctx: GeneratedParser.ExecCommandContext) = translate(ctx) { - throw error(ctx, "EXEC no longer supported in the default PartiQLParser.") - } - - /** - * - * DATA MANIPULATION LANGUAGE (DML) - * - */ - - /** - * The PartiQL grammars allows for multiple DML commands in one UPDATE statement. - * This function unwraps DML commands to the more limited DML.BatchLegacy.Op commands. - */ - override fun visitDmlBaseWrapper(ctx: GeneratedParser.DmlBaseWrapperContext) = translate(ctx) { - throw error(ctx, "DML no longer supported in the default PartiQLParser.") - } - - override fun visitDmlDelete(ctx: GeneratedParser.DmlDeleteContext) = translate(ctx) { - throw error(ctx, "DML no longer supported in the default PartiQLParser.") - } - - override fun visitDmlInsertReturning(ctx: GeneratedParser.DmlInsertReturningContext) = translate(ctx) { - throw error(ctx, "DML no longer supported in the default PartiQLParser.") - } - - override fun visitDmlBase(ctx: GeneratedParser.DmlBaseContext) = translate(ctx) { - throw error(ctx, "DML no longer supported in the default PartiQLParser.") - } - - override fun visitDmlBaseCommand(ctx: GeneratedParser.DmlBaseCommandContext) = translate(ctx) { - throw error(ctx, "DML no longer supported in the default PartiQLParser.") - } - - override fun visitRemoveCommand(ctx: GeneratedParser.RemoveCommandContext) = translate(ctx) { - throw error(ctx, "DML no longer supported in the default PartiQLParser.") - } - - override fun visitDeleteCommand(ctx: GeneratedParser.DeleteCommandContext) = translate(ctx) { - throw error(ctx, "DML no longer supported in the default PartiQLParser.") - } - - /** - * Legacy INSERT with RETURNING clause is not represented in the AST as this grammar .. - * .. only exists for backwards compatibility. The RETURNING clause is ignored. - * - * TODO remove insertCommandReturning grammar rule - * - https://github.com/partiql/partiql-lang-kotlin/issues/698 - * - https://github.com/partiql/partiql-lang-kotlin/issues/708 - */ - override fun visitInsertCommandReturning(ctx: GeneratedParser.InsertCommandReturningContext) = translate(ctx) { - throw error(ctx, "DML no longer supported in the default PartiQLParser.") - } - - override fun visitInsertStatementLegacy(ctx: GeneratedParser.InsertStatementLegacyContext) = translate(ctx) { - throw error(ctx, "DML no longer supported in the default PartiQLParser.") - } - - override fun visitInsertStatement(ctx: GeneratedParser.InsertStatementContext) = translate(ctx) { - throw error(ctx, "DML no longer supported in the default PartiQLParser.") - } - - override fun visitReplaceCommand(ctx: GeneratedParser.ReplaceCommandContext) = translate(ctx) { - throw error(ctx, "DML no longer supported in the default PartiQLParser.") - } - - override fun visitUpsertCommand(ctx: GeneratedParser.UpsertCommandContext) = translate(ctx) { - throw error(ctx, "DML no longer supported in the default PartiQLParser.") - } - - override fun visitReturningClause(ctx: GeneratedParser.ReturningClauseContext) = translate(ctx) { - throw error(ctx, "DML no longer supported in the default PartiQLParser.") - } - - override fun visitReturningColumn(ctx: GeneratedParser.ReturningColumnContext) = translate(ctx) { - throw error(ctx, "DML no longer supported in the default PartiQLParser.") - } - - override fun visitOnConflict(ctx: GeneratedParser.OnConflictContext) = translate(ctx) { - throw error(ctx, "DML no longer supported in the default PartiQLParser.") - } - - /** - * TODO Remove this when we remove INSERT LEGACY as no other conflict actions are allowed in PartiQL.g4. - */ - override fun visitOnConflictLegacy(ctx: GeneratedParser.OnConflictLegacyContext) = translate(ctx) { - throw error(ctx, "DML no longer supported in the default PartiQLParser.") - } - - override fun visitConflictTarget(ctx: GeneratedParser.ConflictTargetContext) = translate(ctx) { - throw error(ctx, "DML no longer supported in the default PartiQLParser.") - } - - override fun visitConflictAction(ctx: GeneratedParser.ConflictActionContext) = translate(ctx) { - throw error(ctx, "DML no longer supported in the default PartiQLParser.") - } - - override fun visitDoReplace(ctx: GeneratedParser.DoReplaceContext) = translate(ctx) { - throw error(ctx, "DML no longer supported in the default PartiQLParser.") - } - - override fun visitDoUpdate(ctx: GeneratedParser.DoUpdateContext) = translate(ctx) { - throw error(ctx, "DML no longer supported in the default PartiQLParser.") - } - - // "simple paths" used by previous DDL's CREATE INDEX - override fun visitPathSimple(ctx: GeneratedParser.PathSimpleContext) = translate(ctx) { - throw error(ctx, "DDL no longer supported in the default PartiQLParser.") - } - - // "simple paths" used by previous DDL's CREATE INDEX - override fun visitPathSimpleLiteral(ctx: GeneratedParser.PathSimpleLiteralContext) = translate(ctx) { - throw error(ctx, "DDL no longer supported in the default PartiQLParser.") - } - - // "simple paths" used by previous DDL's CREATE INDEX - override fun visitPathSimpleSymbol(ctx: GeneratedParser.PathSimpleSymbolContext) = translate(ctx) { - throw error(ctx, "DDL no longer supported in the default PartiQLParser.") - } - - // "simple paths" used by previous DDL's CREATE INDEX - override fun visitPathSimpleDotSymbol(ctx: GeneratedParser.PathSimpleDotSymbolContext) = translate(ctx) { - throw error(ctx, "DDL no longer supported in the default PartiQLParser.") - } - - /** - * TODO current PartiQL.g4 grammar models a SET with no UPDATE target as valid DML command. - */ - override fun visitSetCommand(ctx: GeneratedParser.SetCommandContext) = translate(ctx) { - throw error(ctx, "DML no longer supported in the default PartiQLParser.") - } - - override fun visitSetAssignment(ctx: GeneratedParser.SetAssignmentContext) = translate(ctx) { - throw error(ctx, "DML no longer supported in the default PartiQLParser.") - } - - /** - * - * DATA QUERY LANGUAGE (DQL) - * - */ - - override fun visitDql(ctx: GeneratedParser.DqlContext) = translate(ctx) { - val expr = visitAs(ctx.expr()) - query(expr) - } - - override fun visitQueryBase(ctx: GeneratedParser.QueryBaseContext): AstNode = visit(ctx.exprSelect()) - - override fun visitSfwQuery(ctx: GeneratedParser.SfwQueryContext) = translate(ctx) { - val select = visit(ctx.select) as Select - val from = visitFromClause(ctx.from) - val exclude = visitOrNull(ctx.exclude) - val let = visitOrNull(ctx.let) - val where = visitOrNull(ctx.where) - val groupBy = ctx.group?.let { visitGroupClause(it) } - val having = visitOrNull(ctx.having?.arg) - val orderBy = ctx.order?.let { visitOrderByClause(it) } - val limit = visitOrNull(ctx.limit?.arg) - val offset = visitOrNull(ctx.offset?.arg) - exprQuerySet( - body = queryBodySFW( - select, exclude, from, let, where, groupBy, having - ), - orderBy = orderBy, - limit = limit, - offset = offset - ) - } - - /** - * - * SELECT & PROJECTIONS - * - */ - - override fun visitSelectAll(ctx: GeneratedParser.SelectAllContext) = translate(ctx) { - val quantifier = convertSetQuantifier(ctx.setQuantifierStrategy()) - selectStar(quantifier) - } - - override fun visitSelectItems(ctx: GeneratedParser.SelectItemsContext) = translate(ctx) { - val items = visitOrEmpty(ctx.projectionItems().projectionItem()) - val setq = convertSetQuantifier(ctx.setQuantifierStrategy()) - selectList(items, setq) - } - - override fun visitSelectPivot(ctx: GeneratedParser.SelectPivotContext) = translate(ctx) { - val key = visitExpr(ctx.at) - val value = visitExpr(ctx.pivot) - selectPivot(key, value) - } - - override fun visitSelectValue(ctx: GeneratedParser.SelectValueContext) = translate(ctx) { - val constructor = visitExpr(ctx.expr()) - val setq = convertSetQuantifier(ctx.setQuantifierStrategy()) - selectValue(constructor, setq) - } - - override fun visitProjectionItem(ctx: GeneratedParser.ProjectionItemContext) = translate(ctx) { - val expr = visitExpr(ctx.expr()) - val alias = ctx.symbolPrimitive()?.let { visitSymbolPrimitive(it) } - if (expr is ExprPath) { - convertPathToProjectionItem(ctx, expr, alias) - } else { - selectItemExpr(expr, alias) - } - } - - /** - * - * SIMPLE CLAUSES - * - */ - - override fun visitLimitClause(ctx: GeneratedParser.LimitClauseContext): Expr = visitAs(ctx.arg) - - override fun visitExpr(ctx: GeneratedParser.ExprContext): Expr { - if (Thread.interrupted()) { - throw InterruptedException() - } - return visitAs(ctx.exprBagOp()) - } - - override fun visitOffsetByClause(ctx: GeneratedParser.OffsetByClauseContext) = visitAs(ctx.arg) - - override fun visitWhereClause(ctx: GeneratedParser.WhereClauseContext) = visitExpr(ctx.arg) - - override fun visitWhereClauseSelect(ctx: GeneratedParser.WhereClauseSelectContext) = visitAs(ctx.arg) - - override fun visitHavingClause(ctx: GeneratedParser.HavingClauseContext) = visitAs(ctx.arg) - - /** - * - * LET CLAUSE - * - */ - - override fun visitLetClause(ctx: GeneratedParser.LetClauseContext) = translate(ctx) { - val bindings = visitOrEmpty(ctx.letBinding()) - Ast.let(bindings) - } - - override fun visitLetBinding(ctx: GeneratedParser.LetBindingContext) = translate(ctx) { - val expr = visitAs(ctx.expr()) - val alias = visitSymbolPrimitive(ctx.symbolPrimitive()) - letBinding(expr, alias) - } - - /** - * - * ORDER BY CLAUSE - * - */ - - override fun visitOrderByClause(ctx: GeneratedParser.OrderByClauseContext) = translate(ctx) { - val sorts = visitOrEmpty(ctx.orderSortSpec()) - orderBy(sorts) - } - - override fun visitOrderSortSpec(ctx: GeneratedParser.OrderSortSpecContext) = translate(ctx) { - val expr = visitAs(ctx.expr()) - val dir = when { - ctx.dir == null -> null - ctx.dir.type == GeneratedParser.ASC -> Order.ASC() - ctx.dir.type == GeneratedParser.DESC -> Order.DESC() - else -> throw error(ctx.dir, "Invalid ORDER BY direction; expected ASC or DESC") - } - val nulls = when { - ctx.nulls == null -> null - ctx.nulls.type == GeneratedParser.FIRST -> Nulls.FIRST() - ctx.nulls.type == GeneratedParser.LAST -> Nulls.LAST() - else -> throw error(ctx.nulls, "Invalid ORDER BY null ordering; expected FIRST or LAST") - } - sort(expr, dir, nulls) - } - - /** - * - * GROUP BY CLAUSE - * - */ - - override fun visitGroupClause(ctx: GeneratedParser.GroupClauseContext) = translate(ctx) { - val strategy = if (ctx.PARTIAL() != null) GroupByStrategy.PARTIAL() else GroupByStrategy.FULL() - val keys = visitOrEmpty(ctx.groupKey()) - val alias = ctx.groupAlias()?.symbolPrimitive()?.let { visitSymbolPrimitive(it) } - groupBy(strategy, keys, alias) - } - - override fun visitGroupKey(ctx: GeneratedParser.GroupKeyContext) = translate(ctx) { - val expr = visitAs(ctx.key) - val alias = ctx.symbolPrimitive()?.let { visitSymbolPrimitive(it) } - groupByKey(expr, alias) - } - - /** - * EXCLUDE CLAUSE - */ - override fun visitExcludeClause(ctx: GeneratedParser.ExcludeClauseContext) = translate(ctx) { - val excludeExprs = ctx.excludeExpr().map { expr -> - visitExcludeExpr(expr) - } - exclude(excludeExprs) - } - - override fun visitExcludeExpr(ctx: GeneratedParser.ExcludeExprContext) = translate(ctx) { - val rootId = visitSymbolPrimitive(ctx.symbolPrimitive()) - val root = exprVarRef(identifierChain(rootId, null), Scope.DEFAULT()) - val steps = visitOrEmpty(ctx.excludeExprSteps()) - excludePath(root, steps) - } - - override fun visitExcludeExprTupleAttr(ctx: GeneratedParser.ExcludeExprTupleAttrContext) = translate(ctx) { - val identifier = visitSymbolPrimitive(ctx.symbolPrimitive()) - excludeStepStructField(identifier) - } - - override fun visitExcludeExprCollectionIndex(ctx: GeneratedParser.ExcludeExprCollectionIndexContext) = - translate(ctx) { - val index = ctx.index.text.toInt() - excludeStepCollIndex(index) - } - - override fun visitExcludeExprCollectionAttr(ctx: GeneratedParser.ExcludeExprCollectionAttrContext) = - translate(ctx) { - val attr = ctx.attr.getStringValue() - val identifier = identifier(attr, true) - excludeStepStructField(identifier) - } - - override fun visitExcludeExprCollectionWildcard(ctx: GeneratedParser.ExcludeExprCollectionWildcardContext) = - translate(ctx) { - excludeStepCollWildcard() - } - - override fun visitExcludeExprTupleWildcard(ctx: GeneratedParser.ExcludeExprTupleWildcardContext) = - translate(ctx) { - excludeStepStructWildcard() - } - - /** - * - * BAG OPERATIONS - * - */ - override fun visitBagOp(ctx: GeneratedParser.BagOpContext) = translate(ctx) { - val setq = when { - ctx.ALL() != null -> SetQuantifier.ALL() - ctx.DISTINCT() != null -> SetQuantifier.DISTINCT() - else -> null - } - val op = when (ctx.op.type) { - GeneratedParser.UNION -> setOp(SetOpType.UNION(), setq) - GeneratedParser.INTERSECT -> setOp(SetOpType.INTERSECT(), setq) - GeneratedParser.EXCEPT -> setOp(SetOpType.EXCEPT(), setq) - else -> error("Unsupported bag op token ${ctx.op}") - } - val lhs = visitAs(ctx.lhs) - val rhs = visitAs(ctx.rhs) - val outer = ctx.OUTER() != null - val orderBy = ctx.order?.let { visitOrderByClause(it) } - val limit = ctx.limit?.let { visitAs(it) } - val offset = ctx.offset?.let { visitAs(it) } - exprQuerySet( - queryBodySetOp( - op, - outer, - lhs, - rhs - ), - orderBy, - limit, - offset, - ) - } - - /** - * - * GRAPH PATTERN MANIPULATION LANGUAGE (GPML) - * - */ - - override fun visitGpmlPattern(ctx: GeneratedParser.GpmlPatternContext) = translate(ctx) { - val pattern = visitMatchPattern(ctx.matchPattern()) - val selector = visitOrNull(ctx.matchSelector()) - graphMatch(listOf(pattern), selector) - } - - override fun visitGpmlPatternList(ctx: GeneratedParser.GpmlPatternListContext) = translate(ctx) { - val patterns = ctx.matchPattern().map { pattern -> visitMatchPattern(pattern) } - val selector = visitOrNull(ctx.matchSelector()) - graphMatch(patterns, selector) - } - - override fun visitMatchPattern(ctx: GeneratedParser.MatchPatternContext) = translate(ctx) { - val parts = visitOrEmpty(ctx.graphPart()) - val restrictor = ctx.restrictor?.let { - when (ctx.restrictor.text.lowercase()) { - "trail" -> GraphRestrictor.TRAIL() - "acyclic" -> GraphRestrictor.ACYCLIC() - "simple" -> GraphRestrictor.SIMPLE() - else -> throw error(ctx.restrictor, "Unrecognized pattern restrictor") - } - } - val variable = visitOrNull(ctx.variable)?.symbol - graphPattern(restrictor, null, variable, null, parts) - } - - override fun visitPatternPathVariable(ctx: GeneratedParser.PatternPathVariableContext) = - visitSymbolPrimitive(ctx.symbolPrimitive()) - - override fun visitSelectorBasic(ctx: GeneratedParser.SelectorBasicContext) = translate(ctx) { - when (ctx.mod.type) { - GeneratedParser.ANY -> graphSelectorAnyShortest() - GeneratedParser.ALL -> graphSelectorAllShortest() - else -> throw error(ctx, "Unsupported match selector.") - } - } - - override fun visitSelectorAny(ctx: GeneratedParser.SelectorAnyContext) = translate(ctx) { - when (ctx.k) { - null -> graphSelectorAny() - else -> graphSelectorAnyK(ctx.k.text.toLong()) - } - } - - override fun visitSelectorShortest(ctx: GeneratedParser.SelectorShortestContext) = translate(ctx) { - val k = ctx.k.text.toLong() - when (ctx.GROUP()) { - null -> graphSelectorShortestK(k) - else -> graphSelectorShortestKGroup(k) - } - } - - override fun visitLabelSpecOr(ctx: GeneratedParser.LabelSpecOrContext) = translate(ctx) { - val lhs = visit(ctx.labelSpec()) as GraphLabel - val rhs = visit(ctx.labelTerm()) as GraphLabel - graphLabelDisj(lhs, rhs) - } - - override fun visitLabelTermAnd(ctx: GeneratedParser.LabelTermAndContext) = translate(ctx) { - val lhs = visit(ctx.labelTerm()) as GraphLabel - val rhs = visit(ctx.labelFactor()) as GraphLabel - graphLabelConj(lhs, rhs) - } - - override fun visitLabelFactorNot(ctx: GeneratedParser.LabelFactorNotContext) = translate(ctx) { - val arg = visit(ctx.labelPrimary()) as GraphLabel - graphLabelNegation(arg) - } - - override fun visitLabelPrimaryName(ctx: GeneratedParser.LabelPrimaryNameContext) = translate(ctx) { - val x = visitSymbolPrimitive(ctx.symbolPrimitive()) - graphLabelName(x.symbol) - } - - override fun visitLabelPrimaryWild(ctx: GeneratedParser.LabelPrimaryWildContext) = translate(ctx) { - graphLabelWildcard() - } - - override fun visitLabelPrimaryParen(ctx: GeneratedParser.LabelPrimaryParenContext) = - visit(ctx.labelSpec()) as GraphLabel - - override fun visitPattern(ctx: GeneratedParser.PatternContext) = translate(ctx) { - val restrictor = visitRestrictor(ctx.restrictor) - val variable = visitOrNull(ctx.variable)?.symbol - val prefilter = ctx.where?.let { visitExpr(it.expr()) } - val quantifier = ctx.quantifier?.let { visitPatternQuantifier(it) } - val parts = visitOrEmpty(ctx.graphPart()) - graphPattern(restrictor, prefilter, variable, quantifier, parts) - } - - override fun visitEdgeAbbreviated(ctx: GeneratedParser.EdgeAbbreviatedContext) = translate(ctx) { - val direction = visitEdge(ctx.edgeAbbrev()) - val quantifier = visitOrNull(ctx.quantifier) - graphMatchEdge(direction, quantifier, null, null, null) - } - - private fun GraphPart.Edge.copy( - direction: GraphDirection? = null, - quantifier: GraphQuantifier? = null, - prefilter: Expr? = null, - variable: String? = null, - label: GraphLabel? = null, - ) = graphMatchEdge( - direction = direction ?: this.direction, - quantifier = quantifier ?: this.quantifier, - prefilter = prefilter ?: this.prefilter, - variable = variable ?: this.variable, - label = label ?: this.label, - ) - - override fun visitEdgeWithSpec(ctx: GeneratedParser.EdgeWithSpecContext) = translate(ctx) { - val quantifier = visitOrNull(ctx.quantifier) - val edge = visitOrNull(ctx.edgeWSpec()) - edge!!.copy(quantifier = quantifier) - } - - override fun visitEdgeSpec(ctx: GeneratedParser.EdgeSpecContext) = translate(ctx) { - val placeholderDirection = GraphDirection.RIGHT() - val variable = visitOrNull(ctx.symbolPrimitive())?.symbol - val prefilter = ctx.whereClause()?.let { visitExpr(it.expr()) } - val label = visitOrNull(ctx.labelSpec()) - graphMatchEdge(placeholderDirection, null, prefilter, variable, label) - } - - override fun visitEdgeSpecLeft(ctx: GeneratedParser.EdgeSpecLeftContext): AstNode { - val edge = visitEdgeSpec(ctx.edgeSpec()) - return edge.copy(direction = GraphDirection.LEFT()) - } - - override fun visitEdgeSpecRight(ctx: GeneratedParser.EdgeSpecRightContext): AstNode { - val edge = visitEdgeSpec(ctx.edgeSpec()) - return edge.copy(direction = GraphDirection.RIGHT()) - } - - override fun visitEdgeSpecBidirectional(ctx: GeneratedParser.EdgeSpecBidirectionalContext): AstNode { - val edge = visitEdgeSpec(ctx.edgeSpec()) - return edge.copy(direction = GraphDirection.LEFT_OR_RIGHT()) - } - - override fun visitEdgeSpecUndirectedBidirectional(ctx: GeneratedParser.EdgeSpecUndirectedBidirectionalContext): AstNode { - val edge = visitEdgeSpec(ctx.edgeSpec()) - return edge.copy(direction = GraphDirection.LEFT_UNDIRECTED_OR_RIGHT()) - } - - override fun visitEdgeSpecUndirected(ctx: GeneratedParser.EdgeSpecUndirectedContext): AstNode { - val edge = visitEdgeSpec(ctx.edgeSpec()) - return edge.copy(direction = GraphDirection.UNDIRECTED()) - } - - override fun visitEdgeSpecUndirectedLeft(ctx: GeneratedParser.EdgeSpecUndirectedLeftContext): AstNode { - val edge = visitEdgeSpec(ctx.edgeSpec()) - return edge.copy(direction = GraphDirection.LEFT_OR_UNDIRECTED()) - } - - override fun visitEdgeSpecUndirectedRight(ctx: GeneratedParser.EdgeSpecUndirectedRightContext): AstNode { - val edge = visitEdgeSpec(ctx.edgeSpec()) - return edge.copy(direction = GraphDirection.UNDIRECTED_OR_RIGHT()) - } - - private fun visitEdge(ctx: GeneratedParser.EdgeAbbrevContext): GraphDirection = when { - ctx.TILDE() != null && ctx.ANGLE_RIGHT() != null -> GraphDirection.UNDIRECTED_OR_RIGHT() - ctx.TILDE() != null && ctx.ANGLE_LEFT() != null -> GraphDirection.LEFT_OR_UNDIRECTED() - ctx.TILDE() != null -> GraphDirection.UNDIRECTED() - ctx.MINUS() != null && ctx.ANGLE_LEFT() != null && ctx.ANGLE_RIGHT() != null -> GraphDirection.LEFT_OR_RIGHT() - ctx.MINUS() != null && ctx.ANGLE_LEFT() != null -> GraphDirection.LEFT() - ctx.MINUS() != null && ctx.ANGLE_RIGHT() != null -> GraphDirection.RIGHT() - ctx.MINUS() != null -> GraphDirection.LEFT_UNDIRECTED_OR_RIGHT() - else -> throw error(ctx, "Unsupported edge type") - } - - override fun visitGraphPart(ctx: GeneratedParser.GraphPartContext): GraphPart { - val part = super.visitGraphPart(ctx) - if (part is GraphPattern) { - return translate(ctx) { graphMatchPattern(part) } - } - return part as GraphPart - } - - override fun visitPatternQuantifier(ctx: GeneratedParser.PatternQuantifierContext) = translate(ctx) { - when { - ctx.quant == null -> graphQuantifier(ctx.lower.text.toLong(), ctx.upper?.text?.toLong()) - ctx.quant.type == GeneratedParser.PLUS -> graphQuantifier(1L, null) - ctx.quant.type == GeneratedParser.ASTERISK -> graphQuantifier(0L, null) - else -> throw error(ctx, "Unsupported quantifier") - } - } - - override fun visitNode(ctx: GeneratedParser.NodeContext) = translate(ctx) { - val variable = visitOrNull(ctx.symbolPrimitive())?.symbol - val prefilter = ctx.whereClause()?.let { visitExpr(it.expr()) } - val label = visitOrNull(ctx.labelSpec()) - graphMatchNode(prefilter, variable, label) - } - - private fun visitRestrictor(ctx: GeneratedParser.PatternRestrictorContext?): GraphRestrictor? { - if (ctx == null) return null - return when (ctx.restrictor.text.lowercase()) { - "trail" -> GraphRestrictor.TRAIL() - "acyclic" -> GraphRestrictor.ACYCLIC() - "simple" -> GraphRestrictor.SIMPLE() - else -> throw error(ctx, "Unrecognized pattern restrictor") - } - } - - /** - * - * TABLE REFERENCES & JOINS & FROM CLAUSE - * - */ - override fun visitFromClause(ctx: GeneratedParser.FromClauseContext): From = translate(ctx) { - val tableRefs = visitOrEmpty(ctx.tableReference()) - from(tableRefs) - } - - override fun visitTableBaseRefSymbol(ctx: PartiQLParser.TableBaseRefSymbolContext): FromTableRef = translate(ctx) { - val expr = visitAs(ctx.source) - val asAlias = visitSymbolPrimitive(ctx.symbolPrimitive()) - fromExpr(expr, FromType.SCAN(), asAlias, null) - } - - override fun visitTableBaseRefClauses(ctx: PartiQLParser.TableBaseRefClausesContext): FromTableRef = translate(ctx) { - val expr = visitAs(ctx.source) - val asAlias = ctx.asIdent()?.let { visitSymbolPrimitive(it.symbolPrimitive()) } - val atAlias = ctx.atIdent()?.let { visitSymbolPrimitive(it.symbolPrimitive()) } - fromExpr(expr, FromType.SCAN(), asAlias, atAlias) - } - - override fun visitTableBaseRefMatch(ctx: PartiQLParser.TableBaseRefMatchContext): FromTableRef = translate(ctx) { - val expr = visitAs(ctx.source) - val asAlias = ctx.asIdent()?.let { visitSymbolPrimitive(it.symbolPrimitive()) } - val atAlias = ctx.atIdent()?.let { visitSymbolPrimitive(it.symbolPrimitive()) } - fromExpr(expr, FromType.SCAN(), asAlias, atAlias) - } - - override fun visitTableUnpivot(ctx: PartiQLParser.TableUnpivotContext): FromTableRef = translate(ctx) { - val expr = visitAs(ctx.expr()) - val asAlias = ctx.asIdent()?.let { visitSymbolPrimitive(it.symbolPrimitive()) } - val atAlias = ctx.atIdent()?.let { visitSymbolPrimitive(it.symbolPrimitive()) } - fromExpr(expr, FromType.UNPIVOT(), asAlias, atAlias) - } - - override fun visitTableWrapped(ctx: PartiQLParser.TableWrappedContext): FromTableRef = translate(ctx) { - visitAs(ctx.tableReference()) - } - - override fun visitTableLeftCrossJoin(ctx: PartiQLParser.TableLeftCrossJoinContext): FromTableRef = translate(ctx) { - val lhs = visitAs(ctx.lhs) - val rhs = visitAs(ctx.rhs) - // PartiQL spec defines equivalence of - // l LEFT CROSS JOIN r <=> l LEFT JOIN r ON TRUE - // The other join types combined w/ CROSS JOIN are unspecified -- https://github.com/partiql/partiql-lang-kotlin/issues/1013 - fromJoin(lhs, rhs, JoinType.LEFT_CROSS(), null) - } - - override fun visitTableCrossJoin(ctx: PartiQLParser.TableCrossJoinContext): FromTableRef = translate(ctx) { - val lhs = visitAs(ctx.lhs) - val rhs = visitAs(ctx.rhs) - fromJoin(lhs, rhs, JoinType.CROSS(), null) - } - - override fun visitTableQualifiedJoin(ctx: PartiQLParser.TableQualifiedJoinContext): FromTableRef = translate(ctx) { - val lhs = visitAs(ctx.lhs) - val rhs = visitAs(ctx.rhs) - val type = convertJoinType(ctx.joinType()) - val condition = ctx.joinSpec()?.let { visitExpr(it.expr()) } - fromJoin(lhs, rhs, type, condition) - } - - private fun convertJoinType(ctx: GeneratedParser.JoinTypeContext?): JoinType? { - if (ctx == null) return null - return when (ctx.mod.type) { - GeneratedParser.INNER -> JoinType.INNER() - GeneratedParser.LEFT -> when (ctx.OUTER()) { - null -> JoinType.LEFT() - else -> JoinType.LEFT_OUTER() - } - GeneratedParser.RIGHT -> when (ctx.OUTER()) { - null -> JoinType.RIGHT() - else -> JoinType.RIGHT_OUTER() - } - GeneratedParser.FULL -> when (ctx.OUTER()) { - null -> JoinType.FULL() - else -> JoinType.FULL_OUTER() - } - GeneratedParser.OUTER -> { - // TODO https://github.com/partiql/partiql-spec/issues/41 - // TODO https://github.com/partiql/partiql-lang-kotlin/issues/1013 - JoinType.FULL_OUTER() - } - else -> null - } - } - - /** - * TODO Remove as/at/by aliases from DELETE command grammar in PartiQL.g4 - */ - override fun visitFromClauseSimpleExplicit(ctx: GeneratedParser.FromClauseSimpleExplicitContext) = translate(ctx) { - throw error(ctx, "DML no longer supported in the default PartiQLParser.") - } - - /** - * TODO Remove fromClauseSimple rule from DELETE command grammar in PartiQL.g4 - */ - override fun visitFromClauseSimpleImplicit(ctx: GeneratedParser.FromClauseSimpleImplicitContext) = translate(ctx) { - throw error(ctx, "DML no longer supported in the default PartiQLParser.") - } - - /** - * SIMPLE EXPRESSIONS - */ - - override fun visitOr(ctx: GeneratedParser.OrContext) = translate(ctx) { - val l = visit(ctx.lhs) as Expr - val r = visit(ctx.rhs) as Expr - exprOr(l, r) - } - - override fun visitAnd(ctx: GeneratedParser.AndContext) = translate(ctx) { - val l = visit(ctx.lhs) as Expr - val r = visit(ctx.rhs) as Expr - exprAnd(l, r) - } - - override fun visitNot(ctx: GeneratedParser.NotContext) = translate(ctx) { - val expr = visit(ctx.exprNot()) as Expr - exprNot(expr) - } - - private fun checkForInvalidTokens(op: ParserRuleContext) { - val start = op.start.tokenIndex - val stop = op.stop.tokenIndex - val tokensInRange = tokens.get(start, stop) - if (tokensInRange.any { it.channel == GeneratedLexer.HIDDEN }) { - throw error(op, "Invalid whitespace or comment in operator") - } - } - - private fun convertToOperator(value: ParserRuleContext, op: ParserRuleContext): Expr { - checkForInvalidTokens(op) - return convertToOperator(value, op.text) - } - - private fun convertToOperator(value: ParserRuleContext, op: String): Expr { - val v = visit(value) as Expr - return exprOperator(op, null, v) - } - - private fun convertToOperator(lhs: ParserRuleContext, rhs: ParserRuleContext, op: ParserRuleContext): Expr { - checkForInvalidTokens(op) - return convertToOperator(lhs, rhs, op.text) - } - - private fun convertToOperator(lhs: ParserRuleContext, rhs: ParserRuleContext, op: String): Expr { - val l = visit(lhs) as Expr - val r = visit(rhs) as Expr - return exprOperator(op, l, r) - } - - override fun visitMathOp00(ctx: GeneratedParser.MathOp00Context) = translate(ctx) { - if (ctx.parent != null) return@translate visit(ctx.parent) - convertToOperator(ctx.lhs, ctx.rhs, ctx.op) - } - - override fun visitMathOp01(ctx: GeneratedParser.MathOp01Context) = translate(ctx) { - if (ctx.parent != null) return@translate visit(ctx.parent) - convertToOperator(ctx.rhs, ctx.op) - } - - override fun visitMathOp02(ctx: GeneratedParser.MathOp02Context) = translate(ctx) { - if (ctx.parent != null) return@translate visit(ctx.parent) - convertToOperator(ctx.lhs, ctx.rhs, ctx.op.text) - } - - override fun visitMathOp03(ctx: GeneratedParser.MathOp03Context) = translate(ctx) { - if (ctx.parent != null) return@translate visit(ctx.parent) - convertToOperator(ctx.lhs, ctx.rhs, ctx.op.text) - } - - override fun visitValueExpr(ctx: GeneratedParser.ValueExprContext) = translate(ctx) { - if (ctx.parent != null) return@translate visit(ctx.parent) - convertToOperator(ctx.rhs, ctx.sign.text) - } - - /** - * - * PREDICATES - * - */ - - override fun visitPredicateComparison(ctx: GeneratedParser.PredicateComparisonContext) = translate(ctx) { - convertToOperator(ctx.lhs, ctx.rhs, ctx.op) - } - - /** - * TODO Fix the IN collection grammar, also label alternative forms - * - https://github.com/partiql/partiql-lang-kotlin/issues/1115 - * - https://github.com/partiql/partiql-lang-kotlin/issues/1113 - */ - override fun visitPredicateIn(ctx: GeneratedParser.PredicateInContext) = translate(ctx) { - val lhs = visitAs(ctx.lhs) - val rhs = visitAs(ctx.rhs ?: ctx.expr()).let { - // Wrap rhs in an array unless it's a query or already a collection - if (it is ExprQuerySet || it is ExprArray || it is ExprBag || ctx.PAREN_LEFT() == null) { - it - } else { - // IN ( expr ) - exprArray(listOf(it)) - } - } - val not = ctx.NOT() != null - exprInCollection(lhs, rhs, not) - } - - override fun visitPredicateIs(ctx: GeneratedParser.PredicateIsContext) = translate(ctx) { - val value = visitAs(ctx.lhs) - val type = visitAs(ctx.type()) - val not = ctx.NOT() != null - exprIsType(value, type, not) - } - - override fun visitPredicateBetween(ctx: GeneratedParser.PredicateBetweenContext) = translate(ctx) { - val value = visitAs(ctx.lhs) - val lower = visitAs(ctx.lower) - val upper = visitAs(ctx.upper) - val not = ctx.NOT() != null - exprBetween(value, lower, upper, not) - } - - override fun visitPredicateLike(ctx: GeneratedParser.PredicateLikeContext) = translate(ctx) { - val value = visitAs(ctx.lhs) - val pattern = visitAs(ctx.rhs) - val escape = visitOrNull(ctx.escape) - val not = ctx.NOT() != null - exprLike(value, pattern, escape, not) - } - - /** - * - * PRIMARY EXPRESSIONS - * - */ - - override fun visitExprTermWrappedQuery(ctx: GeneratedParser.ExprTermWrappedQueryContext): AstNode = - visit(ctx.expr()) - - override fun visitVariableIdentifier(ctx: GeneratedParser.VariableIdentifierContext) = translate(ctx) { - val symbol = ctx.ident.getStringValue() - val isDelimited = when (ctx.ident.type) { - GeneratedParser.IDENTIFIER -> false - else -> true - } - val scope = when (ctx.qualifier) { - null -> Scope.DEFAULT() - else -> Scope.LOCAL() - } - exprVarRef( - identifierChain( - root = identifier(symbol, isDelimited), - next = null - ), - scope - ) - } - - override fun visitVariableKeyword(ctx: GeneratedParser.VariableKeywordContext) = translate(ctx) { - val symbol = ctx.key.text - val isDelimited = false - val scope = when (ctx.qualifier) { - null -> Scope.DEFAULT() - else -> Scope.LOCAL() - } - exprVarRef( - identifierChain( - root = identifier(symbol, isDelimited), - next = null - ), - scope - ) - } - - override fun visitParameter(ctx: GeneratedParser.ParameterContext) = translate(ctx) { - val index = parameters[ctx.QUESTION_MARK().symbol.tokenIndex] ?: throw error( - ctx, "Unable to find index of parameter." - ) - exprParameter(index) - } - - override fun visitSequenceConstructor(ctx: GeneratedParser.SequenceConstructorContext) = translate(ctx) { - error("Sequence constructor not supported") - } - - private fun PathStep.copy(next: PathStep?) = when (this) { - is PathStep.Element -> exprPathStepElement(this.element, next) - is PathStep.Field -> exprPathStepField(this.field, next) - is PathStep.AllElements -> exprPathStepAllElements(next) - is PathStep.AllFields -> exprPathStepAllFields(next) - else -> error("Unsupported PathStep: $this") - } - - override fun visitExprPrimaryPath(ctx: GeneratedParser.ExprPrimaryPathContext) = translate(ctx) { - val base = visitAs(ctx.exprPrimary()) - val init: PathStep? = null - val steps = ctx.pathStep().reversed().fold(init) { acc, step -> - val stepExpr = visit(step) as PathStep - stepExpr.copy(acc) - } - exprPath(base, steps) - } - - override fun visitPathStepIndexExpr(ctx: GeneratedParser.PathStepIndexExprContext) = translate(ctx) { - val key = visitAs(ctx.key) - exprPathStepElement(key, null) - } - - override fun visitPathStepDotExpr(ctx: GeneratedParser.PathStepDotExprContext) = translate(ctx) { - val symbol = visitSymbolPrimitive(ctx.symbolPrimitive()) - exprPathStepField(symbol, null) - } - - override fun visitPathStepIndexAll(ctx: GeneratedParser.PathStepIndexAllContext) = translate(ctx) { - exprPathStepAllElements(null) - } - - override fun visitPathStepDotAll(ctx: GeneratedParser.PathStepDotAllContext) = translate(ctx) { - exprPathStepAllFields(null) - } - - override fun visitValues(ctx: GeneratedParser.ValuesContext) = translate(ctx) { - val rows = visitOrEmpty(ctx.valueRow()) - exprBag(rows) - } - - override fun visitValueRow(ctx: GeneratedParser.ValueRowContext) = translate(ctx) { - val expressions = visitOrEmpty(ctx.expr()) - exprArray(expressions) - } - - override fun visitValueList(ctx: GeneratedParser.ValueListContext) = translate(ctx) { - val expressions = visitOrEmpty(ctx.expr()) - exprArray(expressions) - } - - override fun visitExprGraphMatchMany(ctx: GeneratedParser.ExprGraphMatchManyContext) = translate(ctx) { - val graph = visit(ctx.exprPrimary()) as Expr - val pattern = visitGpmlPatternList(ctx.gpmlPatternList()) - exprMatch(graph, pattern) - } - - override fun visitExprGraphMatchOne(ctx: GeneratedParser.ExprGraphMatchOneContext) = translate(ctx) { - val graph = visit(ctx.exprPrimary()) as Expr - val pattern = visitGpmlPattern(ctx.gpmlPattern()) - exprMatch(graph, pattern) - } - - override fun visitExprTermCurrentUser(ctx: GeneratedParser.ExprTermCurrentUserContext) = translate(ctx) { - exprSessionAttribute(SessionAttribute.CURRENT_USER()) - } - - override fun visitExprTermCurrentDate(ctx: GeneratedParser.ExprTermCurrentDateContext) = - translate(ctx) { - exprSessionAttribute(SessionAttribute.CURRENT_DATE()) - } - - /** - * - * FUNCTIONS - * - */ - - override fun visitNullIf(ctx: GeneratedParser.NullIfContext) = translate(ctx) { - val value = visitExpr(ctx.expr(0)) - val nullifier = visitExpr(ctx.expr(1)) - exprNullIf(value, nullifier) - } - - override fun visitCoalesce(ctx: GeneratedParser.CoalesceContext) = translate(ctx) { - val expressions = visitOrEmpty(ctx.expr()) - exprCoalesce(expressions) - } - - override fun visitCaseExpr(ctx: GeneratedParser.CaseExprContext) = translate(ctx) { - val expr = ctx.case_?.let { visitExpr(it) } - val branches = ctx.whens.indices.map { i -> - // consider adding locations - val w = visitExpr(ctx.whens[i]) - val t = visitExpr(ctx.thens[i]) - exprCaseBranch(w, t) - } - val default = ctx.else_?.let { visitExpr(it) } - exprCase(expr, branches, default) - } - - override fun visitCast(ctx: GeneratedParser.CastContext) = translate(ctx) { - val expr = visitExpr(ctx.expr()) - val type = visitAs(ctx.type()) - exprCast(expr, type) - } - - override fun visitCanCast(ctx: GeneratedParser.CanCastContext) = translate(ctx) { - throw error(ctx, "CAN_CAST is no longer supported in the default PartiQLParser") - } - - override fun visitCanLosslessCast(ctx: GeneratedParser.CanLosslessCastContext) = translate(ctx) { - throw error(ctx, "CAN_LOSSLESS_CAST is no longer supported in the default PartiQLParser") - } - - override fun visitFunctionCall(ctx: GeneratedParser.FunctionCallContext) = translate(ctx) { - val args = visitOrEmpty(ctx.expr()) - when (val funcName = ctx.qualifiedName()) { - is GeneratedParser.QualifiedNameContext -> { - when (funcName.name.start.type) { - GeneratedParser.MOD -> exprOperator("%", args[0], args[1]) - GeneratedParser.CHARACTER_LENGTH, GeneratedParser.CHAR_LENGTH -> { - val path = ctx.qualifiedName().qualifier.map { visitSymbolPrimitive(it) } - val name = identifierChain(identifier("char_length", false), null) - if (path.isEmpty()) { - exprCall(name, args, null) // setq = null for scalar fn - } else { - val function = path.reversed().fold(name) { acc, id -> - identifierChain(root = id, next = acc) - } - exprCall(function, args, setq = null) - } - } - else -> visitNonReservedFunctionCall(ctx, args) - } - } - else -> visitNonReservedFunctionCall(ctx, args) - } - } - private fun visitNonReservedFunctionCall(ctx: GeneratedParser.FunctionCallContext, args: List): ExprCall { - val function = visitQualifiedName(ctx.qualifiedName()) - return exprCall(function, args, convertSetQuantifier(ctx.setQuantifierStrategy())) - } - - /** - * - * FUNCTIONS WITH SPECIAL FORMS - * - */ - - override fun visitDateFunction(ctx: GeneratedParser.DateFunctionContext) = translate(ctx) { - try { - DatetimeField.parse(ctx.dt.text) - } catch (ex: IllegalArgumentException) { - throw error(ctx.dt, "Expected one of: ${DatetimeField.codes().joinToString()}", ex) - } - val lhs = visitExpr(ctx.expr(0)) - val rhs = visitExpr(ctx.expr(1)) - // TODO change to not use PartiQLValue -- https://github.com/partiql/partiql-lang-kotlin/issues/1589 - val fieldLit = ctx.dt.text.lowercase() - // TODO error on invalid datetime fields like TIMEZONE_HOUR and TIMEZONE_MINUTE - when { - ctx.DATE_ADD() != null -> exprCall(identifierChain(identifier("date_add_$fieldLit", false), null), listOf(lhs, rhs), null) - ctx.DATE_DIFF() != null -> exprCall(identifierChain(identifier("date_diff_$fieldLit", false), null), listOf(lhs, rhs), null) - else -> throw error(ctx, "Expected DATE_ADD or DATE_DIFF") - } - } - - /** - * TODO Add labels to each alternative, https://github.com/partiql/partiql-lang-kotlin/issues/1113 - */ - override fun visitSubstring(ctx: GeneratedParser.SubstringContext) = translate(ctx) { - if (ctx.FROM() == null) { - // normal form - val function = "SUBSTRING".toIdentifierChain() - val args = visitOrEmpty(ctx.expr()) - exprCall(function, args, setq = null) // setq = null for scalar fn - } else { - // special form - val value = visitExpr(ctx.expr(0)) - val start = visitOrNull(ctx.expr(1)) - val length = visitOrNull(ctx.expr(2)) - exprSubstring(value, start, length) - } - } - - /** - * TODO Add labels to each alternative, https://github.com/partiql/partiql-lang-kotlin/issues/1113 - */ - override fun visitPosition(ctx: GeneratedParser.PositionContext) = translate(ctx) { - if (ctx.IN() == null) { - // normal form - val function = "POSITION".toIdentifierChain() - val args = visitOrEmpty(ctx.expr()) - exprCall(function, args, setq = null) // setq = null for scalar fn - } else { - // special form - val lhs = visitExpr(ctx.expr(0)) - val rhs = visitExpr(ctx.expr(1)) - exprPosition(lhs, rhs) - } - } - - /** - * TODO Add labels to each alternative, https://github.com/partiql/partiql-lang-kotlin/issues/1113 - */ - override fun visitOverlay(ctx: GeneratedParser.OverlayContext) = translate(ctx) { - // TODO: figure out why do we have a normalized form for overlay? - if (ctx.PLACING() == null) { - // normal form - val function = "OVERLAY".toIdentifierChain() - val args = arrayOfNulls(4).also { - visitOrEmpty(ctx.expr()).forEachIndexed { index, expr -> - it[index] = expr - } - } - val e = error(ctx, "overlay function requires at least three args") - - exprOverlay(args[0] ?: throw e, args[1] ?: throw e, args[2] ?: throw e, args[3]) - } else { - // special form - val value = visitExpr(ctx.expr(0)) - val overlay = visitExpr(ctx.expr(1)) - val start = visitExpr(ctx.expr(2)) - val length = visitOrNull(ctx.expr(3)) - exprOverlay(value, overlay, start, length) - } - } - - override fun visitExtract(ctx: GeneratedParser.ExtractContext) = translate(ctx) { - val field = try { - DatetimeField.parse(ctx.IDENTIFIER().text.uppercase()) - } catch (ex: IllegalArgumentException) { - // TODO decide if we want int codes here or actual text. If we want text here, then there should be a - // method to convert the code into text. - throw error(ctx.IDENTIFIER().symbol, "Expected one of: ${DatetimeField.codes().joinToString()}", ex) - } - val source = visitExpr(ctx.expr()) - exprExtract(field, source) - } - - override fun visitTrimFunction(ctx: GeneratedParser.TrimFunctionContext) = translate(ctx) { - val spec = ctx.mod?.let { - try { - TrimSpec.parse(it.text.uppercase()) - } catch (ex: IllegalArgumentException) { - throw error(it, "Expected on of: ${TrimSpec.codes().joinToString()}", ex) - } - } - val (chars, value) = when (ctx.expr().size) { - 1 -> null to visitExpr(ctx.expr(0)) - 2 -> visitExpr(ctx.expr(0)) to visitExpr(ctx.expr(1)) - else -> throw error(ctx, "Expected one or two TRIM expression arguments") - } - exprTrim(value, chars, spec) - } - - /** - * Window Functions - */ - - override fun visitLagLeadFunction(ctx: GeneratedParser.LagLeadFunctionContext) = translate(ctx) { - val function = when { - ctx.LAG() != null -> WindowFunction.LAG() - ctx.LEAD() != null -> WindowFunction.LEAD() - else -> throw error(ctx, "Expected LAG or LEAD") - } - val expression = visitExpr(ctx.expr(0)) - val offset = visitOrNull(ctx.expr(1)) - val default = visitOrNull(ctx.expr(2)) - val over = visitOver(ctx.over()) - if (over.sorts == null) { - throw error(ctx.over(), "$function requires Window ORDER BY") - } - exprWindow(function, expression, offset, default, over) - } - - override fun visitOver(ctx: GeneratedParser.OverContext) = translate(ctx) { - val partitions = ctx.windowPartitionList()?.let { visitOrEmpty(it.expr()) } - val sorts = ctx.windowSortSpecList()?.let { visitOrEmpty(it.orderSortSpec()) } - exprWindowOver(partitions, sorts) - } - - /** - * - * LITERALS - * - */ - - override fun visitBag(ctx: GeneratedParser.BagContext) = translate(ctx) { - // Prohibit hidden characters between angle brackets - val startTokenIndex = ctx.start.tokenIndex - val endTokenIndex = ctx.stop.tokenIndex - if (tokens.getHiddenTokensToRight(startTokenIndex, GeneratedLexer.HIDDEN) != null || tokens.getHiddenTokensToLeft(endTokenIndex, GeneratedLexer.HIDDEN) != null) { - throw error(ctx, "Invalid bag expression") - } - val expressions = visitOrEmpty(ctx.expr()) - exprBag(expressions) - } - - override fun visitLiteralDecimal(ctx: GeneratedParser.LiteralDecimalContext) = translate(ctx) { - val decimal = try { - val v = ctx.LITERAL_DECIMAL().text.trim() - BigDecimal(v, MathContext(38, RoundingMode.HALF_EVEN)) - } catch (e: NumberFormatException) { - throw error(ctx, "Invalid decimal literal", e) - } - exprLit(decimalValue(decimal)) - } - - override fun visitArray(ctx: GeneratedParser.ArrayContext) = translate(ctx) { - val expressions = visitOrEmpty(ctx.expr()) - exprArray(expressions) - } - - override fun visitLiteralNull(ctx: GeneratedParser.LiteralNullContext) = translate(ctx) { - exprLit(nullValue()) - } - - override fun visitLiteralMissing(ctx: GeneratedParser.LiteralMissingContext) = translate(ctx) { - exprLit(missingValue()) - } - - override fun visitLiteralTrue(ctx: GeneratedParser.LiteralTrueContext) = translate(ctx) { - exprLit(boolValue(true)) - } - - override fun visitLiteralFalse(ctx: GeneratedParser.LiteralFalseContext) = translate(ctx) { - exprLit(boolValue(false)) - } - - override fun visitLiteralIon(ctx: GeneratedParser.LiteralIonContext) = translate(ctx) { - val value = ctx.ION_CLOSURE().getStringValue() - val encoding = "ion" - exprVariant(value, encoding) - } - - override fun visitLiteralString(ctx: GeneratedParser.LiteralStringContext) = translate(ctx) { - val value = ctx.LITERAL_STRING().getStringValue() - exprLit(stringValue(value)) - } - - override fun visitLiteralInteger(ctx: GeneratedParser.LiteralIntegerContext) = translate(ctx) { - val n = ctx.LITERAL_INTEGER().text - - // 1st, try parse as int - try { - val v = n.toInt(10) - return@translate exprLit(int32Value(v)) - } catch (ex: NumberFormatException) { - // ignore - } - - // 2nd, try parse as long - try { - val v = n.toLong(10) - return@translate exprLit(int64Value(v)) - } catch (ex: NumberFormatException) { - // ignore - } - - // 3rd, try parse as BigInteger - try { - val v = BigInteger(n) - return@translate exprLit(intValue(v)) - } catch (ex: NumberFormatException) { - throw ex - } - } - - override fun visitLiteralDate(ctx: GeneratedParser.LiteralDateContext) = translate(ctx) { - val pattern = ctx.LITERAL_STRING().symbol - val dateString = ctx.LITERAL_STRING().getStringValue() - if (DATE_PATTERN_REGEX.matches(dateString).not()) { - throw error(pattern, "Expected DATE string to be of the format yyyy-MM-dd") - } - val value = try { - LocalDate.parse(dateString, DateTimeFormatter.ISO_LOCAL_DATE) - } catch (e: DateTimeParseException) { - throw error(pattern, e.localizedMessage, e) - } catch (e: IndexOutOfBoundsException) { - throw error(pattern, e.localizedMessage, e) - } - val date = DateTimeValue.date(value.year, value.monthValue, value.dayOfMonth) - exprLit(dateValue(date)) - } - - override fun visitLiteralTime(ctx: GeneratedParser.LiteralTimeContext) = translate(ctx) { - val (timeString, precision) = getTimeStringAndPrecision(ctx.LITERAL_STRING(), ctx.LITERAL_INTEGER()) - val time = try { - DateTimeUtils.parseTimeLiteral(timeString) - } catch (e: DateTimeException) { - throw error(ctx, "Invalid Date Time Literal", e) - } - val value = time.toPrecision(precision) - exprLit(timeValue(value)) - } - - override fun visitLiteralTimestamp(ctx: GeneratedParser.LiteralTimestampContext) = translate(ctx) { - val (timeString, precision) = getTimeStringAndPrecision(ctx.LITERAL_STRING(), ctx.LITERAL_INTEGER()) - val timestamp = try { - DateTimeUtils.parseTimestamp(timeString) - } catch (e: DateTimeException) { - throw error(ctx, "Invalid Date Time Literal", e) - } - val value = timestamp.toPrecision(precision) - exprLit(timestampValue(value)) - } - - override fun visitTuple(ctx: GeneratedParser.TupleContext) = translate(ctx) { - val fields = ctx.pair().map { - val k = visitExpr(it.lhs) - val v = visitExpr(it.rhs) - exprStructField(k, v) - } - exprStruct(fields) - } - - /** - * - * TYPES - * - */ - - override fun visitTypeAtomic(ctx: GeneratedParser.TypeAtomicContext) = translate(ctx) { - when (ctx.datatype.type) { - GeneratedParser.NULL -> DataType.NULL() - GeneratedParser.BOOL -> DataType.BOOLEAN() - GeneratedParser.BOOLEAN -> DataType.BOOL() - GeneratedParser.SMALLINT -> DataType.SMALLINT() - GeneratedParser.INT2 -> DataType.INT2() - GeneratedParser.INTEGER2 -> DataType.INTEGER2() - // TODO, we have INT aliased to INT4 when it should be visa-versa. - GeneratedParser.INT4 -> DataType.INT4() - GeneratedParser.INTEGER4 -> DataType.INTEGER4() - GeneratedParser.INT -> DataType.INT() - GeneratedParser.INTEGER -> DataType.INTEGER() - GeneratedParser.BIGINT -> DataType.BIGINT() - GeneratedParser.INT8 -> DataType.INT8() - GeneratedParser.INTEGER8 -> DataType.INTEGER8() - GeneratedParser.FLOAT -> DataType.FLOAT() - GeneratedParser.DOUBLE -> TODO() // not sure if DOUBLE is to be supported - GeneratedParser.REAL -> DataType.REAL() - GeneratedParser.TIMESTAMP -> DataType.TIMESTAMP() - GeneratedParser.CHAR -> DataType.CHAR() - GeneratedParser.CHARACTER -> DataType.CHARACTER() - GeneratedParser.MISSING -> DataType.MISSING() - GeneratedParser.STRING -> DataType.STRING() - GeneratedParser.SYMBOL -> DataType.SYMBOL() - // TODO https://github.com/partiql/partiql-lang-kotlin/issues/1125 - GeneratedParser.BLOB -> DataType.BLOB() - GeneratedParser.CLOB -> DataType.CLOB() - GeneratedParser.DATE -> DataType.DATE() - GeneratedParser.STRUCT -> DataType.STRUCT() - GeneratedParser.TUPLE -> DataType.TUPLE() - GeneratedParser.LIST -> DataType.LIST() - GeneratedParser.SEXP -> DataType.SEXP() - GeneratedParser.BAG -> DataType.BAG() - GeneratedParser.ANY -> TODO() // not sure if ANY is to be supported - else -> throw error(ctx, "Unknown atomic type.") // TODO other types included in parser - } - } - - override fun visitTypeVarChar(ctx: GeneratedParser.TypeVarCharContext): DataType = translate(ctx) { - when (val n = ctx.arg0?.text?.toInt()) { - null -> DataType.VARCHAR() - else -> DataType.VARCHAR(n) - } - } - - override fun visitTypeArgSingle(ctx: GeneratedParser.TypeArgSingleContext) = translate(ctx) { - val n = ctx.arg0?.text?.toInt() - when (ctx.datatype.type) { - GeneratedParser.FLOAT -> when (n) { - null -> DataType.FLOAT(64) - 32 -> DataType.FLOAT(32) - 64 -> DataType.FLOAT(64) - else -> throw error(ctx.datatype, "Invalid FLOAT precision. Expected 32 or 64") - } - GeneratedParser.CHAR, GeneratedParser.CHARACTER -> when (n) { - null -> DataType.CHAR() - else -> DataType.CHAR(n) - } - GeneratedParser.VARCHAR -> when (n) { - null -> DataType.VARCHAR() - else -> DataType.VARCHAR(n) - } - else -> throw error(ctx.datatype, "Invalid datatype") - } - } - - override fun visitTypeArgDouble(ctx: GeneratedParser.TypeArgDoubleContext) = translate(ctx) { - val arg0 = ctx.arg0?.text?.toInt() - val arg1 = ctx.arg1?.text?.toInt() - when (ctx.datatype.type) { - GeneratedParser.DECIMAL -> when { - arg0 == null && arg1 == null -> DataType.DECIMAL() - arg0 != null && arg1 == null -> DataType.DECIMAL(arg0) - arg0 != null && arg1 != null -> DataType.DECIMAL(arg0, arg1) - else -> error("Invalid parameters for decimal") - } - GeneratedParser.DEC -> when { - arg0 == null && arg1 == null -> DataType.DEC() - arg0 != null && arg1 == null -> DataType.DEC(arg0) - arg0 != null && arg1 != null -> DataType.DEC(arg0, arg1) - else -> error("Invalid parameters for dec") - } - GeneratedParser.NUMERIC -> when { - arg0 == null && arg1 == null -> DataType.NUMERIC() - arg0 != null && arg1 == null -> DataType.NUMERIC(arg0) - arg0 != null && arg1 != null -> DataType.NUMERIC(arg0, arg1) - else -> error("Invalid parameters for decimal") - } - else -> throw error(ctx.datatype, "Invalid datatype") - } - } - - override fun visitTypeTimeZone(ctx: GeneratedParser.TypeTimeZoneContext) = translate(ctx) { - val precision = ctx.precision?.let { - val p = ctx.precision.text.toInt() - if (p < 0 || 9 < p) throw error(ctx.precision, "Unsupported time precision") - p - } - - when (ctx.datatype.type) { - GeneratedParser.TIME -> when (ctx.ZONE()) { - null -> when (precision) { - null -> DataType.TIME() - else -> DataType.TIME(precision) - } - else -> when (precision) { - null -> DataType.TIME_WITH_TIME_ZONE() - else -> DataType.TIME_WITH_TIME_ZONE(precision) - } - } - GeneratedParser.TIMESTAMP -> when (ctx.ZONE()) { - null -> when (precision) { - null -> DataType.TIMESTAMP() - else -> DataType.TIMESTAMP(precision) - } - else -> when (precision) { - null -> DataType.TIMESTAMP_WITH_TIME_ZONE() - else -> DataType.TIMESTAMP_WITH_TIME_ZONE(precision) - } - } - else -> throw error(ctx.datatype, "Invalid datatype") - } - } - - override fun visitTypeCustom(ctx: GeneratedParser.TypeCustomContext) = translate(ctx) { - DataType.USER_DEFINED(ctx.text.uppercase().toIdentifierChain()) - } - - private inline fun visitOrEmpty(ctx: List?): List = when { - ctx.isNullOrEmpty() -> emptyList() - else -> ctx.map { visit(it) as T } - } - - private inline fun visitOrNull(ctx: ParserRuleContext?): T? = - ctx?.let { it.accept(this) as T } - - private inline fun visitAs(ctx: ParserRuleContext): T = visit(ctx) as T - - /** - * Visiting a symbol to get a string, skip the wrapping, unwrapping, and location tracking. - */ - private fun symbolToString(ctx: GeneratedParser.SymbolPrimitiveContext) = when (ctx) { - is GeneratedParser.IdentifierQuotedContext -> ctx.IDENTIFIER_QUOTED().getStringValue() - is GeneratedParser.IdentifierUnquotedContext -> ctx.text - else -> throw error(ctx, "Invalid symbol reference.") - } - - /** - * Convert [ALL|DISTINCT] to SetQuantifier Enum - */ - private fun convertSetQuantifier(ctx: GeneratedParser.SetQuantifierStrategyContext?): SetQuantifier? = when { - ctx == null -> null - ctx.ALL() != null -> SetQuantifier.ALL() - ctx.DISTINCT() != null -> SetQuantifier.DISTINCT() - else -> throw error(ctx, "Expected set quantifier ALL or DISTINCT") - } - - /** - * With the and nodes of a literal time expression, returns the parsed string and precision. - * TIME ()? (WITH TIME ZONE)? - */ - private fun getTimeStringAndPrecision( - stringNode: TerminalNode, - integerNode: TerminalNode?, - ): Pair { - val timeString = stringNode.getStringValue() - val precision = when (integerNode) { - null -> { - try { - getPrecisionFromTimeString(timeString) - } catch (e: Exception) { - throw error(stringNode.symbol, "Unable to parse precision.", e) - } - } - else -> { - val p = integerNode.text.toBigInteger().toInt() - if (p < 0 || 9 < p) throw error(integerNode.symbol, "Precision out of bounds") - p - } - } - return timeString to precision - } - - private fun getPrecisionFromTimeString(timeString: String): Int { - val matcher = GENERIC_TIME_REGEX.toPattern().matcher(timeString) - if (!matcher.find()) { - throw IllegalArgumentException("Time string does not match the format 'HH:MM:SS[.ddd....][+|-HH:MM]'") - } - val fraction = matcher.group(1)?.removePrefix(".") - return fraction?.length ?: 0 - } - - /** - * Converts a Path expression into a Projection Item (either ALL or EXPR). Note: A Projection Item only allows a - * subset of a typical Path expressions. See the following examples. - * - * Examples of valid projections are: - * - * ```partiql - * SELECT * FROM foo - * SELECT foo.* FROM foo - * SELECT f.* FROM foo as f - * SELECT foo.bar.* FROM foo - * SELECT f.bar.* FROM foo as f - * ``` - * Also validates that the expression is valid for select list context. It does this by making - * sure that expressions looking like the following do not appear: - * - * ```partiql - * SELECT foo[*] FROM foo - * SELECT f.*.bar FROM foo as f - * SELECT foo[1].* FROM foo - * SELECT foo.*.bar FROM foo - * ``` - */ - protected fun convertPathToProjectionItem(ctx: ParserRuleContext, path: ExprPath, alias: Identifier?) = - translate(ctx) { - val steps = mutableListOf() - var containsIndex = false - var curStep = path.next - var last = curStep - while (curStep != null) { - val isLastStep = curStep.next == null - // Only last step can have a '.*' - if (curStep is PathStep.AllFields && !isLastStep) { - throw error(ctx, "Projection item cannot unpivot unless at end.") - } - // No step can have an indexed wildcard: '[*]' - if (curStep is PathStep.AllElements) { - throw error(ctx, "Projection item cannot index using wildcard.") - } - // TODO If the last step is '.*', no indexing is allowed - // if (step.metas.containsKey(IsPathIndexMeta.TAG)) { - // containsIndex = true - // } - if (curStep !is PathStep.AllFields) { - steps.add(curStep) - } - - if (isLastStep && curStep is PathStep.AllFields && containsIndex) { - throw error(ctx, "Projection item use wildcard with any indexing.") - } - last = curStep - curStep = curStep.next - } - when { - last is PathStep.AllFields && steps.isEmpty() -> { - selectItemStar(path.root) - } - last is PathStep.AllFields -> { - val init: PathStep? = null - val newSteps = steps.reversed().fold(init) { acc, step -> - when (step) { - is PathStep.Element -> PathStep.Element(step.element, acc) - is PathStep.Field -> PathStep.Field(step.field, acc) - is PathStep.AllElements -> PathStep.AllElements(acc) - is PathStep.AllFields -> PathStep.AllFields(acc) - else -> error("Unexpected path step") - } - } - selectItemStar(exprPath(path.root, newSteps)) - } - else -> { - selectItemExpr(path, alias) - } - } - } - - private fun TerminalNode.getStringValue(): String = this.symbol.getStringValue() - - private fun Token.getStringValue(): String = when (this.type) { - GeneratedParser.IDENTIFIER -> this.text - GeneratedParser.IDENTIFIER_QUOTED -> this.text.removePrefix("\"").removeSuffix("\"").replace("\"\"", "\"") - GeneratedParser.LITERAL_STRING -> this.text.removePrefix("'").removeSuffix("'").replace("''", "'") - GeneratedParser.ION_CLOSURE -> this.text.removePrefix("`").removeSuffix("`") - else -> throw error(this, "Unsupported token for grabbing string value.") - } - - private fun String.toIdentifier(): Identifier = identifier(this, false) - - private fun String.toIdentifierChain(): IdentifierChain = identifierChain(root = this.toIdentifier(), next = null) - - private fun String.toBigInteger() = BigInteger(this, 10) - - private fun assertIntegerElement(token: Token?, value: IonElement?) { - if (value == null || token == null) return - if (value !is IntElement) throw error(token, "Expected an integer value.") - if (value.integerSize == IntElementSize.BIG_INTEGER || value.longValue > Int.MAX_VALUE || value.longValue < Int.MIN_VALUE) throw error( - token, "Type parameter exceeded maximum value" - ) - } - - private enum class ExplainParameters { - TYPE, FORMAT; - - fun getCompliantString(target: String?, input: Token): String = when (target) { - null -> input.text!! - else -> throw error(input, "Cannot set EXPLAIN parameter ${this.name} multiple times.") - } - } - } -} diff --git a/partiql-parser/src/test/kotlin/org/partiql/parser/internal/ParserTestCaseSimple.kt b/partiql-parser/src/test/kotlin/org/partiql/parser/internal/ParserTestCaseSimple.kt index a3e526948d..8005d79def 100644 --- a/partiql-parser/src/test/kotlin/org/partiql/parser/internal/ParserTestCaseSimple.kt +++ b/partiql-parser/src/test/kotlin/org/partiql/parser/internal/ParserTestCaseSimple.kt @@ -1,6 +1,6 @@ package org.partiql.parser.internal -import org.partiql.parser.PartiQLParserV1 +import org.partiql.parser.PartiQLParser /** * This test case simply cares about whether the [input] can be parsed or not. @@ -13,7 +13,7 @@ class ParserTestCaseSimple( override fun name(): String = name - private val parser: PartiQLParserV1 = PartiQLParserV1.standard() + private val parser: PartiQLParser = PartiQLParser.standard() override fun assert() { when (isValid) { diff --git a/partiql-parser/src/test/kotlin/org/partiql/parser/internal/PartiQLParserBagOpTests.kt b/partiql-parser/src/test/kotlin/org/partiql/parser/internal/PartiQLParserBagOpTests.kt index 52794b27fa..badff87bc3 100644 --- a/partiql-parser/src/test/kotlin/org/partiql/parser/internal/PartiQLParserBagOpTests.kt +++ b/partiql-parser/src/test/kotlin/org/partiql/parser/internal/PartiQLParserBagOpTests.kt @@ -26,7 +26,7 @@ import kotlin.test.assertEquals class PartiQLParserBagOpTests { - private val parser = PartiQLParserDefaultV1() + private val parser = PartiQLParserDefault() private fun queryBody(body: () -> Expr) = query(body()) diff --git a/partiql-parser/src/test/kotlin/org/partiql/parser/internal/PartiQLParserDDLTests.kt b/partiql-parser/src/test/kotlin/org/partiql/parser/internal/PartiQLParserDDLTests.kt index dbbacafdf5..c9e9ada52f 100644 --- a/partiql-parser/src/test/kotlin/org/partiql/parser/internal/PartiQLParserDDLTests.kt +++ b/partiql-parser/src/test/kotlin/org/partiql/parser/internal/PartiQLParserDDLTests.kt @@ -5,7 +5,7 @@ import kotlin.test.assertEquals class PartiQLParserDDLTests { - private val parser = PartiQLParserDefaultV1() + private val parser = PartiQLParserDefault() data class SuccessTestCase( val description: String? = null, diff --git a/partiql-parser/src/test/kotlin/org/partiql/parser/internal/PartiQLParserFunctionCallTests.kt b/partiql-parser/src/test/kotlin/org/partiql/parser/internal/PartiQLParserFunctionCallTests.kt index 2bda30224f..f529bcbe4c 100644 --- a/partiql-parser/src/test/kotlin/org/partiql/parser/internal/PartiQLParserFunctionCallTests.kt +++ b/partiql-parser/src/test/kotlin/org/partiql/parser/internal/PartiQLParserFunctionCallTests.kt @@ -11,7 +11,7 @@ import kotlin.test.assertEquals class PartiQLParserFunctionCallTests { - private val parser = PartiQLParserDefaultV1() + private val parser = PartiQLParserDefault() private inline fun queryBody(body: () -> Expr) = query(body()) diff --git a/partiql-parser/src/test/kotlin/org/partiql/parser/internal/PartiQLParserOperatorTests.kt b/partiql-parser/src/test/kotlin/org/partiql/parser/internal/PartiQLParserOperatorTests.kt index a5c8634066..1abeb2b66b 100644 --- a/partiql-parser/src/test/kotlin/org/partiql/parser/internal/PartiQLParserOperatorTests.kt +++ b/partiql-parser/src/test/kotlin/org/partiql/parser/internal/PartiQLParserOperatorTests.kt @@ -13,7 +13,7 @@ import kotlin.test.assertEquals @OptIn(PartiQLValueExperimental::class) class PartiQLParserOperatorTests { - private val parser = PartiQLParserDefaultV1() + private val parser = PartiQLParserDefault() private inline fun queryBody(body: () -> Expr) = query(body()) diff --git a/partiql-parser/src/test/kotlin/org/partiql/parser/internal/PartiQLParserSessionAttributeTests.kt b/partiql-parser/src/test/kotlin/org/partiql/parser/internal/PartiQLParserSessionAttributeTests.kt index 6dd55d49e8..125e3817bb 100644 --- a/partiql-parser/src/test/kotlin/org/partiql/parser/internal/PartiQLParserSessionAttributeTests.kt +++ b/partiql-parser/src/test/kotlin/org/partiql/parser/internal/PartiQLParserSessionAttributeTests.kt @@ -15,7 +15,7 @@ import kotlin.test.assertEquals @OptIn(PartiQLValueExperimental::class) class PartiQLParserSessionAttributeTests { - private val parser = PartiQLParserDefaultV1() + private val parser = PartiQLParserDefault() private inline fun queryBody(body: () -> Expr) = query(body()) diff --git a/partiql-planner/src/main/kotlin/org/partiql/planner/internal/SqlPlanner.kt b/partiql-planner/src/main/kotlin/org/partiql/planner/internal/SqlPlanner.kt index 3bbd30250b..1733c559ac 100644 --- a/partiql-planner/src/main/kotlin/org/partiql/planner/internal/SqlPlanner.kt +++ b/partiql-planner/src/main/kotlin/org/partiql/planner/internal/SqlPlanner.kt @@ -8,8 +8,8 @@ import org.partiql.plan.rex.Rex import org.partiql.planner.PartiQLPlanner import org.partiql.planner.PartiQLPlannerPass import org.partiql.planner.internal.normalize.normalize +import org.partiql.planner.internal.transforms.AstToPlan import org.partiql.planner.internal.transforms.PlanTransform -import org.partiql.planner.internal.transforms.V1AstToPlan import org.partiql.planner.internal.typer.PlanTyper import org.partiql.spi.Context import org.partiql.spi.catalog.Session @@ -39,7 +39,7 @@ internal class SqlPlanner( val ast = statement.normalize() // 2. AST to Rel/Rex - val root = V1AstToPlan.apply(ast, env) + val root = AstToPlan.apply(ast, env) // 3. Resolve variables val typer = PlanTyper(env, ctx) diff --git a/partiql-planner/src/main/kotlin/org/partiql/planner/internal/transforms/AstToPlan.kt b/partiql-planner/src/main/kotlin/org/partiql/planner/internal/transforms/AstToPlan.kt index 5b51624bd3..1796b45331 100644 --- a/partiql-planner/src/main/kotlin/org/partiql/planner/internal/transforms/AstToPlan.kt +++ b/partiql-planner/src/main/kotlin/org/partiql/planner/internal/transforms/AstToPlan.kt @@ -16,14 +16,16 @@ package org.partiql.planner.internal.transforms -import org.partiql.ast.AstNode -import org.partiql.ast.Expr -import org.partiql.ast.visitor.AstBaseVisitor +import org.partiql.ast.v1.AstNode +import org.partiql.ast.v1.AstVisitor +import org.partiql.ast.v1.Query +import org.partiql.ast.v1.expr.ExprQuerySet import org.partiql.planner.internal.Env import org.partiql.planner.internal.ir.statementQuery import org.partiql.spi.catalog.Identifier -import org.partiql.ast.Identifier as AstIdentifier -import org.partiql.ast.Statement as AstStatement +import org.partiql.ast.v1.Identifier as AstIdentifier +import org.partiql.ast.v1.IdentifierChain as AstIdentifierChain +import org.partiql.ast.v1.Statement as AstStatement import org.partiql.planner.internal.ir.Statement as PlanStatement /** @@ -36,13 +38,13 @@ internal object AstToPlan { fun apply(statement: AstStatement, env: Env): PlanStatement = statement.accept(ToPlanStatement, env) @Suppress("PARAMETER_NAME_CHANGED_ON_OVERRIDE") - private object ToPlanStatement : AstBaseVisitor() { + private object ToPlanStatement : AstVisitor() { override fun defaultReturn(node: AstNode, env: Env) = throw IllegalArgumentException("Unsupported statement") - override fun visitStatementQuery(node: AstStatement.Query, env: Env): PlanStatement { + override fun visitQuery(node: Query, env: Env): PlanStatement { val rex = when (val expr = node.expr) { - is Expr.QuerySet -> RelConverter.apply(expr, env) + is ExprQuerySet -> RelConverter.apply(expr, env) else -> RexConverter.apply(expr, env) } return statementQuery(rex) @@ -51,24 +53,23 @@ internal object AstToPlan { // --- Helpers -------------------- - fun convert(identifier: AstIdentifier): Identifier = when (identifier) { - is AstIdentifier.Qualified -> convert(identifier) - is AstIdentifier.Symbol -> convert(identifier) - } - - fun convert(identifier: AstIdentifier.Qualified): Identifier { + fun convert(identifier: AstIdentifierChain): Identifier { val parts = mutableListOf() parts.add(part(identifier.root)) - parts.addAll(identifier.steps.map { part(it) }) + var curStep = identifier.next + while (curStep != null) { + parts.add(part(curStep.root)) + curStep = curStep.next + } return Identifier.of(parts) } - fun convert(identifier: AstIdentifier.Symbol): Identifier { + fun convert(identifier: AstIdentifier): Identifier { return Identifier.of(part(identifier)) } - fun part(identifier: AstIdentifier.Symbol): Identifier.Part = when (identifier.caseSensitivity) { - AstIdentifier.CaseSensitivity.SENSITIVE -> Identifier.Part.delimited(identifier.symbol) - AstIdentifier.CaseSensitivity.INSENSITIVE -> Identifier.Part.regular(identifier.symbol) + fun part(identifier: AstIdentifier): Identifier.Part = when (identifier.isDelimited) { + true -> Identifier.Part.delimited(identifier.symbol) + false -> Identifier.Part.regular(identifier.symbol) } } diff --git a/partiql-planner/src/main/kotlin/org/partiql/planner/internal/transforms/NormalizeSelect.kt b/partiql-planner/src/main/kotlin/org/partiql/planner/internal/transforms/NormalizeSelect.kt index 0354c39f97..da96a785a6 100644 --- a/partiql-planner/src/main/kotlin/org/partiql/planner/internal/transforms/NormalizeSelect.kt +++ b/partiql-planner/src/main/kotlin/org/partiql/planner/internal/transforms/NormalizeSelect.kt @@ -14,27 +14,42 @@ package org.partiql.planner.internal.transforms -import org.partiql.ast.Expr -import org.partiql.ast.From -import org.partiql.ast.GroupBy -import org.partiql.ast.Identifier -import org.partiql.ast.QueryBody -import org.partiql.ast.Select -import org.partiql.ast.exprCall -import org.partiql.ast.exprCase -import org.partiql.ast.exprCaseBranch -import org.partiql.ast.exprIsType -import org.partiql.ast.exprLit -import org.partiql.ast.exprStruct -import org.partiql.ast.exprStructField -import org.partiql.ast.exprVar -import org.partiql.ast.helpers.toBinder -import org.partiql.ast.identifierSymbol -import org.partiql.ast.selectProject -import org.partiql.ast.selectProjectItemExpression -import org.partiql.ast.selectValue -import org.partiql.ast.typeStruct -import org.partiql.ast.util.AstRewriter +import org.partiql.ast.v1.Ast.exprCall +import org.partiql.ast.v1.Ast.exprCase +import org.partiql.ast.v1.Ast.exprCaseBranch +import org.partiql.ast.v1.Ast.exprIsType +import org.partiql.ast.v1.Ast.exprLit +import org.partiql.ast.v1.Ast.exprQuerySet +import org.partiql.ast.v1.Ast.exprStruct +import org.partiql.ast.v1.Ast.exprStructField +import org.partiql.ast.v1.Ast.exprVarRef +import org.partiql.ast.v1.Ast.identifier +import org.partiql.ast.v1.Ast.identifierChain +import org.partiql.ast.v1.Ast.queryBodySFW +import org.partiql.ast.v1.Ast.queryBodySetOp +import org.partiql.ast.v1.Ast.selectItemExpr +import org.partiql.ast.v1.Ast.selectList +import org.partiql.ast.v1.Ast.selectValue +import org.partiql.ast.v1.AstRewriter +import org.partiql.ast.v1.DataType +import org.partiql.ast.v1.From +import org.partiql.ast.v1.FromExpr +import org.partiql.ast.v1.FromJoin +import org.partiql.ast.v1.FromTableRef +import org.partiql.ast.v1.GroupBy +import org.partiql.ast.v1.QueryBody +import org.partiql.ast.v1.SelectItem +import org.partiql.ast.v1.SelectList +import org.partiql.ast.v1.SelectStar +import org.partiql.ast.v1.SelectValue +import org.partiql.ast.v1.expr.Expr +import org.partiql.ast.v1.expr.ExprCase +import org.partiql.ast.v1.expr.ExprLit +import org.partiql.ast.v1.expr.ExprQuerySet +import org.partiql.ast.v1.expr.ExprStruct +import org.partiql.ast.v1.expr.ExprVarRef +import org.partiql.ast.v1.expr.Scope +import org.partiql.planner.internal.helpers.toBinder import org.partiql.value.PartiQLValueExperimental import org.partiql.value.stringValue @@ -98,30 +113,39 @@ import org.partiql.value.stringValue */ internal object NormalizeSelect { - internal fun normalize(node: Expr.QuerySet): Expr.QuerySet { + internal fun normalize(node: ExprQuerySet): ExprQuerySet { return when (val body = node.body) { is QueryBody.SFW -> { val sfw = Visitor.visitSFW(body, newCtx()) - node.copy( - body = sfw + exprQuerySet( + body = sfw, + orderBy = node.orderBy, + limit = node.limit, + offset = node.offset ) } is QueryBody.SetOp -> { val lhs = body.lhs.normalizeOrIdentity() val rhs = body.rhs.normalizeOrIdentity() - node.copy( - body = body.copy( + exprQuerySet( + body = queryBodySetOp( + type = body.type, + isOuter = body.isOuter, lhs = lhs, rhs = rhs - ) + ), + orderBy = node.orderBy, + limit = node.limit, + offset = node.offset ) } + else -> error("Unexpected QueryBody type: $body") } } private fun Expr.normalizeOrIdentity(): Expr { return when (this) { - is Expr.QuerySet -> normalize(this) + is ExprQuerySet -> normalize(this) else -> this } } @@ -164,12 +188,20 @@ internal object NormalizeSelect { internal fun visitSFW(node: QueryBody.SFW, ctx: () -> Int): QueryBody.SFW { val sfw = super.visitQueryBodySFW(node, ctx) as QueryBody.SFW return when (val select = sfw.select) { - is Select.Star -> { + is SelectStar -> { val selectValue = when (val group = sfw.groupBy) { null -> visitSelectAll(select, sfw.from) else -> visitSelectAll(select, group) } - sfw.copy(select = selectValue) + queryBodySFW( + select = selectValue, + exclude = sfw.exclude, + from = sfw.from, + let = sfw.let, + where = sfw.where, + groupBy = sfw.groupBy, + having = sfw.having, + ) } else -> sfw } @@ -179,33 +211,33 @@ internal object NormalizeSelect { return node } - override fun visitSelectProject(node: Select.Project, ctx: () -> Int): Select.Value { + override fun visitSelectList(node: SelectList, ctx: () -> Int): SelectValue { // Visit items, adding a binder if necessary var diff = false - val visitedItems = ArrayList(node.items.size) + val visitedItems = ArrayList(node.items.size) node.items.forEach { n -> - val item = visitSelectProjectItem(n, ctx) as Select.Project.Item + val item = n.accept(this, ctx) as SelectItem if (item !== n) diff = true visitedItems.add(item) } - val visitedNode = if (diff) selectProject(visitedItems, node.setq) else node + val visitedNode = if (diff) selectList(visitedItems, node.setq) else node // Rewrite selection - return when (node.items.any { it is Select.Project.Item.All }) { + return when (node.items.any { it is SelectItem.Star }) { false -> visitSelectProjectWithoutProjectAll(visitedNode) true -> visitSelectProjectWithProjectAll(visitedNode) } } - override fun visitSelectProjectItemExpression(node: Select.Project.Item.Expression, ctx: () -> Int): Select.Project.Item.Expression { + override fun visitSelectItemExpr(node: SelectItem.Expr, ctx: () -> Int): SelectItem.Expr { val expr = visitExpr(node.expr, newCtx()) as Expr val alias = when (node.asAlias) { null -> expr.toBinder(ctx) else -> node.asAlias } return if (expr != node.expr || alias != node.asAlias) { - selectProjectItemExpression(expr, alias) + selectItemExpr(expr, alias) } else { node } @@ -219,28 +251,22 @@ internal object NormalizeSelect { * * Note: We assume that [select] and [from] have already been visited. */ - private fun visitSelectAll(select: Select.Star, from: From): Select.Value { - val tupleUnionArgs = from.aliases().flatMapIndexed { i, binding -> + private fun visitSelectAll(select: SelectStar, from: From): SelectValue { + val tupleUnionArgs = from.tableRefs.flatMap { it.aliases() }.flatMapIndexed { i, binding -> val asAlias = binding.first val atAlias = binding.second - val byAlias = binding.third val atAliasItem = atAlias?.simple()?.let { val alias = it.asAlias ?: error("The AT alias should be present. This wasn't normalized.") buildSimpleStruct(it.expr, alias.symbol) } - val byAliasItem = byAlias?.simple()?.let { - val alias = it.asAlias ?: error("The BY alias should be present. This wasn't normalized.") - buildSimpleStruct(it.expr, alias.symbol) - } listOfNotNull( buildCaseWhenStruct(asAlias.star(i).expr, i), atAliasItem, - byAliasItem ) } return selectValue( constructor = exprCall( - function = identifierSymbol("TUPLEUNION", Identifier.CaseSensitivity.SENSITIVE), + function = identifierChain(identifier("TUPLEUNION", isDelimited = true), next = null), args = tupleUnionArgs, setq = null // setq = null for scalar fn ), @@ -254,7 +280,7 @@ internal object NormalizeSelect { * * Note: We assume that [select] and [group] have already been visited. */ - private fun visitSelectAll(select: Select.Star, group: GroupBy): Select.Value { + private fun visitSelectAll(select: SelectStar, group: GroupBy): SelectValue { val groupAs = group.asAlias?.let { structField(it.symbol, varLocal(it.symbol)) } val fields = group.keys.map { key -> val alias = key.asAlias ?: error("Expected a GROUP BY alias.") @@ -267,21 +293,22 @@ internal object NormalizeSelect { ) } - private fun visitSelectProjectWithProjectAll(node: Select.Project): Select.Value { + private fun visitSelectProjectWithProjectAll(node: SelectList): SelectValue { val tupleUnionArgs = node.items.mapIndexed { index, item -> when (item) { - is Select.Project.Item.All -> buildCaseWhenStruct(item.expr, index) - is Select.Project.Item.Expression -> buildSimpleStruct( + is SelectItem.Star -> buildCaseWhenStruct(item.expr, index) + is SelectItem.Expr -> buildSimpleStruct( item.expr, item.asAlias?.symbol ?: error("The alias should've been here. This AST is not normalized.") ) + else -> error("Unexpected SelectItem type: $item") } } return selectValue( setq = node.setq, constructor = exprCall( - function = identifierSymbol("TUPLEUNION", Identifier.CaseSensitivity.SENSITIVE), + function = identifierChain(identifier("TUPLEUNION", isDelimited = true), next = null), args = tupleUnionArgs, setq = null // setq = null for scalar fn ) @@ -289,9 +316,9 @@ internal object NormalizeSelect { } @OptIn(PartiQLValueExperimental::class) - private fun visitSelectProjectWithoutProjectAll(node: Select.Project): Select.Value { + private fun visitSelectProjectWithoutProjectAll(node: SelectList): SelectValue { val structFields = node.items.map { item -> - val itemExpr = item as? Select.Project.Item.Expression ?: error("Expected the projection to be an expression.") + val itemExpr = item as? SelectItem.Expr ?: error("Expected the projection to be an expression.") exprStructField( name = exprLit(stringValue(itemExpr.asAlias?.symbol!!)), value = item.expr @@ -305,20 +332,19 @@ internal object NormalizeSelect { ) } - @OptIn(PartiQLValueExperimental::class) - private fun buildCaseWhenStruct(expr: Expr, index: Int): Expr.Case = exprCase( + private fun buildCaseWhenStruct(expr: Expr, index: Int): ExprCase = exprCase( expr = null, branches = listOf( exprCaseBranch( - condition = exprIsType(expr, typeStruct(), null), + condition = exprIsType(expr, DataType.STRUCT(), not = false), expr = expr ) ), - default = buildSimpleStruct(expr, col(index)) + defaultExpr = buildSimpleStruct(expr, col(index)) ) @OptIn(PartiQLValueExperimental::class) - private fun buildSimpleStruct(expr: Expr, name: String): Expr.Struct = exprStruct( + private fun buildSimpleStruct(expr: Expr, name: String): ExprStruct = exprStruct( fields = listOf( exprStructField( name = exprLit(stringValue(name)), @@ -328,40 +354,40 @@ internal object NormalizeSelect { ) @OptIn(PartiQLValueExperimental::class) - private fun structField(name: String, expr: Expr): Expr.Struct.Field = Expr.Struct.Field( - name = Expr.Lit(stringValue(name)), + private fun structField(name: String, expr: Expr): ExprStruct.Field = exprStructField( + name = ExprLit(stringValue(name)), value = expr ) - private fun varLocal(name: String): Expr.Var = Expr.Var( - identifier = Identifier.Symbol(name, Identifier.CaseSensitivity.SENSITIVE), - scope = Expr.Var.Scope.LOCAL + private fun varLocal(name: String): ExprVarRef = exprVarRef( + identifierChain = identifierChain(identifier(name, isDelimited = true), next = null), + scope = Scope.LOCAL() ) - private fun From.aliases(): List> = when (this) { - is From.Join -> lhs.aliases() + rhs.aliases() - is From.Value -> { + private fun FromTableRef.aliases(): List> = when (this) { + is FromJoin -> lhs.aliases() + rhs.aliases() + is FromExpr -> { val asAlias = asAlias?.symbol ?: error("AST not normalized, missing asAlias on FROM source.") val atAlias = atAlias?.symbol - val byAlias = byAlias?.symbol - listOf(Triple(asAlias, atAlias, byAlias)) + listOf(Pair(asAlias, atAlias)) } + else -> error("Unexpected FromTableRef type: $this") } // t -> t.* AS _i - private fun String.star(i: Int): Select.Project.Item.Expression { - val expr = exprVar(id(this), Expr.Var.Scope.DEFAULT) + private fun String.star(i: Int): SelectItem.Expr { + val expr = exprVarRef(identifierChain(id(this), next = null), Scope.DEFAULT()) val alias = expr.toBinder(i) - return selectProjectItemExpression(expr, alias) + return selectItemExpr(expr, alias) } // t -> t AS t - private fun String.simple(): Select.Project.Item.Expression { - val expr = exprVar(id(this), Expr.Var.Scope.DEFAULT) + private fun String.simple(): SelectItem.Expr { + val expr = exprVarRef(identifierChain(id(this), next = null), Scope.DEFAULT()) val alias = id(this) - return selectProjectItemExpression(expr, alias) + return selectItemExpr(expr, alias) } - private fun id(symbol: String) = identifierSymbol(symbol, Identifier.CaseSensitivity.INSENSITIVE) + private fun id(symbol: String) = identifier(symbol, isDelimited = false) } } diff --git a/partiql-planner/src/main/kotlin/org/partiql/planner/internal/transforms/RelConverter.kt b/partiql-planner/src/main/kotlin/org/partiql/planner/internal/transforms/RelConverter.kt index 6579a873fa..f6e48421c9 100644 --- a/partiql-planner/src/main/kotlin/org/partiql/planner/internal/transforms/RelConverter.kt +++ b/partiql-planner/src/main/kotlin/org/partiql/planner/internal/transforms/RelConverter.kt @@ -16,26 +16,40 @@ package org.partiql.planner.internal.transforms -import org.partiql.ast.AstNode -import org.partiql.ast.Exclude -import org.partiql.ast.Expr -import org.partiql.ast.From -import org.partiql.ast.GroupBy -import org.partiql.ast.Identifier -import org.partiql.ast.OrderBy -import org.partiql.ast.QueryBody -import org.partiql.ast.Select -import org.partiql.ast.SetOp -import org.partiql.ast.SetQuantifier -import org.partiql.ast.Sort -import org.partiql.ast.builder.ast -import org.partiql.ast.exprLit -import org.partiql.ast.exprVar -import org.partiql.ast.helpers.toBinder -import org.partiql.ast.identifierSymbol -import org.partiql.ast.util.AstRewriter -import org.partiql.ast.visitor.AstBaseVisitor +import org.partiql.ast.v1.Ast.exprLit +import org.partiql.ast.v1.Ast.exprVarRef +import org.partiql.ast.v1.Ast.identifier +import org.partiql.ast.v1.Ast.identifierChain +import org.partiql.ast.v1.AstNode +import org.partiql.ast.v1.AstRewriter +import org.partiql.ast.v1.AstVisitor +import org.partiql.ast.v1.Exclude +import org.partiql.ast.v1.ExcludeStep +import org.partiql.ast.v1.From +import org.partiql.ast.v1.FromExpr +import org.partiql.ast.v1.FromJoin +import org.partiql.ast.v1.FromType +import org.partiql.ast.v1.GroupBy +import org.partiql.ast.v1.GroupByStrategy +import org.partiql.ast.v1.IdentifierChain +import org.partiql.ast.v1.JoinType +import org.partiql.ast.v1.Nulls +import org.partiql.ast.v1.Order +import org.partiql.ast.v1.OrderBy +import org.partiql.ast.v1.QueryBody +import org.partiql.ast.v1.SelectItem +import org.partiql.ast.v1.SelectList +import org.partiql.ast.v1.SelectPivot +import org.partiql.ast.v1.SelectStar +import org.partiql.ast.v1.SelectValue +import org.partiql.ast.v1.SetOpType +import org.partiql.ast.v1.SetQuantifier +import org.partiql.ast.v1.expr.Expr +import org.partiql.ast.v1.expr.ExprCall +import org.partiql.ast.v1.expr.ExprQuerySet +import org.partiql.ast.v1.expr.Scope import org.partiql.planner.internal.Env +import org.partiql.planner.internal.helpers.toBinder import org.partiql.planner.internal.ir.Rel import org.partiql.planner.internal.ir.Rex import org.partiql.planner.internal.ir.rel @@ -88,14 +102,14 @@ internal object RelConverter { /** * Here we convert an SFW to composed [Rel]s, then apply the appropriate relation-value projection to get a [Rex]. */ - internal fun apply(qSet: Expr.QuerySet, env: Env): Rex { + internal fun apply(qSet: ExprQuerySet, env: Env): Rex { val newQSet = NormalizeSelect.normalize(qSet) val rex = when (val body = newQSet.body) { is QueryBody.SFW -> { val rel = newQSet.accept(ToRel(env), nil) when (val projection = body.select) { // PIVOT ... FROM - is Select.Pivot -> { + is SelectPivot -> { val key = projection.key.toRex(env) val value = projection.value.toRex(env) val type = (STRUCT) @@ -103,7 +117,7 @@ internal object RelConverter { rex(type, op) } // SELECT VALUE ... FROM - is Select.Value -> { + is SelectValue -> { assert(rel.type.schema.size == 1) { "Expected SELECT VALUE's input to have a single binding. " + "However, it contained: ${rel.type.schema.map { it.name }}." @@ -117,13 +131,14 @@ internal object RelConverter { rex(type, op) } // SELECT * FROM - is Select.Star -> { + is SelectStar -> { throw IllegalArgumentException("AST not normalized") } // SELECT ... FROM - is Select.Project -> { + is SelectList -> { throw IllegalArgumentException("AST not normalized") } + else -> error("Unexpected Select type: $projection") } } is QueryBody.SetOp -> { @@ -136,6 +151,7 @@ internal object RelConverter { } rex(type, op) } + else -> error("Unexpected QueryBody type: ${newQSet.body}") } return rex } @@ -146,7 +162,7 @@ internal object RelConverter { private fun Expr.toRex(env: Env): Rex = RexConverter.apply(this, env) @Suppress("PARAMETER_NAME_CHANGED_ON_OVERRIDE", "LocalVariableName") - internal class ToRel(private val env: Env) : AstBaseVisitor() { + internal class ToRel(private val env: Env) : AstVisitor() { override fun defaultReturn(node: AstNode, input: Rel): Rel = throw IllegalArgumentException("unsupported rel $node") @@ -155,7 +171,7 @@ internal object RelConverter { * Translate SFW AST node to a pipeline of [Rel] operators; skip any SELECT VALUE or PIVOT projection. */ - override fun visitExprQuerySet(node: Expr.QuerySet, ctx: Rel): Rel { + override fun visitExprQuerySet(node: ExprQuerySet, ctx: Rel): Rel { val body = node.body val orderBy = node.orderBy val limit = node.limit @@ -178,14 +194,15 @@ internal object RelConverter { rel = convertExclude(rel, sel.exclude) // append SQL projection if present rel = when (val projection = sel.select) { - is Select.Value -> { + is SelectValue -> { val project = visitSelectValue(projection, rel) visitSetQuantifier(projection.setq, project) } - is Select.Star, is Select.Project -> { + is SelectStar, is SelectList -> { error("AST not normalized, found ${projection.javaClass.simpleName}") } - is Select.Pivot -> rel // Skip PIVOT + is SelectPivot -> rel // Skip PIVOT + else -> error("Unexpected Select type: $projection") } return rel } @@ -197,27 +214,29 @@ internal object RelConverter { rel = convertLimit(rel, limit) return rel } + else -> error("Unexpected QueryBody type: $body") } } /** - * Given a non-null [setQuantifier], this will return a [Rel] of [Rel.Op.Distinct] wrapping the [input]. + * Given a [setQuantifier], this will return a [Rel] of [Rel.Op.Distinct] wrapping the [input]. * If [setQuantifier] is null or ALL, this will return the [input]. */ private fun visitSetQuantifier(setQuantifier: SetQuantifier?, input: Rel): Rel { - return when (setQuantifier) { + return when (setQuantifier?.code()) { SetQuantifier.DISTINCT -> rel(input.type, relOpDistinct(input)) SetQuantifier.ALL, null -> input + else -> error("Unexpected SetQuantifier type: $setQuantifier") } } - override fun visitSelectProject(node: Select.Project, input: Rel): Rel { + override fun visitSelectList(node: SelectList, input: Rel): Rel { // this ignores aggregations val schema = mutableListOf() val props = input.type.props val projections = mutableListOf() node.items.forEach { - val (binding, projection) = convertProjectionItem(it) + val (binding, projection) = convertSelectItem(it) schema.add(binding) projections.add(projection) } @@ -226,7 +245,7 @@ internal object RelConverter { return rel(type, op) } - override fun visitSelectValue(node: Select.Value, input: Rel): Rel { + override fun visitSelectValue(node: SelectValue, input: Rel): Rel { val name = node.constructor.toBinder(1).symbol val rex = RexConverter.apply(node.constructor, env) val schema = listOf(relBinding(name, rex.type)) @@ -236,7 +255,21 @@ internal object RelConverter { return rel(type, op) } - override fun visitFromValue(node: From.Value, nil: Rel): Rel { + @OptIn(PartiQLValueExperimental::class) + override fun visitFrom(node: From, ctx: Rel): Rel { + val tableRefs = node.tableRefs.map { visitFromTableRef(it, ctx) } + return tableRefs.drop(1).fold(tableRefs.first()) { acc, tRef -> + val joinType = Rel.Op.Join.Type.INNER + val condition = rex(BOOL, rexOpLit(boolValue(true))) + val schema = acc.type.schema + tRef.type.schema + val props = emptySet() + val type = relType(schema, props) + val op = relOpJoin(acc, tRef, condition, joinType) + rel(type, op) + } + } + + override fun visitFromExpr(node: FromExpr, nil: Rel): Rel { val rex = RexConverter.applyRel(node.expr, env) val binding = when (val a = node.asAlias) { null -> error("AST not normalized, missing AS alias on $node") @@ -245,8 +278,8 @@ internal object RelConverter { type = rex.type ) } - return when (node.type) { - From.Value.Type.SCAN -> { + return when (node.fromType.code()) { + FromType.SCAN -> { when (val i = node.atAlias) { null -> convertScan(rex, binding) else -> { @@ -258,7 +291,7 @@ internal object RelConverter { } } } - From.Value.Type.UNPIVOT -> { + FromType.UNPIVOT -> { val atAlias = when (val at = node.atAlias) { null -> error("AST not normalized, missing AT alias on UNPIVOT $node") else -> relBinding( @@ -268,6 +301,7 @@ internal object RelConverter { } convertUnpivot(rex, k = atAlias, v = binding) } + else -> error("Unexpected FromType type: ${node.fromType}") } } @@ -277,20 +311,20 @@ internal object RelConverter { * TODO compute basic schema */ @OptIn(PartiQLValueExperimental::class) - override fun visitFromJoin(node: From.Join, nil: Rel): Rel { - val lhs = visitFrom(node.lhs, nil) - val rhs = visitFrom(node.rhs, nil) + override fun visitFromJoin(node: FromJoin, nil: Rel): Rel { + val lhs = visitFromTableRef(node.lhs, nil) + val rhs = visitFromTableRef(node.rhs, nil) val schema = lhs.type.schema + rhs.type.schema // Note: This gets more specific in PlanTyper. It is only used to find binding names here. val props = emptySet() val condition = node.condition?.let { RexConverter.apply(it, env) } ?: rex(BOOL, rexOpLit(boolValue(true))) - val joinType = when (node.type) { - From.Join.Type.LEFT_OUTER, From.Join.Type.LEFT -> Rel.Op.Join.Type.LEFT - From.Join.Type.RIGHT_OUTER, From.Join.Type.RIGHT -> Rel.Op.Join.Type.RIGHT - From.Join.Type.FULL_OUTER, From.Join.Type.FULL -> Rel.Op.Join.Type.FULL - From.Join.Type.COMMA, - From.Join.Type.INNER, - From.Join.Type.CROSS -> Rel.Op.Join.Type.INNER // Cross Joins are just INNER JOIN ON TRUE + val joinType = when (node.joinType?.code()) { + JoinType.LEFT_OUTER, JoinType.LEFT, JoinType.LEFT_CROSS -> Rel.Op.Join.Type.LEFT + JoinType.RIGHT_OUTER, JoinType.RIGHT -> Rel.Op.Join.Type.RIGHT + JoinType.FULL_OUTER, JoinType.FULL -> Rel.Op.Join.Type.FULL + JoinType.INNER, + JoinType.CROSS -> Rel.Op.Join.Type.INNER // Cross Joins are just INNER JOIN ON TRUE null -> Rel.Op.Join.Type.INNER // a JOIN b ON a.id = b.id <--> a INNER JOIN b ON a.id = b.id + else -> error("Unexpected JoinType type: ${node.joinType}") } val type = relType(schema, props) val op = relOpJoin(lhs, rhs, condition, joinType) @@ -329,18 +363,19 @@ internal object RelConverter { return rel(type, op) } - private fun convertProjectionItem(item: Select.Project.Item) = when (item) { - is Select.Project.Item.All -> convertProjectItemAll(item) - is Select.Project.Item.Expression -> convertProjectItemRex(item) + private fun convertSelectItem(item: SelectItem) = when (item) { + is SelectItem.Star -> convertSelectItemStar(item) + is SelectItem.Expr -> convertSelectItemExpr(item) + else -> error("Unexpected SelectItem type: $item") } - private fun convertProjectItemAll(item: Select.Project.Item.All): Pair { + private fun convertSelectItemStar(item: SelectItem.Star): Pair { throw IllegalArgumentException("AST not normalized") } - private fun convertProjectItemRex(item: Select.Project.Item.Expression): Pair { + private fun convertSelectItemExpr(item: SelectItem.Expr): Pair { val name = when (val a = item.asAlias) { - null -> error("AST not normalized, missing AS alias on projection item $item") + null -> error("AST not normalized, missing AS alias on select item $item") else -> a.symbol } val rex = RexConverter.apply(item.expr, env) @@ -407,10 +442,11 @@ internal object RelConverter { args = listOf(exprLit(int32Value(1)).toRex(env)) ) } else { - val setq = when (expr.setq) { + val setq = when (expr.setq?.code()) { null -> org.partiql.planner.internal.ir.SetQuantifier.ALL SetQuantifier.ALL -> org.partiql.planner.internal.ir.SetQuantifier.ALL SetQuantifier.DISTINCT -> org.partiql.planner.internal.ir.SetQuantifier.DISTINCT + else -> error("Unexpected SetQuantifier type: ${expr.setq}") } relOpAggregateCallUnresolved(name, setq, args) } @@ -444,9 +480,10 @@ internal object RelConverter { schema.add(binding) it.expr.toRex(env) } - strategy = when (groupBy.strategy) { - GroupBy.Strategy.FULL -> Rel.Op.Aggregate.Strategy.FULL - GroupBy.Strategy.PARTIAL -> Rel.Op.Aggregate.Strategy.PARTIAL + strategy = when (groupBy.strategy.code()) { + GroupByStrategy.FULL -> Rel.Op.Aggregate.Strategy.FULL + GroupByStrategy.PARTIAL -> Rel.Op.Aggregate.Strategy.PARTIAL + else -> error("Unexpected GroupByStrategy type: ${groupBy.strategy}") } } val type = relType(schema, props) @@ -473,7 +510,7 @@ internal object RelConverter { private fun visitIfQuerySet(expr: Expr): Rel { return when (expr) { - is Expr.QuerySet -> visit(expr, nil) + is ExprQuerySet -> visit(expr, nil) else -> { val rex = RexConverter.applyRel(expr, env) val op = relOpScan(rex) @@ -490,15 +527,17 @@ internal object RelConverter { val lhs = visitIfQuerySet(setExpr.lhs) val rhs = visitIfQuerySet(setExpr.rhs) val type = Rel.Type(listOf(Rel.Binding("_0", ANY)), props = emptySet()) - val quantifier = when (setExpr.type.setq) { + val quantifier = when (setExpr.type.setq?.code()) { SetQuantifier.ALL -> org.partiql.planner.internal.ir.SetQuantifier.ALL null, SetQuantifier.DISTINCT -> org.partiql.planner.internal.ir.SetQuantifier.DISTINCT + else -> error("Unexpected SetQuantifier type: ${setExpr.type.setq}") } val outer = setExpr.isOuter - val op = when (setExpr.type.type) { - SetOp.Type.UNION -> Rel.Op.Union(quantifier, outer, lhs, rhs) - SetOp.Type.EXCEPT -> Rel.Op.Except(quantifier, outer, lhs, rhs) - SetOp.Type.INTERSECT -> Rel.Op.Intersect(quantifier, outer, lhs, rhs) + val op = when (setExpr.type.setOpType.code()) { + SetOpType.UNION -> Rel.Op.Union(quantifier, outer, lhs, rhs) + SetOpType.EXCEPT -> Rel.Op.Except(quantifier, outer, lhs, rhs) + SetOpType.INTERSECT -> Rel.Op.Intersect(quantifier, outer, lhs, rhs) + else -> error("Unexpected SetOpType type: ${setExpr.type.setOpType}") } return rel(type, op) } @@ -513,14 +552,16 @@ internal object RelConverter { val type = input.type.copy(props = setOf(Rel.Prop.ORDERED)) val specs = orderBy.sorts.map { val rex = it.expr.toRex(env) - val order = when (it.dir) { - Sort.Dir.DESC -> when (it.nulls) { - Sort.Nulls.LAST -> Rel.Op.Sort.Order.DESC_NULLS_LAST - else -> Rel.Op.Sort.Order.DESC_NULLS_FIRST + val order = when (it.order?.code()) { + Order.DESC -> when (it.nulls?.code()) { + Nulls.LAST -> Rel.Op.Sort.Order.DESC_NULLS_LAST + Nulls.FIRST, null -> Rel.Op.Sort.Order.DESC_NULLS_FIRST + else -> error("Unexpected Nulls type: ${it.nulls}") } - else -> when (it.nulls) { - Sort.Nulls.FIRST -> Rel.Op.Sort.Order.ASC_NULLS_FIRST - else -> Rel.Op.Sort.Order.ASC_NULLS_LAST + else -> when (it.nulls?.code()) { + Nulls.FIRST -> Rel.Op.Sort.Order.ASC_NULLS_FIRST + Nulls.LAST, null -> Rel.Op.Sort.Order.ASC_NULLS_LAST + else -> error("Unexpected Nulls type: ${it.nulls}") } } relOpSortSpec(rex, order) @@ -560,8 +601,8 @@ internal object RelConverter { return input } val type = input.type // PlanTyper handles typing the exclusion and removing redundant exclude paths - val paths = exclude.items - .groupBy(keySelector = { it.root }, valueTransform = { it.steps }) + val paths = exclude.excludePaths + .groupBy(keySelector = { it.root }, valueTransform = { it.excludeSteps }) .map { (root, exclusions) -> val rootVar = (root.toRex(env)).op as Rex.Op.Var val steps = exclusionsToSteps(exclusions) @@ -571,7 +612,7 @@ internal object RelConverter { return rel(type, op) } - private fun exclusionsToSteps(exclusions: List>): List { + private fun exclusionsToSteps(exclusions: List>): List { if (exclusions.any { it.isEmpty() }) { // if there exists a path with no further steps, can remove the longer paths // e.g. t.a.b, t.a.b.c, t.a.b.d[*].*.e -> can keep just t.a.b @@ -586,17 +627,18 @@ internal object RelConverter { } } - private fun stepToExcludeType(step: Exclude.Step): Rel.Op.Exclude.Type { + private fun stepToExcludeType(step: ExcludeStep): Rel.Op.Exclude.Type { return when (step) { - is Exclude.Step.StructField -> { - when (step.symbol.caseSensitivity) { - Identifier.CaseSensitivity.INSENSITIVE -> relOpExcludeTypeStructSymbol(step.symbol.symbol) - Identifier.CaseSensitivity.SENSITIVE -> relOpExcludeTypeStructKey(step.symbol.symbol) + is ExcludeStep.StructField -> { + when (step.symbol.isDelimited) { + false -> relOpExcludeTypeStructSymbol(step.symbol.symbol) + true -> relOpExcludeTypeStructKey(step.symbol.symbol) } } - is Exclude.Step.CollIndex -> relOpExcludeTypeCollIndex(step.index) - is Exclude.Step.StructWildcard -> relOpExcludeTypeStructWildcard() - is Exclude.Step.CollWildcard -> relOpExcludeTypeCollWildcard() + is ExcludeStep.CollIndex -> relOpExcludeTypeCollIndex(step.index) + is ExcludeStep.StructWildcard -> relOpExcludeTypeStructWildcard() + is ExcludeStep.CollWildcard -> relOpExcludeTypeCollWildcard() + else -> error("Unexpected ExcludeStep type: $step") } } @@ -638,22 +680,22 @@ internal object RelConverter { private val aggregates = setOf("count", "avg", "sum", "min", "max", "any", "some", "every") private data class Context( - val aggregations: MutableList, + val aggregations: MutableList, val keys: List ) - fun apply(node: QueryBody.SFW): Pair> { - val aggs = mutableListOf() + fun apply(node: QueryBody.SFW): Pair> { + val aggs = mutableListOf() val keys = node.groupBy?.keys ?: emptyList() val context = Context(aggs, keys) val select = super.visitQueryBodySFW(node, context) as QueryBody.SFW return Pair(select, aggs) } - override fun visitSelectValue(node: Select.Value, ctx: Context): AstNode { + override fun visitSelectValue(node: SelectValue, ctx: Context): AstNode { val visited = super.visitSelectValue(node, ctx) val substitutions = ctx.keys.associate { - it.expr to exprVar(identifierSymbol(it.asAlias!!.symbol, Identifier.CaseSensitivity.SENSITIVE), Expr.Var.Scope.DEFAULT) + it.expr to exprVarRef(identifierChain(identifier(it.asAlias!!.symbol, isDelimited = true), next = null), Scope.DEFAULT()) } return SubstitutionVisitor.visit(visited, substitutions) } @@ -661,32 +703,44 @@ internal object RelConverter { // only rewrite top-level SFW override fun visitQueryBodySFW(node: QueryBody.SFW, ctx: Context): AstNode = node - override fun visitExprCall(node: Expr.Call, ctx: Context) = ast { + override fun visitExprCall(node: ExprCall, ctx: Context) = // TODO replace w/ proper function resolution to determine whether a function call is a scalar or aggregate. // may require further modification of SPI interfaces to support when (node.function.isAggregateCall()) { true -> { - val id = identifierSymbol { - symbol = syntheticAgg(ctx.aggregations.size) - caseSensitivity = org.partiql.ast.Identifier.CaseSensitivity.INSENSITIVE - } + val id = identifierChain( + identifier( + symbol = syntheticAgg(ctx.aggregations.size), + isDelimited = false + ), + next = null + ) ctx.aggregations += node - exprVar(id, Expr.Var.Scope.DEFAULT) + exprVarRef(id, Scope.DEFAULT()) } else -> node } - } private fun String.isAggregateCall(): Boolean { return aggregates.contains(this) } - private fun Identifier.isAggregateCall(): Boolean { - return when (this) { - is Identifier.Symbol -> this.symbol.lowercase().isAggregateCall() - is Identifier.Qualified -> this.steps.last().symbol.lowercase().isAggregateCall() + private fun IdentifierChain.isAggregateCall(): Boolean { + return when (next) { + null -> root.symbol.lowercase().isAggregateCall() + else -> { + var curId = next + var last = curId + while (curId != null) { + last = curId + curId = curId.next + } + last!!.root.symbol.lowercase().isAggregateCall() + } } } + + override fun defaultReturn(node: AstNode, ctx: Context) = node } private fun syntheticAgg(i: Int) = "\$agg_$i" diff --git a/partiql-planner/src/main/kotlin/org/partiql/planner/internal/transforms/RexConverter.kt b/partiql-planner/src/main/kotlin/org/partiql/planner/internal/transforms/RexConverter.kt index b458587a54..dc656359fa 100644 --- a/partiql-planner/src/main/kotlin/org/partiql/planner/internal/transforms/RexConverter.kt +++ b/partiql-planner/src/main/kotlin/org/partiql/planner/internal/transforms/RexConverter.kt @@ -17,14 +17,43 @@ package org.partiql.planner.internal.transforms import com.amazon.ionelement.api.loadSingleElement -import org.partiql.ast.AstNode -import org.partiql.ast.DatetimeField -import org.partiql.ast.Expr -import org.partiql.ast.QueryBody -import org.partiql.ast.Select -import org.partiql.ast.SetQuantifier -import org.partiql.ast.Type -import org.partiql.ast.visitor.AstBaseVisitor +import org.partiql.ast.v1.AstNode +import org.partiql.ast.v1.AstVisitor +import org.partiql.ast.v1.DataType +import org.partiql.ast.v1.QueryBody +import org.partiql.ast.v1.SelectList +import org.partiql.ast.v1.SelectStar +import org.partiql.ast.v1.expr.Expr +import org.partiql.ast.v1.expr.ExprAnd +import org.partiql.ast.v1.expr.ExprArray +import org.partiql.ast.v1.expr.ExprBag +import org.partiql.ast.v1.expr.ExprBetween +import org.partiql.ast.v1.expr.ExprCall +import org.partiql.ast.v1.expr.ExprCase +import org.partiql.ast.v1.expr.ExprCast +import org.partiql.ast.v1.expr.ExprCoalesce +import org.partiql.ast.v1.expr.ExprExtract +import org.partiql.ast.v1.expr.ExprInCollection +import org.partiql.ast.v1.expr.ExprIsType +import org.partiql.ast.v1.expr.ExprLike +import org.partiql.ast.v1.expr.ExprLit +import org.partiql.ast.v1.expr.ExprNot +import org.partiql.ast.v1.expr.ExprNullIf +import org.partiql.ast.v1.expr.ExprOperator +import org.partiql.ast.v1.expr.ExprOr +import org.partiql.ast.v1.expr.ExprOverlay +import org.partiql.ast.v1.expr.ExprPath +import org.partiql.ast.v1.expr.ExprPosition +import org.partiql.ast.v1.expr.ExprQuerySet +import org.partiql.ast.v1.expr.ExprSessionAttribute +import org.partiql.ast.v1.expr.ExprStruct +import org.partiql.ast.v1.expr.ExprSubstring +import org.partiql.ast.v1.expr.ExprTrim +import org.partiql.ast.v1.expr.ExprVarRef +import org.partiql.ast.v1.expr.ExprVariant +import org.partiql.ast.v1.expr.PathStep +import org.partiql.ast.v1.expr.Scope +import org.partiql.ast.v1.expr.TrimSpec import org.partiql.errors.TypeCheckException import org.partiql.planner.internal.Env import org.partiql.planner.internal.ir.Rel @@ -66,7 +95,7 @@ import org.partiql.value.int64Value import org.partiql.value.io.PartiQLValueIonReaderBuilder import org.partiql.value.nullValue import org.partiql.value.stringValue -import org.partiql.ast.Identifier as AstIdentifier +import org.partiql.ast.v1.SetQuantifier as AstSetQuantifier /** * Converts an AST expression node to a Plan Rex node; ignoring any typing. @@ -79,7 +108,7 @@ internal object RexConverter { @OptIn(PartiQLValueExperimental::class) @Suppress("PARAMETER_NAME_CHANGED_ON_OVERRIDE") - private object ToRex : AstBaseVisitor() { + private object ToRex : AstVisitor() { private val COLL_AGG_NAMES = setOf( "coll_any", @@ -95,7 +124,7 @@ internal object RexConverter { override fun defaultReturn(node: AstNode, context: Env): Rex = throw IllegalArgumentException("unsupported rex $node") - override fun visitExprLit(node: Expr.Lit, context: Env): Rex { + override fun visitExprLit(node: ExprLit, context: Env): Rex { val type = CompilerType( _delegate = node.value.type.toPType(), isNullValue = node.value.isNull, @@ -108,7 +137,7 @@ internal object RexConverter { /** * TODO PartiQLValue will be replaced by Datum (i.e. IonDatum) is a subsequent PR. */ - override fun visitExprVariant(node: Expr.Variant, ctx: Env): Rex { + override fun visitExprVariant(node: ExprVariant, ctx: Env): Rex { if (node.encoding != "ion") { throw IllegalArgumentException("unsupported encoding ${node.encoding}") } @@ -136,7 +165,7 @@ internal object RexConverter { * @return */ internal fun visitExprCoerce(node: Expr, ctx: Env, coercion: Rex.Op.Subquery.Coercion = Rex.Op.Subquery.Coercion.SCALAR): Rex { - val rex = super.visitExpr(node, ctx) + val rex = node.accept(this, ctx) return when (isSqlSelect(node)) { true -> { val select = rex.op as Rex.Op.Select @@ -153,12 +182,13 @@ internal object RexConverter { } } - override fun visitExprVar(node: Expr.Var, context: Env): Rex { + override fun visitExprVarRef(node: ExprVarRef, context: Env): Rex { val type = (ANY) - val identifier = AstToPlan.convert(node.identifier) - val scope = when (node.scope) { - Expr.Var.Scope.DEFAULT -> Rex.Op.Var.Scope.DEFAULT - Expr.Var.Scope.LOCAL -> Rex.Op.Var.Scope.LOCAL + val identifier = AstToPlan.convert(node.identifierChain) + val scope = when (node.scope.code()) { + Scope.DEFAULT -> Rex.Op.Var.Scope.DEFAULT + Scope.LOCAL -> Rex.Op.Var.Scope.LOCAL + else -> error("Unexpected Scope type: ${node.scope}") } val op = rexOpVarUnresolved(identifier, scope) return rex(type, op) @@ -245,7 +275,7 @@ internal object RexConverter { } } - override fun visitExprOperator(node: Expr.Operator, ctx: Env): Rex { + override fun visitExprOperator(node: ExprOperator, ctx: Env): Rex { val lhs = node.lhs return if (lhs != null) { resolveBinaryOp(lhs, node.symbol, node.rhs, ctx) @@ -254,7 +284,7 @@ internal object RexConverter { } } - override fun visitExprNot(node: Expr.Not, ctx: Env): Rex { + override fun visitExprNot(node: ExprNot, ctx: Env): Rex { val type = (ANY) // Args val arg = visitExprCoerce(node.value, ctx) @@ -265,7 +295,7 @@ internal object RexConverter { return rex(type, op) } - override fun visitExprAnd(node: Expr.And, ctx: Env): Rex { + override fun visitExprAnd(node: ExprAnd, ctx: Env): Rex { val type = (ANY) val l = visitExprCoerce(node.lhs, ctx) val r = visitExprCoerce(node.rhs, ctx) @@ -277,7 +307,7 @@ internal object RexConverter { return rex(type, op) } - override fun visitExprOr(node: Expr.Or, ctx: Env): Rex { + override fun visitExprOr(node: ExprOr, ctx: Env): Rex { val type = (ANY) val l = visitExprCoerce(node.lhs, ctx) val r = visitExprCoerce(node.rhs, ctx) @@ -289,43 +319,47 @@ internal object RexConverter { return rex(type, op) } - private fun isLiteralArray(node: Expr): Boolean = node is Expr.Collection && (node.type == Expr.Collection.Type.ARRAY || node.type == Expr.Collection.Type.LIST) + private fun isLiteralArray(node: Expr): Boolean = node is ExprArray private fun isSqlSelect(node: Expr): Boolean { - return if (node is Expr.QuerySet) { + return if (node is ExprQuerySet) { val body = node.body - body is QueryBody.SFW && (body.select is Select.Project || body.select is Select.Star) + body is QueryBody.SFW && (body.select is SelectList || body.select is SelectStar) } else { false } } - override fun visitExprPath(node: Expr.Path, context: Env): Rex { + override fun visitExprPath(node: ExprPath, context: Env): Rex { // Args val root = visitExprCoerce(node.root, context) // Attempt to create qualified identifier - val (newRoot, newSteps) = when (val op = root.op) { + val (newRoot, nextStep) = when (val op = root.op) { is Rex.Op.Var.Unresolved -> { // convert consecutive symbol path steps to the root identifier var i = 0 val parts = mutableListOf() parts.addAll(op.identifier.getParts()) - for (step in node.steps) { - if (step !is Expr.Path.Step.Symbol) { + var curStep = node.next + while (curStep != null) { + if (curStep !is PathStep.Field) { break } - parts.add(AstToPlan.part(step.symbol)) + parts.add(AstToPlan.part(curStep.field)) i += 1 + curStep = curStep.next } val newRoot = rex(ANY, rexOpVarUnresolved(Identifier.of(parts), op.scope)) - val newSteps = node.steps.subList(i, node.steps.size) + val newSteps = curStep newRoot to newSteps } - else -> root to node.steps + else -> { + root to node.next + } } - if (newSteps.isEmpty()) { + if (nextStep == null) { return newRoot } @@ -333,35 +367,35 @@ internal object RexConverter { var varRefIndex = 0 // tracking var ref index - val pathNavi = newSteps.fold(newRoot) { current, step -> - val path = when (step) { - is Expr.Path.Step.Index -> { - val key = visitExprCoerce(step.key, context) - val op = when (val astKey = step.key) { - is Expr.Lit -> when (astKey.value) { - is StringValue -> rexOpPathKey(current, key) - else -> rexOpPathIndex(current, key) + var curStep = nextStep + var curPathNavi = newRoot + while (curStep != null) { + val path = when (curStep) { + is PathStep.Element -> { + val key = visitExprCoerce(curStep.element, context) + val op = when (val astKey = curStep.element) { + is ExprLit -> when (astKey.value) { + is StringValue -> rexOpPathKey(curPathNavi, key) + else -> rexOpPathIndex(curPathNavi, key) } - - is Expr.Cast -> when (astKey.asType is Type.String) { - true -> rexOpPathKey(current, key) - false -> rexOpPathIndex(current, key) + is ExprCast -> when (astKey.asType.code() == DataType.STRING) { + true -> rexOpPathKey(curPathNavi, key) + false -> rexOpPathIndex(curPathNavi, key) } - - else -> rexOpPathIndex(current, key) + else -> rexOpPathIndex(curPathNavi, key) } op } - is Expr.Path.Step.Symbol -> { - when (step.symbol.caseSensitivity) { - AstIdentifier.CaseSensitivity.SENSITIVE -> { + is PathStep.Field -> { + when (curStep.field.isDelimited) { + true -> { // case-sensitive path step becomes a key lookup - rexOpPathKey(current, rexString(step.symbol.symbol)) + rexOpPathKey(curPathNavi, rexString(curStep.field.symbol)) } - AstIdentifier.CaseSensitivity.INSENSITIVE -> { + false -> { // case-insensitive path step becomes a symbol lookup - rexOpPathSymbol(current, step.symbol.symbol) + rexOpPathSymbol(curPathNavi, curStep.field.symbol) } } } @@ -391,28 +425,29 @@ internal object RexConverter { // each join will produce its own schema and pass the schema as a type Env. // The (k_i) indicate the possible key binding produced by unpivot. // We calculate the var ref on the fly. - is Expr.Path.Step.Unpivot -> { + is PathStep.AllFields -> { // Unpivot produces two binding, in this context we want the value, // which always going to be the second binding val op = rexOpVarLocal(1, varRefIndex + 1) varRefIndex += 2 val index = fromList.size - fromList.add(relFromUnpivot(current, index)) + fromList.add(relFromUnpivot(curPathNavi, index)) op } - is Expr.Path.Step.Wildcard -> { + is PathStep.AllElements -> { // Scan produce only one binding val op = rexOpVarLocal(1, varRefIndex) varRefIndex += 1 val index = fromList.size - fromList.add(relFromDefault(current, index)) + fromList.add(relFromDefault(curPathNavi, index)) op } + else -> error("Unexpected PathStep type: $curStep") } - rex(ANY, path) + curStep = curStep.next + curPathNavi = rex(ANY, path) } - - if (fromList.size == 0) return pathNavi + if (fromList.size == 0) return curPathNavi val fromNode = fromList.reduce { acc, scan -> val schema = acc.type.schema + scan.type.schema val props = emptySet() @@ -424,11 +459,11 @@ internal object RexConverter { // always going to be the last binding val selectRef = fromNode.type.schema.size - 1 - val constructor = when (val op = pathNavi.op) { - is Rex.Op.Path.Index -> rex(pathNavi.type, rexOpPathIndex(rex(op.root.type, rexOpVarLocal(0, selectRef)), op.key)) - is Rex.Op.Path.Key -> rex(pathNavi.type, rexOpPathKey(rex(op.root.type, rexOpVarLocal(0, selectRef)), op.key)) - is Rex.Op.Path.Symbol -> rex(pathNavi.type, rexOpPathSymbol(rex(op.root.type, rexOpVarLocal(0, selectRef)), op.key)) - is Rex.Op.Var.Local -> rex(pathNavi.type, rexOpVarLocal(0, selectRef)) + val constructor = when (val op = curPathNavi.op) { + is Rex.Op.Path.Index -> rex(curPathNavi.type, rexOpPathIndex(rex(op.root.type, rexOpVarLocal(0, selectRef)), op.key)) + is Rex.Op.Path.Key -> rex(curPathNavi.type, rexOpPathKey(rex(op.root.type, rexOpVarLocal(0, selectRef)), op.key)) + is Rex.Op.Path.Symbol -> rex(curPathNavi.type, rexOpPathSymbol(rex(op.root.type, rexOpVarLocal(0, selectRef)), op.key)) + is Rex.Op.Var.Local -> rex(curPathNavi.type, rexOpVarLocal(0, selectRef)) else -> throw IllegalStateException() } val op = rexOpSelect(constructor, fromNode) @@ -475,7 +510,7 @@ internal object RexConverter { private fun rexString(str: String) = rex(STRING, rexOpLit(stringValue(str))) - override fun visitExprCall(node: Expr.Call, context: Env): Rex { + override fun visitExprCall(node: ExprCall, context: Env): Rex { val type = (ANY) // Fn val id = AstToPlan.convert(node.function) @@ -506,8 +541,14 @@ internal object RexConverter { /** * @return whether call is `COLL_`. */ - private fun isCollAgg(node: Expr.Call): Boolean { - val id = node.function as? org.partiql.ast.Identifier.Symbol ?: return false + private fun isCollAgg(node: ExprCall): Boolean { + val fn = node.function + val id = if (fn.next == null) { + // is not a qualified identifier chain + node.function.root + } else { + return false + } return COLL_AGG_NAMES.contains(id.symbol.lowercase()) } @@ -519,24 +560,25 @@ internal object RexConverter { * * It is assumed that the [id] has already been vetted by [isCollAgg]. */ - private fun callToCollAgg(id: Identifier, setQuantifier: SetQuantifier?, args: List): Rex { + private fun callToCollAgg(id: Identifier, setQuantifier: AstSetQuantifier?, args: List): Rex { if (id.hasQualifier()) { error("Qualified function calls are not currently supported.") } if (args.size != 1) { error("Aggregate calls currently only support single arguments. Received ${args.size} arguments.") } - val postfix = when (setQuantifier) { - SetQuantifier.DISTINCT -> "_distinct" - SetQuantifier.ALL -> "_all" + val postfix = when (setQuantifier?.code()) { + AstSetQuantifier.DISTINCT -> "_distinct" + AstSetQuantifier.ALL -> "_all" null -> "_all" + else -> error("Unexpected SetQuantifier type: $setQuantifier") } val newId = Identifier.regular(id.getIdentifier().getText() + postfix) val op = Rex.Op.Call.Unresolved(newId, listOf(args[0])) return Rex(ANY, op) } - private fun visitExprCallTupleUnion(node: Expr.Call, context: Env): Rex { + private fun visitExprCallTupleUnion(node: ExprCall, context: Env): Rex { val type = (STRUCT) val args = node.args.map { visitExprCoerce(it, context) }.toMutableList() val op = rexOpTupleUnion(args) @@ -547,7 +589,7 @@ internal object RexConverter { * Assume that the node's identifier refers to EXISTS. * TODO: This could be better suited as a dedicated node in the future. */ - private fun visitExprCallExists(node: Expr.Call, context: Env): Rex { + private fun visitExprCallExists(node: ExprCall, context: Env): Rex { val type = (BOOL) if (node.args.size != 1) { error("EXISTS requires a single argument.") @@ -557,7 +599,7 @@ internal object RexConverter { return rex(type, op) } - override fun visitExprCase(node: Expr.Case, context: Env) = plan { + override fun visitExprCase(node: ExprCase, context: Env) = plan { val type = (ANY) val rex = when (node.expr) { null -> null @@ -580,7 +622,7 @@ internal object RexConverter { createBranch(branchCondition, branchRex) }.toMutableList() - val defaultRex = when (val default = node.default) { + val defaultRex = when (val default = node.defaultExpr) { null -> rex(type = ANY, op = rexOpLit(value = nullValue())) else -> visitExprCoerce(default, context) } @@ -588,20 +630,19 @@ internal object RexConverter { rex(type, op) } - override fun visitExprCollection(node: Expr.Collection, context: Env): Rex { - val type = when (node.type) { - Expr.Collection.Type.BAG -> BAG - Expr.Collection.Type.ARRAY -> LIST - Expr.Collection.Type.VALUES -> LIST - Expr.Collection.Type.LIST -> LIST - Expr.Collection.Type.SEXP -> SEXP - } - val values = node.values.map { visitExprCoerce(it, context) } + override fun visitExprArray(node: ExprArray, ctx: Env): Rex { + val values = node.values.map { visitExprCoerce(it, ctx) } val op = rexOpCollection(values) - return rex(type, op) + return rex(LIST, op) + } + + override fun visitExprBag(node: ExprBag, ctx: Env): Rex { + val values = node.values.map { visitExprCoerce(it, ctx) } + val op = rexOpCollection(values) + return rex(BAG, op) } - override fun visitExprStruct(node: Expr.Struct, context: Env): Rex { + override fun visitExprStruct(node: ExprStruct, context: Env): Rex { val type = (STRUCT) val fields = node.fields.map { val k = visitExprCoerce(it.name, context) @@ -617,7 +658,7 @@ internal object RexConverter { /** * NOT? LIKE ( ESCAPE )? */ - override fun visitExprLike(node: Expr.Like, ctx: Env): Rex { + override fun visitExprLike(node: ExprLike, ctx: Env): Rex { val type = BOOL // Args val arg0 = visitExprCoerce(node.value, ctx) @@ -638,7 +679,7 @@ internal object RexConverter { /** * NOT? BETWEEN AND */ - override fun visitExprBetween(node: Expr.Between, ctx: Env): Rex = plan { + override fun visitExprBetween(node: ExprBetween, ctx: Env): Rex = plan { val type = BOOL // Args val arg0 = visitExprCoerce(node.value, ctx) @@ -665,11 +706,11 @@ internal object RexConverter { * Otherwise, T in C is unknown. * */ - override fun visitExprInCollection(node: Expr.InCollection, ctx: Env): Rex { + override fun visitExprInCollection(node: ExprInCollection, ctx: Env): Rex { val type = BOOL // Args val arg0 = visitExprCoerce(node.lhs, ctx) - val arg1 = visitExpr(node.rhs, ctx) // !! don't insert scalar subquery coercions + val arg1 = node.rhs.accept(this, ctx) // !! don't insert scalar subquery coercions // Call var call = call("in_collection", arg0, arg1) @@ -683,49 +724,61 @@ internal object RexConverter { /** * IS ? */ - override fun visitExprIsType(node: Expr.IsType, ctx: Env): Rex { + override fun visitExprIsType(node: ExprIsType, ctx: Env): Rex { val type = BOOL // arg val arg0 = visitExprCoerce(node.value, ctx) - - var call = when (val targetType = node.type) { - is Type.NullType -> call("is_null", arg0) - is Type.Missing -> call("is_missing", arg0) - is Type.Bool -> call("is_bool", arg0) - is Type.Tinyint -> call("is_int8", arg0) - is Type.Smallint, is Type.Int2 -> call("is_int16", arg0) - is Type.Int4 -> call("is_int32", arg0) - is Type.Bigint, is Type.Int8 -> call("is_int64", arg0) - is Type.Int -> call("is_int", arg0) - is Type.Real -> call("is_real", arg0) - is Type.Float32 -> call("is_float32", arg0) - is Type.Float64 -> call("is_float64", arg0) - is Type.Decimal -> call("is_decimal", targetType.precision.toRex(), targetType.scale.toRex(), arg0) - is Type.Numeric -> call("is_numeric", targetType.precision.toRex(), targetType.scale.toRex(), arg0) - is Type.Char -> call("is_char", targetType.length.toRex(), arg0) - is Type.Varchar -> call("is_varchar", targetType.length.toRex(), arg0) - is Type.String -> call("is_string", targetType.length.toRex(), arg0) - is Type.Symbol -> call("is_symbol", arg0) - is Type.Bit -> call("is_bit", arg0) - is Type.BitVarying -> call("is_bitVarying", arg0) - is Type.ByteString -> call("is_byteString", arg0) - is Type.Blob -> call("is_blob", arg0) - is Type.Clob -> call("is_clob", arg0) - is Type.Date -> call("is_date", arg0) - is Type.Time -> call("is_time", arg0) + val targetType = node.type + var call = when (targetType.code()) { + // + DataType.NULL -> call("is_null", arg0) + DataType.MISSING -> call("is_missing", arg0) + // + // TODO CHAR_VARYING, CHARACTER_LARGE_OBJECT, CHAR_LARGE_OBJECT + DataType.CHARACTER, DataType.CHAR -> call("is_char", targetType.length.toRex(), arg0) + DataType.CHARACTER_VARYING, DataType.VARCHAR -> call("is_varchar", targetType.length.toRex(), arg0) + DataType.CLOB -> call("is_clob", arg0) + DataType.STRING -> call("is_string", targetType.length.toRex(), arg0) + DataType.SYMBOL -> call("is_symbol", arg0) + // + // TODO BINARY_LARGE_OBJECT + DataType.BLOB -> call("is_blob", arg0) + // + DataType.BIT -> call("is_bit", arg0) // TODO define in parser + DataType.BIT_VARYING -> call("is_bitVarying", arg0) // TODO define in parser + // - + DataType.NUMERIC -> call("is_numeric", targetType.precision.toRex(), targetType.scale.toRex(), arg0) + DataType.DEC, DataType.DECIMAL -> call("is_decimal", targetType.precision.toRex(), targetType.scale.toRex(), arg0) + DataType.BIGINT, DataType.INT8, DataType.INTEGER8 -> call("is_int64", arg0) + DataType.INT4, DataType.INTEGER4, DataType.INTEGER -> call("is_int32", arg0) + DataType.INT -> call("is_int", arg0) + DataType.INT2, DataType.SMALLINT -> call("is_int16", arg0) + DataType.TINYINT -> call("is_int8", arg0) // TODO define in parser + // - + DataType.FLOAT -> call("is_float32", arg0) + DataType.REAL -> call("is_real", arg0) + DataType.DOUBLE_PRECISION -> call("is_float64", arg0) + // + DataType.BOOLEAN, DataType.BOOL -> call("is_bool", arg0) + // + DataType.DATE -> call("is_date", arg0) // TODO: DO we want to seperate with time zone vs without time zone into two different type in the plan? // leave the parameterized type out for now until the above is answered - is Type.TimeWithTz -> call("is_timeWithTz", arg0) - is Type.Timestamp -> call("is_timestamp", arg0) - is Type.TimestampWithTz -> call("is_timestampWithTz", arg0) - is Type.Interval -> call("is_interval", arg0) - is Type.Bag -> call("is_bag", arg0) - is Type.List -> call("is_list", arg0) - is Type.Sexp -> call("is_sexp", arg0) - is Type.Tuple -> call("is_struct", arg0) - is Type.Struct -> call("is_struct", arg0) - is Type.Any -> call("is_any", arg0) - is Type.Custom -> call("is_custom", arg0) + DataType.TIME -> call("is_time", arg0) + DataType.TIME_WITH_TIME_ZONE -> call("is_timeWithTz", arg0) + DataType.TIMESTAMP -> call("is_timestamp", arg0) + DataType.TIMESTAMP_WITH_TIME_ZONE -> call("is_timestampWithTz", arg0) + // + DataType.INTERVAL -> call("is_interval", arg0) // TODO define in parser + // + DataType.STRUCT, DataType.TUPLE -> call("is_struct", arg0) + // + DataType.LIST -> call("is_list", arg0) + DataType.BAG -> call("is_bag", arg0) + DataType.SEXP -> call("is_sexp", arg0) + // + DataType.USER_DEFINED -> call("is_custom", arg0) + else -> error("Unexpected DataType type: $targetType") } if (node.not == true) { @@ -735,7 +788,7 @@ internal object RexConverter { return rex(type, call) } - override fun visitExprCoalesce(node: Expr.Coalesce, ctx: Env): Rex { + override fun visitExprCoalesce(node: ExprCoalesce, ctx: Env): Rex { val type = ANY val args = node.args.map { arg -> visitExprCoerce(arg, ctx) @@ -744,18 +797,18 @@ internal object RexConverter { return rex(type, op) } - override fun visitExprNullIf(node: Expr.NullIf, ctx: Env): Rex { + override fun visitExprNullIf(node: ExprNullIf, ctx: Env): Rex { val type = ANY - val value = visitExprCoerce(node.value, ctx) - val nullifier = visitExprCoerce(node.nullifier, ctx) - val op = rexOpNullif(value, nullifier) + val v1 = visitExprCoerce(node.v1, ctx) + val v2 = visitExprCoerce(node.v2, ctx) + val op = rexOpNullif(v1, v2) return rex(type, op) } /** * SUBSTRING( (FROM (FOR )?)? ) */ - override fun visitExprSubstring(node: Expr.Substring, ctx: Env): Rex { + override fun visitExprSubstring(node: ExprSubstring, ctx: Env): Rex { val type = ANY // Args val arg0 = visitExprCoerce(node.value, ctx) @@ -772,7 +825,7 @@ internal object RexConverter { /** * POSITION( IN ) */ - override fun visitExprPosition(node: Expr.Position, ctx: Env): Rex { + override fun visitExprPosition(node: ExprPosition, ctx: Env): Rex { val type = ANY // Args val arg0 = visitExprCoerce(node.lhs, ctx) @@ -785,18 +838,18 @@ internal object RexConverter { /** * TRIM([LEADING|TRAILING|BOTH]? ( FROM)? ) */ - override fun visitExprTrim(node: Expr.Trim, ctx: Env): Rex { + override fun visitExprTrim(node: ExprTrim, ctx: Env): Rex { val type = STRING // Args val arg0 = visitExprCoerce(node.value, ctx) val arg1 = node.chars?.let { visitExprCoerce(it, ctx) } // Call Variants - val call = when (node.spec) { - Expr.Trim.Spec.LEADING -> when (arg1) { + val call = when (node.trimSpec?.code()) { + TrimSpec.LEADING -> when (arg1) { null -> call("trim_leading", arg0) else -> call("trim_leading_chars", arg0, arg1) } - Expr.Trim.Spec.TRAILING -> when (arg1) { + TrimSpec.TRAILING -> when (arg1) { null -> call("trim_trailing", arg0) else -> call("trim_trailing_chars", arg0, arg1) } @@ -827,11 +880,11 @@ internal object RexConverter { * RS is the second , * SL is the if specified, otherwise it is char_length(RS). */ - override fun visitExprOverlay(node: Expr.Overlay, ctx: Env): Rex { + override fun visitExprOverlay(node: ExprOverlay, ctx: Env): Rex { val cv = visitExprCoerce(node.value, ctx) - val sp = visitExprCoerce(node.start, ctx) - val rs = visitExprCoerce(node.overlay, ctx) - val sl = node.length?.let { visitExprCoerce(it, ctx) } ?: rex(ANY, call("char_length", rs)) + val sp = visitExprCoerce(node.from, ctx) + val rs = visitExprCoerce(node.placing, ctx) + val sl = node.forLength?.let { visitExprCoerce(it, ctx) } ?: rex(ANY, call("char_length", rs)) val p1 = rex( ANY, call( @@ -852,98 +905,110 @@ internal object RexConverter { ) } - override fun visitExprExtract(node: Expr.Extract, ctx: Env): Rex { - val call = call("extract_${node.field.name.lowercase()}", visitExprCoerce(node.source, ctx)) + override fun visitExprExtract(node: ExprExtract, ctx: Env): Rex { + val call = call("extract_${node.field.name().lowercase()}", visitExprCoerce(node.source, ctx)) return rex(ANY, call) } - override fun visitExprCast(node: Expr.Cast, ctx: Env): Rex { + override fun visitExprCast(node: ExprCast, ctx: Env): Rex { val type = visitType(node.asType) val arg = visitExprCoerce(node.value, ctx) return rex(ANY, rexOpCastUnresolved(type, arg)) } - private fun visitType(type: Type): CompilerType { - return when (type) { - is Type.NullType -> error("Casting to NULL is not supported.") - is Type.Missing -> error("Casting to MISSING is not supported.") - is Type.Bool -> PType.bool() - is Type.Tinyint -> PType.tinyint() - is Type.Smallint, is Type.Int2 -> PType.smallint() - is Type.Int4 -> PType.integer() - is Type.Bigint, is Type.Int8 -> PType.bigint() - is Type.Int -> PType.numeric() - is Type.Real -> PType.real() - is Type.Float32 -> PType.real() - is Type.Float64 -> PType.doublePrecision() - is Type.Decimal -> { + private fun visitType(type: DataType): CompilerType { + return when (type.code()) { + // + DataType.NULL -> error("Casting to NULL is not supported.") + DataType.MISSING -> error("Casting to MISSING is not supported.") + // + // TODO CHAR_VARYING, CHARACTER_LARGE_OBJECT, CHAR_LARGE_OBJECT + DataType.CHARACTER, DataType.CHAR -> { + val length = type.length ?: 1 + assertGtZeroAndCreate(PType.Kind.CHAR, "length", length, PType::character) + } + DataType.CHARACTER_VARYING, DataType.VARCHAR -> { + val length = type.length ?: 1 + assertGtZeroAndCreate(PType.Kind.VARCHAR, "length", length, PType::varchar) + } + DataType.CLOB -> assertGtZeroAndCreate(PType.Kind.CLOB, "length", type.length ?: Int.MAX_VALUE, PType::clob) + DataType.STRING -> PType.string() + DataType.SYMBOL -> PType.symbol() + // + // TODO BINARY_LARGE_OBJECT + DataType.BLOB -> assertGtZeroAndCreate(PType.Kind.BLOB, "length", type.length ?: Int.MAX_VALUE, PType::blob) + // + DataType.BIT -> error("BIT is not supported yet.") + DataType.BIT_VARYING -> error("BIT VARYING is not supported yet.") + // - + DataType.NUMERIC -> { val p = type.precision val s = type.scale when { p == null && s == null -> PType.decimal() p != null && s != null -> { - assertParamCompToZero(PType.Kind.DECIMAL, "precision", p, false) - assertParamCompToZero(PType.Kind.DECIMAL, "scale", s, true) + assertParamCompToZero(PType.Kind.NUMERIC, "precision", p, false) + assertParamCompToZero(PType.Kind.NUMERIC, "scale", s, true) if (s > p) { - throw TypeCheckException("Decimal scale cannot be greater than precision.") + throw TypeCheckException("Numeric scale cannot be greater than precision.") } - PType.decimal(p, s) + PType.decimal(type.precision!!, type.scale!!) } p != null && s == null -> { - assertParamCompToZero(PType.Kind.DECIMAL, "precision", p, false) + assertParamCompToZero(PType.Kind.NUMERIC, "precision", p, false) PType.decimal(p, 0) } else -> error("Precision can never be null while scale is specified.") } } - is Type.Numeric -> { + DataType.DEC, DataType.DECIMAL -> { val p = type.precision val s = type.scale when { p == null && s == null -> PType.decimal() p != null && s != null -> { - assertParamCompToZero(PType.Kind.NUMERIC, "precision", p, false) - assertParamCompToZero(PType.Kind.NUMERIC, "scale", s, true) + assertParamCompToZero(PType.Kind.DECIMAL, "precision", p, false) + assertParamCompToZero(PType.Kind.DECIMAL, "scale", s, true) if (s > p) { - throw TypeCheckException("Numeric scale cannot be greater than precision.") + throw TypeCheckException("Decimal scale cannot be greater than precision.") } - PType.decimal(type.precision!!, type.scale!!) + PType.decimal(p, s) } p != null && s == null -> { - assertParamCompToZero(PType.Kind.NUMERIC, "precision", p, false) + assertParamCompToZero(PType.Kind.DECIMAL, "precision", p, false) PType.decimal(p, 0) } else -> error("Precision can never be null while scale is specified.") } } - is Type.Char -> { - val length = type.length ?: 1 - assertGtZeroAndCreate(PType.Kind.CHAR, "length", length, PType::character) - } - is Type.Varchar -> { - val length = type.length ?: 1 - assertGtZeroAndCreate(PType.Kind.VARCHAR, "length", length, PType::varchar) - } - is Type.String -> PType.string() - is Type.Symbol -> PType.symbol() - is Type.Bit -> error("BIT is not supported yet.") - is Type.BitVarying -> error("BIT VARYING is not supported yet.") - is Type.ByteString -> error("BINARY is not supported yet.") - is Type.Blob -> assertGtZeroAndCreate(PType.Kind.BLOB, "length", type.length ?: Int.MAX_VALUE, PType::blob) - is Type.Clob -> assertGtZeroAndCreate(PType.Kind.CLOB, "length", type.length ?: Int.MAX_VALUE, PType::clob) - is Type.Date -> PType.date() - is Type.Time -> assertGtEqZeroAndCreate(PType.Kind.TIME, "precision", type.precision ?: 0, PType::time) - is Type.TimeWithTz -> assertGtEqZeroAndCreate(PType.Kind.TIMEZ, "precision", type.precision ?: 0, PType::timez) - is Type.Timestamp -> assertGtEqZeroAndCreate(PType.Kind.TIMESTAMP, "precision", type.precision ?: 6, PType::timestamp) - is Type.TimestampWithTz -> assertGtEqZeroAndCreate(PType.Kind.TIMESTAMPZ, "precision", type.precision ?: 6, PType::timestampz) - is Type.Interval -> error("INTERVAL is not supported yet.") - is Type.Bag -> PType.bag() - is Type.Sexp -> PType.sexp() - is Type.Any -> PType.dynamic() - is Type.Custom -> TODO("Custom type not supported ") - is Type.List -> PType.array() - is Type.Tuple -> PType.struct() - is Type.Struct -> PType.struct() + DataType.BIGINT, DataType.INT8, DataType.INTEGER8 -> PType.bigint() + DataType.INT4, DataType.INTEGER4, DataType.INTEGER, DataType.INT -> PType.integer() + DataType.INT2, DataType.SMALLINT -> PType.smallint() + DataType.TINYINT -> PType.tinyint() // TODO define in parser + // - + DataType.FLOAT -> PType.real() + DataType.REAL -> PType.real() + DataType.DOUBLE_PRECISION -> PType.doublePrecision() + // + DataType.BOOL -> PType.bool() + // + DataType.DATE -> PType.date() + DataType.TIME -> assertGtEqZeroAndCreate(PType.Kind.TIME, "precision", type.precision ?: 0, PType::time) + DataType.TIME_WITH_TIME_ZONE -> assertGtEqZeroAndCreate(PType.Kind.TIMEZ, "precision", type.precision ?: 0, PType::timez) + DataType.TIMESTAMP -> assertGtEqZeroAndCreate(PType.Kind.TIMESTAMP, "precision", type.precision ?: 6, PType::timestamp) + DataType.TIMESTAMP_WITH_TIME_ZONE -> assertGtEqZeroAndCreate(PType.Kind.TIMESTAMPZ, "precision", type.precision ?: 6, PType::timestampz) + // + DataType.INTERVAL -> error("INTERVAL is not supported yet.") + // + DataType.STRUCT -> PType.struct() + DataType.TUPLE -> PType.struct() + // + DataType.LIST -> PType.array() + DataType.BAG -> PType.bag() + DataType.SEXP -> PType.sexp() + // + DataType.USER_DEFINED -> TODO("Custom type not supported ") + else -> error("Unsupported DataType type: $type") }.toCType() } @@ -970,42 +1035,14 @@ internal object RexConverter { } } - override fun visitExprDateAdd(node: Expr.DateAdd, ctx: Env): Rex { - val type = TIMESTAMP - // Args - val arg0 = visitExprCoerce(node.lhs, ctx) - val arg1 = visitExprCoerce(node.rhs, ctx) - // Call Variants - val call = when (node.field) { - DatetimeField.TIMEZONE_HOUR -> error("Invalid call DATE_ADD(TIMEZONE_HOUR, ...)") - DatetimeField.TIMEZONE_MINUTE -> error("Invalid call DATE_ADD(TIMEZONE_MINUTE, ...)") - else -> call("date_add_${node.field.name.lowercase()}", arg0, arg1) - } - return rex(type, call) - } - - override fun visitExprDateDiff(node: Expr.DateDiff, ctx: Env): Rex { - val type = TIMESTAMP - // Args - val arg0 = visitExprCoerce(node.lhs, ctx) - val arg1 = visitExprCoerce(node.rhs, ctx) - // Call Variants - val call = when (node.field) { - DatetimeField.TIMEZONE_HOUR -> error("Invalid call DATE_DIFF(TIMEZONE_HOUR, ...)") - DatetimeField.TIMEZONE_MINUTE -> error("Invalid call DATE_DIFF(TIMEZONE_MINUTE, ...)") - else -> call("date_diff_${node.field.name.lowercase()}", arg0, arg1) - } - return rex(type, call) - } - - override fun visitExprSessionAttribute(node: Expr.SessionAttribute, ctx: Env): Rex { + override fun visitExprSessionAttribute(node: ExprSessionAttribute, ctx: Env): Rex { val type = ANY - val fn = node.attribute.name.lowercase() + val fn = node.sessionAttribute.name().lowercase() val call = call(fn) return rex(type, call) } - override fun visitExprQuerySet(node: Expr.QuerySet, context: Env): Rex = RelConverter.apply(node, context) + override fun visitExprQuerySet(node: ExprQuerySet, context: Env): Rex = RelConverter.apply(node, context) // Helpers diff --git a/partiql-planner/src/main/kotlin/org/partiql/planner/internal/transforms/SubstitutionVisitor.kt b/partiql-planner/src/main/kotlin/org/partiql/planner/internal/transforms/SubstitutionVisitor.kt index 04114b5346..00c77ade7e 100644 --- a/partiql-planner/src/main/kotlin/org/partiql/planner/internal/transforms/SubstitutionVisitor.kt +++ b/partiql-planner/src/main/kotlin/org/partiql/planner/internal/transforms/SubstitutionVisitor.kt @@ -1,8 +1,8 @@ package org.partiql.planner.internal.transforms -import org.partiql.ast.AstNode -import org.partiql.ast.Expr -import org.partiql.ast.util.AstRewriter +import org.partiql.ast.v1.AstNode +import org.partiql.ast.v1.AstRewriter +import org.partiql.ast.v1.expr.Expr internal object SubstitutionVisitor : AstRewriter>() { override fun visitExpr(node: Expr, ctx: Map<*, AstNode>): AstNode { diff --git a/partiql-planner/src/main/kotlin/org/partiql/planner/internal/transforms/V1AstToPlan.kt b/partiql-planner/src/main/kotlin/org/partiql/planner/internal/transforms/V1AstToPlan.kt deleted file mode 100644 index 4f25650a34..0000000000 --- a/partiql-planner/src/main/kotlin/org/partiql/planner/internal/transforms/V1AstToPlan.kt +++ /dev/null @@ -1,75 +0,0 @@ -/* - * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. - * - * Licensed under the Apache License, Version 2.0 (the "License"). - * You may not use this file except in compliance with the License. - * A copy of the License is located at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * or in the "license" file accompanying this file. This file is distributed - * on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either - * express or implied. See the License for the specific language governing - * permissions and limitations under the License. - * - */ - -package org.partiql.planner.internal.transforms - -import org.partiql.ast.v1.AstNode -import org.partiql.ast.v1.AstVisitor -import org.partiql.ast.v1.Query -import org.partiql.ast.v1.expr.ExprQuerySet -import org.partiql.planner.internal.Env -import org.partiql.planner.internal.ir.statementQuery -import org.partiql.spi.catalog.Identifier -import org.partiql.ast.v1.Identifier as AstIdentifier -import org.partiql.ast.v1.IdentifierChain as AstIdentifierChain -import org.partiql.ast.v1.Statement as AstStatement -import org.partiql.planner.internal.ir.Statement as PlanStatement - -/** - * Simple translation from AST to an unresolved algebraic IR. - */ -internal object V1AstToPlan { - - // statement.toPlan() - @JvmStatic - fun apply(statement: AstStatement, env: Env): PlanStatement = statement.accept(ToPlanStatement, env) - - @Suppress("PARAMETER_NAME_CHANGED_ON_OVERRIDE") - private object ToPlanStatement : AstVisitor() { - - override fun defaultReturn(node: AstNode, env: Env) = throw IllegalArgumentException("Unsupported statement") - - override fun visitQuery(node: Query, env: Env): PlanStatement { - val rex = when (val expr = node.expr) { - is ExprQuerySet -> V1RelConverter.apply(expr, env) - else -> V1RexConverter.apply(expr, env) - } - return statementQuery(rex) - } - } - - // --- Helpers -------------------- - - fun convert(identifier: AstIdentifierChain): Identifier { - val parts = mutableListOf() - parts.add(part(identifier.root)) - var curStep = identifier.next - while (curStep != null) { - parts.add(part(curStep.root)) - curStep = curStep.next - } - return Identifier.of(parts) - } - - fun convert(identifier: AstIdentifier): Identifier { - return Identifier.of(part(identifier)) - } - - fun part(identifier: AstIdentifier): Identifier.Part = when (identifier.isDelimited) { - true -> Identifier.Part.delimited(identifier.symbol) - false -> Identifier.Part.regular(identifier.symbol) - } -} diff --git a/partiql-planner/src/main/kotlin/org/partiql/planner/internal/transforms/V1NormalizeSelect.kt b/partiql-planner/src/main/kotlin/org/partiql/planner/internal/transforms/V1NormalizeSelect.kt deleted file mode 100644 index 96a994ac03..0000000000 --- a/partiql-planner/src/main/kotlin/org/partiql/planner/internal/transforms/V1NormalizeSelect.kt +++ /dev/null @@ -1,393 +0,0 @@ -/* - * Copyright Amazon.com, Inc. or its affiliates. All rights reserved. - * - * Licensed under the Apache License, Version 2.0 (the "License"). - * You may not use this file except in compliance with the License. - * A copy of the License is located at: - * - * http://aws.amazon.com/apache2.0/ - * - * or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific - * language governing permissions and limitations under the License. - */ - -package org.partiql.planner.internal.transforms - -import org.partiql.ast.v1.Ast.exprCall -import org.partiql.ast.v1.Ast.exprCase -import org.partiql.ast.v1.Ast.exprCaseBranch -import org.partiql.ast.v1.Ast.exprIsType -import org.partiql.ast.v1.Ast.exprLit -import org.partiql.ast.v1.Ast.exprQuerySet -import org.partiql.ast.v1.Ast.exprStruct -import org.partiql.ast.v1.Ast.exprStructField -import org.partiql.ast.v1.Ast.exprVarRef -import org.partiql.ast.v1.Ast.identifier -import org.partiql.ast.v1.Ast.identifierChain -import org.partiql.ast.v1.Ast.queryBodySFW -import org.partiql.ast.v1.Ast.queryBodySetOp -import org.partiql.ast.v1.Ast.selectItemExpr -import org.partiql.ast.v1.Ast.selectList -import org.partiql.ast.v1.Ast.selectValue -import org.partiql.ast.v1.AstRewriter -import org.partiql.ast.v1.DataType -import org.partiql.ast.v1.From -import org.partiql.ast.v1.FromExpr -import org.partiql.ast.v1.FromJoin -import org.partiql.ast.v1.FromTableRef -import org.partiql.ast.v1.GroupBy -import org.partiql.ast.v1.QueryBody -import org.partiql.ast.v1.SelectItem -import org.partiql.ast.v1.SelectList -import org.partiql.ast.v1.SelectStar -import org.partiql.ast.v1.SelectValue -import org.partiql.ast.v1.expr.Expr -import org.partiql.ast.v1.expr.ExprCase -import org.partiql.ast.v1.expr.ExprLit -import org.partiql.ast.v1.expr.ExprQuerySet -import org.partiql.ast.v1.expr.ExprStruct -import org.partiql.ast.v1.expr.ExprVarRef -import org.partiql.ast.v1.expr.Scope -import org.partiql.planner.internal.helpers.toBinder -import org.partiql.value.PartiQLValueExperimental -import org.partiql.value.stringValue - -/** - * Converts SQL-style SELECT to PartiQL SELECT VALUE. - * - If there is a PROJECT ALL, we use the TUPLEUNION. - * - If there is NOT a PROJECT ALL, we use a literal struct. - * - * Here are some example of rewrites: - * - * ``` - * SELECT * - * FROM - * A AS x, - * B AS y AT i - * ``` - * gets rewritten to: - * ``` - * SELECT VALUE TUPLEUNION( - * CASE WHEN x IS STRUCT THEN x ELSE { '_1': x }, - * CASE WHEN y IS STRUCT THEN y ELSE { '_2': y }, - * { 'i': i } - * ) FROM A AS x, B AS y AT i - * ``` - * - * ``` - * SELECT x.*, x.a FROM A AS x - * ``` - * gets rewritten to: - * ``` - * SELECT VALUE TUPLEUNION( - * CASE WHEN x IS STRUCT THEN x ELSE { '_1': x }, - * { 'a': x.a } - * ) FROM A AS x - * ``` - * - * ``` - * SELECT x.a FROM A AS x - * ``` - * gets rewritten to: - * ``` - * SELECT VALUE { - * 'a': x.a - * } FROM A AS x - * ``` - * - * NOTE: This does NOT transform subqueries. It operates directly on an [QueryExpr.SFW] -- and that is it. Therefore: - * ``` - * SELECT - * (SELECT 1 FROM T AS "T") - * FROM R AS "R" - * ``` - * will be transformed to: - * ``` - * SELECT VALUE { - * '_1': (SELECT 1 FROM T AS "T") -- notice that SELECT 1 didn't get transformed. - * } FROM R AS "R" - * ``` - * - * Requires [NormalizeFromSource]. - */ -internal object V1NormalizeSelect { - - internal fun normalize(node: ExprQuerySet): ExprQuerySet { - return when (val body = node.body) { - is QueryBody.SFW -> { - val sfw = Visitor.visitSFW(body, newCtx()) - exprQuerySet( - body = sfw, - orderBy = node.orderBy, - limit = node.limit, - offset = node.offset - ) - } - is QueryBody.SetOp -> { - val lhs = body.lhs.normalizeOrIdentity() - val rhs = body.rhs.normalizeOrIdentity() - exprQuerySet( - body = queryBodySetOp( - type = body.type, - isOuter = body.isOuter, - lhs = lhs, - rhs = rhs - ), - orderBy = node.orderBy, - limit = node.limit, - offset = node.offset - ) - } - else -> error("Unexpected QueryBody type: $body") - } - } - - private fun Expr.normalizeOrIdentity(): Expr { - return when (this) { - is ExprQuerySet -> normalize(this) - else -> this - } - } - - /** - * Closure for incrementing a derived binding counter - */ - private fun newCtx(): () -> Int = run { - var i = 1; - { i++ } - } - - /** - * The type parameter () -> Int - */ - private object Visitor : AstRewriter<() -> Int>() { - - /** - * This is used to give projections a name. For example: - * ``` - * SELECT t.* FROM t AS t - * ``` - * - * Will get converted into: - * ``` - * SELECT VALUE TUPLEUNION( - * CASE - * WHEN t IS STRUCT THEN t - * ELSE { '_1': t } - * END - * ) - * FROM t AS t - * ``` - * - * In order to produce the struct's key in `{ '_1': t }` above, we use [col] to produce the column name - * given the ordinal. - */ - private val col = { index: Int -> "_${index + 1}" } - - internal fun visitSFW(node: QueryBody.SFW, ctx: () -> Int): QueryBody.SFW { - val sfw = super.visitQueryBodySFW(node, ctx) as QueryBody.SFW - return when (val select = sfw.select) { - is SelectStar -> { - val selectValue = when (val group = sfw.groupBy) { - null -> visitSelectAll(select, sfw.from) - else -> visitSelectAll(select, group) - } - queryBodySFW( - select = selectValue, - exclude = sfw.exclude, - from = sfw.from, - let = sfw.let, - where = sfw.where, - groupBy = sfw.groupBy, - having = sfw.having, - ) - } - else -> sfw - } - } - - override fun visitQueryBodySFW(node: QueryBody.SFW, ctx: () -> Int): QueryBody.SFW { - return node - } - - override fun visitSelectList(node: SelectList, ctx: () -> Int): SelectValue { - - // Visit items, adding a binder if necessary - var diff = false - val visitedItems = ArrayList(node.items.size) - node.items.forEach { n -> - val item = n.accept(this, ctx) as SelectItem - if (item !== n) diff = true - visitedItems.add(item) - } - val visitedNode = if (diff) selectList(visitedItems, node.setq) else node - - // Rewrite selection - return when (node.items.any { it is SelectItem.Star }) { - false -> visitSelectProjectWithoutProjectAll(visitedNode) - true -> visitSelectProjectWithProjectAll(visitedNode) - } - } - - override fun visitSelectItemExpr(node: SelectItem.Expr, ctx: () -> Int): SelectItem.Expr { - val expr = visitExpr(node.expr, newCtx()) as Expr - val alias = when (node.asAlias) { - null -> expr.toBinder(ctx) - else -> node.asAlias - } - return if (expr != node.expr || alias != node.asAlias) { - selectItemExpr(expr, alias) - } else { - node - } - } - - // Helpers - - /** - * We need to call this from [visitExprSFW] and not override [visitSelectStar] because we need access to the - * [From] aliases. - * - * Note: We assume that [select] and [from] have already been visited. - */ - private fun visitSelectAll(select: SelectStar, from: From): SelectValue { - val tupleUnionArgs = from.tableRefs.flatMap { it.aliases() }.flatMapIndexed { i, binding -> - val asAlias = binding.first - val atAlias = binding.second - val atAliasItem = atAlias?.simple()?.let { - val alias = it.asAlias ?: error("The AT alias should be present. This wasn't normalized.") - buildSimpleStruct(it.expr, alias.symbol) - } - listOfNotNull( - buildCaseWhenStruct(asAlias.star(i).expr, i), - atAliasItem, - ) - } - return selectValue( - constructor = exprCall( - function = identifierChain(identifier("TUPLEUNION", isDelimited = true), next = null), - args = tupleUnionArgs, - setq = null // setq = null for scalar fn - ), - setq = select.setq - ) - } - - /** - * We need to call this from [visitExprSFW] and not override [visitSelectStar] because we need access to the - * [GroupBy] aliases. - * - * Note: We assume that [select] and [group] have already been visited. - */ - private fun visitSelectAll(select: SelectStar, group: GroupBy): SelectValue { - val groupAs = group.asAlias?.let { structField(it.symbol, varLocal(it.symbol)) } - val fields = group.keys.map { key -> - val alias = key.asAlias ?: error("Expected a GROUP BY alias.") - structField(alias.symbol, varLocal(alias.symbol)) - } + listOfNotNull(groupAs) - val constructor = exprStruct(fields) - return selectValue( - constructor = constructor, - setq = select.setq - ) - } - - private fun visitSelectProjectWithProjectAll(node: SelectList): SelectValue { - val tupleUnionArgs = node.items.mapIndexed { index, item -> - when (item) { - is SelectItem.Star -> buildCaseWhenStruct(item.expr, index) - is SelectItem.Expr -> buildSimpleStruct( - item.expr, - item.asAlias?.symbol - ?: error("The alias should've been here. This AST is not normalized.") - ) - else -> error("Unexpected SelectItem type: $item") - } - } - return selectValue( - setq = node.setq, - constructor = exprCall( - function = identifierChain(identifier("TUPLEUNION", isDelimited = true), next = null), - args = tupleUnionArgs, - setq = null // setq = null for scalar fn - ) - ) - } - - @OptIn(PartiQLValueExperimental::class) - private fun visitSelectProjectWithoutProjectAll(node: SelectList): SelectValue { - val structFields = node.items.map { item -> - val itemExpr = item as? SelectItem.Expr ?: error("Expected the projection to be an expression.") - exprStructField( - name = exprLit(stringValue(itemExpr.asAlias?.symbol!!)), - value = item.expr - ) - } - return selectValue( - setq = node.setq, - constructor = exprStruct( - fields = structFields - ) - ) - } - - private fun buildCaseWhenStruct(expr: Expr, index: Int): ExprCase = exprCase( - expr = null, - branches = listOf( - exprCaseBranch( - condition = exprIsType(expr, DataType.STRUCT(), not = false), - expr = expr - ) - ), - defaultExpr = buildSimpleStruct(expr, col(index)) - ) - - @OptIn(PartiQLValueExperimental::class) - private fun buildSimpleStruct(expr: Expr, name: String): ExprStruct = exprStruct( - fields = listOf( - exprStructField( - name = exprLit(stringValue(name)), - value = expr - ) - ) - ) - - @OptIn(PartiQLValueExperimental::class) - private fun structField(name: String, expr: Expr): ExprStruct.Field = exprStructField( - name = ExprLit(stringValue(name)), - value = expr - ) - - private fun varLocal(name: String): ExprVarRef = exprVarRef( - identifierChain = identifierChain(identifier(name, isDelimited = true), next = null), - scope = Scope.LOCAL() - ) - - private fun FromTableRef.aliases(): List> = when (this) { - is FromJoin -> lhs.aliases() + rhs.aliases() - is FromExpr -> { - val asAlias = asAlias?.symbol ?: error("AST not normalized, missing asAlias on FROM source.") - val atAlias = atAlias?.symbol - listOf(Pair(asAlias, atAlias)) - } - else -> error("Unexpected FromTableRef type: $this") - } - - // t -> t.* AS _i - private fun String.star(i: Int): SelectItem.Expr { - val expr = exprVarRef(identifierChain(id(this), next = null), Scope.DEFAULT()) - val alias = expr.toBinder(i) - return selectItemExpr(expr, alias) - } - - // t -> t AS t - private fun String.simple(): SelectItem.Expr { - val expr = exprVarRef(identifierChain(id(this), next = null), Scope.DEFAULT()) - val alias = id(this) - return selectItemExpr(expr, alias) - } - - private fun id(symbol: String) = identifier(symbol, isDelimited = false) - } -} diff --git a/partiql-planner/src/main/kotlin/org/partiql/planner/internal/transforms/V1RelConverter.kt b/partiql-planner/src/main/kotlin/org/partiql/planner/internal/transforms/V1RelConverter.kt deleted file mode 100644 index 00010c2963..0000000000 --- a/partiql-planner/src/main/kotlin/org/partiql/planner/internal/transforms/V1RelConverter.kt +++ /dev/null @@ -1,755 +0,0 @@ -/* - * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. - * - * Licensed under the Apache License, Version 2.0 (the "License"). - * You may not use this file except in compliance with the License. - * A copy of the License is located at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * or in the "license" file accompanying this file. This file is distributed - * on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either - * express or implied. See the License for the specific language governing - * permissions and limitations under the License. - * - */ - -package org.partiql.planner.internal.transforms - -import org.partiql.ast.v1.Ast.exprLit -import org.partiql.ast.v1.Ast.exprVarRef -import org.partiql.ast.v1.Ast.identifier -import org.partiql.ast.v1.Ast.identifierChain -import org.partiql.ast.v1.AstNode -import org.partiql.ast.v1.AstRewriter -import org.partiql.ast.v1.AstVisitor -import org.partiql.ast.v1.Exclude -import org.partiql.ast.v1.ExcludeStep -import org.partiql.ast.v1.From -import org.partiql.ast.v1.FromExpr -import org.partiql.ast.v1.FromJoin -import org.partiql.ast.v1.FromType -import org.partiql.ast.v1.GroupBy -import org.partiql.ast.v1.GroupByStrategy -import org.partiql.ast.v1.IdentifierChain -import org.partiql.ast.v1.JoinType -import org.partiql.ast.v1.Nulls -import org.partiql.ast.v1.Order -import org.partiql.ast.v1.OrderBy -import org.partiql.ast.v1.QueryBody -import org.partiql.ast.v1.SelectItem -import org.partiql.ast.v1.SelectList -import org.partiql.ast.v1.SelectPivot -import org.partiql.ast.v1.SelectStar -import org.partiql.ast.v1.SelectValue -import org.partiql.ast.v1.SetOpType -import org.partiql.ast.v1.SetQuantifier -import org.partiql.ast.v1.expr.Expr -import org.partiql.ast.v1.expr.ExprCall -import org.partiql.ast.v1.expr.ExprQuerySet -import org.partiql.ast.v1.expr.Scope -import org.partiql.planner.internal.Env -import org.partiql.planner.internal.helpers.toBinder -import org.partiql.planner.internal.ir.Rel -import org.partiql.planner.internal.ir.Rex -import org.partiql.planner.internal.ir.rel -import org.partiql.planner.internal.ir.relBinding -import org.partiql.planner.internal.ir.relOpAggregate -import org.partiql.planner.internal.ir.relOpAggregateCallUnresolved -import org.partiql.planner.internal.ir.relOpDistinct -import org.partiql.planner.internal.ir.relOpErr -import org.partiql.planner.internal.ir.relOpExclude -import org.partiql.planner.internal.ir.relOpExcludePath -import org.partiql.planner.internal.ir.relOpExcludeStep -import org.partiql.planner.internal.ir.relOpExcludeTypeCollIndex -import org.partiql.planner.internal.ir.relOpExcludeTypeCollWildcard -import org.partiql.planner.internal.ir.relOpExcludeTypeStructKey -import org.partiql.planner.internal.ir.relOpExcludeTypeStructSymbol -import org.partiql.planner.internal.ir.relOpExcludeTypeStructWildcard -import org.partiql.planner.internal.ir.relOpFilter -import org.partiql.planner.internal.ir.relOpJoin -import org.partiql.planner.internal.ir.relOpLimit -import org.partiql.planner.internal.ir.relOpOffset -import org.partiql.planner.internal.ir.relOpProject -import org.partiql.planner.internal.ir.relOpScan -import org.partiql.planner.internal.ir.relOpScanIndexed -import org.partiql.planner.internal.ir.relOpSort -import org.partiql.planner.internal.ir.relOpSortSpec -import org.partiql.planner.internal.ir.relOpUnpivot -import org.partiql.planner.internal.ir.relType -import org.partiql.planner.internal.ir.rex -import org.partiql.planner.internal.ir.rexOpLit -import org.partiql.planner.internal.ir.rexOpPivot -import org.partiql.planner.internal.ir.rexOpSelect -import org.partiql.planner.internal.ir.rexOpStruct -import org.partiql.planner.internal.ir.rexOpStructField -import org.partiql.planner.internal.ir.rexOpVarLocal -import org.partiql.planner.internal.typer.CompilerType -import org.partiql.types.PType -import org.partiql.value.PartiQLValueExperimental -import org.partiql.value.boolValue -import org.partiql.value.int32Value -import org.partiql.value.stringValue - -/** - * Lexically scoped state for use in translating an individual SELECT statement. - */ -internal object V1RelConverter { - - // IGNORE — so we don't have to non-null assert on operator inputs - internal val nil = rel(relType(emptyList(), emptySet()), relOpErr("nil")) - - /** - * Here we convert an SFW to composed [Rel]s, then apply the appropriate relation-value projection to get a [Rex]. - */ - internal fun apply(qSet: ExprQuerySet, env: Env): Rex { - val newQSet = V1NormalizeSelect.normalize(qSet) - val rex = when (val body = newQSet.body) { - is QueryBody.SFW -> { - val rel = newQSet.accept(ToRel(env), nil) - when (val projection = body.select) { - // PIVOT ... FROM - is SelectPivot -> { - val key = projection.key.toRex(env) - val value = projection.value.toRex(env) - val type = (STRUCT) - val op = rexOpPivot(key, value, rel) - rex(type, op) - } - // SELECT VALUE ... FROM - is SelectValue -> { - assert(rel.type.schema.size == 1) { - "Expected SELECT VALUE's input to have a single binding. " + - "However, it contained: ${rel.type.schema.map { it.name }}." - } - val constructor = rex(ANY, rexOpVarLocal(0, 0)) - val op = rexOpSelect(constructor, rel) - val type = when (rel.type.props.contains(Rel.Prop.ORDERED)) { - true -> (LIST) - else -> (BAG) - } - rex(type, op) - } - // SELECT * FROM - is SelectStar -> { - throw IllegalArgumentException("AST not normalized") - } - // SELECT ... FROM - is SelectList -> { - throw IllegalArgumentException("AST not normalized") - } - else -> error("Unexpected Select type: $projection") - } - } - is QueryBody.SetOp -> { - val rel = newQSet.accept(ToRel(env), nil) - val constructor = rex(ANY, rexOpVarLocal(0, 0)) - val op = rexOpSelect(constructor, rel) - val type = when (rel.type.props.contains(Rel.Prop.ORDERED)) { - true -> (LIST) - else -> (BAG) - } - rex(type, op) - } - else -> error("Unexpected QueryBody type: ${newQSet.body}") - } - return rex - } - - /** - * Syntax sugar for converting an [Expr] tree to a [Rex] tree. - */ - private fun Expr.toRex(env: Env): Rex = V1RexConverter.apply(this, env) - - @Suppress("PARAMETER_NAME_CHANGED_ON_OVERRIDE", "LocalVariableName") - internal class ToRel(private val env: Env) : AstVisitor() { - - override fun defaultReturn(node: AstNode, input: Rel): Rel = - throw IllegalArgumentException("unsupported rel $node") - - /** - * Translate SFW AST node to a pipeline of [Rel] operators; skip any SELECT VALUE or PIVOT projection. - */ - - override fun visitExprQuerySet(node: ExprQuerySet, ctx: Rel): Rel { - val body = node.body - val orderBy = node.orderBy - val limit = node.limit - val offset = node.offset - when (body) { - is QueryBody.SFW -> { - var sel = body - var rel = visitFrom(sel.from, nil) - rel = convertWhere(rel, sel.where) - // kotlin does not have destructuring reassignment - val (_sel, _rel) = convertAgg(rel, sel, sel.groupBy) - sel = _sel - rel = _rel - // Plan.create (possibly rewritten) sel node - rel = convertHaving(rel, sel.having) - rel = convertOrderBy(rel, orderBy) - // offset should precede limit - rel = convertOffset(rel, offset) - rel = convertLimit(rel, limit) - rel = convertExclude(rel, sel.exclude) - // append SQL projection if present - rel = when (val projection = sel.select) { - is SelectValue -> { - val project = visitSelectValue(projection, rel) - visitSetQuantifier(projection.setq, project) - } - is SelectStar, is SelectList -> { - error("AST not normalized, found ${projection.javaClass.simpleName}") - } - is SelectPivot -> rel // Skip PIVOT - else -> error("Unexpected Select type: $projection") - } - return rel - } - is QueryBody.SetOp -> { - var rel = convertSetOp(body) - rel = convertOrderBy(rel, orderBy) - // offset should precede limit - rel = convertOffset(rel, offset) - rel = convertLimit(rel, limit) - return rel - } - else -> error("Unexpected QueryBody type: $body") - } - } - - /** - * Given a [setQuantifier], this will return a [Rel] of [Rel.Op.Distinct] wrapping the [input]. - * If [setQuantifier] is null or ALL, this will return the [input]. - */ - private fun visitSetQuantifier(setQuantifier: SetQuantifier?, input: Rel): Rel { - return when (setQuantifier?.code()) { - SetQuantifier.DISTINCT -> rel(input.type, relOpDistinct(input)) - SetQuantifier.ALL, null -> input - else -> error("Unexpected SetQuantifier type: $setQuantifier") - } - } - - override fun visitSelectList(node: SelectList, input: Rel): Rel { - // this ignores aggregations - val schema = mutableListOf() - val props = input.type.props - val projections = mutableListOf() - node.items.forEach { - val (binding, projection) = convertSelectItem(it) - schema.add(binding) - projections.add(projection) - } - val type = relType(schema, props) - val op = relOpProject(input, projections) - return rel(type, op) - } - - override fun visitSelectValue(node: SelectValue, input: Rel): Rel { - val name = node.constructor.toBinder(1).symbol - val rex = V1RexConverter.apply(node.constructor, env) - val schema = listOf(relBinding(name, rex.type)) - val props = input.type.props - val type = relType(schema, props) - val op = relOpProject(input, projections = listOf(rex)) - return rel(type, op) - } - - @OptIn(PartiQLValueExperimental::class) - override fun visitFrom(node: From, ctx: Rel): Rel { - val tableRefs = node.tableRefs.map { visitFromTableRef(it, ctx) } - return tableRefs.drop(1).fold(tableRefs.first()) { acc, tRef -> - val joinType = Rel.Op.Join.Type.INNER - val condition = rex(BOOL, rexOpLit(boolValue(true))) - val schema = acc.type.schema + tRef.type.schema - val props = emptySet() - val type = relType(schema, props) - val op = relOpJoin(acc, tRef, condition, joinType) - rel(type, op) - } - } - - override fun visitFromExpr(node: FromExpr, nil: Rel): Rel { - val rex = V1RexConverter.applyRel(node.expr, env) - val binding = when (val a = node.asAlias) { - null -> error("AST not normalized, missing AS alias on $node") - else -> relBinding( - name = a.symbol, - type = rex.type - ) - } - return when (node.fromType.code()) { - FromType.SCAN -> { - when (val i = node.atAlias) { - null -> convertScan(rex, binding) - else -> { - val index = relBinding( - name = i.symbol, - type = (INT) - ) - convertScanIndexed(rex, binding, index) - } - } - } - FromType.UNPIVOT -> { - val atAlias = when (val at = node.atAlias) { - null -> error("AST not normalized, missing AT alias on UNPIVOT $node") - else -> relBinding( - name = at.symbol, - type = (STRING) - ) - } - convertUnpivot(rex, k = atAlias, v = binding) - } - else -> error("Unexpected FromType type: ${node.fromType}") - } - } - - /** - * Appends [Rel.Op.Join] where the left and right sides are converted FROM sources - * - * TODO compute basic schema - */ - @OptIn(PartiQLValueExperimental::class) - override fun visitFromJoin(node: FromJoin, nil: Rel): Rel { - val lhs = visitFromTableRef(node.lhs, nil) - val rhs = visitFromTableRef(node.rhs, nil) - val schema = lhs.type.schema + rhs.type.schema // Note: This gets more specific in PlanTyper. It is only used to find binding names here. - val props = emptySet() - val condition = node.condition?.let { V1RexConverter.apply(it, env) } ?: rex(BOOL, rexOpLit(boolValue(true))) - val joinType = when (node.joinType?.code()) { - JoinType.LEFT_OUTER, JoinType.LEFT, JoinType.LEFT_CROSS -> Rel.Op.Join.Type.LEFT - JoinType.RIGHT_OUTER, JoinType.RIGHT -> Rel.Op.Join.Type.RIGHT - JoinType.FULL_OUTER, JoinType.FULL -> Rel.Op.Join.Type.FULL - JoinType.INNER, - JoinType.CROSS -> Rel.Op.Join.Type.INNER // Cross Joins are just INNER JOIN ON TRUE - null -> Rel.Op.Join.Type.INNER // a JOIN b ON a.id = b.id <--> a INNER JOIN b ON a.id = b.id - else -> error("Unexpected JoinType type: ${node.joinType}") - } - val type = relType(schema, props) - val op = relOpJoin(lhs, rhs, condition, joinType) - return rel(type, op) - } - - // Helpers - private fun convertScan(rex: Rex, binding: Rel.Binding): Rel { - val schema = listOf(binding) - val props = emptySet() - val type = relType(schema, props) - val op = relOpScan(rex) - return rel(type, op) - } - - private fun convertScanIndexed(rex: Rex, binding: Rel.Binding, index: Rel.Binding): Rel { - val schema = listOf(binding, index) - val props = emptySet() - val type = relType(schema, props) - val op = relOpScanIndexed(rex) - return rel(type, op) - } - - /** - * Output schema of an UNPIVOT is < k, v > - * - * @param rex - * @param k - * @param v - */ - private fun convertUnpivot(rex: Rex, k: Rel.Binding, v: Rel.Binding): Rel { - val schema = listOf(k, v) - val props = emptySet() - val type = relType(schema, props) - val op = relOpUnpivot(rex) - return rel(type, op) - } - - private fun convertSelectItem(item: SelectItem) = when (item) { - is SelectItem.Star -> convertSelectItemStar(item) - is SelectItem.Expr -> convertSelectItemExpr(item) - else -> error("Unexpected SelectItem type: $item") - } - - private fun convertSelectItemStar(item: SelectItem.Star): Pair { - throw IllegalArgumentException("AST not normalized") - } - - private fun convertSelectItemExpr(item: SelectItem.Expr): Pair { - val name = when (val a = item.asAlias) { - null -> error("AST not normalized, missing AS alias on select item $item") - else -> a.symbol - } - val rex = V1RexConverter.apply(item.expr, env) - val binding = relBinding(name, rex.type) - return binding to rex - } - - /** - * Append [Rel.Op.Filter] only if a WHERE condition exists - */ - private fun convertWhere(input: Rel, expr: Expr?): Rel { - if (expr == null) { - return input - } - val type = input.type - val predicate = expr.toRex(env) - val op = relOpFilter(input, predicate) - return rel(type, op) - } - - /** - * Append [Rel.Op.Aggregate] only if SELECT contains aggregate expressions. - * - * TODO Set quantifiers - * TODO Group As - * - * @return Pair is returned where - * 1. Ast.Expr.SFW has every Ast.Expr.CallAgg replaced by a synthetic Ast.Expr.Var - * 2. Rel which has the appropriate Rex.Agg calls and groups - */ - @OptIn(PartiQLValueExperimental::class) - private fun convertAgg(input: Rel, select: QueryBody.SFW, groupBy: GroupBy?): Pair { - // Rewrite and extract all aggregations in the SELECT clause - val (sel, aggregations) = AggregationTransform.apply(select) - - // No aggregation planning required for GROUP BY - if (aggregations.isEmpty() && groupBy == null) { - return Pair(select, input) - } - - // Build the schema -> (calls... groups...) - val schema = mutableListOf() - val props = emptySet() - - // Build the rel operator - var strategy = Rel.Op.Aggregate.Strategy.FULL - val calls = aggregations.mapIndexed { i, expr -> - val binding = relBinding( - name = syntheticAgg(i), - type = (ANY), - ) - schema.add(binding) - val args = expr.args.map { arg -> arg.toRex(env) } - val id = V1AstToPlan.convert(expr.function) - if (id.hasQualifier()) { - error("Qualified aggregation calls are not supported.") - } - // lowercase normalize all calls - val name = id.getIdentifier().getText().lowercase() - if (name == "count" && expr.args.isEmpty()) { - relOpAggregateCallUnresolved( - name, - org.partiql.planner.internal.ir.SetQuantifier.ALL, - args = listOf(exprLit(int32Value(1)).toRex(env)) - ) - } else { - val setq = when (expr.setq?.code()) { - null -> org.partiql.planner.internal.ir.SetQuantifier.ALL - SetQuantifier.ALL -> org.partiql.planner.internal.ir.SetQuantifier.ALL - SetQuantifier.DISTINCT -> org.partiql.planner.internal.ir.SetQuantifier.DISTINCT - else -> error("Unexpected SetQuantifier type: ${expr.setq}") - } - relOpAggregateCallUnresolved(name, setq, args) - } - }.toMutableList() - - // Add GROUP_AS aggregation - groupBy?.let { gb -> - gb.asAlias?.let { groupAs -> - val binding = relBinding(groupAs.symbol, ANY) - schema.add(binding) - val fields = input.type.schema.mapIndexed { bindingIndex, currBinding -> - rexOpStructField( - k = rex(STRING, rexOpLit(stringValue(currBinding.name))), - v = rex(ANY, rexOpVarLocal(0, bindingIndex)) - ) - } - val arg = listOf(rex(ANY, rexOpStruct(fields))) - calls.add(relOpAggregateCallUnresolved("group_as", org.partiql.planner.internal.ir.SetQuantifier.ALL, arg)) - } - } - var groups = emptyList() - if (groupBy != null) { - groups = groupBy.keys.map { - if (it.asAlias == null) { - error("not normalized, group key $it missing unique name") - } - val binding = relBinding( - name = it.asAlias!!.symbol, - type = (ANY) - ) - schema.add(binding) - it.expr.toRex(env) - } - strategy = when (groupBy.strategy.code()) { - GroupByStrategy.FULL -> Rel.Op.Aggregate.Strategy.FULL - GroupByStrategy.PARTIAL -> Rel.Op.Aggregate.Strategy.PARTIAL - else -> error("Unexpected GroupByStrategy type: ${groupBy.strategy}") - } - } - val type = relType(schema, props) - val op = relOpAggregate(input, strategy, calls, groups) - val rel = rel(type, op) - return Pair(sel, rel) - } - - /** - * Append [Rel.Op.Filter] only if a HAVING condition exists - * - * Notes: - * - This currently does not support aggregation expressions in the WHERE condition - */ - private fun convertHaving(input: Rel, expr: Expr?): Rel { - if (expr == null) { - return input - } - val type = input.type - val predicate = expr.toRex(env) - val op = relOpFilter(input, predicate) - return rel(type, op) - } - - private fun visitIfQuerySet(expr: Expr): Rel { - return when (expr) { - is ExprQuerySet -> visit(expr, nil) - else -> { - val rex = V1RexConverter.applyRel(expr, env) - val op = relOpScan(rex) - val type = Rel.Type(listOf(Rel.Binding("_1", ANY)), props = emptySet()) - return rel(type, op) - } - } - } - - /** - * Append SQL set operator if present - */ - private fun convertSetOp(setExpr: QueryBody.SetOp): Rel { - val lhs = visitIfQuerySet(setExpr.lhs) - val rhs = visitIfQuerySet(setExpr.rhs) - val type = Rel.Type(listOf(Rel.Binding("_0", ANY)), props = emptySet()) - val quantifier = when (setExpr.type.setq?.code()) { - SetQuantifier.ALL -> org.partiql.planner.internal.ir.SetQuantifier.ALL - null, SetQuantifier.DISTINCT -> org.partiql.planner.internal.ir.SetQuantifier.DISTINCT - else -> error("Unexpected SetQuantifier type: ${setExpr.type.setq}") - } - val outer = setExpr.isOuter - val op = when (setExpr.type.setOpType.code()) { - SetOpType.UNION -> Rel.Op.Union(quantifier, outer, lhs, rhs) - SetOpType.EXCEPT -> Rel.Op.Except(quantifier, outer, lhs, rhs) - SetOpType.INTERSECT -> Rel.Op.Intersect(quantifier, outer, lhs, rhs) - else -> error("Unexpected SetOpType type: ${setExpr.type.setOpType}") - } - return rel(type, op) - } - - /** - * Append [Rel.Op.Sort] only if an ORDER BY clause is present - */ - private fun convertOrderBy(input: Rel, orderBy: OrderBy?): Rel { - if (orderBy == null) { - return input - } - val type = input.type.copy(props = setOf(Rel.Prop.ORDERED)) - val specs = orderBy.sorts.map { - val rex = it.expr.toRex(env) - val order = when (it.order?.code()) { - Order.DESC -> when (it.nulls?.code()) { - Nulls.LAST -> Rel.Op.Sort.Order.DESC_NULLS_LAST - Nulls.FIRST, null -> Rel.Op.Sort.Order.DESC_NULLS_FIRST - else -> error("Unexpected Nulls type: ${it.nulls}") - } - else -> when (it.nulls?.code()) { - Nulls.FIRST -> Rel.Op.Sort.Order.ASC_NULLS_FIRST - Nulls.LAST, null -> Rel.Op.Sort.Order.ASC_NULLS_LAST - else -> error("Unexpected Nulls type: ${it.nulls}") - } - } - relOpSortSpec(rex, order) - } - val op = relOpSort(input, specs) - return rel(type, op) - } - - /** - * Append [Rel.Op.Limit] if there is a LIMIT - */ - private fun convertLimit(input: Rel, limit: Expr?): Rel { - if (limit == null) { - return input - } - val type = input.type - val rex = V1RexConverter.apply(limit, env) - val op = relOpLimit(input, rex) - return rel(type, op) - } - - /** - * Append [Rel.Op.Offset] if there is an OFFSET - */ - private fun convertOffset(input: Rel, offset: Expr?): Rel { - if (offset == null) { - return input - } - val type = input.type - val rex = V1RexConverter.apply(offset, env) - val op = relOpOffset(input, rex) - return rel(type, op) - } - - private fun convertExclude(input: Rel, exclude: Exclude?): Rel { - if (exclude == null) { - return input - } - val type = input.type // PlanTyper handles typing the exclusion and removing redundant exclude paths - val paths = exclude.excludePaths - .groupBy(keySelector = { it.root }, valueTransform = { it.excludeSteps }) - .map { (root, exclusions) -> - val rootVar = (root.toRex(env)).op as Rex.Op.Var - val steps = exclusionsToSteps(exclusions) - relOpExcludePath(rootVar, steps) - } - val op = relOpExclude(input, paths) - return rel(type, op) - } - - private fun exclusionsToSteps(exclusions: List>): List { - if (exclusions.any { it.isEmpty() }) { - // if there exists a path with no further steps, can remove the longer paths - // e.g. t.a.b, t.a.b.c, t.a.b.d[*].*.e -> can keep just t.a.b - return emptyList() - } - return exclusions - .groupBy(keySelector = { it.first() }, valueTransform = { it.drop(1) }) - .map { (head, steps) -> - val type = stepToExcludeType(head) - val substeps = exclusionsToSteps(steps) - relOpExcludeStep(type, substeps) - } - } - - private fun stepToExcludeType(step: ExcludeStep): Rel.Op.Exclude.Type { - return when (step) { - is ExcludeStep.StructField -> { - when (step.symbol.isDelimited) { - false -> relOpExcludeTypeStructSymbol(step.symbol.symbol) - true -> relOpExcludeTypeStructKey(step.symbol.symbol) - } - } - is ExcludeStep.CollIndex -> relOpExcludeTypeCollIndex(step.index) - is ExcludeStep.StructWildcard -> relOpExcludeTypeStructWildcard() - is ExcludeStep.CollWildcard -> relOpExcludeTypeCollWildcard() - else -> error("Unexpected ExcludeStep type: $step") - } - } - - // /** - // * Converts a GROUP AS X clause to a binding of the form: - // * ``` - // * { 'X': group_as({ 'a_0': e_0, ..., 'a_n': e_n }) } - // * ``` - // * - // * Notes: - // * - This was included to be consistent with the existing PartiqlAst and PartiqlLogical representations, - // * but perhaps we don't want to represent GROUP AS with an agg function. - // */ - // private fun convertGroupAs(name: String, from: From): Binding { - // val fields = from.bindings().map { n -> - // Plan.field( - // name = Plan.rexLit(ionString(n), STRING), - // value = Plan.rexId(n, Case.SENSITIVE, Rex.Id.Qualifier.UNQUALIFIED, type = STRUCT) - // ) - // } - // return Plan.binding( - // name = name, - // value = Plan.rexAgg( - // id = "group_as", - // args = listOf(Plan.rexTuple(fields, STRUCT)), - // modifier = Rex.Agg.Modifier.ALL, - // type = STRUCT - // ) - // ) - // } - } - - /** - * Rewrites a SELECT node replacing (and extracting) each aggregation `i` with a synthetic field name `$agg_i`. - */ - private object AggregationTransform : AstRewriter() { - // currently hard-coded - @JvmStatic - private val aggregates = setOf("count", "avg", "sum", "min", "max", "any", "some", "every") - - private data class Context( - val aggregations: MutableList, - val keys: List - ) - - fun apply(node: QueryBody.SFW): Pair> { - val aggs = mutableListOf() - val keys = node.groupBy?.keys ?: emptyList() - val context = Context(aggs, keys) - val select = super.visitQueryBodySFW(node, context) as QueryBody.SFW - return Pair(select, aggs) - } - - override fun visitSelectValue(node: SelectValue, ctx: Context): AstNode { - val visited = super.visitSelectValue(node, ctx) - val substitutions = ctx.keys.associate { - it.expr to exprVarRef(identifierChain(identifier(it.asAlias!!.symbol, isDelimited = true), next = null), Scope.DEFAULT()) - } - return V1SubstitutionVisitor.visit(visited, substitutions) - } - - // only rewrite top-level SFW - override fun visitQueryBodySFW(node: QueryBody.SFW, ctx: Context): AstNode = node - - override fun visitExprCall(node: ExprCall, ctx: Context) = - // TODO replace w/ proper function resolution to determine whether a function call is a scalar or aggregate. - // may require further modification of SPI interfaces to support - when (node.function.isAggregateCall()) { - true -> { - val id = identifierChain( - identifier( - symbol = syntheticAgg(ctx.aggregations.size), - isDelimited = false - ), - next = null - ) - ctx.aggregations += node - exprVarRef(id, Scope.DEFAULT()) - } - else -> node - } - - private fun String.isAggregateCall(): Boolean { - return aggregates.contains(this) - } - - private fun IdentifierChain.isAggregateCall(): Boolean { - return when (next) { - null -> root.symbol.lowercase().isAggregateCall() - else -> { - var curId = next - var last = curId - while (curId != null) { - last = curId - curId = curId.next - } - last!!.root.symbol.lowercase().isAggregateCall() - } - } - } - - override fun defaultReturn(node: AstNode, ctx: Context) = node - } - - private fun syntheticAgg(i: Int) = "\$agg_$i" - - private val ANY: CompilerType = CompilerType(PType.dynamic()) - private val BOOL: CompilerType = CompilerType(PType.bool()) - private val STRING: CompilerType = CompilerType(PType.string()) - private val STRUCT: CompilerType = CompilerType(PType.struct()) - private val BAG: CompilerType = CompilerType(PType.bag()) - private val LIST: CompilerType = CompilerType(PType.array()) - private val INT: CompilerType = CompilerType(PType.numeric()) -} diff --git a/partiql-planner/src/main/kotlin/org/partiql/planner/internal/transforms/V1RexConverter.kt b/partiql-planner/src/main/kotlin/org/partiql/planner/internal/transforms/V1RexConverter.kt deleted file mode 100644 index cd203329af..0000000000 --- a/partiql-planner/src/main/kotlin/org/partiql/planner/internal/transforms/V1RexConverter.kt +++ /dev/null @@ -1,1077 +0,0 @@ -/* - * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. - * - * Licensed under the Apache License, Version 2.0 (the "License"). - * You may not use this file except in compliance with the License. - * A copy of the License is located at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * or in the "license" file accompanying this file. This file is distributed - * on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either - * express or implied. See the License for the specific language governing - * permissions and limitations under the License. - * - */ - -package org.partiql.planner.internal.transforms - -import com.amazon.ionelement.api.loadSingleElement -import org.partiql.ast.v1.AstNode -import org.partiql.ast.v1.AstVisitor -import org.partiql.ast.v1.DataType -import org.partiql.ast.v1.QueryBody -import org.partiql.ast.v1.SelectList -import org.partiql.ast.v1.SelectStar -import org.partiql.ast.v1.expr.Expr -import org.partiql.ast.v1.expr.ExprAnd -import org.partiql.ast.v1.expr.ExprArray -import org.partiql.ast.v1.expr.ExprBag -import org.partiql.ast.v1.expr.ExprBetween -import org.partiql.ast.v1.expr.ExprCall -import org.partiql.ast.v1.expr.ExprCase -import org.partiql.ast.v1.expr.ExprCast -import org.partiql.ast.v1.expr.ExprCoalesce -import org.partiql.ast.v1.expr.ExprExtract -import org.partiql.ast.v1.expr.ExprInCollection -import org.partiql.ast.v1.expr.ExprIsType -import org.partiql.ast.v1.expr.ExprLike -import org.partiql.ast.v1.expr.ExprLit -import org.partiql.ast.v1.expr.ExprNot -import org.partiql.ast.v1.expr.ExprNullIf -import org.partiql.ast.v1.expr.ExprOperator -import org.partiql.ast.v1.expr.ExprOr -import org.partiql.ast.v1.expr.ExprOverlay -import org.partiql.ast.v1.expr.ExprPath -import org.partiql.ast.v1.expr.ExprPosition -import org.partiql.ast.v1.expr.ExprQuerySet -import org.partiql.ast.v1.expr.ExprSessionAttribute -import org.partiql.ast.v1.expr.ExprStruct -import org.partiql.ast.v1.expr.ExprSubstring -import org.partiql.ast.v1.expr.ExprTrim -import org.partiql.ast.v1.expr.ExprVarRef -import org.partiql.ast.v1.expr.ExprVariant -import org.partiql.ast.v1.expr.PathStep -import org.partiql.ast.v1.expr.Scope -import org.partiql.ast.v1.expr.TrimSpec -import org.partiql.errors.TypeCheckException -import org.partiql.planner.internal.Env -import org.partiql.planner.internal.ir.Rel -import org.partiql.planner.internal.ir.Rex -import org.partiql.planner.internal.ir.builder.plan -import org.partiql.planner.internal.ir.rel -import org.partiql.planner.internal.ir.relBinding -import org.partiql.planner.internal.ir.relOpJoin -import org.partiql.planner.internal.ir.relOpScan -import org.partiql.planner.internal.ir.relOpUnpivot -import org.partiql.planner.internal.ir.relType -import org.partiql.planner.internal.ir.rex -import org.partiql.planner.internal.ir.rexOpCallUnresolved -import org.partiql.planner.internal.ir.rexOpCastUnresolved -import org.partiql.planner.internal.ir.rexOpCoalesce -import org.partiql.planner.internal.ir.rexOpCollection -import org.partiql.planner.internal.ir.rexOpLit -import org.partiql.planner.internal.ir.rexOpNullif -import org.partiql.planner.internal.ir.rexOpPathIndex -import org.partiql.planner.internal.ir.rexOpPathKey -import org.partiql.planner.internal.ir.rexOpPathSymbol -import org.partiql.planner.internal.ir.rexOpSelect -import org.partiql.planner.internal.ir.rexOpStruct -import org.partiql.planner.internal.ir.rexOpStructField -import org.partiql.planner.internal.ir.rexOpSubquery -import org.partiql.planner.internal.ir.rexOpTupleUnion -import org.partiql.planner.internal.ir.rexOpVarLocal -import org.partiql.planner.internal.ir.rexOpVarUnresolved -import org.partiql.planner.internal.typer.CompilerType -import org.partiql.planner.internal.typer.PlanTyper.Companion.toCType -import org.partiql.spi.catalog.Identifier -import org.partiql.types.PType -import org.partiql.value.MissingValue -import org.partiql.value.PartiQLValueExperimental -import org.partiql.value.StringValue -import org.partiql.value.boolValue -import org.partiql.value.int32Value -import org.partiql.value.int64Value -import org.partiql.value.io.PartiQLValueIonReaderBuilder -import org.partiql.value.nullValue -import org.partiql.value.stringValue -import org.partiql.ast.v1.SetQuantifier as AstSetQuantifier - -/** - * Converts an AST expression node to a Plan Rex node; ignoring any typing. - */ -internal object V1RexConverter { - - internal fun apply(expr: Expr, context: Env): Rex = ToRex.visitExprCoerce(expr, context) - - internal fun applyRel(expr: Expr, context: Env): Rex = expr.accept(ToRex, context) - - @OptIn(PartiQLValueExperimental::class) - @Suppress("PARAMETER_NAME_CHANGED_ON_OVERRIDE") - private object ToRex : AstVisitor() { - - private val COLL_AGG_NAMES = setOf( - "coll_any", - "coll_avg", - "coll_count", - "coll_every", - "coll_max", - "coll_min", - "coll_some", - "coll_sum", - ) - - override fun defaultReturn(node: AstNode, context: Env): Rex = - throw IllegalArgumentException("unsupported rex $node") - - override fun visitExprLit(node: ExprLit, context: Env): Rex { - val type = CompilerType( - _delegate = node.value.type.toPType(), - isNullValue = node.value.isNull, - isMissingValue = node.value is MissingValue - ) - val op = rexOpLit(node.value) - return rex(type, op) - } - - /** - * TODO PartiQLValue will be replaced by Datum (i.e. IonDatum) is a subsequent PR. - */ - override fun visitExprVariant(node: ExprVariant, ctx: Env): Rex { - if (node.encoding != "ion") { - throw IllegalArgumentException("unsupported encoding ${node.encoding}") - } - val ion = loadSingleElement(node.value) - val value = PartiQLValueIonReaderBuilder.standard().build(ion).read() - val type = CompilerType(value.type.toPType()) - return rex(type, rexOpLit(value)) - } - - /** - * !! IMPORTANT !! - * - * This is the top-level visit for handling subquery coercion. The default behavior is to coerce to a scalar. - * In some situations, ie comparison to complex types we may make assertions on the desired type. - * - * It is recommended that every method (except for the exceptional cases) recurse the tree from visitExprCoerce. - * - * - RHS of comparison when LHS is an array or collection expression; and visa-versa - * - It is the collection expression of a FROM clause or JOIN - * - It is the RHS of an IN predicate - * - It is an argument of an OUTER set operator. - * - * @param node - * @param ctx - * @return - */ - internal fun visitExprCoerce(node: Expr, ctx: Env, coercion: Rex.Op.Subquery.Coercion = Rex.Op.Subquery.Coercion.SCALAR): Rex { - val rex = node.accept(this, ctx) - return when (isSqlSelect(node)) { - true -> { - val select = rex.op as Rex.Op.Select - rex( - CompilerType(PType.dynamic()), - rexOpSubquery( - constructor = select.constructor, - rel = select.rel, - coercion = coercion - ) - ) - } - false -> rex - } - } - - override fun visitExprVarRef(node: ExprVarRef, context: Env): Rex { - val type = (ANY) - val identifier = V1AstToPlan.convert(node.identifierChain) - val scope = when (node.scope.code()) { - Scope.DEFAULT -> Rex.Op.Var.Scope.DEFAULT - Scope.LOCAL -> Rex.Op.Var.Scope.LOCAL - else -> error("Unexpected Scope type: ${node.scope}") - } - val op = rexOpVarUnresolved(identifier, scope) - return rex(type, op) - } - - private fun resolveUnaryOp(symbol: String, rhs: Expr, context: Env): Rex { - val type = (ANY) - // Args - val arg = visitExprCoerce(rhs, context) - val args = listOf(arg) - // Fn - val name = when (symbol) { - // TODO move hard-coded operator resolution into SPI - "+" -> "pos" - "-" -> "neg" - else -> error("unsupported unary op $symbol") - } - val id = Identifier.delimited(name) - val op = rexOpCallUnresolved(id, args) - return rex(type, op) - } - - private fun resolveBinaryOp(lhs: Expr, symbol: String, rhs: Expr, context: Env): Rex { - val type = (ANY) - val args = when (symbol) { - "<", ">", - "<=", ">=", - "=", "<>", "!=" -> { - when { - // Example: [1, 2] < (SELECT a, b FROM t) - isLiteralArray(lhs) && isSqlSelect(rhs) -> { - val l = visitExprCoerce(lhs, context) - val r = visitExprCoerce(rhs, context, Rex.Op.Subquery.Coercion.ROW) - listOf(l, r) - } - // Example: (SELECT a, b FROM t) < [1, 2] - isSqlSelect(lhs) && isLiteralArray(rhs) -> { - val l = visitExprCoerce(lhs, context, Rex.Op.Subquery.Coercion.ROW) - val r = visitExprCoerce(rhs, context) - listOf(l, r) - } - // Example: 1 < 2 - else -> { - val l = visitExprCoerce(lhs, context) - val r = visitExprCoerce(rhs, context) - listOf(l, r) - } - } - } - // Example: 1 + 2 - else -> { - val l = visitExprCoerce(lhs, context) - val r = visitExprCoerce(rhs, context) - listOf(l, r) - } - } - // Wrap if a NOT, if necessary - return when (symbol) { - "<>", "!=" -> { - val op = negate(call("eq", *args.toTypedArray())) - rex(type, op) - } - else -> { - val name = when (symbol) { - // TODO eventually move hard-coded operator resolution into SPI - "<" -> "lt" - ">" -> "gt" - "<=" -> "lte" - ">=" -> "gte" - "=" -> "eq" - "||" -> "concat" - "+" -> "plus" - "-" -> "minus" - "*" -> "times" - "/" -> "divide" - "%" -> "modulo" - "&" -> "bitwise_and" - else -> error("unsupported binary op $symbol") - } - val id = Identifier.delimited(name) - val op = rexOpCallUnresolved(id, args) - rex(type, op) - } - } - } - - override fun visitExprOperator(node: ExprOperator, ctx: Env): Rex { - val lhs = node.lhs - return if (lhs != null) { - resolveBinaryOp(lhs, node.symbol, node.rhs, ctx) - } else { - resolveUnaryOp(node.symbol, node.rhs, ctx) - } - } - - override fun visitExprNot(node: ExprNot, ctx: Env): Rex { - val type = (ANY) - // Args - val arg = visitExprCoerce(node.value, ctx) - val args = listOf(arg) - // Fn - val id = Identifier.delimited("not") - val op = rexOpCallUnresolved(id, args) - return rex(type, op) - } - - override fun visitExprAnd(node: ExprAnd, ctx: Env): Rex { - val type = (ANY) - val l = visitExprCoerce(node.lhs, ctx) - val r = visitExprCoerce(node.rhs, ctx) - val args = listOf(l, r) - - // Wrap if a NOT, if necessary - val id = Identifier.delimited("and") - val op = rexOpCallUnresolved(id, args) - return rex(type, op) - } - - override fun visitExprOr(node: ExprOr, ctx: Env): Rex { - val type = (ANY) - val l = visitExprCoerce(node.lhs, ctx) - val r = visitExprCoerce(node.rhs, ctx) - val args = listOf(l, r) - - // Wrap if a NOT, if necessary - val id = Identifier.delimited("or") - val op = rexOpCallUnresolved(id, args) - return rex(type, op) - } - - private fun isLiteralArray(node: Expr): Boolean = node is ExprArray - - private fun isSqlSelect(node: Expr): Boolean { - return if (node is ExprQuerySet) { - val body = node.body - body is QueryBody.SFW && (body.select is SelectList || body.select is SelectStar) - } else { - false - } - } - - override fun visitExprPath(node: ExprPath, context: Env): Rex { - // Args - val root = visitExprCoerce(node.root, context) - - // Attempt to create qualified identifier - val (newRoot, nextStep) = when (val op = root.op) { - is Rex.Op.Var.Unresolved -> { - // convert consecutive symbol path steps to the root identifier - var i = 0 - val parts = mutableListOf() - parts.addAll(op.identifier.getParts()) - var curStep = node.next - while (curStep != null) { - if (curStep !is PathStep.Field) { - break - } - parts.add(V1AstToPlan.part(curStep.field)) - i += 1 - curStep = curStep.next - } - val newRoot = rex(ANY, rexOpVarUnresolved(Identifier.of(parts), op.scope)) - val newSteps = curStep - newRoot to newSteps - } - else -> { - root to node.next - } - } - - if (nextStep == null) { - return newRoot - } - - val fromList = mutableListOf() - - var varRefIndex = 0 // tracking var ref index - - var curStep = nextStep - var curPathNavi = newRoot - while (curStep != null) { - val path = when (curStep) { - is PathStep.Element -> { - val key = visitExprCoerce(curStep.element, context) - val op = when (val astKey = curStep.element) { - is ExprLit -> when (astKey.value) { - is StringValue -> rexOpPathKey(curPathNavi, key) - else -> rexOpPathIndex(curPathNavi, key) - } - is ExprCast -> when (astKey.asType.code() == DataType.STRING) { - true -> rexOpPathKey(curPathNavi, key) - false -> rexOpPathIndex(curPathNavi, key) - } - else -> rexOpPathIndex(curPathNavi, key) - } - op - } - - is PathStep.Field -> { - when (curStep.field.isDelimited) { - true -> { - // case-sensitive path step becomes a key lookup - rexOpPathKey(curPathNavi, rexString(curStep.field.symbol)) - } - false -> { - // case-insensitive path step becomes a symbol lookup - rexOpPathSymbol(curPathNavi, curStep.field.symbol) - } - } - } - - // Unpivot and Wildcard steps trigger the rewrite - // According to spec Section 4.3 - // ew1p1...wnpn - // rewrite to: - // SELECT VALUE v_n.p_n - // FROM - // u_1 e as v_1 - // u_2 @v_1.p_1 as v_2 - // ... - // u_n @v_(n-1).p_(n-1) as v_n - // The From clause needs to be rewritten to - // Join <------------------- schema: [(k_1), v_1, (k_2), v_2, ..., (k_(n-1)) v_(n-1)] - // / \ - // ... un @v_(n-1).p_(n-1) <-- stack: [global, typeEnv: [outer: [global], schema: [(k_1), v_1, (k_2), v_2, ..., (k_(n-1)) v_(n-1)]]] - // Join <----------------------- schema: [(k_1), v_1, (k_2), v_2, (k_3), v_3] - // / \ - // u_2 @v_1.p_1 as v2 <------- stack: [global, typeEnv: [outer: [global], schema: [(k_1), v_1, (k_2), v_2]]] - // JOIN <---------------------------- schema: [(k_1), v_1, (k_2), v_2] - // / \ - // u_1 e as v_1 < ----\----------------------- stack: [global] - // u_2 @v_1.p_1 as v2 <------ stack: [global, typeEnv: [outer: [global], schema: [(k_1), v_1]]] - // while doing the traversal, instead of passing the stack, - // each join will produce its own schema and pass the schema as a type Env. - // The (k_i) indicate the possible key binding produced by unpivot. - // We calculate the var ref on the fly. - is PathStep.AllFields -> { - // Unpivot produces two binding, in this context we want the value, - // which always going to be the second binding - val op = rexOpVarLocal(1, varRefIndex + 1) - varRefIndex += 2 - val index = fromList.size - fromList.add(relFromUnpivot(curPathNavi, index)) - op - } - is PathStep.AllElements -> { - // Scan produce only one binding - val op = rexOpVarLocal(1, varRefIndex) - varRefIndex += 1 - val index = fromList.size - fromList.add(relFromDefault(curPathNavi, index)) - op - } - else -> error("Unexpected PathStep type: $curStep") - } - curStep = curStep.next - curPathNavi = rex(ANY, path) - } - if (fromList.size == 0) return curPathNavi - val fromNode = fromList.reduce { acc, scan -> - val schema = acc.type.schema + scan.type.schema - val props = emptySet() - val type = relType(schema, props) - rel(type, relOpJoin(acc, scan, rex(BOOL, rexOpLit(boolValue(true))), Rel.Op.Join.Type.INNER)) - } - - // compute the ref used by select construct - // always going to be the last binding - val selectRef = fromNode.type.schema.size - 1 - - val constructor = when (val op = curPathNavi.op) { - is Rex.Op.Path.Index -> rex(curPathNavi.type, rexOpPathIndex(rex(op.root.type, rexOpVarLocal(0, selectRef)), op.key)) - is Rex.Op.Path.Key -> rex(curPathNavi.type, rexOpPathKey(rex(op.root.type, rexOpVarLocal(0, selectRef)), op.key)) - is Rex.Op.Path.Symbol -> rex(curPathNavi.type, rexOpPathSymbol(rex(op.root.type, rexOpVarLocal(0, selectRef)), op.key)) - is Rex.Op.Var.Local -> rex(curPathNavi.type, rexOpVarLocal(0, selectRef)) - else -> throw IllegalStateException() - } - val op = rexOpSelect(constructor, fromNode) - return rex(ANY, op) - } - - /** - * Construct Rel(Scan([path])). - * - * The constructed rel would produce one binding: _v$[index] - */ - private fun relFromDefault(path: Rex, index: Int): Rel { - val schema = listOf( - relBinding( - name = "_v$index", // fresh variable - type = path.type - ) - ) - val props = emptySet() - val relType = relType(schema, props) - return rel(relType, relOpScan(path)) - } - - /** - * Construct Rel(Unpivot([path])). - * - * The constructed rel would produce two bindings: _k$[index] and _v$[index] - */ - private fun relFromUnpivot(path: Rex, index: Int): Rel { - val schema = listOf( - relBinding( - name = "_k$index", // fresh variable - type = STRING - ), - relBinding( - name = "_v$index", // fresh variable - type = path.type - ) - ) - val props = emptySet() - val relType = relType(schema, props) - return rel(relType, relOpUnpivot(path)) - } - - private fun rexString(str: String) = rex(STRING, rexOpLit(stringValue(str))) - - override fun visitExprCall(node: ExprCall, context: Env): Rex { - val type = (ANY) - // Fn - val id = V1AstToPlan.convert(node.function) - if (id.hasQualifier()) { - error("Qualified function calls are not currently supported.") - } - if (id.matches("TUPLEUNION")) { - return visitExprCallTupleUnion(node, context) - } - if (id.matches("EXISTS", ignoreCase = true)) { - return visitExprCallExists(node, context) - } - // Args - val args = node.args.map { visitExprCoerce(it, context) } - - // Check if function is actually coll_ - if (isCollAgg(node)) { - return callToCollAgg(id, node.setq, args) - } - - if (node.setq != null) { - error("Currently, only COLL_ may use set quantifiers.") - } - val op = rexOpCallUnresolved(id, args) - return rex(type, op) - } - - /** - * @return whether call is `COLL_`. - */ - private fun isCollAgg(node: ExprCall): Boolean { - val fn = node.function - val id = if (fn.next == null) { - // is not a qualified identifier chain - node.function.root - } else { - return false - } - return COLL_AGG_NAMES.contains(id.symbol.lowercase()) - } - - /** - * Converts COLL_ to the relevant function calls. For example: - * - `COLL_SUM(x)` becomes `coll_sum_all(x)` - * - `COLL_SUM(ALL x)` becomes `coll_sum_all(x)` - * - `COLL_SUM(DISTINCT x)` becomes `coll_sum_distinct(x)` - * - * It is assumed that the [id] has already been vetted by [isCollAgg]. - */ - private fun callToCollAgg(id: Identifier, setQuantifier: AstSetQuantifier?, args: List): Rex { - if (id.hasQualifier()) { - error("Qualified function calls are not currently supported.") - } - if (args.size != 1) { - error("Aggregate calls currently only support single arguments. Received ${args.size} arguments.") - } - val postfix = when (setQuantifier?.code()) { - AstSetQuantifier.DISTINCT -> "_distinct" - AstSetQuantifier.ALL -> "_all" - null -> "_all" - else -> error("Unexpected SetQuantifier type: $setQuantifier") - } - val newId = Identifier.regular(id.getIdentifier().getText() + postfix) - val op = Rex.Op.Call.Unresolved(newId, listOf(args[0])) - return Rex(ANY, op) - } - - private fun visitExprCallTupleUnion(node: ExprCall, context: Env): Rex { - val type = (STRUCT) - val args = node.args.map { visitExprCoerce(it, context) }.toMutableList() - val op = rexOpTupleUnion(args) - return rex(type, op) - } - - /** - * Assume that the node's identifier refers to EXISTS. - * TODO: This could be better suited as a dedicated node in the future. - */ - private fun visitExprCallExists(node: ExprCall, context: Env): Rex { - val type = (BOOL) - if (node.args.size != 1) { - error("EXISTS requires a single argument.") - } - val arg = visitExpr(node.args[0], context) - val op = rexOpCallUnresolved(V1AstToPlan.convert(node.function), listOf(arg)) - return rex(type, op) - } - - override fun visitExprCase(node: ExprCase, context: Env) = plan { - val type = (ANY) - val rex = when (node.expr) { - null -> null - else -> visitExprCoerce(node.expr!!, context) // match `rex - } - - // Converts AST CASE (x) WHEN y THEN z --> Plan CASE WHEN x = y THEN z - val id = Identifier.delimited("eq") - val createBranch: (Rex, Rex) -> Rex.Op.Case.Branch = { condition: Rex, result: Rex -> - val updatedCondition = when (rex) { - null -> condition - else -> rex(type, rexOpCallUnresolved(id, listOf(rex, condition))) - } - rexOpCaseBranch(updatedCondition, result) - } - - val branches = node.branches.map { - val branchCondition = visitExprCoerce(it.condition, context) - val branchRex = visitExprCoerce(it.expr, context) - createBranch(branchCondition, branchRex) - }.toMutableList() - - val defaultRex = when (val default = node.defaultExpr) { - null -> rex(type = ANY, op = rexOpLit(value = nullValue())) - else -> visitExprCoerce(default, context) - } - val op = rexOpCase(branches = branches, default = defaultRex) - rex(type, op) - } - - override fun visitExprArray(node: ExprArray, ctx: Env): Rex { - val values = node.values.map { visitExprCoerce(it, ctx) } - val op = rexOpCollection(values) - return rex(LIST, op) - } - - override fun visitExprBag(node: ExprBag, ctx: Env): Rex { - val values = node.values.map { visitExprCoerce(it, ctx) } - val op = rexOpCollection(values) - return rex(BAG, op) - } - - override fun visitExprStruct(node: ExprStruct, context: Env): Rex { - val type = (STRUCT) - val fields = node.fields.map { - val k = visitExprCoerce(it.name, context) - val v = visitExprCoerce(it.value, context) - rexOpStructField(k, v) - } - val op = rexOpStruct(fields) - return rex(type, op) - } - - // SPECIAL FORMS - - /** - * NOT? LIKE ( ESCAPE )? - */ - override fun visitExprLike(node: ExprLike, ctx: Env): Rex { - val type = BOOL - // Args - val arg0 = visitExprCoerce(node.value, ctx) - val arg1 = visitExprCoerce(node.pattern, ctx) - val arg2 = node.escape?.let { visitExprCoerce(it, ctx) } - // Call Variants - var call = when (arg2) { - null -> call("like", arg0, arg1) - else -> call("like_escape", arg0, arg1, arg2) - } - // NOT? - if (node.not == true) { - call = negate(call) - } - return rex(type, call) - } - - /** - * NOT? BETWEEN AND - */ - override fun visitExprBetween(node: ExprBetween, ctx: Env): Rex = plan { - val type = BOOL - // Args - val arg0 = visitExprCoerce(node.value, ctx) - val arg1 = visitExprCoerce(node.from, ctx) - val arg2 = visitExprCoerce(node.to, ctx) - // Call - var call = call("between", arg0, arg1, arg2) - // NOT? - if (node.not == true) { - call = negate(call) - } - rex(type, call) - } - - /** - * NOT? IN - * - * SQL Spec 1999 section 8.4 - * RVC IN IPV is equivalent to RVC = ANY IPV -> Quantified Comparison Predicate - * Which means: - * Let the expression be T in C, where C is [a1, ..., an] - * T in C is true iff T = a_x is true for any a_x in [a1, ...., an] - * T in C is false iff T = a_x is false for every a_x in [a1, ....., an ] or cardinality of the collection is 0. - * Otherwise, T in C is unknown. - * - */ - override fun visitExprInCollection(node: ExprInCollection, ctx: Env): Rex { - val type = BOOL - // Args - val arg0 = visitExprCoerce(node.lhs, ctx) - val arg1 = node.rhs.accept(this, ctx) // !! don't insert scalar subquery coercions - - // Call - var call = call("in_collection", arg0, arg1) - // NOT? - if (node.not == true) { - call = negate(call) - } - return rex(type, call) - } - - /** - * IS ? - */ - override fun visitExprIsType(node: ExprIsType, ctx: Env): Rex { - val type = BOOL - // arg - val arg0 = visitExprCoerce(node.value, ctx) - val targetType = node.type - var call = when (targetType.code()) { - // - DataType.NULL -> call("is_null", arg0) - DataType.MISSING -> call("is_missing", arg0) - // - // TODO CHAR_VARYING, CHARACTER_LARGE_OBJECT, CHAR_LARGE_OBJECT - DataType.CHARACTER, DataType.CHAR -> call("is_char", targetType.length.toRex(), arg0) - DataType.CHARACTER_VARYING, DataType.VARCHAR -> call("is_varchar", targetType.length.toRex(), arg0) - DataType.CLOB -> call("is_clob", arg0) - DataType.STRING -> call("is_string", targetType.length.toRex(), arg0) - DataType.SYMBOL -> call("is_symbol", arg0) - // - // TODO BINARY_LARGE_OBJECT - DataType.BLOB -> call("is_blob", arg0) - // - DataType.BIT -> call("is_bit", arg0) // TODO define in parser - DataType.BIT_VARYING -> call("is_bitVarying", arg0) // TODO define in parser - // - - DataType.NUMERIC -> call("is_numeric", targetType.precision.toRex(), targetType.scale.toRex(), arg0) - DataType.DEC, DataType.DECIMAL -> call("is_decimal", targetType.precision.toRex(), targetType.scale.toRex(), arg0) - DataType.BIGINT, DataType.INT8, DataType.INTEGER8 -> call("is_int64", arg0) - DataType.INT4, DataType.INTEGER4, DataType.INTEGER -> call("is_int32", arg0) - DataType.INT -> call("is_int", arg0) - DataType.INT2, DataType.SMALLINT -> call("is_int16", arg0) - DataType.TINYINT -> call("is_int8", arg0) // TODO define in parser - // - - DataType.FLOAT -> call("is_float32", arg0) - DataType.REAL -> call("is_real", arg0) - DataType.DOUBLE_PRECISION -> call("is_float64", arg0) - // - DataType.BOOLEAN, DataType.BOOL -> call("is_bool", arg0) - // - DataType.DATE -> call("is_date", arg0) - // TODO: DO we want to seperate with time zone vs without time zone into two different type in the plan? - // leave the parameterized type out for now until the above is answered - DataType.TIME -> call("is_time", arg0) - DataType.TIME_WITH_TIME_ZONE -> call("is_timeWithTz", arg0) - DataType.TIMESTAMP -> call("is_timestamp", arg0) - DataType.TIMESTAMP_WITH_TIME_ZONE -> call("is_timestampWithTz", arg0) - // - DataType.INTERVAL -> call("is_interval", arg0) // TODO define in parser - // - DataType.STRUCT, DataType.TUPLE -> call("is_struct", arg0) - // - DataType.LIST -> call("is_list", arg0) - DataType.BAG -> call("is_bag", arg0) - DataType.SEXP -> call("is_sexp", arg0) - // - DataType.USER_DEFINED -> call("is_custom", arg0) - else -> error("Unexpected DataType type: $targetType") - } - - if (node.not == true) { - call = negate(call) - } - - return rex(type, call) - } - - override fun visitExprCoalesce(node: ExprCoalesce, ctx: Env): Rex { - val type = ANY - val args = node.args.map { arg -> - visitExprCoerce(arg, ctx) - } - val op = rexOpCoalesce(args) - return rex(type, op) - } - - override fun visitExprNullIf(node: ExprNullIf, ctx: Env): Rex { - val type = ANY - val v1 = visitExprCoerce(node.v1, ctx) - val v2 = visitExprCoerce(node.v2, ctx) - val op = rexOpNullif(v1, v2) - return rex(type, op) - } - - /** - * SUBSTRING( (FROM (FOR )?)? ) - */ - override fun visitExprSubstring(node: ExprSubstring, ctx: Env): Rex { - val type = ANY - // Args - val arg0 = visitExprCoerce(node.value, ctx) - val arg1 = node.start?.let { visitExprCoerce(it, ctx) } ?: rex(INT, rexOpLit(int64Value(1))) - val arg2 = node.length?.let { visitExprCoerce(it, ctx) } - // Call Variants - val call = when (arg2) { - null -> call("substring", arg0, arg1) - else -> call("substring", arg0, arg1, arg2) - } - return rex(type, call) - } - - /** - * POSITION( IN ) - */ - override fun visitExprPosition(node: ExprPosition, ctx: Env): Rex { - val type = ANY - // Args - val arg0 = visitExprCoerce(node.lhs, ctx) - val arg1 = visitExprCoerce(node.rhs, ctx) - // Call - val call = call("position", arg0, arg1) - return rex(type, call) - } - - /** - * TRIM([LEADING|TRAILING|BOTH]? ( FROM)? ) - */ - override fun visitExprTrim(node: ExprTrim, ctx: Env): Rex { - val type = STRING - // Args - val arg0 = visitExprCoerce(node.value, ctx) - val arg1 = node.chars?.let { visitExprCoerce(it, ctx) } - // Call Variants - val call = when (node.trimSpec?.code()) { - TrimSpec.LEADING -> when (arg1) { - null -> call("trim_leading", arg0) - else -> call("trim_leading_chars", arg0, arg1) - } - TrimSpec.TRAILING -> when (arg1) { - null -> call("trim_trailing", arg0) - else -> call("trim_trailing_chars", arg0, arg1) - } - // TODO: We may want to add a trim_both for trim(BOTH FROM arg) - else -> when (arg1) { - null -> call("trim", arg0) - else -> call("trim_chars", arg0, arg1) - } - } - return rex(type, call) - } - - /** - * SQL Spec 1999: Section 6.18 - * - * ::= - * OVERLAY - * PLACING - * FROM - * [ FOR ] - * - * The is equivalent to: - * - * SUBSTRING ( CV FROM 1 FOR SP - 1 ) || RS || SUBSTRING ( CV FROM SP + SL ) - * - * Where CV is the first , - * SP is the - * RS is the second , - * SL is the if specified, otherwise it is char_length(RS). - */ - override fun visitExprOverlay(node: ExprOverlay, ctx: Env): Rex { - val cv = visitExprCoerce(node.value, ctx) - val sp = visitExprCoerce(node.from, ctx) - val rs = visitExprCoerce(node.placing, ctx) - val sl = node.forLength?.let { visitExprCoerce(it, ctx) } ?: rex(ANY, call("char_length", rs)) - val p1 = rex( - ANY, - call( - "substring", - cv, - rex(INT4, rexOpLit(int32Value(1))), - rex(ANY, call("minus", sp, rex(INT4, rexOpLit(int32Value(1))))) - ) - ) - val p2 = rex(ANY, call("concat", p1, rs)) - return rex( - ANY, - call( - "concat", - p2, - rex(ANY, call("substring", cv, rex(ANY, call("plus", sp, sl)))) - ) - ) - } - - override fun visitExprExtract(node: ExprExtract, ctx: Env): Rex { - val call = call("extract_${node.field.name().lowercase()}", visitExprCoerce(node.source, ctx)) - return rex(ANY, call) - } - - override fun visitExprCast(node: ExprCast, ctx: Env): Rex { - val type = visitType(node.asType) - val arg = visitExprCoerce(node.value, ctx) - return rex(ANY, rexOpCastUnresolved(type, arg)) - } - - private fun visitType(type: DataType): CompilerType { - return when (type.code()) { - // - DataType.NULL -> error("Casting to NULL is not supported.") - DataType.MISSING -> error("Casting to MISSING is not supported.") - // - // TODO CHAR_VARYING, CHARACTER_LARGE_OBJECT, CHAR_LARGE_OBJECT - DataType.CHARACTER, DataType.CHAR -> { - val length = type.length ?: 1 - assertGtZeroAndCreate(PType.Kind.CHAR, "length", length, PType::character) - } - DataType.CHARACTER_VARYING, DataType.VARCHAR -> { - val length = type.length ?: 1 - assertGtZeroAndCreate(PType.Kind.VARCHAR, "length", length, PType::varchar) - } - DataType.CLOB -> assertGtZeroAndCreate(PType.Kind.CLOB, "length", type.length ?: Int.MAX_VALUE, PType::clob) - DataType.STRING -> PType.string() - DataType.SYMBOL -> PType.symbol() - // - // TODO BINARY_LARGE_OBJECT - DataType.BLOB -> assertGtZeroAndCreate(PType.Kind.BLOB, "length", type.length ?: Int.MAX_VALUE, PType::blob) - // - DataType.BIT -> error("BIT is not supported yet.") - DataType.BIT_VARYING -> error("BIT VARYING is not supported yet.") - // - - DataType.NUMERIC -> { - val p = type.precision - val s = type.scale - when { - p == null && s == null -> PType.decimal() - p != null && s != null -> { - assertParamCompToZero(PType.Kind.NUMERIC, "precision", p, false) - assertParamCompToZero(PType.Kind.NUMERIC, "scale", s, true) - if (s > p) { - throw TypeCheckException("Numeric scale cannot be greater than precision.") - } - PType.decimal(type.precision!!, type.scale!!) - } - p != null && s == null -> { - assertParamCompToZero(PType.Kind.NUMERIC, "precision", p, false) - PType.decimal(p, 0) - } - else -> error("Precision can never be null while scale is specified.") - } - } - DataType.DEC, DataType.DECIMAL -> { - val p = type.precision - val s = type.scale - when { - p == null && s == null -> PType.decimal() - p != null && s != null -> { - assertParamCompToZero(PType.Kind.DECIMAL, "precision", p, false) - assertParamCompToZero(PType.Kind.DECIMAL, "scale", s, true) - if (s > p) { - throw TypeCheckException("Decimal scale cannot be greater than precision.") - } - PType.decimal(p, s) - } - p != null && s == null -> { - assertParamCompToZero(PType.Kind.DECIMAL, "precision", p, false) - PType.decimal(p, 0) - } - else -> error("Precision can never be null while scale is specified.") - } - } - DataType.BIGINT, DataType.INT8, DataType.INTEGER8 -> PType.bigint() - DataType.INT4, DataType.INTEGER4, DataType.INTEGER, DataType.INT -> PType.integer() - DataType.INT2, DataType.SMALLINT -> PType.smallint() - DataType.TINYINT -> PType.tinyint() // TODO define in parser - // - - DataType.FLOAT -> PType.real() - DataType.REAL -> PType.real() - DataType.DOUBLE_PRECISION -> PType.doublePrecision() - // - DataType.BOOL -> PType.bool() - // - DataType.DATE -> PType.date() - DataType.TIME -> assertGtEqZeroAndCreate(PType.Kind.TIME, "precision", type.precision ?: 0, PType::time) - DataType.TIME_WITH_TIME_ZONE -> assertGtEqZeroAndCreate(PType.Kind.TIMEZ, "precision", type.precision ?: 0, PType::timez) - DataType.TIMESTAMP -> assertGtEqZeroAndCreate(PType.Kind.TIMESTAMP, "precision", type.precision ?: 6, PType::timestamp) - DataType.TIMESTAMP_WITH_TIME_ZONE -> assertGtEqZeroAndCreate(PType.Kind.TIMESTAMPZ, "precision", type.precision ?: 6, PType::timestampz) - // - DataType.INTERVAL -> error("INTERVAL is not supported yet.") - // - DataType.STRUCT -> PType.struct() - DataType.TUPLE -> PType.struct() - // - DataType.LIST -> PType.array() - DataType.BAG -> PType.bag() - DataType.SEXP -> PType.sexp() - // - DataType.USER_DEFINED -> TODO("Custom type not supported ") - else -> error("Unsupported DataType type: $type") - }.toCType() - } - - private fun assertGtZeroAndCreate(type: PType.Kind, param: String, value: Int, create: (Int) -> PType): PType { - assertParamCompToZero(type, param, value, false) - return create.invoke(value) - } - - private fun assertGtEqZeroAndCreate(type: PType.Kind, param: String, value: Int, create: (Int) -> PType): PType { - assertParamCompToZero(type, param, value, true) - return create.invoke(value) - } - - /** - * @param allowZero when FALSE, this asserts that [value] > 0. If TRUE, this asserts that [value] >= 0. - */ - private fun assertParamCompToZero(type: PType.Kind, param: String, value: Int, allowZero: Boolean) { - val (result, compString) = when (allowZero) { - true -> (value >= 0) to "greater than" - false -> (value > 0) to "greater than or equal to" - } - if (!result) { - throw TypeCheckException("$type $param must be an integer value $compString 0.") - } - } - - override fun visitExprSessionAttribute(node: ExprSessionAttribute, ctx: Env): Rex { - val type = ANY - val fn = node.sessionAttribute.name().lowercase() - val call = call(fn) - return rex(type, call) - } - - override fun visitExprQuerySet(node: ExprQuerySet, context: Env): Rex = V1RelConverter.apply(node, context) - - // Helpers - - private fun negate(call: Rex.Op): Rex.Op.Call { - val id = Identifier.delimited("not") - val arg = rex(BOOL, call) - return rexOpCallUnresolved(id, listOf(arg)) - } - - /** - * Create a [Rex.Op.Call.Static] node which has a hidden unresolved Function. - * The purpose of having such hidden function is to prevent usage of generated function name in query text. - */ - private fun call(name: String, vararg args: Rex): Rex.Op.Call { - val id = Identifier.regular(name) - return rexOpCallUnresolved(id, args.toList()) - } - - private fun Int?.toRex() = rex(INT4, rexOpLit(int32Value(this))) - - private val ANY: CompilerType = CompilerType(PType.dynamic()) - private val BOOL: CompilerType = CompilerType(PType.bool()) - private val STRING: CompilerType = CompilerType(PType.string()) - private val STRUCT: CompilerType = CompilerType(PType.struct()) - private val BAG: CompilerType = CompilerType(PType.bag()) - private val LIST: CompilerType = CompilerType(PType.array()) - private val SEXP: CompilerType = CompilerType(PType.sexp()) - private val INT: CompilerType = CompilerType(PType.numeric()) - private val INT4: CompilerType = CompilerType(PType.integer()) - private val TIMESTAMP: CompilerType = CompilerType(PType.timestamp(6)) - } -} diff --git a/partiql-planner/src/main/kotlin/org/partiql/planner/internal/transforms/V1SubstitutionVisitor.kt b/partiql-planner/src/main/kotlin/org/partiql/planner/internal/transforms/V1SubstitutionVisitor.kt deleted file mode 100644 index 497d83f783..0000000000 --- a/partiql-planner/src/main/kotlin/org/partiql/planner/internal/transforms/V1SubstitutionVisitor.kt +++ /dev/null @@ -1,15 +0,0 @@ -package org.partiql.planner.internal.transforms - -import org.partiql.ast.v1.AstNode -import org.partiql.ast.v1.AstRewriter -import org.partiql.ast.v1.expr.Expr - -internal object V1SubstitutionVisitor : AstRewriter>() { - override fun visitExpr(node: Expr, ctx: Map<*, AstNode>): AstNode { - val visited = super.visitExpr(node, ctx) - if (ctx.containsKey(visited)) { - return ctx[visited]!! - } - return visited - } -} diff --git a/partiql-planner/src/test/kotlin/org/partiql/planner/PlanTest.kt b/partiql-planner/src/test/kotlin/org/partiql/planner/PlanTest.kt index bc5aa6c9c8..79e82fd04a 100644 --- a/partiql-planner/src/test/kotlin/org/partiql/planner/PlanTest.kt +++ b/partiql-planner/src/test/kotlin/org/partiql/planner/PlanTest.kt @@ -5,7 +5,7 @@ import org.junit.jupiter.api.DynamicContainer.dynamicContainer import org.junit.jupiter.api.DynamicNode import org.junit.jupiter.api.DynamicTest import org.junit.jupiter.api.TestFactory -import org.partiql.parser.PartiQLParserV1 +import org.partiql.parser.PartiQLParser import org.partiql.plan.Plan import org.partiql.planner.internal.TestCatalog import org.partiql.planner.test.PartiQLTest @@ -76,7 +76,7 @@ class PlanTest { ) .namespace("SCHEMA") .build() - val parseResult = PartiQLParserV1.standard().parse(test.statement) + val parseResult = PartiQLParser.standard().parse(test.statement) assertEquals(1, parseResult.statements.size) val ast = parseResult.statements[0] val planner = PartiQLPlanner.builder().signal(isSignalMode).build() diff --git a/partiql-planner/src/test/kotlin/org/partiql/planner/PlannerPErrorReportingTests.kt b/partiql-planner/src/test/kotlin/org/partiql/planner/PlannerPErrorReportingTests.kt index ba609f886b..4c08378467 100644 --- a/partiql-planner/src/test/kotlin/org/partiql/planner/PlannerPErrorReportingTests.kt +++ b/partiql-planner/src/test/kotlin/org/partiql/planner/PlannerPErrorReportingTests.kt @@ -3,7 +3,7 @@ package org.partiql.planner import org.junit.jupiter.params.ParameterizedTest import org.junit.jupiter.params.provider.MethodSource import org.partiql.ast.v1.Statement -import org.partiql.parser.PartiQLParserBuilderV1 +import org.partiql.parser.PartiQLParser import org.partiql.plan.Operation import org.partiql.planner.internal.typer.CompilerType import org.partiql.planner.internal.typer.PlanTyper.Companion.toCType @@ -42,7 +42,7 @@ internal class PlannerPErrorReportingTests { .catalogs(catalog) .build() - private val parser = PartiQLParserBuilderV1().build() + private val parser = PartiQLParser.builder().build() private val statement: ((String) -> Statement) = { query -> val parseResult = parser.parse(query) diff --git a/partiql-planner/src/test/kotlin/org/partiql/planner/internal/exclude/SubsumptionTest.kt b/partiql-planner/src/test/kotlin/org/partiql/planner/internal/exclude/SubsumptionTest.kt index c56a1852aa..191ccf9216 100644 --- a/partiql-planner/src/test/kotlin/org/partiql/planner/internal/exclude/SubsumptionTest.kt +++ b/partiql-planner/src/test/kotlin/org/partiql/planner/internal/exclude/SubsumptionTest.kt @@ -7,7 +7,7 @@ import org.junit.jupiter.params.ParameterizedTest import org.junit.jupiter.params.provider.Arguments import org.junit.jupiter.params.provider.ArgumentsProvider import org.junit.jupiter.params.provider.ArgumentsSource -import org.partiql.parser.PartiQLParserV1 +import org.partiql.parser.PartiQLParser import org.partiql.plan.Exclusion import org.partiql.plan.Operation import org.partiql.plan.builder.PlanFactory @@ -26,7 +26,7 @@ class SubsumptionTest { companion object { private val planner = PartiQLPlanner.standard() - private val parser = PartiQLParserV1.standard() + private val parser = PartiQLParser.standard() private val catalog = Catalog.builder().name("default").build() } diff --git a/partiql-planner/src/test/kotlin/org/partiql/planner/internal/transforms/NormalizeSelectTest.kt b/partiql-planner/src/test/kotlin/org/partiql/planner/internal/transforms/NormalizeSelectTest.kt index ce81b193da..cee56f362b 100644 --- a/partiql-planner/src/test/kotlin/org/partiql/planner/internal/transforms/NormalizeSelectTest.kt +++ b/partiql-planner/src/test/kotlin/org/partiql/planner/internal/transforms/NormalizeSelectTest.kt @@ -46,7 +46,7 @@ class NormalizeSelectTest { "b" to variable("b"), "c" to variable("c"), ) - val actual = V1NormalizeSelect.normalize(input) + val actual = NormalizeSelect.normalize(input) assertEquals(expected, actual) } @@ -71,7 +71,7 @@ class NormalizeSelectTest { "_2" to lit(2), "_3" to lit(3), ) - val actual = V1NormalizeSelect.normalize(input) + val actual = NormalizeSelect.normalize(input) assertEquals(expected, actual) } @@ -96,7 +96,7 @@ class NormalizeSelectTest { "_1" to lit(2), "_2" to lit(3), ) - val actual = V1NormalizeSelect.normalize(input) + val actual = NormalizeSelect.normalize(input) assertEquals(expected, actual) } @@ -121,7 +121,7 @@ class NormalizeSelectTest { "b" to lit(2), "c" to lit(3), ) - val actual = V1NormalizeSelect.normalize(input) + val actual = NormalizeSelect.normalize(input) assertEquals(expected, actual) } diff --git a/partiql-planner/src/test/kotlin/org/partiql/planner/internal/typer/PartiQLTyperTestBase.kt b/partiql-planner/src/test/kotlin/org/partiql/planner/internal/typer/PartiQLTyperTestBase.kt index 33d5c71a80..23e681520f 100644 --- a/partiql-planner/src/test/kotlin/org/partiql/planner/internal/typer/PartiQLTyperTestBase.kt +++ b/partiql-planner/src/test/kotlin/org/partiql/planner/internal/typer/PartiQLTyperTestBase.kt @@ -2,7 +2,7 @@ package org.partiql.planner.internal.typer import org.junit.jupiter.api.DynamicContainer import org.junit.jupiter.api.DynamicTest -import org.partiql.parser.PartiQLParserV1 +import org.partiql.parser.PartiQLParser import org.partiql.plan.Operation import org.partiql.planner.PartiQLPlanner import org.partiql.planner.test.PartiQLTest @@ -38,7 +38,7 @@ abstract class PartiQLTyperTestBase { companion object { - public val parser = PartiQLParserV1.standard() + public val parser = PartiQLParser.standard() public val planner = PartiQLPlanner.standard() internal val session: ((String, Catalog) -> Session) = { catalog, metadata -> diff --git a/partiql-planner/src/test/kotlin/org/partiql/planner/internal/typer/PlanTyperTestsPorted.kt b/partiql-planner/src/test/kotlin/org/partiql/planner/internal/typer/PlanTyperTestsPorted.kt index 123a57dc71..490a64ee73 100644 --- a/partiql-planner/src/test/kotlin/org/partiql/planner/internal/typer/PlanTyperTestsPorted.kt +++ b/partiql-planner/src/test/kotlin/org/partiql/planner/internal/typer/PlanTyperTestsPorted.kt @@ -12,7 +12,7 @@ import org.junit.jupiter.params.provider.Arguments import org.junit.jupiter.params.provider.ArgumentsProvider import org.junit.jupiter.params.provider.ArgumentsSource import org.junit.jupiter.params.provider.MethodSource -import org.partiql.parser.PartiQLParserV1 +import org.partiql.parser.PartiQLParser import org.partiql.planner.PartiQLPlanner import org.partiql.planner.internal.PErrors import org.partiql.planner.internal.TestCatalog @@ -125,7 +125,7 @@ internal class PlanTyperTestsPorted { companion object { - private val parser = PartiQLParserV1.standard() + private val parser = PartiQLParser.standard() private val planner = PartiQLPlanner.builder().signal().build() private fun assertProblemExists(problem: PError) = ProblemHandler { problems, _ -> diff --git a/test/partiql-randomized-tests/src/test/kotlin/org/partiql/lang/randomized/eval/Utils.kt b/test/partiql-randomized-tests/src/test/kotlin/org/partiql/lang/randomized/eval/Utils.kt index 8fa1dd3db3..432b7531a6 100644 --- a/test/partiql-randomized-tests/src/test/kotlin/org/partiql/lang/randomized/eval/Utils.kt +++ b/test/partiql-randomized-tests/src/test/kotlin/org/partiql/lang/randomized/eval/Utils.kt @@ -1,7 +1,7 @@ package org.partiql.lang.randomized.eval import org.partiql.eval.compiler.PartiQLCompiler -import org.partiql.parser.PartiQLParserV1 +import org.partiql.parser.PartiQLParser import org.partiql.planner.PartiQLPlanner import org.partiql.spi.catalog.Catalog import org.partiql.spi.catalog.Session @@ -22,7 +22,7 @@ fun runEvaluatorTestCase( @OptIn(PartiQLValueExperimental::class) private fun execute(query: String): PartiQLValue { - val parser = PartiQLParserV1.builder().build() + val parser = PartiQLParser.builder().build() val planner = PartiQLPlanner.builder().build() val catalog = object : Catalog { override fun getName(): String = "default" diff --git a/test/partiql-tests-runner/src/test/kotlin/org/partiql/runner/executor/EvalExecutor.kt b/test/partiql-tests-runner/src/test/kotlin/org/partiql/runner/executor/EvalExecutor.kt index 63aabb3b27..71335eeeea 100644 --- a/test/partiql-tests-runner/src/test/kotlin/org/partiql/runner/executor/EvalExecutor.kt +++ b/test/partiql-tests-runner/src/test/kotlin/org/partiql/runner/executor/EvalExecutor.kt @@ -10,7 +10,7 @@ import com.amazon.ionelement.api.toIonValue import org.partiql.eval.Mode import org.partiql.eval.Statement import org.partiql.eval.compiler.PartiQLCompiler -import org.partiql.parser.PartiQLParserV1 +import org.partiql.parser.PartiQLParser import org.partiql.plan.Operation.Query import org.partiql.planner.PartiQLPlanner import org.partiql.runner.CompileType @@ -143,7 +143,7 @@ class EvalExecutor( companion object { val compiler = PartiQLCompiler.standard() - val parser = PartiQLParserV1.standard() + val parser = PartiQLParser.standard() val planner = PartiQLPlanner.standard() // TODO REPLACE WITH DATUM COMPARATOR val comparator = PartiQLValue.comparator()