From b1d67b23dda50b854f1199725e6d4e5a4361d88a Mon Sep 17 00:00:00 2001 From: "R. C. Howell" Date: Thu, 13 Jul 2023 09:51:04 -0700 Subject: [PATCH] Initialize partiql-parser package with partiql-ast IR (#1142) --- .github/workflows/build.yml | 4 +- CHANGELOG.md | 3 + buildSrc/src/main/kotlin/partiql.versions.kt | 2 +- partiql-ast/README.adoc | 364 +++ partiql-ast/build.gradle.kts | 29 +- .../src/main/kotlin/org/partiql/ast/Ast.kt | 15 + .../kotlin/org/partiql/ast/helpers/ToIon.kt | 128 ++ .../org/partiql/ast/helpers/ToLegacyAst.kt | 1366 +++++++++++ .../main/kotlin/org/partiql/ast/impl/.gitkeep | 0 .../lang/ast/AggregateCallSiteListMeta.kt | 0 .../org/partiql/lang/ast/InternalMetas.kt | 0 .../org/partiql/lang/ast/IsCountStarMeta.kt | 0 .../lang/ast/IsGroupAttributeReferenceMeta.kt | 2 +- .../org/partiql/lang/ast/IsImplictJoinMeta.kt | 0 .../org/partiql/lang/ast/IsIonLiteralMeta.kt | 0 .../lang/ast/IsListParenthesizedMeta.kt | 0 .../org/partiql/lang/ast/IsOrderedMeta.kt | 0 .../org/partiql/lang/ast/IsPathIndexMeta.kt | 0 .../lang/ast/IsTransformedOrderByAliasMeta.kt | 7 +- .../org/partiql/lang/ast/IsValuesExprMeta.kt | 0 .../partiql/lang/ast/LegacyLogicalNotMeta.kt | 18 +- .../partiql/lang/ast/SourceLocationMeta.kt | 15 +- .../org/partiql/lang/ast/StaticTypeMeta.kt | 17 + .../main/kotlin/org/partiql/lang/ast/meta.kt | 4 +- .../src/main/resources/partiql_ast.ion | 792 +++++++ .../partiql/ast/helpers/ToLegacyAstTest.kt | 721 ++++++ partiql-lang/build.gradle.kts | 2 +- .../lang/prettyprint/QueryPrettyPrinter.kt | 2 +- .../lang/syntax/PartiQLParserBuilder.kt | 16 +- .../lang/syntax/impl/PartiQLPigVisitor.kt | 2 +- .../lang/syntax/impl/PartiQLShimParser.kt | 89 + .../lang/ast/passes/StatementRedactorTest.kt | 19 +- .../errors/WindowRelatedParserErrorsTest.kt | 65 - .../AggregateSupportVisitorTransformTests.kt | 6 +- ...StaticTypeInferenceVisitorTransformTest.kt | 4 +- .../StaticTypeVisitorTransformTests.kt | 4 +- .../eval/visitors/VisitorTransformTestBase.kt | 4 +- .../prettyprint/QueryPrettyPrinterTest.kt | 4 +- .../lang/syntax/PartiQLParserCastTests.kt | 2 + .../PartiQLParserCorrelatedJoinTests.kt | 3 + .../lang/syntax/PartiQLParserDateTimeTests.kt | 16 +- .../lang/syntax/PartiQLParserExplainTest.kt | 6 +- .../lang/syntax/PartiQLParserJoinTest.kt | 3 + .../lang/syntax/PartiQLParserMatchTest.kt | 2 + .../lang/syntax/PartiQLParserMetaTests.kt | 22 +- .../syntax/PartiQLParserPrecedenceTest.kt | 20 +- .../partiql/lang/syntax/PartiQLParserTest.kt | 129 +- .../lang/syntax/PartiQLParserTestBase.kt | 30 +- .../lang/syntax/PartiQLParserWindowTests.kt | 65 + partiql-parser/README.adoc | 66 + partiql-parser/build.gradle.kts | 15 +- .../kotlin/org/partiql/parser/Exceptions.kt | 83 + .../org/partiql/parser/PartiQLParser.kt | 29 + .../partiql/parser/PartiQLParserBuilder.kt | 35 + .../org/partiql/parser/SourceLocation.kt | 44 + .../org/partiql/parser/SourceLocations.kt | 49 + .../parser/impl/PartiQLParserDefault.kt | 2005 +++++++++++++++++ .../partiql/parser/impl/util/DateTimeUtils.kt | 76 + 58 files changed, 6186 insertions(+), 218 deletions(-) create mode 100644 partiql-ast/README.adoc create mode 100644 partiql-ast/src/main/kotlin/org/partiql/ast/Ast.kt create mode 100644 partiql-ast/src/main/kotlin/org/partiql/ast/helpers/ToIon.kt create mode 100644 partiql-ast/src/main/kotlin/org/partiql/ast/helpers/ToLegacyAst.kt create mode 100644 partiql-ast/src/main/kotlin/org/partiql/ast/impl/.gitkeep rename {partiql-lang => partiql-ast}/src/main/kotlin/org/partiql/lang/ast/AggregateCallSiteListMeta.kt (100%) rename {partiql-lang => partiql-ast}/src/main/kotlin/org/partiql/lang/ast/InternalMetas.kt (100%) rename {partiql-lang => partiql-ast}/src/main/kotlin/org/partiql/lang/ast/IsCountStarMeta.kt (100%) rename {partiql-lang => partiql-ast}/src/main/kotlin/org/partiql/lang/ast/IsGroupAttributeReferenceMeta.kt (92%) rename {partiql-lang => partiql-ast}/src/main/kotlin/org/partiql/lang/ast/IsImplictJoinMeta.kt (100%) rename {partiql-lang => partiql-ast}/src/main/kotlin/org/partiql/lang/ast/IsIonLiteralMeta.kt (100%) rename {partiql-lang => partiql-ast}/src/main/kotlin/org/partiql/lang/ast/IsListParenthesizedMeta.kt (100%) rename {partiql-lang => partiql-ast}/src/main/kotlin/org/partiql/lang/ast/IsOrderedMeta.kt (100%) rename {partiql-lang => partiql-ast}/src/main/kotlin/org/partiql/lang/ast/IsPathIndexMeta.kt (100%) rename {partiql-lang => partiql-ast}/src/main/kotlin/org/partiql/lang/ast/IsTransformedOrderByAliasMeta.kt (82%) rename {partiql-lang => partiql-ast}/src/main/kotlin/org/partiql/lang/ast/IsValuesExprMeta.kt (100%) rename {partiql-lang => partiql-ast}/src/main/kotlin/org/partiql/lang/ast/LegacyLogicalNotMeta.kt (53%) rename {partiql-lang => partiql-ast}/src/main/kotlin/org/partiql/lang/ast/SourceLocationMeta.kt (88%) create mode 100644 partiql-ast/src/main/kotlin/org/partiql/lang/ast/StaticTypeMeta.kt rename {partiql-lang => partiql-ast}/src/main/kotlin/org/partiql/lang/ast/meta.kt (96%) create mode 100644 partiql-ast/src/main/resources/partiql_ast.ion create mode 100644 partiql-ast/src/test/kotlin/org/partiql/ast/helpers/ToLegacyAstTest.kt create mode 100644 partiql-lang/src/main/kotlin/org/partiql/lang/syntax/impl/PartiQLShimParser.kt delete mode 100644 partiql-lang/src/test/kotlin/org/partiql/lang/errors/WindowRelatedParserErrorsTest.kt create mode 100644 partiql-parser/README.adoc create mode 100644 partiql-parser/src/main/kotlin/org/partiql/parser/Exceptions.kt create mode 100644 partiql-parser/src/main/kotlin/org/partiql/parser/PartiQLParser.kt create mode 100644 partiql-parser/src/main/kotlin/org/partiql/parser/PartiQLParserBuilder.kt create mode 100644 partiql-parser/src/main/kotlin/org/partiql/parser/SourceLocation.kt create mode 100644 partiql-parser/src/main/kotlin/org/partiql/parser/SourceLocations.kt create mode 100644 partiql-parser/src/main/kotlin/org/partiql/parser/impl/PartiQLParserDefault.kt create mode 100644 partiql-parser/src/main/kotlin/org/partiql/parser/impl/util/DateTimeUtils.kt diff --git a/.github/workflows/build.yml b/.github/workflows/build.yml index 278066a4c7..e69b3331b3 100644 --- a/.github/workflows/build.yml +++ b/.github/workflows/build.yml @@ -6,11 +6,13 @@ on: - '**' - '!docs/**' - '!**/*.md' + - '!**/*.adoc' pull_request: paths: - '**' - '!docs/**' - - '!**.*.md' + - '!**/*.md' + - '!**/*.adoc' jobs: test: diff --git a/CHANGELOG.md b/CHANGELOG.md index bed6f1c485..8e11bbb4ed 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -31,9 +31,12 @@ Thank you to all who have contributed! ### Added - Adds `org.partiql.value` (experimental) package for reading/writing PartiQL values +- Adds `org.partiql.ast` package and usage documentation +- Adds `org.partiql.parser` package and usage documentation - Adds PartiQL's Timestamp Data Model. - Adds support for Timestamp constructor call in Parser. + ### Changed ### Deprecated diff --git a/buildSrc/src/main/kotlin/partiql.versions.kt b/buildSrc/src/main/kotlin/partiql.versions.kt index 19a37c4ece..47f911f54c 100644 --- a/buildSrc/src/main/kotlin/partiql.versions.kt +++ b/buildSrc/src/main/kotlin/partiql.versions.kt @@ -41,7 +41,7 @@ object Versions { const val kotlinxCollections = "0.3.5" const val picoCli = "4.7.0" const val kasechange = "1.3.0" - const val ktlint = "10.2.1" + const val ktlint = "11.5.0" const val pig = "0.6.2" // Testing diff --git a/partiql-ast/README.adoc b/partiql-ast/README.adoc new file mode 100644 index 0000000000..40935b00b3 --- /dev/null +++ b/partiql-ast/README.adoc @@ -0,0 +1,364 @@ += PartiQL AST + +The PartiQL AST package contains interfaces, data classes, and utilities for manipulating a syntax tree. + +NOTE: If you are on an older version of PartiQL, you can convert to the old AST via `.toLegacyAst()` in `org.partiql.ast.helpers`. + +== Interfaces + +The interfaces are generated from `resources/partiql_ast.ion` (details in `lib/sprout/README`) + +=== Node + +[source,kotlin] +---- +public interface AstNode { + + // Every node gets an _id for associating any metadata + public val _id: String + + public val children: List + + public fun accept(visitor: AstVisitor, ctx: C): R +} +---- + +=== Example + +.Example Definition +[source,ion] +---- +expr::[ + // ... + binary::{ + op: [ + PLUS, MINUS, TIMES, DIVIDE, MODULO, CONCAT, + AND, OR, + EQ, NE, GT, GTE, LT, LTE, + ], + lhs: expr, + rhs: expr, + }, + // ... +] +---- + +.Generated Interface +[source,kotlin] +---- +// Note: `Expr:AstNode` is a sealed interface of all expr variants + +public interface Binary : Expr { + public val op: Op + public val lhs: Expr + public val rhs: Expr + + public fun copy( + op: Op = this.op, + lhs: Expr = this.lhs, + rhs: Expr = this.rhs, + ): Binary + + public enum class Op { + PLUS, + MINUS, + TIMES, + DIVIDE, + MODULO, + CONCAT, + AND, + OR, + EQ, + NE, + GT, + GTE, + LT, + LTE, + } +} +---- + +== Factory, DSL, and Builders + +The PartiQL AST library provides several creational patterns in `org.partiql.ast.builder` such as an abstract base factory, Kotlin DSL, and Java fluent-builders. +These patterns enable customers to extend the AST to fit their needs, while providing a base which can be decorated appropriately. + +=== Factory Usage + +The factory is how you instantiate a node. The default factory can be called directly like, + +[source,kotlin] +---- +import org.partiql.ast.Ast + +Ast.exprLit(int32Value(1)) // expr.lit +---- + +==== Custom Nodes + +Additionally, you can extend the abstract base factory and use it in builders as well as the DSL. This gives you full +control over how your nodes are instantiated. If you are ambitious, you can implement your own versions of AST node interfaces and implement a base factory. This +will allow you to create custom behaviors. For example, generated equals functions do not consider semantics. Perhaps +we want to improve how we compare nodes? Here's an example that considers the case-sensitivity of identifiers. + +.Custom Node and Factory Example +[source,kotlin] +---- +public abstract class MyFactory : AstBaseFactory() { + + override fun identifierSymbol(symbol: String, caseSensitivity: Identifier.CaseSensitivity): Identifier.Symbol { + return ComparableIdentifier(_id(), symbol, caseSensitivity) + } +} + +class ComparableIdentifier( + override val _id: String, + override val symbol: String, + override val caseSensitivity: Identifier.CaseSensitivity, +) : Identifier.Symbol { + + // override copy, children + + override fun equals(other: Any?): Boolean { + if (other == null || other !is Identifier.Symbol) return false // different type + if (other === this) return true // same object + return when (caseSensitivity) { + Identifier.CaseSensitivity.SENSITIVE -> this.symbol == other.symbol + Identifier.CaseSensitivity.INSENSITIVE -> this.symbol.lowercase() == other.symbol.lowercase() + } + } +} +---- + +=== DSL Usage + +The DSL is useful from Kotlin and is some syntax sugar over fluent builders. Here is how its used: + +.Default Factory DSL Example +[source,kotlin] +---- +import org.partiql.ast.builder.ast + +// Tree for PartiQL `VALUES (1, 2)` +ast { + exprCollection(Expr.Collection.Type.VALUES) { + values += exprLit(int32Value(1)) + values += exprLit(int32Value(2)) + } +} + +// Tree for `SELECT a FROM T` +ast { + exprSFW { + select = selectProject { + items += selectProjectItemExpression { + expr = exprVar { + identifier = identifierSymbol("a", Identifier.CaseSensitivity.INSENSITIVE) + scope = Expr.Var.Scope.DEFAULT + } + } + } + from = fromValue { + expr = v(symbol) + type = From.Value.Type.SCAN + } + } +} +---- + +.Fancier DSL Usage +[source,kotlin] +---- +import org.partiql.ast.builder.ast +import org.partiql.ast.builder.AstBuilder + +// define some helpers +private fun AstBuilder.select(vararg s: String) = selectProject { + s.forEach { + items += selectProjectItemExpression(v(it)) + } +} + +private fun AstBuilder.table(symbol: String) = fromValue { + expr = v(symbol) + type = From.Value.Type.SCAN +} + +private fun AstBuilder.v(symbol: String) = this.exprVar { + identifier = id(symbol) + scope = Expr.Var.Scope.DEFAULT +} + + +// Tree for `SELECT x, y, z FROM T` + +ast { + exprSFW { + select = select("x", "y", "z") + from = table("T") + } +} +---- + +.Custom Factory DSL Example +[source,kotlin] +---- +import org.partiql.ast.builder.ast + +// This will instantiate your custom `ComparableIdentifier`. Nice! +ast(myFactory) { + exprSFW { + select = select("x", "y", "z") + from = table("T") + } +} +---- + +IMPORTANT: The last examples works because the DSL block closes over the factory with an AstBuilder. This means that +the helper extensions or any DSL usage will use the provided factory! + +=== Builder Usage + +The DSL is not much more than Kotlin syntactic sugar over traditional fluent-builder classes. If you are coming from Java, these will be useful. +Every node defines a static `builder()` function. Keeping with the previous example, let's see how we can inject our custom +factory. + +[source,kotlin] +---- +// instance of default IdentifierSymbolImpl +val a = Identifier.Symbol.builder() + .symbol("HELLO") + .caseSensitivity(Identifier.CaseSensitivity.INSENSITIVE) + .build() // empty, build with default factory + +// instance of ComparableIdentifier +val b = Identifier.Symbol.builder() + .symbol("hello") + .caseSensitivity(Identifier.CaseSensitivity.INSENSITIVE) + .build(myFactory) // nice! + +assert(b == a) // TRUE +assert(a == b) // !! FALSE !! consider always using the same type of factory +---- + +== Visitor and Rewriter + +The PartiQL AST is a set of interfaces, so how might we extend these for our own purposes? We do not have pattern matching in Kotlin/Java, so we use the visitor pattern. + +The visitor pattern is effectively adding methods to each object with some compile safety. You define a behavior and use the node `accept` the behavior. The visitor provides an additional parameter `ctx: C` which is the equivalent of arguments to each method for your behavior. + +[source,kotlin] +---- +public abstract class AstBaseVisitor : AstVisitor { + + public override fun visit(node: AstNode, ctx: C): R = node.accept(this, ctx) + + public open fun defaultVisit(node: AstNode, ctx: C): R { + for (child in node.children) { + child.accept(this, ctx) + } + return defaultReturn(node, ctx) + } + + public abstract fun defaultReturn(node: AstNode, ctx: C): R +} +---- + +For example, let's implement a `ToSimpleNameString(case: Case)` function on some basic nodes. + +[source,kotlin] +---- +// +// Usage: +// node.accept(ToSimpleNameString, Case.UPPER) +// +object ToSimpleNameString : AstBaseVisitor() { + + override fun defaultVisit(node: AstNode, ctx: Case) = defaultReturn(node, ctx) + + override fun defaultReturn(node: AstNode, ctx: Case): String = when (ctx) { + Case.UPPER -> node::class.simpleName.uppercase() + Case.LOWER -> node::class.simpleName.lowercase() + Case.PASCAL -> node::class.simpleName + Case.SNAKE -> snakeCaseHelper(node::class.simpleName) + } + + // Any other overrides you want! +} +---- + + +=== Folding + +Folding is straightforward by using either mutable context or an immutable accumulators. The structure you fold to is +entirely dependent on your use case, but here is a simple example with a mutable list that you can generalize. Often times you may fold to an entirely new domain — or fold to the same domain which we'll cover in the rewriter. + +.Example "ClassName" Collector +[source,kotlin] +---- +// Traverse the tree collecting all node names +object AstClassNameCollector { + + // Public static entry for Java style consumption + @JvmStatic + fun collect(node: AstNode): List { + val acc = mutableListOf() + node.accept(ToSimpleNameString, acc) + return acc + } + + // Private implementation + private object ToSimpleNameString : AstBaseVisitor>() { + + override fun defaultVisit(node: AstNode, ctx: MutableList): String? { + node.children.forEach { child -> child.accept(this, ctx) } // traverse + defaultReturn(this, ctx)?.let { ctx.add(it) } + } + + override fun defaultReturn(node: AstNode, ctx: MutableList) = node::class.simpleName + + // Any other overrides you want! + } +} +---- + +=== Rewriter + +See `org.partiql.ast.util.AstRewriter`. This class facilitates rewriting an AST; you need only override the relevant methods for your rewriter. + +=== Tips + +- Each `visit` is a function call; adding state to a visitor is akin to global variables. _Consider keeping state in the context parameter_. This is beneficial because you state is naturally scoped via the call stack. +- Sometimes state in a visitor makes an implementation much cleaner (go for it!). Just remember that the visitor might not be re-usable or idempotent. +- Consider using singletons/objects for stateless visitors +- Consider making your visitors private with a single public static entry point. +- When you make a private visitor, you can rename the ctx parameter to something relevant. Use the `Suppress("PARAMETER_NAME_CHANGED_ON_OVERRIDE")` to make the linter to relax. +- If writing and using Kotlin, consider adding an extension method to the base class. This _really_ makes it look like you've opened the classes (but really it's just a static method). + +=== Understanding Visitors + +I believe Robert Nystrom captured the misunderstanding of visitors quite well: + +[quote] +____ +The Visitor pattern is the most widely misunderstood pattern in all of Design Patterns, which is really saying something when you look at the software architecture excesses of the past couple of decades. + +The trouble starts with terminology. The pattern isn’t about “visiting”, and the “accept” method in it doesn’t conjure up any helpful imagery either. Many think the pattern has to do with traversing trees, which isn’t the case at all. We are going to use it on a set of classes that are tree-like, but that’s a coincidence. As you’ll see, the pattern works as well on a single object. + +The Visitor pattern is really about approximating the functional style within an OOP language. It lets us add new columns to that table easily. We can define all of the behavior for a new operation on a set of types in one place, without having to touch the types themselves. It does this the same way we solve almost every problem in computer science: by adding a layer of indirection. + +-- Robert Nystrom, Crafting Interpreters +____ + +Additionally, see how the wiki page explicitly mentions pattern matching. Kotlin is interesting because we have something _like_ pattern matching, but the PartiQL AST library is intended for consumption from both Kotlin and Java. + +[quote] +____ +A visitor pattern is a software design pattern and separates the algorithm from the object structure. Because of this separation new operations can be added to existing object structures without modifying the structures. It is one way to follow the open/closed principle in object-oriented programming and software engineering. + +In essence, the visitor allows adding new virtual functions to a family of classes, without modifying the classes. Instead, a visitor class is created that implements all of the appropriate specializations of the virtual function. The visitor takes the instance reference as input, and implements the goal through double dispatch. + +Programming languages with sum types and pattern matching obviate many of the benefits of the visitor pattern, as the visitor class is able to both easily branch on the type of the object and generate a compiler error if a new object type is defined which the visitor does not yet handle. + +https://en.wikipedia.org/wiki/Visitor_pattern +____ diff --git a/partiql-ast/build.gradle.kts b/partiql-ast/build.gradle.kts index 1df7d02100..0814ffaa75 100644 --- a/partiql-ast/build.gradle.kts +++ b/partiql-ast/build.gradle.kts @@ -22,6 +22,13 @@ plugins { dependencies { api(Deps.pigRuntime) api(Deps.ionElement) + api(project(":partiql-types")) +} + +publish { + artifactId = "partiql-ast" + name = "PartiQL AST" + description = "PartiQL's Abstract Syntax Tree" } pig { @@ -45,8 +52,22 @@ kotlin { explicitApi = null } -publish { - artifactId = "partiql-ast" - name = "PartiQL AST" - description = "PartiQL's Abstract Syntax Tree" +val generate = tasks.register("generate") { + dependsOn(":lib:sprout:install") + workingDir(projectDir) + commandLine( + "../lib/sprout/build/install/sprout/bin/sprout", "generate", "kotlin", + "-o", "$buildDir/generated-src", + "-p", "org.partiql.ast", + "-u", "Ast", + "--poems", "visitor", + "--poems", "builder", + "--poems", "util", + "--opt-in", "org.partiql.value.PartiQLValueExperimental", + "./src/main/resources/partiql_ast.ion" + ) +} + +tasks.compileKotlin { + dependsOn(generate) } diff --git a/partiql-ast/src/main/kotlin/org/partiql/ast/Ast.kt b/partiql-ast/src/main/kotlin/org/partiql/ast/Ast.kt new file mode 100644 index 0000000000..380a924651 --- /dev/null +++ b/partiql-ast/src/main/kotlin/org/partiql/ast/Ast.kt @@ -0,0 +1,15 @@ +package org.partiql.ast + +import org.partiql.ast.builder.AstFactoryImpl + +/** + * Singleton instance of the default factory; also accessible via `AstFactory.DEFAULT`. + */ +object Ast : AstBaseFactory() + +/** + * AstBaseFactory can be used to create a factory which extends from the factory provided by AstFactory.DEFAULT. + */ +public abstract class AstBaseFactory : AstFactoryImpl() { + // internal default overrides here +} diff --git a/partiql-ast/src/main/kotlin/org/partiql/ast/helpers/ToIon.kt b/partiql-ast/src/main/kotlin/org/partiql/ast/helpers/ToIon.kt new file mode 100644 index 0000000000..3bed70c0f3 --- /dev/null +++ b/partiql-ast/src/main/kotlin/org/partiql/ast/helpers/ToIon.kt @@ -0,0 +1,128 @@ +package org.partiql.ast.helpers + +import com.amazon.ion.Decimal +import com.amazon.ion.Timestamp +import com.amazon.ionelement.api.IonElement +import com.amazon.ionelement.api.field +import com.amazon.ionelement.api.ionBlob +import com.amazon.ionelement.api.ionBool +import com.amazon.ionelement.api.ionClob +import com.amazon.ionelement.api.ionDecimal +import com.amazon.ionelement.api.ionFloat +import com.amazon.ionelement.api.ionInt +import com.amazon.ionelement.api.ionListOf +import com.amazon.ionelement.api.ionNull +import com.amazon.ionelement.api.ionSexpOf +import com.amazon.ionelement.api.ionString +import com.amazon.ionelement.api.ionStructOf +import com.amazon.ionelement.api.ionSymbol +import com.amazon.ionelement.api.ionTimestamp +import org.partiql.value.BlobValue +import org.partiql.value.BoolValue +import org.partiql.value.ClobValue +import org.partiql.value.CollectionValue +import org.partiql.value.DecimalValue +import org.partiql.value.Float32Value +import org.partiql.value.Float64Value +import org.partiql.value.Int16Value +import org.partiql.value.Int32Value +import org.partiql.value.Int64Value +import org.partiql.value.Int8Value +import org.partiql.value.IntValue +import org.partiql.value.ListValue +import org.partiql.value.NullValue +import org.partiql.value.PartiQLValue +import org.partiql.value.PartiQLValueExperimental +import org.partiql.value.ScalarValue +import org.partiql.value.SexpValue +import org.partiql.value.StringValue +import org.partiql.value.StructValue +import org.partiql.value.SymbolValue +import org.partiql.value.TimestampValue +import org.partiql.value.datetime.TimeZone +import org.partiql.value.util.PartiQLValueBaseVisitor + +/** + * PartiQL Value .toIon helper; to be replaced by https://github.com/partiql/partiql-lang-kotlin/pull/1131/files + * + * TODO add `lower` mode, this just errors + */ +@OptIn(PartiQLValueExperimental::class) +internal object ToIon : PartiQLValueBaseVisitor() { + + private inline fun ScalarValue.toIon(block: ScalarValue.() -> IonElement): IonElement { + val e = this.block() + return e.withAnnotations(this.annotations) + } + + private inline fun CollectionValue<*>.toIon(block: CollectionValue<*>.(elements: List) -> IonElement): IonElement { + val elements = this.elements.map { it.accept(ToIon, Unit) } + val e = this.block(elements) + return e.withAnnotations(this.annotations) + } + + override fun defaultVisit(v: PartiQLValue, ctx: Unit) = defaultReturn(v, ctx) + + override fun defaultReturn(v: PartiQLValue, ctx: Unit) = + throw IllegalArgumentException("Cannot represent $v as Ion in strict mode") + + override fun visitNull(v: NullValue, ctx: Unit) = ionNull().withAnnotations(v.annotations) + + override fun visitBool(v: BoolValue, ctx: Unit) = v.toIon { ionBool(value) } + + override fun visitInt8(v: Int8Value, ctx: Unit) = v.toIon { ionInt(value.toLong()) } + + override fun visitInt16(v: Int16Value, ctx: Unit) = v.toIon { ionInt(value.toLong()) } + + override fun visitInt32(v: Int32Value, ctx: Unit) = v.toIon { ionInt(value.toLong()) } + + override fun visitInt64(v: Int64Value, ctx: Unit) = v.toIon { ionInt(value) } + + // Call .toLong() because IonElement .equals() is failing with BigInteger (it's comparing by reference). + override fun visitInt(v: IntValue, ctx: Unit) = v.toIon { ionInt(value.toLong()) } + + override fun visitDecimal(v: DecimalValue, ctx: Unit) = v.toIon { ionDecimal(Decimal.valueOf(value)) } + + override fun visitFloat32(v: Float32Value, ctx: Unit) = v.toIon { ionFloat(value.toString().toDouble()) } + + override fun visitFloat64(v: Float64Value, ctx: Unit) = v.toIon { ionFloat(value) } + + override fun visitString(v: StringValue, ctx: Unit) = v.toIon { ionString(value) } + + override fun visitSymbol(v: SymbolValue, ctx: Unit) = v.toIon { ionSymbol(value) } + + override fun visitClob(v: ClobValue, ctx: Unit) = v.toIon { ionClob(value) } + + override fun visitBlob(v: BlobValue, ctx: Unit) = v.toIon { ionBlob(value) } + + override fun visitTimestamp(v: TimestampValue, ctx: Unit) = v.toIon { + val offset = when (val z = v.value.timeZone) { + TimeZone.UnknownTimeZone -> null + is TimeZone.UtcOffset -> z.totalOffsetMinutes + null -> 0 + } + val timestamp = Timestamp.forSecond( + v.value.year, + v.value.month, + v.value.day, + v.value.hour, + v.value.minute, + v.value.decimalSecond, + offset, + ) + ionTimestamp(timestamp) + } + + override fun visitList(v: ListValue<*>, ctx: Unit) = v.toIon { elements -> ionListOf(elements) } + + override fun visitSexp(v: SexpValue<*>, ctx: Unit) = v.toIon { elements -> ionSexpOf(elements) } + + override fun visitStruct(v: StructValue<*>, ctx: Unit): IonElement { + val fields = v.fields.map { + val key = it.first + val value = it.second.accept(this, ctx) + field(key, value) + } + return ionStructOf(fields, v.annotations) + } +} diff --git a/partiql-ast/src/main/kotlin/org/partiql/ast/helpers/ToLegacyAst.kt b/partiql-ast/src/main/kotlin/org/partiql/ast/helpers/ToLegacyAst.kt new file mode 100644 index 0000000000..24f7d497b3 --- /dev/null +++ b/partiql-ast/src/main/kotlin/org/partiql/ast/helpers/ToLegacyAst.kt @@ -0,0 +1,1366 @@ +@file:OptIn(PartiQLValueExperimental::class) + +package org.partiql.ast.helpers + +import com.amazon.ion.Decimal +import com.amazon.ionelement.api.DecimalElement +import com.amazon.ionelement.api.FloatElement +import com.amazon.ionelement.api.IntElement +import com.amazon.ionelement.api.IntElementSize +import com.amazon.ionelement.api.MetaContainer +import com.amazon.ionelement.api.emptyMetaContainer +import com.amazon.ionelement.api.ionDecimal +import com.amazon.ionelement.api.ionFloat +import com.amazon.ionelement.api.ionInt +import com.amazon.ionelement.api.ionString +import com.amazon.ionelement.api.ionSymbol +import com.amazon.ionelement.api.metaContainerOf +import org.partiql.ast.AstNode +import org.partiql.ast.DatetimeField +import org.partiql.ast.Expr +import org.partiql.ast.From +import org.partiql.ast.GraphMatch +import org.partiql.ast.GroupBy +import org.partiql.ast.Identifier +import org.partiql.ast.Let +import org.partiql.ast.OnConflict +import org.partiql.ast.OrderBy +import org.partiql.ast.Path +import org.partiql.ast.Returning +import org.partiql.ast.Select +import org.partiql.ast.SetOp +import org.partiql.ast.SetQuantifier +import org.partiql.ast.Sort +import org.partiql.ast.Statement +import org.partiql.ast.TableDefinition +import org.partiql.ast.Type +import org.partiql.ast.visitor.AstBaseVisitor +import org.partiql.lang.ast.IsListParenthesizedMeta +import org.partiql.lang.ast.IsValuesExprMeta +import org.partiql.lang.ast.Meta +import org.partiql.lang.domains.PartiqlAst +import org.partiql.value.DateValue +import org.partiql.value.MissingValue +import org.partiql.value.PartiQLValueExperimental +import org.partiql.value.TimeValue +import org.partiql.value.TimestampValue +import org.partiql.value.datetime.TimeZone +import java.math.BigDecimal +import java.math.BigInteger + +/** + * Translates an [AstNode] tree to the legacy PIG AST. + * + * Optionally, you can provide a Map of MetaContainers to attach to the legacy AST nodes. + */ +public fun AstNode.toLegacyAst(metas: Map = emptyMap()): PartiqlAst.PartiqlAstNode { + val translator = AstTranslator(metas) + return accept(translator, Ctx()) +} + +/** + * Empty visitor method arguments + */ +private class Ctx + +/** + * Traverses an [AstNode] tree, folding to a [PartiqlAst.PartiqlAstNode] tree. + */ +@OptIn(PartiQLValueExperimental::class) +private class AstTranslator(val metas: Map) : AstBaseVisitor() { + + private val pig = PartiqlAst.BUILDER() + + override fun defaultReturn(node: AstNode, ctx: Ctx): Nothing { + val fromClass = node::class.qualifiedName + val toClass = PartiqlAst.PartiqlAstNode::class.qualifiedName + throw IllegalArgumentException("$fromClass cannot be translated to $toClass") + } + + override fun defaultVisit(node: AstNode, ctx: Ctx) = defaultReturn(node, ctx) + + /** + * Attach Metas if-any + */ + private inline fun translate( + node: AstNode, + block: PartiqlAst.Builder.(metas: MetaContainer) -> T, + ): T { + val metas = metas[node._id] ?: emptyMetaContainer() + return pig.block(metas) + } + + override fun visitStatement(node: Statement, ctx: Ctx) = super.visitStatement(node, ctx) as PartiqlAst.Statement + + override fun visitStatementQuery(node: Statement.Query, ctx: Ctx) = translate(node) { metas -> + val expr = visitExpr(node.expr, ctx) + query(expr, metas) + } + + override fun visitStatementExec(node: Statement.Exec, ctx: Ctx) = translate(node) { metas -> + val procedureName = node.procedure + val args = node.args.translate(ctx) + exec(procedureName, args, metas) + } + + override fun visitStatementExplain(node: Statement.Explain, ctx: Ctx) = translate(node) { metas -> + val target = visitStatementExplainTarget(node.target, ctx) + explain(target, metas) + } + + override fun visitStatementExplainTarget(node: Statement.Explain.Target, ctx: Ctx) = + super.visitStatementExplainTarget(node, ctx) as PartiqlAst.ExplainTarget + + override fun visitStatementExplainTargetDomain(node: Statement.Explain.Target.Domain, ctx: Ctx) = + translate(node) { metas -> + val statement = visitStatement(node.statement, ctx) + val type = node.type + val format = node.format + domain(statement, type, format, metas) + } + + override fun visitStatementDDL(node: Statement.DDL, ctx: Ctx) = super.visit(node, ctx) as PartiqlAst.Statement.Ddl + + override fun visitStatementDDLCreateTable( + node: Statement.DDL.CreateTable, + ctx: Ctx, + ) = translate(node) { metas -> + if (node.name !is Identifier.Symbol) { + error("The legacy AST does not support qualified identifiers as table names") + } + val tableName = (node.name as Identifier.Symbol).symbol + val def = node.definition?.let { visitTableDefinition(it, ctx) } + ddl(createTable(tableName, def), metas) + } + + override fun visitStatementDDLCreateIndex( + node: Statement.DDL.CreateIndex, + ctx: Ctx, + ) = translate(node) { metas -> + if (node.index != null) { + error("The legacy AST does not support index names") + } + if (node.table !is Identifier.Symbol) { + error("The legacy AST does not support qualified identifiers as table names") + } + val tableName = visitIdentifierSymbol((node.table as Identifier.Symbol), ctx) + val fields = node.fields.map { visitPathUnpack(it, ctx) } + ddl(createIndex(tableName, fields), metas) + } + + override fun visitStatementDDLDropTable(node: Statement.DDL.DropTable, ctx: Ctx) = translate(node) { metas -> + if (node.table !is Identifier.Symbol) { + error("The legacy AST does not support qualified identifiers as table names") + } + // !! Legacy AST "index_name" mix up !! + val tableName = visitIdentifierSymbol((node.table as Identifier.Symbol), ctx) + ddl(dropTable(tableName), metas) + } + + override fun visitStatementDDLDropIndex(node: Statement.DDL.DropIndex, ctx: Ctx) = translate(node) { metas -> + if (node.index !is Identifier.Symbol) { + error("The legacy AST does not support qualified identifiers as index names") + } + if (node.table !is Identifier.Symbol) { + error("The legacy AST does not support qualified identifiers as table names") + } + // !! Legacy AST "table" mix up !! + val index = visitIdentifierSymbol(node.index as Identifier.Symbol, ctx) + // !! Legacy AST "keys" mix up !! + val table = visitIdentifierSymbol(node.table as Identifier.Symbol, ctx) + ddl(dropIndex(table, index), metas) + } + + override fun visitTableDefinition(node: TableDefinition, ctx: Ctx) = translate(node) { metas -> + val parts = node.columns.translate(ctx) + tableDef(parts, metas) + } + + override fun visitTableDefinitionColumn(node: TableDefinition.Column, ctx: Ctx) = translate(node) { metas -> + val name = node.name + val type = visitType(node.type, ctx) + val constraints = node.constraints.translate(ctx) + columnDeclaration(name, type, constraints, metas) + } + + override fun visitTableDefinitionColumnConstraint( + node: TableDefinition.Column.Constraint, + ctx: Ctx, + ) = translate(node) { metas -> + val name = node.name + val def = when (node.body) { + is TableDefinition.Column.Constraint.Body.Check -> { + throw IllegalArgumentException("PIG AST does not support CHECK () constraint") + } + is TableDefinition.Column.Constraint.Body.NotNull -> columnNotnull() + is TableDefinition.Column.Constraint.Body.Nullable -> columnNull() + } + columnConstraint(name, def, metas) + } + + /** + * IDENTIFIERS / PATHS - Always expressions in legacy AST + */ + + override fun visitIdentifier(node: Identifier, ctx: Ctx) = when (node) { + is Identifier.Qualified -> visitIdentifierQualified(node, ctx) + is Identifier.Symbol -> visitIdentifierSymbolAsExpr(node) + } + + override fun visitIdentifierSymbol(node: Identifier.Symbol, ctx: Ctx) = translate(node) { metas -> + val name = node.symbol + val case = node.caseSensitivity.toLegacyCaseSensitivity() + // !! NOT AN EXPRESSION!! + identifier(name, case, metas) + } + + fun visitIdentifierSymbolAsExpr(node: Identifier.Symbol) = translate(node) { metas -> + val name = node.symbol + val case = node.caseSensitivity.toLegacyCaseSensitivity() + val scope = unqualified() + // !! ID EXPRESSION!! + id(name, case, scope, metas) + } + + override fun visitIdentifierQualified(node: Identifier.Qualified, ctx: Ctx) = translate(node) { metas -> + // !! Legacy AST represents qualified identifiers as Expr.Path !! + val root = visitIdentifierSymbolAsExpr(node.root) + val steps = node.steps.map { + // Legacy AST wraps id twice and always uses CaseSensitive + val expr = visitIdentifierSymbolAsExpr(it) + pathExpr(expr, caseSensitive()) + } + path(root, steps, metas) + } + + override fun visitPath(node: Path, ctx: Ctx) = translate(node) { metas -> + val root = visitIdentifierSymbolAsExpr(node.root) + val steps = node.steps.translate(ctx) + path(root, steps, metas) + } + + override fun visitPathStep(node: Path.Step, ctx: Ctx) = super.visitPathStep(node, ctx) as PartiqlAst.PathStep + + override fun visitPathStepSymbol(node: Path.Step.Symbol, ctx: Ctx) = translate(node) { metas -> + // val index = visitIdentifierSymbolAsExpr(node.symbol, ctx) + val index = lit(ionString(node.symbol.symbol), metas) + val case = node.symbol.caseSensitivity.toLegacyCaseSensitivity() + pathExpr(index, case, metas) + } + + override fun visitPathStepIndex(node: Path.Step.Index, ctx: Ctx) = translate(node) { metas -> + val index = lit(ionInt(node.index.toLong())) + val case = caseSensitive() // ??? + pathExpr(index, case, metas) + } + + /** + * EXPRESSIONS + */ + + override fun visitExpr(node: Expr, ctx: Ctx): PartiqlAst.Expr = super.visitExpr(node, ctx) as PartiqlAst.Expr + + override fun visitExprLit(node: Expr.Lit, ctx: Ctx) = translate(node) { metas -> + when (val v = node.value) { + is MissingValue -> missing(metas) + is DateValue -> v.toLegacyAst(metas) + is TimeValue -> v.toLegacyAst(metas) + is TimestampValue -> v.toLegacyAst(metas) + else -> { + val ion = v.accept(ToIon, Unit) // v.toIon() + lit(ion, metas) + } + } + } + + override fun visitExprIon(node: Expr.Ion, ctx: Ctx) = translate(node) { metas -> + lit(node.value, metas) + } + + override fun visitExprVar(node: Expr.Var, ctx: Ctx) = translate(node) { metas -> + if (node.identifier is Identifier.Qualified) { + error("Qualified identifiers not allowed in legacy AST `id` variable references") + } + val v = node.identifier as Identifier.Symbol + val name = v.symbol + val case = v.caseSensitivity.toLegacyCaseSensitivity() + val qualifier = node.scope.toLegacyScope() + id(name, case, qualifier, metas) + } + + override fun visitExprCall(node: Expr.Call, ctx: Ctx) = translate(node) { metas -> + if (node.function is Identifier.Qualified) { + error("Qualified identifiers are not allowed in legacy AST `call` function identifiers") + } + val funcName = (node.function as Identifier.Symbol).symbol.lowercase() + val args = node.args.translate(ctx) + call(funcName, args, metas) + } + + override fun visitExprAgg(node: Expr.Agg, ctx: Ctx) = translate(node) { metas -> + val setq = node.setq?.toLegacySetQuantifier() ?: all() + // Legacy AST translates COUNT(*) to COUNT(1) + if (node.function is Identifier.Symbol && (node.function as Identifier.Symbol).symbol == "COUNT_STAR") { + return callAgg(setq, "count", lit(ionInt(1)), metas) + } + // Default Case + if (node.args.size != 1) { + error("Legacy `call_agg` must have exactly one argument") + } + if (node.function is Identifier.Qualified) { + error("Qualified identifiers are not allowed in legacy AST `call_agg` function identifiers") + } + // Legacy parser/ast always inserts ALL quantifier + val funcName = (node.function as Identifier.Symbol).symbol.lowercase() + val arg = visitExpr(node.args[0], ctx) + callAgg(setq, funcName, arg, metas) + } + + override fun visitExprUnary(node: Expr.Unary, ctx: Ctx) = translate(node) { metas -> + val arg = visitExpr(node.expr, ctx) + when (node.op) { + Expr.Unary.Op.NOT -> not(arg, metas) + Expr.Unary.Op.POS -> { + when { + arg !is PartiqlAst.Expr.Lit -> pos(arg) + arg.value is IntElement -> arg + arg.value is FloatElement -> arg + arg.value is DecimalElement -> arg + else -> pos(arg) + } + } + Expr.Unary.Op.NEG -> { + when { + arg !is PartiqlAst.Expr.Lit -> neg(arg, metas) + arg.value is IntElement -> { + val intValue = when (arg.value.integerSize) { + IntElementSize.LONG -> ionInt(-arg.value.longValue) + IntElementSize.BIG_INTEGER -> when (arg.value.bigIntegerValue) { + Long.MAX_VALUE.toBigInteger() + (1L).toBigInteger() -> ionInt(Long.MIN_VALUE) + else -> ionInt(arg.value.bigIntegerValue * BigInteger.valueOf(-1L)) + } + } + arg.copy( + value = intValue.asAnyElement(), + metas = metas, + ) + } + arg.value is FloatElement -> arg.copy( + value = ionFloat(-(arg.value.doubleValue)).asAnyElement(), + metas = metas, + ) + arg.value is DecimalElement -> arg.copy( + value = ionDecimal(Decimal.valueOf(-(arg.value.decimalValue))).asAnyElement(), + metas = metas, + ) + else -> neg(arg, metas) + } + } + } + } + + override fun visitExprBinary(node: Expr.Binary, ctx: Ctx) = translate(node) { metas -> + val lhs = visitExpr(node.lhs, ctx) + val rhs = visitExpr(node.rhs, ctx) + val operands = listOf(lhs, rhs) + when (node.op) { + Expr.Binary.Op.PLUS -> plus(operands, metas) + Expr.Binary.Op.MINUS -> minus(operands, metas) + Expr.Binary.Op.TIMES -> times(operands, metas) + Expr.Binary.Op.DIVIDE -> divide(operands, metas) + Expr.Binary.Op.MODULO -> modulo(operands, metas) + Expr.Binary.Op.CONCAT -> concat(operands, metas) + Expr.Binary.Op.AND -> and(operands, metas) + Expr.Binary.Op.OR -> or(operands, metas) + Expr.Binary.Op.EQ -> eq(operands, metas) + Expr.Binary.Op.NE -> ne(operands, metas) + Expr.Binary.Op.GT -> gt(operands, metas) + Expr.Binary.Op.GTE -> gte(operands, metas) + Expr.Binary.Op.LT -> lt(operands, metas) + Expr.Binary.Op.LTE -> lte(operands, metas) + } + } + + override fun visitExprPath(node: Expr.Path, ctx: Ctx) = translate(node) { metas -> + val root = visitExpr(node.root, ctx) + val steps = node.steps.map { visitExprPathStep(it, ctx) } + path(root, steps, metas) + } + + override fun visitExprPathStep(node: Expr.Path.Step, ctx: Ctx) = + super.visitExprPathStep(node, ctx) as PartiqlAst.PathStep + + override fun visitExprPathStepSymbol(node: Expr.Path.Step.Symbol, ctx: Ctx) = translate(node) { metas -> + val index = lit(ionString(node.symbol.symbol)) + val case = node.symbol.caseSensitivity.toLegacyCaseSensitivity() + pathExpr(index, case, metas) + } + + override fun visitExprPathStepIndex(node: Expr.Path.Step.Index, ctx: Ctx) = translate(node) { metas -> + val index = visitExpr(node.key, ctx) + // Legacy AST marks every index step as CaseSensitive + val case = when (index) { + is PartiqlAst.Expr.Id -> index.case + else -> PartiqlAst.CaseSensitivity.CaseSensitive() + } + pathExpr(index, case, metas) + } + + override fun visitExprPathStepWildcard(node: Expr.Path.Step.Wildcard, ctx: Ctx) = translate(node) { metas -> + pathWildcard(metas) + } + + override fun visitExprPathStepUnpivot(node: Expr.Path.Step.Unpivot, ctx: Ctx) = translate(node) { metas -> + pathUnpivot(metas) + } + + override fun visitExprParameter(node: Expr.Parameter, ctx: Ctx) = translate(node) { metas -> + parameter(node.index.toLong(), metas) + } + + override fun visitExprValues(node: Expr.Values, ctx: Ctx) = translate(node) { metas -> + val rows = node.rows.map { visitExprValuesRow(it, ctx) } + bag(rows, metas + metaContainerOf(IsValuesExprMeta.instance)) + } + + override fun visitExprValuesRow(node: Expr.Values.Row, ctx: Ctx) = translate(node) { metas -> + val exprs = node.items.translate(ctx) + list(exprs, metas + metaContainerOf(IsListParenthesizedMeta)) + } + + override fun visitExprCollection(node: Expr.Collection, ctx: Ctx) = translate(node) { metas -> + val values = node.values.translate(ctx) + when (node.type) { + Expr.Collection.Type.BAG -> bag(values, metas) + Expr.Collection.Type.ARRAY -> list(values, metas) + Expr.Collection.Type.VALUES -> list(values, metas + metaContainerOf(IsValuesExprMeta.instance)) + Expr.Collection.Type.LIST -> list(values, metas + metaContainerOf(IsListParenthesizedMeta)) + Expr.Collection.Type.SEXP -> sexp(values, metas) + } + } + + override fun visitExprStruct(node: Expr.Struct, ctx: Ctx) = translate(node) { metas -> + val fields = node.fields.translate(ctx) + struct(fields, metas) + } + + override fun visitExprStructField(node: Expr.Struct.Field, ctx: Ctx) = translate(node) { metas -> + val first = visitExpr(node.name, ctx) + val second = visitExpr(node.value, ctx) + exprPair(first, second, metas) + } + + override fun visitExprLike(node: Expr.Like, ctx: Ctx) = translate(node) { metas -> + val value = visitExpr(node.value, ctx) + val pattern = visitExpr(node.pattern, ctx) + val escape = visitOrNull(node.escape, ctx) + if (node.not != null && node.not!!) { + not(like(value, pattern, escape), metas) + } else { + like(value, pattern, escape, metas) + } + } + + override fun visitExprBetween(node: Expr.Between, ctx: Ctx) = translate(node) { metas -> + val value = visitExpr(node.value, ctx) + val from = visitExpr(node.from, ctx) + val to = visitExpr(node.to, ctx) + if (node.not != null && node.not!!) { + not(between(value, from, to), metas) + } else { + between(value, from, to, metas) + } + } + + override fun visitExprInCollection(node: Expr.InCollection, ctx: Ctx) = translate(node) { metas -> + val lhs = visitExpr(node.lhs, ctx) + val rhs = visitExpr(node.rhs, ctx) + val operands = listOf(lhs, rhs) + if (node.not != null && node.not!!) { + not(inCollection(operands), metas) + } else { + inCollection(operands, metas) + } + } + + override fun visitExprIsType(node: Expr.IsType, ctx: Ctx) = translate(node) { metas -> + val value = visitExpr(node.value, ctx) + val type = visitType(node.type, ctx) + if (node.not != null && node.not!!) { + not(isType(value, type), metas) + } else { + isType(value, type, metas) + } + } + + override fun visitExprCase(node: Expr.Case, ctx: Ctx) = translate(node) { metas -> + val cases = exprPairList(node.branches.translate(ctx)) + val condition = visitOrNull(node.expr, ctx) + val default = visitOrNull(node.default, ctx) + when (condition) { + null -> searchedCase(cases, default, metas) + else -> simpleCase(condition, cases, default, metas) + } + } + + override fun visitExprCaseBranch(node: Expr.Case.Branch, ctx: Ctx) = translate(node) { metas -> + val first = visitExpr(node.condition, ctx) + val second = visitExpr(node.expr, ctx) + exprPair(first, second, metas) + } + + override fun visitExprCoalesce(node: Expr.Coalesce, ctx: Ctx) = translate(node) { metas -> + val args = node.args.translate(ctx) + coalesce(args, metas) + } + + override fun visitExprNullIf(node: Expr.NullIf, ctx: Ctx) = translate(node) { metas -> + val expr1 = visitExpr(node.value, ctx) + val expr2 = visitExpr(node.nullifier, ctx) + nullIf(expr1, expr2, metas) + } + + override fun visitExprSubstring(node: Expr.Substring, ctx: Ctx) = translate(node) { metas -> + val value = visitExpr(node.value, ctx) + val start = visitOrNull(node.start, ctx) + val length = visitOrNull(node.length, ctx) + val operands = listOfNotNull(value, start, length) + call("substring", operands, metas) + } + + override fun visitExprPosition(node: Expr.Position, ctx: Ctx) = translate(node) { metas -> + val lhs = visitExpr(node.lhs, ctx) + val rhs = visitExpr(node.rhs, ctx) + val operands = listOf(lhs, rhs) + call("position", operands, metas) + } + + override fun visitExprTrim(node: Expr.Trim, ctx: Ctx) = translate(node) { metas -> + val operands = mutableListOf() + // Legacy AST requires adding the spec as an argument + val spec = node.spec?.toString()?.lowercase() + val chars = node.chars?.let { visitExpr(it, ctx) } + val value = visitExpr(node.value, ctx) + if (spec != null) operands.add(lit(ionSymbol(spec))) + if (chars != null) operands.add(chars) + operands.add(value) + call("trim", operands, metas) + } + + override fun visitExprOverlay(node: Expr.Overlay, ctx: Ctx) = translate(node) { metas -> + val value = visitExpr(node.value, ctx) + val overlay = visitExpr(node.overlay, ctx) + val start = visitExpr(node.start, ctx) + val length = visitOrNull(node.length, ctx) + val operands = listOfNotNull(value, overlay, start, length) + call("overlay", operands, metas) + } + + override fun visitExprExtract(node: Expr.Extract, ctx: Ctx) = translate(node) { metas -> + val field = node.field.toLegacyDatetimePart() + val source = visitExpr(node.source, ctx) + val operands = listOf(field, source) + call("extract", operands, metas) + } + + override fun visitExprCast(node: Expr.Cast, ctx: Ctx) = translate(node) { metas -> + val value = visitExpr(node.value, ctx) + val asType = visitType(node.asType, ctx) + cast(value, asType, metas) + } + + override fun visitExprCanCast(node: Expr.CanCast, ctx: Ctx) = translate(node) { metas -> + val value = visitExpr(node.value, ctx) + val asType = visitType(node.asType, ctx) + canCast(value, asType, metas) + } + + override fun visitExprCanLosslessCast(node: Expr.CanLosslessCast, ctx: Ctx) = translate(node) { metas -> + val value = visitExpr(node.value, ctx) + val asType = visitType(node.asType, ctx) + canLosslessCast(value, asType, metas) + } + + override fun visitExprDateAdd(node: Expr.DateAdd, ctx: Ctx) = translate(node) { metas -> + val field = node.field.toLegacyDatetimePart() + val lhs = visitExpr(node.lhs, ctx) + val rhs = visitExpr(node.rhs, ctx) + val operands = listOf(field, lhs, rhs) + call("date_add", operands, metas) + } + + override fun visitExprDateDiff(node: Expr.DateDiff, ctx: Ctx) = translate(node) { metas -> + val field = node.field.toLegacyDatetimePart() + val lhs = visitExpr(node.lhs, ctx) + val rhs = visitExpr(node.rhs, ctx) + val operands = listOf(field, lhs, rhs) + call("date_diff", operands, metas) + } + + override fun visitExprBagOp(node: Expr.BagOp, ctx: Ctx) = translate(node) { metas -> + val lhs = visitExpr(node.lhs, ctx) + val rhs = visitExpr(node.rhs, ctx) + val op = when (node.outer) { + true -> when (node.type.type) { + SetOp.Type.UNION -> outerUnion() + SetOp.Type.INTERSECT -> outerIntersect() + SetOp.Type.EXCEPT -> outerExcept() + } + else -> when (node.type.type) { + SetOp.Type.UNION -> union() + SetOp.Type.INTERSECT -> intersect() + SetOp.Type.EXCEPT -> except() + } + } + val setq = node.type.setq?.toLegacySetQuantifier() ?: distinct() + val operands = listOf(lhs, rhs) + bagOp(op, setq, operands, metas) + } + + override fun visitExprMatch(node: Expr.Match, ctx: Ctx) = translate(node) { metas -> + val expr = visitExpr(node.expr, ctx) + val match = visitGraphMatch(node.pattern, ctx) + graphMatch(expr, match, metas) + } + + override fun visitExprWindow(node: Expr.Window, ctx: Ctx) = translate(node) { metas -> + val funcName = node.function.name.lowercase() + val over = visitExprWindowOver(node.over, ctx) + val args = listOfNotNull(node.expression, node.offset, node.default).translate(ctx) + callWindow(funcName, over, args, metas) + } + + override fun visitExprWindowOver(node: Expr.Window.Over, ctx: Ctx) = translate(node) { metas -> + val partitionBy = node.partitions?.let { + val partitions = it.translate(ctx) + windowPartitionList(partitions) + } + val orderBy = node.sorts?.let { + val sorts = it.translate(ctx) + windowSortSpecList(sorts) + } + over(partitionBy, orderBy, metas) + } + + override fun visitExprSessionAttribute(node: Expr.SessionAttribute, ctx: Ctx) = translate(node) { metas -> + sessionAttribute(node.attribute.name.lowercase(), metas) + } + + /** + * SELECT-FROM-WHERE + */ + + override fun visitExprSFW(node: Expr.SFW, ctx: Ctx) = translate(node) { metas -> + var setq = when (val s = node.select) { + is Select.Pivot -> null + is Select.Project -> s.setq?.toLegacySetQuantifier() + is Select.Star -> s.setq?.toLegacySetQuantifier() + is Select.Value -> s.setq?.toLegacySetQuantifier() + } + // Legacy AST removes (setq (all)) + if (setq != null && setq is PartiqlAst.SetQuantifier.All) { + setq = null + } + val project = visitSelect(node.select, ctx) + val from = visitFrom(node.from, ctx) + val fromLet = node.let?.let { visitLet(it, ctx) } + val where = node.where?.let { visitExpr(it, ctx) } + val groupBy = node.groupBy?.let { visitGroupBy(it, ctx) } + val having = node.having?.let { visitExpr(it, ctx) } + val orderBy = node.orderBy?.let { visitOrderBy(it, ctx) } + val limit = node.limit?.let { visitExpr(it, ctx) } + val offset = node.offset?.let { visitExpr(it, ctx) } + select(setq, project, from, fromLet, where, groupBy, having, orderBy, limit, offset, metas) + } + + /** + * UNSUPPORTED in legacy AST + */ + override fun visitExprSFWSetOp(node: Expr.SFW.SetOp, ctx: Ctx) = defaultVisit(node, ctx) + + override fun visitSelect(node: Select, ctx: Ctx) = super.visitSelect(node, ctx) as PartiqlAst.Projection + + override fun visitSelectStar(node: Select.Star, ctx: Ctx) = translate(node) { metas -> + projectStar(metas) + } + + override fun visitSelectProject(node: Select.Project, ctx: Ctx) = translate(node) { metas -> + val items = node.items.translate(ctx) + projectList(items, metas) + } + + override fun visitSelectProjectItem(node: Select.Project.Item, ctx: Ctx) = + super.visitSelectProjectItem(node, ctx) as PartiqlAst.ProjectItem + + override fun visitSelectProjectItemAll(node: Select.Project.Item.All, ctx: Ctx) = translate(node) { metas -> + val expr = visitExpr(node.expr, ctx) + projectAll(expr, metas) + } + + override fun visitSelectProjectItemExpression(node: Select.Project.Item.Expression, ctx: Ctx) = + translate(node) { metas -> + val expr = visitExpr(node.expr, ctx) + val alias = node.asAlias + projectExpr(expr, alias, metas) + } + + override fun visitSelectPivot(node: Select.Pivot, ctx: Ctx) = translate(node) { metas -> + val value = visitExpr(node.value, ctx) + val key = visitExpr(node.key, ctx) + projectPivot(value, key, metas) + } + + override fun visitSelectValue(node: Select.Value, ctx: Ctx) = translate(node) { metas -> + val value = visitExpr(node.constructor, ctx) + projectValue(value, metas) + } + + override fun visitFrom(node: From, ctx: Ctx) = super.visitFrom(node, ctx) as PartiqlAst.FromSource + override fun visitFromValue(node: From.Value, ctx: Ctx) = translate(node) { metas -> + val expr = visitExpr(node.expr, ctx) + val asAlias = node.asAlias + val atAlias = node.atAlias + val byAlias = node.byAlias + when (node.type) { + From.Value.Type.SCAN -> scan(expr, asAlias, atAlias, byAlias, metas) + From.Value.Type.UNPIVOT -> unpivot(expr, asAlias, atAlias, byAlias, metas) + } + } + + // Legacy AST models CROSS JOIN and COMMA-syntax CROSS JOIN as FULL JOIN + // Legacy AST does not have OUTER variants + override fun visitFromJoin(node: From.Join, ctx: Ctx) = translate(node) { metas -> + val type = when (node.type) { + From.Join.Type.INNER -> inner() + From.Join.Type.LEFT -> left() + From.Join.Type.LEFT_OUTER -> left() + From.Join.Type.RIGHT -> right() + From.Join.Type.RIGHT_OUTER -> right() + From.Join.Type.FULL -> full() + From.Join.Type.FULL_OUTER -> full() + From.Join.Type.CROSS -> full() + From.Join.Type.COMMA -> full() + null -> inner() + } + val lhs = visitFrom(node.lhs, ctx) + val rhs = visitFrom(node.rhs, ctx) + val condition = visitOrNull(node.condition, ctx) + join(type, lhs, rhs, condition, metas) + } + + override fun visitLet(node: Let, ctx: Ctx) = translate(node) { metas -> + val bindings = node.bindings.translate(ctx) + let(bindings, metas) + } + + override fun visitLetBinding(node: Let.Binding, ctx: Ctx) = translate(node) { metas -> + val expr = visitExpr(node.expr, ctx) + val name = node.asAlias + letBinding(expr, name, metas) + } + + override fun visitGroupBy(node: GroupBy, ctx: Ctx) = translate(node) { metas -> + val strategy = when (node.strategy) { + GroupBy.Strategy.FULL -> groupFull() + GroupBy.Strategy.PARTIAL -> groupPartial() + } + val keyList = groupKeyList(node.keys.translate(ctx)) + val groupAsAlias = node.asAlias + groupBy(strategy, keyList, groupAsAlias, metas) + } + + override fun visitGroupByKey(node: GroupBy.Key, ctx: Ctx) = translate(node) { metas -> + val expr = visitExpr(node.expr, ctx) + val asAlias = node.asAlias + groupKey(expr, asAlias, metas) + } + + override fun visitOrderBy(node: OrderBy, ctx: Ctx) = translate(node) { metas -> + val sortSpecs = node.sorts.translate(ctx) + orderBy(sortSpecs, metas) + } + + override fun visitSort(node: Sort, ctx: Ctx) = translate(node) { metas -> + val expr = visitExpr(node.expr, ctx) + val orderingSpec = when (node.dir) { + Sort.Dir.ASC -> asc() + Sort.Dir.DESC -> desc() + null -> null + } + val nullsSpec = when (node.nulls) { + Sort.Nulls.FIRST -> nullsFirst() + Sort.Nulls.LAST -> nullsLast() + null -> null + } + sortSpec(expr, orderingSpec, nullsSpec, metas) + } + + /** + * UNSUPPORTED in legacy AST + */ + override fun visitSetOp(node: SetOp, ctx: Ctx) = defaultVisit(node, ctx) + + /** + * GPML + */ + + override fun visitGraphMatch(node: GraphMatch, ctx: Ctx) = translate(node) { metas -> + val selector = node.selector?.let { visitGraphMatchSelector(it, ctx) } + val patterns = node.patterns.translate(ctx) + gpmlPattern(selector, patterns, metas) + } + + override fun visitGraphMatchPattern(node: GraphMatch.Pattern, ctx: Ctx) = translate(node) { metas -> + val restrictor = when (node.restrictor) { + GraphMatch.Restrictor.TRAIL -> restrictorTrail() + GraphMatch.Restrictor.ACYCLIC -> restrictorAcyclic() + GraphMatch.Restrictor.SIMPLE -> restrictorSimple() + null -> null + } + val prefilter = node.prefilter?.let { visitExpr(it, ctx) } + val variable = node.variable + val quantifier = node.quantifier?.let { visitGraphMatchQuantifier(it, ctx) } + val parts = node.parts.translate(ctx) + graphMatchPattern(restrictor, prefilter, variable, quantifier, parts, metas) + } + + override fun visitGraphMatchPatternPart(node: GraphMatch.Pattern.Part, ctx: Ctx) = + super.visitGraphMatchPatternPart(node, ctx) as PartiqlAst.GraphMatchPatternPart + + override fun visitGraphMatchPatternPartNode(node: GraphMatch.Pattern.Part.Node, ctx: Ctx) = + translate(node) { metas -> + val prefilter = node.prefilter?.let { visitExpr(it, ctx) } + val variable = node.variable + val label = node.label + node(prefilter, variable, label, metas) + } + + override fun visitGraphMatchPatternPartEdge(node: GraphMatch.Pattern.Part.Edge, ctx: Ctx) = + translate(node) { metas -> + val direction = when (node.direction) { + GraphMatch.Direction.LEFT -> edgeLeft() + GraphMatch.Direction.UNDIRECTED -> edgeUndirected() + GraphMatch.Direction.RIGHT -> edgeRight() + GraphMatch.Direction.LEFT_OR_UNDIRECTED -> edgeLeftOrUndirected() + GraphMatch.Direction.UNDIRECTED_OR_RIGHT -> edgeUndirectedOrRight() + GraphMatch.Direction.LEFT_OR_RIGHT -> edgeLeftOrRight() + GraphMatch.Direction.LEFT_UNDIRECTED_OR_RIGHT -> edgeLeftOrUndirectedOrRight() + } + val quantifier = node.quantifier?.let { visitGraphMatchQuantifier(it, ctx) } + val prefilter = node.prefilter?.let { visitExpr(it, ctx) } + val variable = node.variable + val label = node.label + edge(direction, quantifier, prefilter, variable, label, metas) + } + + override fun visitGraphMatchPatternPartPattern(node: GraphMatch.Pattern.Part.Pattern, ctx: Ctx) = + translate(node) { metas -> + pattern(visitGraphMatchPattern(node.pattern, ctx), metas) + } + + override fun visitGraphMatchQuantifier(node: GraphMatch.Quantifier, ctx: Ctx) = translate(node) { metas -> + val lower = node.lower + val upper = node.upper + graphMatchQuantifier(lower, upper, metas) + } + + override fun visitGraphMatchSelector(node: GraphMatch.Selector, ctx: Ctx) = + super.visitGraphMatchSelector(node, ctx) as PartiqlAst.GraphMatchSelector + + override fun visitGraphMatchSelectorAnyShortest(node: GraphMatch.Selector.AnyShortest, ctx: Ctx) = + translate(node) { metas -> + selectorAnyShortest(metas) + } + + override fun visitGraphMatchSelectorAllShortest(node: GraphMatch.Selector.AllShortest, ctx: Ctx) = + translate(node) { metas -> + selectorAllShortest(metas) + } + + override fun visitGraphMatchSelectorAny(node: GraphMatch.Selector.Any, ctx: Ctx) = translate(node) { metas -> + selectorAny(metas) + } + + override fun visitGraphMatchSelectorAnyK(node: GraphMatch.Selector.AnyK, ctx: Ctx) = translate(node) { metas -> + val k = node.k + selectorAnyK(k, metas) + } + + override fun visitGraphMatchSelectorShortestK( + node: GraphMatch.Selector.ShortestK, + ctx: Ctx, + ) = translate(node) { metas -> + val k = node.k + selectorShortestK(k, metas) + } + + override fun visitGraphMatchSelectorShortestKGroup( + node: GraphMatch.Selector.ShortestKGroup, + ctx: Ctx, + ) = translate(node) { + val k = node.k + selectorShortestKGroup(k) + } + + /** + * DML + */ + + override fun visitStatementDML(node: Statement.DML, ctx: Ctx) = + super.visitStatementDML(node, ctx) as PartiqlAst.Statement + + override fun visitStatementDMLInsert(node: Statement.DML.Insert, ctx: Ctx) = translate(node) { metas -> + val target = visitIdentifier(node.target, ctx) + val asAlias = node.asAlias + val values = visitExpr(node.values, ctx) + val conflictAction = node.onConflict?.let { visitOnConflictAction(it.action, ctx) } + val op = insert(target, asAlias, values, conflictAction) + dml(dmlOpList(op), null, null, null, metas) + } + + override fun visitStatementDMLInsertLegacy( + node: Statement.DML.InsertLegacy, + ctx: Ctx, + ) = translate(node) { metas -> + val target = visitPathUnpack(node.target, ctx) + val values = visitExpr(node.value, ctx) + val index = node.index?.let { visitExpr(it, ctx) } + val onConflict = node.conflictCondition?.let { + val condition = visitExpr(it, ctx) + onConflict(condition, doNothing()) + } + val op = insertValue(target, values, index, onConflict) + dml(dmlOpList(op), null, null, null, metas) + } + + override fun visitStatementDMLUpsert(node: Statement.DML.Upsert, ctx: Ctx) = translate(node) { metas -> + val target = visitIdentifier(node.target, ctx) + val asAlias = node.asAlias + val values = visitExpr(node.values, ctx) + val conflictAction = doUpdate(excluded()) + // UPSERT overloads legacy INSERT + val op = insert(target, asAlias, values, conflictAction) + dml(dmlOpList(op), null, null, null, metas) + } + + override fun visitStatementDMLReplace(node: Statement.DML.Replace, ctx: Ctx) = translate(node) { metas -> + val target = visitIdentifier(node.target, ctx) + val asAlias = node.asAlias + val values = visitExpr(node.values, ctx) + val conflictAction = doReplace(excluded()) + // REPLACE overloads legacy INSERT + val op = insert(target, asAlias, values, conflictAction) + dml(dmlOpList(op), null, null, null, metas) + } + + override fun visitStatementDMLUpdate(node: Statement.DML.Update, ctx: Ctx) = translate(node) { metas -> + // Current PartiQL.g4 grammar models a SET with no UPDATE target as valid DML command. + // We don't want the target to be nullable in the AST because it's not in the SQL grammar. + // val target = visitPathUnpack(node.target, ctx) + // val from = scan(target) + // UPDATE becomes multiple sets + val operations = node.assignments.map { + val assignment = visitStatementDMLUpdateAssignment(it, ctx) + set(assignment) + } + dml(dmlOpList(operations), null, null, null, metas) + } + + override fun visitStatementDMLUpdateAssignment( + node: Statement.DML.Update.Assignment, + ctx: Ctx, + ) = translate(node) { metas -> + val target = visitPathUnpack(node.target, ctx) + val value = visitExpr(node.value, ctx) + assignment(target, value, metas) + } + + override fun visitStatementDMLRemove(node: Statement.DML.Remove, ctx: Ctx) = translate(node) { metas -> + val target = visitPathUnpack(node.target, ctx) + val op = remove(target) + dml(dmlOpList(op), null, null, null, metas) + } + + override fun visitStatementDMLDelete(node: Statement.DML.Delete, ctx: Ctx) = translate(node) { metas -> + val from = visitStatementDMLDeleteTarget(node.target, ctx) + val where = node.where?.let { visitExpr(it, ctx) } + val returning = node.returning?.let { visitReturning(it, ctx) } + val op = delete() + dml(dmlOpList(op), from, where, returning, metas) + } + + override fun visitStatementDMLDeleteTarget(node: Statement.DML.Delete.Target, ctx: Ctx) = translate(node) { metas -> + val path = visitPathUnpack(node.path, ctx) + val asAlias = node.asAlias + val atAlias = node.atAlias + val byAlias = node.byAlias + scan(path, asAlias, atAlias, byAlias, metas) + } + + override fun visitStatementDMLBatchLegacy(node: Statement.DML.BatchLegacy, ctx: Ctx) = translate(node) { metas -> + val from = node.target?.let { visitFrom(it, ctx) } + val ops = node.ops.translate(ctx).flatMap { it.ops } + val where = node.where?.let { visitExpr(it, ctx) } + val returning = node.returning?.let { visitReturning(it, ctx) } + dml(dmlOpList(ops), from, where, returning, metas) + } + + override fun visitStatementDMLBatchLegacyOp(node: Statement.DML.BatchLegacy.Op, ctx: Ctx) = + super.visitStatementDMLBatchLegacyOp(node, ctx) as PartiqlAst.DmlOpList + + override fun visitStatementDMLBatchLegacyOpSet( + node: Statement.DML.BatchLegacy.Op.Set, + ctx: Ctx, + ) = translate(node) { metas -> + val ops = node.assignments.map { + val assignment = visitStatementDMLUpdateAssignment(it, ctx) + set(assignment) + } + dmlOpList(ops, metas) + } + + override fun visitStatementDMLBatchLegacyOpRemove( + node: Statement.DML.BatchLegacy.Op.Remove, + ctx: Ctx, + ) = translate(node) { metas -> + val target = visitPathUnpack(node.target, ctx) + val ops = listOf(remove(target)) + dmlOpList(ops, metas) + } + + override fun visitStatementDMLBatchLegacyOpDelete( + node: Statement.DML.BatchLegacy.Op.Delete, + ctx: Ctx, + ) = translate(node) { metas -> + val ops = listOf(delete()) + dmlOpList(ops, metas) + } + + override fun visitStatementDMLBatchLegacyOpInsert( + node: Statement.DML.BatchLegacy.Op.Insert, + ctx: Ctx, + ) = translate(node) { metas -> + val target = visitIdentifier(node.target, ctx) + val asAlias = node.asAlias + val values = visitExpr(node.values, ctx) + val conflictAction = node.onConflict?.let { visitOnConflictAction(it.action, ctx) } + dmlOpList(insert(target, asAlias, values, conflictAction, metas)) + } + + override fun visitStatementDMLBatchLegacyOpInsertLegacy( + node: Statement.DML.BatchLegacy.Op.InsertLegacy, + ctx: Ctx, + ) = translate(node) { + val target = visitPathUnpack(node.target, ctx) + val values = visitExpr(node.value, ctx) + val index = node.index?.let { visitExpr(it, ctx) } + val onConflict = node.conflictCondition?.let { + val condition = visitExpr(it, ctx) + onConflict(condition, doNothing()) + } + dmlOpList(insertValue(target, values, index, onConflict)) + } + + override fun visitOnConflict(node: OnConflict, ctx: Ctx) = translate(node) { metas -> + val action = visitOnConflictAction(node.action, ctx) + if (node.target == null) { + // Legacy PartiQLPifVisitor doesn't respect the return type for the OnConflict rule + // - visitOnConflictLegacy returns an OnConflict node + // - visitOnConflict returns an OnConflict.Action + // Essentially, the on_conflict target appears in the grammar but not the PIG model + // Which means you technically can't use the #OnConflict alternative in certain contexts. + // We generally shouldn't have parser rule alternatives which are not variants of the same type. + throw IllegalArgumentException("PIG OnConflict (#OnConflictLegacy grammar rule) requires an expression") + } + val expr = visitOnConflictTarget(node.target!!, ctx) + onConflict(expr, action, metas) + } + + override fun visitOnConflictTarget(node: OnConflict.Target, ctx: Ctx) = + super.visitOnConflictTarget(node, ctx) as PartiqlAst.Expr + + override fun visitOnConflictTargetSymbols( + node: OnConflict.Target.Symbols, + ctx: Ctx, + ) = translate(node) { metas -> + val symbols = node.symbols.map { + if (it !is Identifier.Symbol) { + throw IllegalArgumentException("Legacy AST does not support qualified identifiers as index names") + } + lit(ionSymbol(it.symbol)) + } + list(symbols, metas) + } + + override fun visitOnConflictTargetConstraint( + node: OnConflict.Target.Constraint, + ctx: Ctx, + ) = translate(node) { metas -> + if (node.constraint !is Identifier.Symbol) { + throw IllegalArgumentException("Legacy AST does not support qualified identifiers as a constraint name") + } + val constraint = (node.constraint as Identifier.Symbol).symbol + lit(ionSymbol(constraint), metas) + } + + override fun visitOnConflictAction(node: OnConflict.Action, ctx: Ctx) = + super.visitOnConflictAction(node, ctx) as PartiqlAst.ConflictAction + + override fun visitOnConflictActionDoReplace( + node: OnConflict.Action.DoReplace, + ctx: Ctx, + ) = translate(node) { metas -> + val value = excluded() + val condition = node.condition?.let { visitExpr(it, ctx) } + doReplace(value, condition, metas) + } + + override fun visitOnConflictActionDoUpdate( + node: OnConflict.Action.DoUpdate, + ctx: Ctx, + ) = translate(node) { metas -> + val value = excluded() + val condition = node.condition?.let { visitExpr(it, ctx) } + doUpdate(value, condition, metas) + } + + override fun visitOnConflictActionDoNothing( + node: OnConflict.Action.DoNothing, + ctx: Ctx, + ) = translate(node) { metas -> + doNothing(metas) + } + + override fun visitReturning(node: Returning, ctx: Ctx) = translate(node) { metas -> + val elems = node.columns.translate(ctx) + returningExpr(elems, metas) + } + + override fun visitReturningColumn(node: Returning.Column, ctx: Ctx) = translate(node) { + // a fine example of `when` is `if`, not pattern matching + val mapping = when (node.status) { + Returning.Column.Status.MODIFIED -> when (node.age) { + Returning.Column.Age.OLD -> modifiedOld() + Returning.Column.Age.NEW -> modifiedNew() + } + Returning.Column.Status.ALL -> when (node.age) { + Returning.Column.Age.OLD -> allOld() + Returning.Column.Age.NEW -> allNew() + } + } + val column = visitReturningColumnValue(node.value, ctx) + returningElem(mapping, column) + } + + override fun visitReturningColumnValue(node: Returning.Column.Value, ctx: Ctx) = + super.visitReturningColumnValue(node, ctx) as PartiqlAst.ColumnComponent + + override fun visitReturningColumnValueWildcard( + node: Returning.Column.Value.Wildcard, + ctx: Ctx, + ) = translate(node) { + returningWildcard() + } + + override fun visitReturningColumnValueExpression( + node: Returning.Column.Value.Expression, + ctx: Ctx, + ) = translate(node) { + val expr = visitExpr(node.expr, ctx) + returningColumn(expr) + } + + /** + * TYPE + */ + + override fun visitType(node: Type, ctx: Ctx) = super.visitType(node, ctx) as PartiqlAst.Type + + override fun visitTypeNullType(node: Type.NullType, ctx: Ctx) = translate(node) { metas -> nullType(metas) } + + override fun visitTypeMissing(node: Type.Missing, ctx: Ctx) = translate(node) { metas -> missingType(metas) } + + override fun visitTypeBool(node: Type.Bool, ctx: Ctx) = translate(node) { metas -> booleanType(metas) } + + override fun visitTypeTinyint(node: Type.Tinyint, ctx: Ctx) = + throw IllegalArgumentException("TINYINT type not supported") + + override fun visitTypeSmallint(node: Type.Smallint, ctx: Ctx) = translate(node) { metas -> smallintType(metas) } + + override fun visitTypeInt2(node: Type.Int2, ctx: Ctx) = translate(node) { metas -> smallintType(metas) } + + override fun visitTypeInt4(node: Type.Int4, ctx: Ctx) = translate(node) { metas -> integer4Type(metas) } + + override fun visitTypeBigint(node: Type.Bigint, ctx: Ctx) = translate(node) { metas -> integer8Type(metas) } + + override fun visitTypeInt8(node: Type.Int8, ctx: Ctx) = translate(node) { metas -> integer8Type(metas) } + + override fun visitTypeInt(node: Type.Int, ctx: Ctx) = translate(node) { metas -> integerType(metas) } + + override fun visitTypeReal(node: Type.Real, ctx: Ctx) = translate(node) { metas -> realType(metas) } + + override fun visitTypeFloat32(node: Type.Float32, ctx: Ctx) = translate(node) { metas -> floatType(null, metas) } + + override fun visitTypeFloat64(node: Type.Float64, ctx: Ctx) = + translate(node) { metas -> doublePrecisionType(metas) } + + override fun visitTypeDecimal(node: Type.Decimal, ctx: Ctx) = translate(node) { metas -> + decimalType( + precision = node.precision?.toLong(), + scale = node.scale?.toLong(), + metas = metas, + ) + } + + override fun visitTypeNumeric(node: Type.Numeric, ctx: Ctx) = translate(node) { metas -> + numericType( + precision = node.precision?.toLong(), + scale = node.scale?.toLong(), + metas = metas, + ) + } + + override fun visitTypeChar(node: Type.Char, ctx: Ctx) = + translate(node) { metas -> characterType(node.length?.toLong(), metas) } + + override fun visitTypeVarchar(node: Type.Varchar, ctx: Ctx) = + translate(node) { metas -> characterVaryingType(node.length?.toLong(), metas) } + + override fun visitTypeString(node: Type.String, ctx: Ctx) = translate(node) { metas -> stringType(metas) } + + override fun visitTypeSymbol(node: Type.Symbol, ctx: Ctx) = translate(node) { metas -> symbolType(metas) } + + override fun visitTypeBit(node: Type.Bit, ctx: Ctx) = throw IllegalArgumentException("BIT type not supported") + + override fun visitTypeBitVarying(node: Type.BitVarying, ctx: Ctx) = + throw IllegalArgumentException("BIT VARYING type not supported") + + override fun visitTypeByteString(node: Type.ByteString, ctx: Ctx) = + throw IllegalArgumentException("BYTESTRING type not supported") + + override fun visitTypeBlob(node: Type.Blob, ctx: Ctx) = translate(node) { metas -> blobType(metas) } + + override fun visitTypeClob(node: Type.Clob, ctx: Ctx) = translate(node) { metas -> clobType(metas) } + + override fun visitTypeDate(node: Type.Date, ctx: Ctx) = translate(node) { metas -> dateType(metas) } + + override fun visitTypeTime(node: Type.Time, ctx: Ctx) = + translate(node) { metas -> timeType(node.precision?.toLong(), metas) } + + override fun visitTypeTimeWithTz(node: Type.TimeWithTz, ctx: Ctx) = + translate(node) { metas -> timeWithTimeZoneType(node.precision?.toLong(), metas) } + + override fun visitTypeTimestamp(node: Type.Timestamp, ctx: Ctx) = + translate(node) { metas -> timestampType(node.precision?.toLong(), metas) } + + override fun visitTypeTimestampWithTz(node: Type.TimestampWithTz, ctx: Ctx) = + throw IllegalArgumentException("TIMESTAMP [WITH TIMEZONE] type not supported") + + override fun visitTypeInterval(node: Type.Interval, ctx: Ctx) = + throw IllegalArgumentException("INTERVAL type not supported") + + override fun visitTypeBag(node: Type.Bag, ctx: Ctx) = translate(node) { metas -> bagType(metas) } + + override fun visitTypeList(node: Type.List, ctx: Ctx) = translate(node) { metas -> listType(metas) } + + override fun visitTypeSexp(node: Type.Sexp, ctx: Ctx) = translate(node) { metas -> sexpType(metas) } + + override fun visitTypeTuple(node: Type.Tuple, ctx: Ctx) = translate(node) { metas -> tupleType(metas) } + + override fun visitTypeStruct(node: Type.Struct, ctx: Ctx) = translate(node) { metas -> structType(metas) } + + override fun visitTypeAny(node: Type.Any, ctx: Ctx) = translate(node) { metas -> anyType(metas) } + + override fun visitTypeCustom(node: Type.Custom, ctx: Ctx) = + translate(node) { metas -> customType(node.name.lowercase(), metas) } + + /** + * HELPERS + */ + + private inline fun List.translate(ctx: Ctx): List = + this.map { visit(it, ctx) as S } + + private inline fun visitOrNull(node: AstNode?, ctx: Ctx): T? = + node?.let { visit(it, ctx) as T } + + private fun Identifier.CaseSensitivity.toLegacyCaseSensitivity() = when (this) { + Identifier.CaseSensitivity.SENSITIVE -> PartiqlAst.CaseSensitivity.CaseSensitive() + Identifier.CaseSensitivity.INSENSITIVE -> PartiqlAst.CaseSensitivity.CaseInsensitive() + } + + private fun Expr.Var.Scope.toLegacyScope() = when (this) { + Expr.Var.Scope.DEFAULT -> PartiqlAst.ScopeQualifier.Unqualified() + Expr.Var.Scope.LOCAL -> PartiqlAst.ScopeQualifier.LocalsFirst() + } + + private fun SetQuantifier.toLegacySetQuantifier() = when (this) { + SetQuantifier.ALL -> PartiqlAst.SetQuantifier.All() + SetQuantifier.DISTINCT -> PartiqlAst.SetQuantifier.Distinct() + } + + private fun DatetimeField.toLegacyDatetimePart(): PartiqlAst.Expr.Lit { + val symbol = this.toString().lowercase() + return pig.lit(ionSymbol(symbol)) + } + + // Legacy AST models targets as expressions + private fun visitPathUnpack(path: Path, ctx: Ctx): PartiqlAst.Expr { + val ex = visitPath(path, ctx) + return if (ex.steps.isEmpty()) ex.root else ex + } + + private fun metaContainerOf(vararg metas: Meta): MetaContainer = metaContainerOf(metas.map { Pair(it.tag, it) }) + + // Time Value is not an Expr.Lit in the legacy AST; needs special treatment. + private fun TimeValue.toLegacyAst(metas: MetaContainer): PartiqlAst.Expr.LitTime { + val d = this.value.decimalSecond + val seconds = d.toLong() + val nano = d.subtract(BigDecimal(seconds)).scaleByPowerOfTen(9).toLong() + val time = pig.timeValue( + hour = this.value.hour.toLong(), + minute = this.value.minute.toLong(), + second = seconds, + nano = nano, + precision = this.value.decimalSecond.precision().toLong(), + withTimeZone = this.value.timeZone != null, + tzMinutes = this.value.timeZone?.let { + when (it) { + is TimeZone.UtcOffset -> it.totalOffsetMinutes.toLong() + else -> 0 + } + }, + ) + return pig.litTime(time, metas) + } + + // Timestamp Value is not an Expr.Lit in the legacy AST; needs special treatment. + private fun TimestampValue.toLegacyAst(metas: MetaContainer): PartiqlAst.Expr.Timestamp { + val timeZone = value.timeZone?.toLegacyAst(metas) + val precision = value.decimalSecond.precision().toLong() + return pig.timestamp( + pig.timestampValue( + value.year.toLong(), value.month.toLong(), value.day.toLong(), + value.hour.toLong(), value.minute.toLong(), ionDecimal(Decimal.valueOf(value.decimalSecond)), + timeZone, precision + ) + ) + } + + // Date Value is not an Expr.Lit in the legacy AST; needs special treatment. + private fun DateValue.toLegacyAst(metas: MetaContainer): PartiqlAst.Expr.Date { + return pig.date( + year = this.value.year.toLong(), + month = this.value.month.toLong(), + day = this.value.day.toLong(), + metas = metas, + ) + } + + private fun TimeZone.toLegacyAst(metas: MetaContainer): PartiqlAst.Timezone { + return when (this) { + TimeZone.UnknownTimeZone -> pig.unknownTimezone(metas) + is TimeZone.UtcOffset -> pig.utcOffset(totalOffsetMinutes.toLong(), metas) + } + } +} diff --git a/partiql-ast/src/main/kotlin/org/partiql/ast/impl/.gitkeep b/partiql-ast/src/main/kotlin/org/partiql/ast/impl/.gitkeep new file mode 100644 index 0000000000..e69de29bb2 diff --git a/partiql-lang/src/main/kotlin/org/partiql/lang/ast/AggregateCallSiteListMeta.kt b/partiql-ast/src/main/kotlin/org/partiql/lang/ast/AggregateCallSiteListMeta.kt similarity index 100% rename from partiql-lang/src/main/kotlin/org/partiql/lang/ast/AggregateCallSiteListMeta.kt rename to partiql-ast/src/main/kotlin/org/partiql/lang/ast/AggregateCallSiteListMeta.kt diff --git a/partiql-lang/src/main/kotlin/org/partiql/lang/ast/InternalMetas.kt b/partiql-ast/src/main/kotlin/org/partiql/lang/ast/InternalMetas.kt similarity index 100% rename from partiql-lang/src/main/kotlin/org/partiql/lang/ast/InternalMetas.kt rename to partiql-ast/src/main/kotlin/org/partiql/lang/ast/InternalMetas.kt diff --git a/partiql-lang/src/main/kotlin/org/partiql/lang/ast/IsCountStarMeta.kt b/partiql-ast/src/main/kotlin/org/partiql/lang/ast/IsCountStarMeta.kt similarity index 100% rename from partiql-lang/src/main/kotlin/org/partiql/lang/ast/IsCountStarMeta.kt rename to partiql-ast/src/main/kotlin/org/partiql/lang/ast/IsCountStarMeta.kt diff --git a/partiql-lang/src/main/kotlin/org/partiql/lang/ast/IsGroupAttributeReferenceMeta.kt b/partiql-ast/src/main/kotlin/org/partiql/lang/ast/IsGroupAttributeReferenceMeta.kt similarity index 92% rename from partiql-lang/src/main/kotlin/org/partiql/lang/ast/IsGroupAttributeReferenceMeta.kt rename to partiql-ast/src/main/kotlin/org/partiql/lang/ast/IsGroupAttributeReferenceMeta.kt index 2ee9989c20..7999101530 100644 --- a/partiql-lang/src/main/kotlin/org/partiql/lang/ast/IsGroupAttributeReferenceMeta.kt +++ b/partiql-ast/src/main/kotlin/org/partiql/lang/ast/IsGroupAttributeReferenceMeta.kt @@ -16,7 +16,7 @@ package org.partiql.lang.ast /** * Meta attached to an identifier when replacing the identifier with a reference to a group key variable declaration. */ -internal class IsGroupAttributeReferenceMeta private constructor() : Meta { +class IsGroupAttributeReferenceMeta private constructor() : Meta { override val tag = TAG companion object { diff --git a/partiql-lang/src/main/kotlin/org/partiql/lang/ast/IsImplictJoinMeta.kt b/partiql-ast/src/main/kotlin/org/partiql/lang/ast/IsImplictJoinMeta.kt similarity index 100% rename from partiql-lang/src/main/kotlin/org/partiql/lang/ast/IsImplictJoinMeta.kt rename to partiql-ast/src/main/kotlin/org/partiql/lang/ast/IsImplictJoinMeta.kt diff --git a/partiql-lang/src/main/kotlin/org/partiql/lang/ast/IsIonLiteralMeta.kt b/partiql-ast/src/main/kotlin/org/partiql/lang/ast/IsIonLiteralMeta.kt similarity index 100% rename from partiql-lang/src/main/kotlin/org/partiql/lang/ast/IsIonLiteralMeta.kt rename to partiql-ast/src/main/kotlin/org/partiql/lang/ast/IsIonLiteralMeta.kt diff --git a/partiql-lang/src/main/kotlin/org/partiql/lang/ast/IsListParenthesizedMeta.kt b/partiql-ast/src/main/kotlin/org/partiql/lang/ast/IsListParenthesizedMeta.kt similarity index 100% rename from partiql-lang/src/main/kotlin/org/partiql/lang/ast/IsListParenthesizedMeta.kt rename to partiql-ast/src/main/kotlin/org/partiql/lang/ast/IsListParenthesizedMeta.kt diff --git a/partiql-lang/src/main/kotlin/org/partiql/lang/ast/IsOrderedMeta.kt b/partiql-ast/src/main/kotlin/org/partiql/lang/ast/IsOrderedMeta.kt similarity index 100% rename from partiql-lang/src/main/kotlin/org/partiql/lang/ast/IsOrderedMeta.kt rename to partiql-ast/src/main/kotlin/org/partiql/lang/ast/IsOrderedMeta.kt diff --git a/partiql-lang/src/main/kotlin/org/partiql/lang/ast/IsPathIndexMeta.kt b/partiql-ast/src/main/kotlin/org/partiql/lang/ast/IsPathIndexMeta.kt similarity index 100% rename from partiql-lang/src/main/kotlin/org/partiql/lang/ast/IsPathIndexMeta.kt rename to partiql-ast/src/main/kotlin/org/partiql/lang/ast/IsPathIndexMeta.kt diff --git a/partiql-lang/src/main/kotlin/org/partiql/lang/ast/IsTransformedOrderByAliasMeta.kt b/partiql-ast/src/main/kotlin/org/partiql/lang/ast/IsTransformedOrderByAliasMeta.kt similarity index 82% rename from partiql-lang/src/main/kotlin/org/partiql/lang/ast/IsTransformedOrderByAliasMeta.kt rename to partiql-ast/src/main/kotlin/org/partiql/lang/ast/IsTransformedOrderByAliasMeta.kt index b1487190e6..092320e570 100644 --- a/partiql-lang/src/main/kotlin/org/partiql/lang/ast/IsTransformedOrderByAliasMeta.kt +++ b/partiql-ast/src/main/kotlin/org/partiql/lang/ast/IsTransformedOrderByAliasMeta.kt @@ -14,11 +14,9 @@ package org.partiql.lang.ast -import org.partiql.lang.eval.visitors.OrderBySortSpecVisitorTransform - /** - * A [Meta] to help the [OrderBySortSpecVisitorTransform] to know when the OrderBy SortSpec has already been transformed. It - * essentially helps to turn + * A [Meta] to help the OrderBySortSpecVisitorTransform to know when the OrderBy SortSpec has already been transformed. + * It essentially helps to turn * * ```SELECT a + 1 AS b FROM c ORDER BY b``` * @@ -30,6 +28,7 @@ import org.partiql.lang.eval.visitors.OrderBySortSpecVisitorTransform */ class IsTransformedOrderByAliasMeta private constructor() : Meta { override val tag = TAG + companion object { const val TAG = "\$is_transformed_order_by_alias" diff --git a/partiql-lang/src/main/kotlin/org/partiql/lang/ast/IsValuesExprMeta.kt b/partiql-ast/src/main/kotlin/org/partiql/lang/ast/IsValuesExprMeta.kt similarity index 100% rename from partiql-lang/src/main/kotlin/org/partiql/lang/ast/IsValuesExprMeta.kt rename to partiql-ast/src/main/kotlin/org/partiql/lang/ast/IsValuesExprMeta.kt diff --git a/partiql-lang/src/main/kotlin/org/partiql/lang/ast/LegacyLogicalNotMeta.kt b/partiql-ast/src/main/kotlin/org/partiql/lang/ast/LegacyLogicalNotMeta.kt similarity index 53% rename from partiql-lang/src/main/kotlin/org/partiql/lang/ast/LegacyLogicalNotMeta.kt rename to partiql-ast/src/main/kotlin/org/partiql/lang/ast/LegacyLogicalNotMeta.kt index f3177fd39d..842e66a200 100644 --- a/partiql-lang/src/main/kotlin/org/partiql/lang/ast/LegacyLogicalNotMeta.kt +++ b/partiql-ast/src/main/kotlin/org/partiql/lang/ast/LegacyLogicalNotMeta.kt @@ -15,23 +15,9 @@ package org.partiql.lang.ast /** - * The old AST has nodes: `not_like`, `not_between` and `is_not` which are respectively paired with its - * `like`, `between`, and `is` nodes. - * - * The new AST lacks the `not` version of these, instead wrapping the non-`not` versions of these notes in a `not` - * n-ary expression to achieve the same semantics. - * - * For example: - * - Legacy: (not_like ) - * - New AST (will be something like): - * (nary - * (op not) - * (args (nary like )) - * (metas ((name LegacyLogicalNotMeta)))) - * - * [LegacyLogicalNotMeta] is added to `(nary (op not) ...` node so that the [ToLegacyAstPass] knows to emit the - * `not_like`, `not_between` or `is_not` s-expression nodes. + * A legacy meta that is no longer used. */ +@Deprecated("To be removed in the next minor version") class LegacyLogicalNotMeta private constructor() : Meta { override val tag = TAG companion object { diff --git a/partiql-lang/src/main/kotlin/org/partiql/lang/ast/SourceLocationMeta.kt b/partiql-ast/src/main/kotlin/org/partiql/lang/ast/SourceLocationMeta.kt similarity index 88% rename from partiql-lang/src/main/kotlin/org/partiql/lang/ast/SourceLocationMeta.kt rename to partiql-ast/src/main/kotlin/org/partiql/lang/ast/SourceLocationMeta.kt index aa32f718d0..c0105845c9 100644 --- a/partiql-lang/src/main/kotlin/org/partiql/lang/ast/SourceLocationMeta.kt +++ b/partiql-ast/src/main/kotlin/org/partiql/lang/ast/SourceLocationMeta.kt @@ -16,8 +16,9 @@ package org.partiql.lang.ast import com.amazon.ion.IonWriter import com.amazon.ionelement.api.MetaContainer +import com.amazon.ionelement.api.ionInt +import com.amazon.ionelement.api.ionStructOf import com.amazon.ionelement.api.metaOrNull -import org.partiql.lang.util.IonWriterContext /** * Represents a specific location within a source file. @@ -28,13 +29,11 @@ data class SourceLocationMeta(val lineNum: Long, val charOffset: Long, val lengt override val tag = TAG override fun serialize(writer: IonWriter) { - IonWriterContext(writer).apply { - struct { - int("line_num", lineNum) - int("char_offset", charOffset) - int("length", length) - } - } + ionStructOf( + "line_num" to ionInt(lineNum), + "char_offset" to ionInt(charOffset), + "length" to ionInt(length), + ).writeTo(writer) } override fun equals(other: Any?): Boolean { diff --git a/partiql-ast/src/main/kotlin/org/partiql/lang/ast/StaticTypeMeta.kt b/partiql-ast/src/main/kotlin/org/partiql/lang/ast/StaticTypeMeta.kt new file mode 100644 index 0000000000..f32f0dbfa3 --- /dev/null +++ b/partiql-ast/src/main/kotlin/org/partiql/lang/ast/StaticTypeMeta.kt @@ -0,0 +1,17 @@ +package org.partiql.lang.ast + +import org.partiql.types.StaticType + +/** + * Represents a static type for an AST element. + */ +data class StaticTypeMeta(val type: StaticType) : Meta { + + override fun toString() = type.toString() + + override val tag = TAG + + companion object { + const val TAG = "\$static_type" + } +} diff --git a/partiql-lang/src/main/kotlin/org/partiql/lang/ast/meta.kt b/partiql-ast/src/main/kotlin/org/partiql/lang/ast/meta.kt similarity index 96% rename from partiql-lang/src/main/kotlin/org/partiql/lang/ast/meta.kt rename to partiql-ast/src/main/kotlin/org/partiql/lang/ast/meta.kt index d1bb88cbb4..0925b5afb3 100644 --- a/partiql-lang/src/main/kotlin/org/partiql/lang/ast/meta.kt +++ b/partiql-ast/src/main/kotlin/org/partiql/lang/ast/meta.kt @@ -16,8 +16,8 @@ package org.partiql.lang.ast import com.amazon.ion.IonWriter import com.amazon.ionelement.api.MetaContainer +import com.amazon.ionelement.api.metaContainerOf import com.amazon.ionelement.api.plus -import org.partiql.lang.domains.metaContainerOf /** * The [Meta] interface is implemented by classes that provide an object mapping view to AST meta nodes. @@ -66,4 +66,4 @@ fun MetaContainer.find(tagName: String): Meta? = this[tagName] as Meta? fun MetaContainer.hasMeta(tagName: String) = this.containsKey(tagName) -fun MetaContainer.add(meta: Meta): MetaContainer = this.plus(metaContainerOf(meta)) +fun MetaContainer.add(meta: Meta): MetaContainer = this.plus(metaContainerOf(meta.tag to meta)) diff --git a/partiql-ast/src/main/resources/partiql_ast.ion b/partiql-ast/src/main/resources/partiql_ast.ion new file mode 100644 index 0000000000..104b1b69d7 --- /dev/null +++ b/partiql-ast/src/main/resources/partiql_ast.ion @@ -0,0 +1,792 @@ +imports::{ kotlin: [ + ion::'com.amazon.ionelement.api.IonElement', + value::'org.partiql.value.PartiQLValue', + ], +} + +statement::[ + + // PartiQL Expressions + query::{ + expr: expr, + }, + + // Data Manipulation Language + d_m_l::[ + + // INSERT INTO [AS ] [] + insert::{ + target: identifier, + values: expr, + as_alias: optional::string, + on_conflict: optional::on_conflict, + }, + + // INSERT INTO VALUE [AT ] [] + insert_legacy::{ + target: path, + value: expr, + index: optional::expr, + conflict_condition: optional::expr, + }, + + // UPSERT INTO [] + upsert::{ + target: identifier, + values: expr, + as_alias: optional::string, + }, + + // REPLACE INTO [AS ] + replace::{ + target: identifier, + values: expr, + as_alias: optional::string, + }, + + // UPDATE SET WHERE + update::{ + target: path, + assignments: list::[assignment], + _: [ + assignment::{ + target: path, + value: expr, + }, + ], + }, + + // REMOVE + remove::{ + target: path, + }, + + // DELETE [FROM [AS ] [AT ] [BY ]] [ WHERE ] RETURNING ... + delete::{ + target: { + path: path, + as_alias: optional::string, + at_alias: optional::string, + by_alias: optional::string, + }, + where: optional::expr, + returning: optional::returning, + }, + + // [UPDATE|FROM] + WHERE RETURNING + batch_legacy::{ + ops: list::[op], + target: optional::from, + where: optional::expr, + returning: optional::returning, + _: [ + op::[ + insert::{ + target: identifier, + values: expr, + as_alias: optional::string, + on_conflict: optional::on_conflict, + }, + insert_legacy::{ + target: path, + value: expr, + index: optional::expr, + conflict_condition: optional::expr, + }, + set::{ + assignments: list::[assignment], + }, + remove::{ + target: path, + }, + delete::{}, + ], + ], + }, + ], + + // Data Definition Language + d_d_l::[ + + // CREATE TABLE [] + create_table::{ + name: identifier, + definition: optional::table_definition, + }, + + // CREATE INDEX [] ON ( [, ]...) + create_index::{ + index: optional::identifier, + table: identifier, + fields: list::[path], + }, + + // DROP TABLE + drop_table::{ + table: identifier, + }, + + // DROP INDEX ON + drop_index::{ + index: identifier, // [0] + table: identifier, // [1] + }, + ], + + // EXEC [.*] + exec::{ + procedure: string, + args: list::[expr], + }, + + // EXPLAIN + explain::{ + target: [ + domain::{ + statement: statement, + type: optional::string, + format: optional::string, + }, + ], + }, +] + +// PartiQL Type AST nodes +// +// Several of these are the same "type", but have various syntax rules we wish to capture. +// +type::[ + null_type::{}, // NULL + missing::{}, // MISSING + bool::{}, // BOOL + tinyint::{}, // TINYINT + smallint::{}, // SMALLINT + int2::{}, // INT2 | INTEGER2 + int4::{}, // INT4 | INTEGER4 + bigint::{}, // BIGINT + int8::{}, // INT8 + int::{}, // INTEGER + real::{}, // REAL + float32::{}, // FLOAT + float64::{}, // DOUBLE PRECISION + decimal::{ // DECIMAL [([ [,])] + precision: optional::int, + scale: optional::int, + }, + numeric::{ // NUMERIC [([ [,])] + precision: optional::int, + scale: optional::int, + }, + char::{ length: optional::int }, // CHARACTER [()] | CHAR [()] + varchar::{ length: optional::int }, // CHARACTER VARYING [()] | VARCHAR [()] + string::{ length: optional::int }, // STRING + symbol::{}, // SYMBOL + bit::{ length: optional::int }, // BIT [()] + bit_varying::{ length: optional::int }, // BIT_VARYING [()] + byte_string::{ length: optional::int }, // BYTE [()] + blob::{ length: optional::int }, // BLOB [()] + clob::{ length: optional::int }, // CLOB [()] + date::{}, // DATE + time::{ precision: optional::int }, // TIME [()] [WITHOUT TIMEZONE] + time_with_tz::{ precision: optional::int }, // TIME [()] WITH TIMEZONE + timestamp::{ precision: optional::int }, // TIMESTAMP [()] [WITHOUT TIMEZONE] + timestamp_with_tz::{ precision: optional::int }, // TIMESTAMP [()] WITH TIMEZONE + interval::{ precision: optional::int }, // INTERVAL + bag::{}, // BAG + list::{}, // LIST + sexp::{}, // SEXP + tuple::{}, // TUPLE + struct::{}, // STRUCT + any::{}, // ANY + custom::{ name: string }, // +] + +// Identifiers and Qualified Identifiers +//---------------------------------------------- +// ::= | +// +// ::= // case-insensitive +// | "" // case-sensitive +// +// ::= ('.' )+; +// +identifier::[ + symbol::{ + symbol: string, + case_sensitivity: case_sensitivity, + }, + qualified::{ + root: symbol, + steps: list::[symbol], + }, + _::[ + case_sensitivity::[ + SENSITIVE, + INSENSITIVE, + ], + ], +] + +// Path Literals +// - Much like qualified identifier but allowing bracket notation '[' | ']' +// - Not a variant of `identifier`, as path literals are not explicit in the specification. +path::{ + root: '.identifier.symbol', + steps: list::[step], + _: [ + step::[ + symbol::{ + symbol: '.identifier.symbol', + }, + index::{ + index: int, + }, + ], + ], +} + +// [ ALL | DISTINCT ] +set_quantifier::[ + ALL, + DISTINCT, +] + +// PartiQL Expression +expr::[ + + // PartiQL Literal Value + lit::{ + value: '.value', + }, + + // Ion Literal Value, ie `` + ion::{ + value: '.ion', + }, + + // Variable Reference + var::{ + identifier: identifier, + scope: [ + DEFAULT, // x.y.z + LOCAL, // @x.y.z + ], + }, + + // SQL Session Keywords (CURRENT_USER, CURRENT_ROLE, etc.) + session_attribute::{ + attribute: [ + CURRENT_USER, + ], + }, + + // Expression Paths + path::{ + root: expr, + steps: list::[step], + _: [ + step::[ + symbol::{ symbol: '.identifier.symbol' }, + index::{ key: expr }, + wildcard::{}, + unpivot::{}, + ], + ], + }, + + // Scalar Function Call + call::{ + function: identifier, + args: list::[expr], + }, + + // Aggregate Function Call + agg::{ + function: identifier, + args: list::[expr], + setq: optional::set_quantifier, + }, + + // Parameter `?` + parameter::{ + index: int, + }, + + // Unary Operators + unary::{ + op: [ NOT, POS, NEG ], + expr: expr, + }, + + // Binary Operators + binary::{ + op: [ + PLUS, MINUS, TIMES, DIVIDE, MODULO, CONCAT, + AND, OR, + EQ, NE, GT, GTE, LT, LTE, + ], + lhs: expr, + rhs: expr, + }, + + // VALUES (',' )* + values::{ + rows: list::[row], + _:[ + row::{ + items: list::[expr], + }, + ], + }, + + // Collection Constructors + collection::{ + type: [ + BAG, // << ... >> + ARRAY, // [ ... ] + VALUES, // ( ... ) + LIST, // LIST ( ... ) + SEXP, // SEXP ( ... ) + ], + values: list::[expr], + }, + + // Struct Constructor + struct::{ + fields: list::[field], + _: [ + field::{ + name: expr, + value: expr, + }, + ], + }, + + // SQL special form `[NOT] LIKE` + like::{ + value: expr, + pattern: expr, + escape: optional::expr, + not: optional::bool, + }, + + // SQL special form `[NOT] BETWEEN` + between::{ + value: expr, + from: expr, + to: expr, + not: optional::bool, + }, + + // SQL special form `[NOT] IN` + in_collection::{ + lhs: expr, + rhs: expr, + not: optional::bool, + }, + + // PartiQL special form `IS [NOT]` + is_type::{ + value: expr, + type: '.type', + not: optional::bool, + }, + + // The simple and searched `case` switch SQL special form F261-01, F261-02 + case::{ + expr: optional::expr, + branches: list::[branch], + default: optional::expr, + _: [ + branch::{ + condition: expr, + expr: expr, + }, + ], + }, + + // SQL special form F261-04 `COALESCE` + coalesce::{ + args: list::[expr], + }, + + // SQL special form F261-03 `NULLIF` + null_if::{ + value: expr, + nullifier: expr, + }, + + // SQL special form E021-06 `SUBSTRING ( [FROM ] [FOR ] )` + substring::{ + value: expr, // [0] + start: optional::expr, // [1] + length: optional::expr, // [2] + }, + + // SQL special form E021-11 `POSITION ( IN )` + position::{ + lhs: expr, + rhs: expr, + }, + + // SQL special form E021-09 `TRIM ( [LEADING|TRAILING|BOTH] [ FROM] )` + trim::{ + value: expr, // [0] + chars: optional::expr, // [1] + spec: optional::[ LEADING, TRAILING, BOTH ], // + }, + + // SQL special form `OVERLAY ( PLACING FROM [FOR ] )` + overlay::{ + value: expr, // [0] + overlay: expr, // [1] + start: expr, // [2] + length: optional::expr, // [3] + }, + + // SQL special form `EXTRACT ( FROM )` + extract::{ + field: datetime_field, + source: expr, + }, + + // SQL special form F201 `CAST` + cast::{ + value: expr, + as_type: '.type', + }, + + // PartiQL special form `CAN_CAST` + can_cast::{ + value: expr, + as_type: '.type', + }, + + // PartiQL special form `CAN_LOSSLESS_CAST` + can_lossless_cast::{ + value: expr, + as_type: '.type', + }, + + // PartiQL special form `DATE_ADD ( , , )` + date_add::{ + field: datetime_field, + lhs: expr, + rhs: expr, + }, + + // PartiQL special form `DATE_DIFF ( , , )` + date_diff::{ + field: datetime_field, + lhs: expr, + rhs: expr, + }, + + // PartiQL special form `[OUTER] (UNION|INTERSECT|EXCEPT) [ALL|DISTINCT]` + bag_op::{ + type: '.set_op', + lhs: expr, + rhs: expr, + outer: optional::bool, + }, + + // The PartiQL `` query expression, think SQL `` + s_f_w::{ + select: select, // oneof SELECT / SELECT VALUE / PIVOT + from: from, + let: optional::let, + where: optional::expr, + group_by: optional::group_by, + having: optional::expr, + set_op: optional::{ + type: '.set_op', + operand: '.expr.s_f_w', + }, + order_by: optional::order_by, + limit: optional::expr, + offset: optional::expr, + }, + + // GPML ( MATCH ) + match::{ + expr: expr, + pattern: graph_match, + }, + + // [LAG|LEAD] ( [ [ ] ] ) + // OVER ([PARTITION BY [, ]... ] [ORDER BY [, ]... ]) + window::{ + function: [ LAG, LEAD ], + expression: expr, // [0] + offset: optional::expr, // [1] + default: optional::expr, // [2] + over: { + partitions: optional::list::[expr], + sorts: optional::list::[sort], + }, + }, +] + +// PartiQL SELECT Clause Variants — https://partiql.org/dql/select.html +select::[ + + // SELECT [ALL|DISTINCT] * + star::{ + setq: optional::set_quantifier, + }, + + // SELECT [ALL|DISTINCT] + project::{ + items: list::[item], + setq: optional::set_quantifier, + _: [ + item::[ + all::{ expr: expr }, // .* + expression::{ expr: expr, as_alias: optional::string } // [as ] + ], + ], + }, + + // PIVOT AT + pivot::{ + key: expr, // [0] + value: expr, // [1] + }, + + // SELECT [ALL|DISTINCT] VALUE + value::{ + constructor: expr, + setq: optional::set_quantifier, + }, +] + +// PartiQL FROM Clause Variants — https://partiql.org/dql/from.html +from::[ + + // FROM [UNPIVOT] [AS ] [AT ] [BY ] + value::{ + expr: expr, + type: [ SCAN, UNPIVOT ], + as_alias: optional::string, + at_alias: optional::string, + by_alias: optional::string, + }, + + // TODO https://github.com/partiql/partiql-spec/issues/41 + // TODO https://github.com/partiql/partiql-lang-kotlin/issues/1013 + join::{ + lhs: from, + rhs: from, + type: optional::[ + INNER, + LEFT, + LEFT_OUTER, + RIGHT, + RIGHT_OUTER, + FULL, + FULL_OUTER, + CROSS, + COMMA, + ], + condition: optional::expr, + }, +] + +let::{ + bindings: list::[binding], + _: [ + binding::{ + expr: expr, + as_alias: string, + }, + ], +} + +// GROUP BY Clause — https://partiql.org/dql/group_by.html +group_by::{ + strategy: [ FULL, PARTIAL ], + keys: list::[key], + as_alias: optional::string, + _: [ + key::{ + expr: expr, + as_alias: optional::string, + }, + ], +} + +// ORDER BY Clause — https://partiql.org/dql/order_by.html +order_by::{ + sorts: list::[sort], +} + +// [ASC|DESC] [NULLS FIRST | NULLS LAST] +sort::{ + expr: expr, + dir: optional::[ ASC, DESC ], + nulls: optional::[ FIRST, LAST ], +} + +// (UNION|INTERSECT|EXCEPT) [ALL|DISTINCT] +set_op::{ + type: [ UNION, INTERSECT, EXCEPT ], + setq: optional::set_quantifier, +} + +// Graph Match Nodes — https://partiql.org/gpml/graph_query.html +graph_match::{ + patterns: list::[pattern], + selector: optional::selector, + + _: [ + pattern::{ + restrictor: optional::restrictor, + prefilter: optional::expr, // An optional pattern pre-filter, e.g.: `WHERE a.name=b.name` in `MATCH [(a)->(b) WHERE a.name=b.name]` + variable: optional::string, // The optional element variable of the pattern, e.g.: `p` in `MATCH p = (a) −[t]−> (b)` + quantifier: optional::quantifier, // An optional quantifier for the entire pattern match, e.g. `{2,5}` in `MATCH (a:Account)−[:Transfer]−>{2,5}(b:Account)` + parts: list::[part], // The ordered pattern parts + _: [ + part::[ + // A single node in a graph pattern + node::{ + prefilter: optional::expr, // An optional node pre-filter, e.g.: `WHERE c.name='Alarm'` in `MATCH (c WHERE c.name='Alarm')` + variable: optional::string, // The optional element variable of the node match, e.g.: `x` in `MATCH (x)` + label: list::[string], // The optional label(s) to match for the node, e.g.: `Entity` in `MATCH (x:Entity)` + }, + // A single edge in a graph pattern + edge::{ + direction: direction, // Edge Direction + quantifier: optional::quantifier, // An optional quantifier for the entire pattern match, e.g. `{2,5}` in `MATCH (a:Account)−[:Transfer]−>{2,5}(b:Account)` + prefilter: optional::expr, // An optional edge pre-filter, e.g.: `WHERE t.capacity>100` in `MATCH −[t:hasSupply WHERE t.capacity>100]−>` + variable: optional::string, // The optional element variable of the edge match, e.g.: `t` in `MATCH −[t]−>` + label: list::[string], // The optional label(s) to match for the edge. e.g.: `Target` in `MATCH −[t:Target]−>` + }, + // A sub-pattern + pattern::{ + pattern: '.graph_match.pattern' + }, + ], + ], + }, + + // Edge Direction // | Orientation | Edge pattern | Abbreviation | + direction::[ // |---------------------------+--------------+--------------| + LEFT, // | Pointing left | <−[ spec ]− | <− | + UNDIRECTED, // | Undirected | ~[ spec ]~ | ~ | + RIGHT, // | Pointing right | −[ spec ]−> | −> | + LEFT_OR_UNDIRECTED, // | Left or undirected | <~[ spec ]~ | <~ | + UNDIRECTED_OR_RIGHT, // | Undirected or right | ~[ spec ]~> | ~> | + LEFT_OR_RIGHT, // | Left or right | <−[ spec ]−> | <−> | + LEFT_UNDIRECTED_OR_RIGHT, // | Left, undirected or right | −[ spec ]− | − | + ], // Fig. 5 — https://arxiv.org/abs/2112.06217 + + // Path Restrictor // | Keyword | Description + restrictor::[ // |----------------+-------------- + TRAIL, // | TRAIL | No repeated edges. + ACYCLIC, // | ACYCLIC | No repeated nodes. + SIMPLE, // | SIMPLE | No repeated nodes, except that the first and last nodes may be the same. + ], // Fig. 7 — https://arxiv.org/abs/2112.06217 + + // Graph Edge Quantifier (e.g., the `{2,5}` in `MATCH (x)->{2,5}(y)`) + quantifier::{ + lower: long, + upper: optional::long, + }, + + // Path Selector + selector::[ + any_shortest::{}, // ANY SHORTEST + all_shortest::{}, // ALL SHORTEST + any::{}, // ANY + any_k::{ k: long }, // ANY k + shortest_k::{ k: long }, // SHORTEST k + shortest_k_group::{ k: long }, // SHORTEST k GROUP + ], // Fig. 8 — https://arxiv.org/abs/2112.06217 + ], +} + +// LEGACY `ON CONFLICT WHERE ` +// `ON CONFLICT [] ` +on_conflict::{ + target: optional::target, + action: action, + + _: [ + // ::= ( [, ]... ) + // | ( { | } ) + // | ON CONSTRAINT + target::[ + symbols::{ + symbols: list::[identifier], + }, + constraint::{ + constraint: identifier, + }, + ], + + // ::= DO NOTHING + // | DO UPDATE + // | DO REPLACE + action::[ + do_nothing::{}, + do_replace::{ + condition: optional::expr, + }, + do_update::{ + condition: optional::expr, + }, + ], + ] +} + +// RETURNING returningColumn ( COMMA returningColumn )* +returning::{ + columns: list::[column], + _: [ + column::{ + status: [ MODIFIED, ALL ], + age: [ OLD, NEW ], + value: [ + wildcard::{}, + expression::{ expr: expr } + ], + }, + ], +} + +// ` *` +// `( CONSTRAINT )? ` +table_definition::{ + columns: list::[column], + _: [ + column::{ + name: string, + type: '.type', + constraints: list::[constraint], + _: [ + // TODO improve modeling language to avoid these wrapped unions + // Also, prefer not to nest more than twice + constraint::{ + name: optional::string, + body: [ + nullable::{}, + not_null::{}, + check::{ expr: expr }, + ], + }, + ], + }, + ], +} + +// SQL-99 Table 11 +datetime_field::[ + YEAR, // 0001-9999 + MONTH, // 01-12 + DAY, // 01-31 + HOUR, // 00-23 + MINUTE, // 00-59 + SECOND, // 00-61.9(N) + TIMEZONE_HOUR, // -12-13 + TIMEZONE_MINUTE, // -59-59 +] diff --git a/partiql-ast/src/test/kotlin/org/partiql/ast/helpers/ToLegacyAstTest.kt b/partiql-ast/src/test/kotlin/org/partiql/ast/helpers/ToLegacyAstTest.kt new file mode 100644 index 0000000000..ef99f949d7 --- /dev/null +++ b/partiql-ast/src/test/kotlin/org/partiql/ast/helpers/ToLegacyAstTest.kt @@ -0,0 +1,721 @@ +@file:OptIn(PartiQLValueExperimental::class) + +package org.partiql.ast.helpers + +import com.amazon.ion.Decimal +import com.amazon.ionelement.api.MetaContainer +import com.amazon.ionelement.api.ionBool +import com.amazon.ionelement.api.ionDecimal +import com.amazon.ionelement.api.ionFloat +import com.amazon.ionelement.api.ionInt +import com.amazon.ionelement.api.ionNull +import com.amazon.ionelement.api.ionString +import com.amazon.ionelement.api.ionSymbol +import com.amazon.ionelement.api.loadSingleElement +import org.junit.jupiter.api.Assertions.assertEquals +import org.junit.jupiter.api.parallel.Execution +import org.junit.jupiter.api.parallel.ExecutionMode +import org.junit.jupiter.params.ParameterizedTest +import org.junit.jupiter.params.provider.MethodSource +import org.partiql.ast.Ast +import org.partiql.ast.AstNode +import org.partiql.ast.Expr +import org.partiql.ast.From +import org.partiql.ast.GroupBy +import org.partiql.ast.Identifier +import org.partiql.ast.SetQuantifier +import org.partiql.ast.Sort +import org.partiql.ast.builder.AstBuilder +import org.partiql.ast.builder.AstFactory +import org.partiql.ast.builder.ast +import org.partiql.lang.domains.PartiqlAst +import org.partiql.value.PartiQLValueExperimental +import org.partiql.value.blobValue +import org.partiql.value.boolValue +import org.partiql.value.clobValue +import org.partiql.value.decimalValue +import org.partiql.value.float32Value +import org.partiql.value.float64Value +import org.partiql.value.int16Value +import org.partiql.value.int32Value +import org.partiql.value.int64Value +import org.partiql.value.int8Value +import org.partiql.value.intValue +import org.partiql.value.missingValue +import org.partiql.value.nullValue +import org.partiql.value.stringValue +import org.partiql.value.symbolValue +import java.math.BigDecimal +import java.math.BigInteger +import kotlin.test.assertFails + +/** + * Tests for translation of the org.partiql.ast.AstNode trees to org.partiql.lang.domains.PartiqlAst.AstNode trees. + * + * The `null` expression value is used extensively because we are testing single node structural translations. + * We don't want convoluted tests with deep trees. More complex tests are covered in end-to-end translation. + * + * Similarly, PartiqlAst.Identifier and Identifier nodes are avoided when their meaning doesn't matter. + * Scan of a string "table" is semantically different than the identifier with name 'table', but structurally + * it doesn't matter for testing where an arbitrary expression is used. Keep the tree shallow. + */ +class ToLegacyAstTest { + + @ParameterizedTest + @MethodSource("literals") + @Execution(ExecutionMode.CONCURRENT) + fun testLiterals(case: Case) = case.assert() + + @ParameterizedTest + @MethodSource("ion") + @Execution(ExecutionMode.CONCURRENT) + fun testIon(case: Case) = case.assert() + + @ParameterizedTest + @MethodSource("identifiers") + @Execution(ExecutionMode.CONCURRENT) + fun testVars(case: Case) = case.assert() + + @ParameterizedTest + @MethodSource("calls") + @Execution(ExecutionMode.CONCURRENT) + fun testCalls(case: Case) = case.assert() + + @ParameterizedTest + @MethodSource("operators") + @Execution(ExecutionMode.CONCURRENT) + fun testOperators(case: Case) = case.assert() + + @ParameterizedTest + @MethodSource("paths") + @Execution(ExecutionMode.CONCURRENT) + fun testPaths(case: Case) = case.assert() + + @ParameterizedTest + @MethodSource("collections") + @Execution(ExecutionMode.CONCURRENT) + fun testCollections(case: Case) = case.assert() + + @ParameterizedTest + @MethodSource("types") + @Execution(ExecutionMode.CONCURRENT) + fun testTypes(case: Case) = case.assert() + + @ParameterizedTest + @MethodSource("specialForms") + @Execution(ExecutionMode.CONCURRENT) + fun testSpecialForms(case: Case) = case.assert() + + @ParameterizedTest + @MethodSource("sfw") + @Execution(ExecutionMode.CONCURRENT) + fun testSfw(case: Case) = case.assert() + + companion object { + + private fun expect(expected: String, block: AstBuilder.() -> AstNode): Case { + val i = ast(AstFactory.DEFAULT, block) + val e = PartiqlAst.transform(loadSingleElement(expected)) + return Case.Translate(i, e) + } + + private fun fail(message: String, block: AstBuilder.() -> AstNode): Case { + val i = ast(AstFactory.DEFAULT, block) + return Case.Fail(i, message) + } + + private val NULL = Ast.exprLit(nullValue()) + + // Shortcut to construct a "legacy-compatible" simple identifier + private fun id(name: String) = Ast.identifierSymbol(name, Identifier.CaseSensitivity.INSENSITIVE) + + @JvmStatic + fun literals() = listOf( + expect("(lit null)") { + exprLit(nullValue()) + }, + expect("(missing)") { + exprLit(missingValue()) + }, + expect("(lit true)") { + exprLit(boolValue(true)) + }, + expect("(lit 1)") { + exprLit(int8Value(1)) + }, + expect("(lit 2)") { + exprLit(int16Value(2)) + }, + expect("(lit 3)") { + exprLit(int32Value(3)) + }, + expect("(lit 4)") { + exprLit(int64Value(4)) + }, + expect("(lit 5)") { + exprLit(intValue(BigInteger.valueOf(5))) + }, + expect("(lit 1.1e0)") { + exprLit(float32Value(1.1f)) + }, + expect("(lit 1.2e0)") { + exprLit(float64Value(1.2)) + }, + expect("(lit 1.3)") { + exprLit(decimalValue(BigDecimal.valueOf(1.3))) + }, + expect("""(lit "hello")""") { + exprLit(stringValue("hello")) + }, + expect("""(lit 'hello')""") { + exprLit(symbolValue("hello")) + }, + expect("""(lit {{ '''Hello''' '''World''' }})""") { + exprLit(clobValue("HelloWorld".toByteArray())) + }, + expect("""(lit {{ VG8gaW5maW5pdHkuLi4gYW5kIGJleW9uZCE= }})""") { + exprLit(blobValue("To infinity... and beyond!".toByteArray())) + }, + // TODO detailed tests just for _DateTime_ types + ) + + @JvmStatic + fun ion() = listOf( + expect("(lit null)") { + exprIon(ionNull()) + }, + expect("(lit true)") { + exprIon(ionBool(true)) + }, + expect("(lit 1)") { + exprIon(ionInt(1)) + }, + expect("(lit 1.2e0)") { + exprIon(ionFloat(1.2)) + }, + expect("(lit 1.3)") { + exprIon(ionDecimal(Decimal.valueOf(1.3))) + }, + expect("""(lit "hello")""") { + exprIon(ionString("hello")) + }, + expect("""(lit 'hello')""") { + exprIon(ionSymbol("hello")) + }, + // TODO detailed tests just for _DateTime_ types + ) + + @JvmStatic + fun identifiers() = listOf( + expect("(id 'a' (case_sensitive) (unqualified))") { + exprVar { + identifier = identifierSymbol("a", Identifier.CaseSensitivity.SENSITIVE) + scope = Expr.Var.Scope.DEFAULT + } + }, + expect("(id 'a' (case_insensitive) (unqualified))") { + exprVar { + identifier = identifierSymbol("a", Identifier.CaseSensitivity.INSENSITIVE) + scope = Expr.Var.Scope.DEFAULT + } + }, + expect("(id 'a' (case_sensitive) (locals_first))") { + exprVar { + identifier = identifierSymbol("a", Identifier.CaseSensitivity.SENSITIVE) + scope = Expr.Var.Scope.LOCAL + } + }, + expect("(id 'a' (case_insensitive) (locals_first))") { + exprVar { + identifier = identifierSymbol("a", Identifier.CaseSensitivity.INSENSITIVE) + scope = Expr.Var.Scope.LOCAL + } + }, + expect("(identifier 'a' (case_insensitive))") { + identifierSymbol("a", Identifier.CaseSensitivity.INSENSITIVE) + }, + // + fail("Cannot translate qualified identifiers in variable references") { + exprVar { + identifier = identifierQualified { + root = identifierSymbol("a", Identifier.CaseSensitivity.INSENSITIVE) + steps += identifierSymbol("b", Identifier.CaseSensitivity.INSENSITIVE) + } + scope = Expr.Var.Scope.LOCAL + } + } + ) + + @JvmStatic + fun calls() = listOf( + expect("(call 'a' (lit null))") { + exprCall { + function = id("a") + args += NULL + } + }, + expect("(call_agg (all) 'a' (lit null))") { + exprAgg { + function = id("a") + args += NULL + } + }, + expect("(call_agg (all) 'a' (lit null))") { + exprAgg { + setq = SetQuantifier.ALL + function = id("a") + args += NULL + } + }, + expect("(call_agg (distinct) 'a' (lit null))") { + exprAgg { + setq = SetQuantifier.DISTINCT + function = id("a") + args += NULL + } + }, + fail("Cannot translate `call_agg` with more than one argument") { + exprAgg { + function = id("a") + args += listOf(NULL, NULL) + } + }, + ) + + @JvmStatic + fun operators() = listOf( + expect("(not (lit null))") { + exprUnary { + op = Expr.Unary.Op.NOT + expr = NULL + } + }, + expect("(pos (lit null))") { + exprUnary { + op = Expr.Unary.Op.POS + expr = NULL + } + }, + expect("(neg (lit null))") { + exprUnary { + op = Expr.Unary.Op.NEG + expr = NULL + } + }, + // we don't really need to test _all_ binary operators + expect("(plus (lit null) (lit null))") { + exprBinary { + op = Expr.Binary.Op.PLUS + lhs = NULL + rhs = NULL + } + }, + ) + + @JvmStatic + fun paths() = listOf( + expect("(path (lit null) (path_expr (lit null) (case_sensitive)))") { + exprPath { + root = NULL + steps += exprPathStepIndex(NULL) + } + }, + expect("(path (lit null) (path_wildcard))") { + exprPath { + root = NULL + steps += exprPathStepWildcard() + } + }, + expect("(path (lit null) (path_unpivot))") { + exprPath { + root = NULL + steps += exprPathStepUnpivot() + } + }, + ) + + @JvmStatic + fun collections() = listOf( + expect("(bag (lit null))") { + exprCollection(Expr.Collection.Type.BAG) { + values += NULL + } + }, + expect("(list (lit null))") { + exprCollection(Expr.Collection.Type.ARRAY) { + values += NULL + } + }, + expect("(list (lit null))") { + exprCollection(Expr.Collection.Type.VALUES) { + values += NULL + } + }, + expect("(list (lit null))") { + exprCollection(Expr.Collection.Type.LIST) { + values += NULL + } + }, + expect("(sexp (lit null))") { + exprCollection(Expr.Collection.Type.SEXP) { + values += NULL + } + }, + expect("(struct (expr_pair (lit null) (lit null)))") { + exprStruct { + fields += exprStructField { + name = NULL + value = NULL + } + } + }, + ) + + @JvmStatic + fun types() = listOf( + // SQL + expect("(null_type)") { typeNullType() }, + expect("(boolean_type)") { typeBool() }, + expect("(smallint_type)") { typeSmallint() }, + expect("(integer_type)") { typeInt() }, + expect("(real_type)") { typeReal() }, + expect("(float_type null)") { typeFloat32() }, + expect("(double_precision_type)") { typeFloat64() }, + expect("(decimal_type null null)") { typeDecimal() }, + expect("(decimal_type 2 null)") { typeDecimal(2) }, + expect("(decimal_type 2 1)") { typeDecimal(2, 1) }, + expect("(numeric_type null null)") { typeNumeric() }, + expect("(numeric_type 2 null)") { typeNumeric(2) }, + expect("(numeric_type 2 1)") { typeNumeric(2, 1) }, + expect("(timestamp_type null)") { typeTimestamp() }, + expect("(timestamp_type 1)") { typeTimestamp(1) }, + expect("(character_type null)") { typeChar() }, + expect("(character_type 1)") { typeChar(1) }, + expect("(character_varying_type null)") { typeVarchar() }, + expect("(character_varying_type 1)") { typeVarchar(1) }, + expect("(blob_type)") { typeBlob() }, + expect("(clob_type)") { typeClob() }, + expect("(date_type)") { typeDate() }, + expect("(time_type null)") { typeTime() }, + expect("(time_type 1)") { typeTime(1) }, + expect("(time_with_time_zone_type null)") { typeTimeWithTz() }, + expect("(time_with_time_zone_type 1)") { typeTimeWithTz(1) }, + // PartiQL + expect("(missing_type)") { typeMissing() }, + expect("(string_type)") { typeString() }, + expect("(symbol_type)") { typeSymbol() }, + expect("(struct_type)") { typeStruct() }, + expect("(tuple_type)") { typeTuple() }, + expect("(list_type)") { typeList() }, + expect("(sexp_type)") { typeSexp() }, + expect("(bag_type)") { typeBag() }, + expect("(any_type)") { typeAny() }, + // Other (??) + expect("(integer4_type)") { typeInt4() }, + expect("(integer8_type)") { typeInt8() }, + expect("(custom_type dog)") { typeCustom("dog") } + // LEGACY AST does not have TIMESTAMP or INTERVAL + // LEGACY AST does not have parameterized blob/clob + ) + + @JvmStatic + fun specialForms() = listOf( + expect("(like (lit 'a') (lit 'b') null)") { + exprLike { + value = exprLit(symbolValue("a")) + pattern = exprLit(symbolValue("b")) + } + }, + expect("(like (lit 'a') (lit 'b') (lit 'c'))") { + exprLike { + value = exprLit(symbolValue("a")) + pattern = exprLit(symbolValue("b")) + escape = exprLit(symbolValue("c")) + } + }, + expect("(not (like (lit 'a') (lit 'b') (lit 'c')))") { + exprLike { + value = exprLit(symbolValue("a")) + pattern = exprLit(symbolValue("b")) + escape = exprLit(symbolValue("c")) + not = true + } + }, + expect("(between (lit 'a') (lit 'b') (lit 'c'))") { + exprBetween { + value = exprLit(symbolValue("a")) + from = exprLit(symbolValue("b")) + to = exprLit(symbolValue("c")) + } + }, + expect("(not (between (lit 'a') (lit 'b') (lit 'c')))") { + exprBetween { + value = exprLit(symbolValue("a")) + from = exprLit(symbolValue("b")) + to = exprLit(symbolValue("c")) + not = true + } + }, + expect("(in_collection (lit 'a') (lit 'b'))") { + exprInCollection { + lhs = exprLit(symbolValue("a")) + rhs = exprLit(symbolValue("b")) + } + }, + expect("(not (in_collection (lit 'a') (lit 'b')))") { + exprInCollection { + lhs = exprLit(symbolValue("a")) + rhs = exprLit(symbolValue("b")) + not = true + } + }, + expect("(is_type (lit 'a') (any_type))") { + exprIsType { + value = exprLit(symbolValue(("a"))) + type = typeAny() + } + }, + expect("(not (is_type (lit 'a') (any_type)))") { + exprIsType { + value = exprLit(symbolValue(("a"))) + type = typeAny() + not = true + } + }, + // TODO case + // TODO coalesce + // TODO nullif + // TODO substring + // TODO position + expect("""(call 'trim' (lit "xyz"))""") { + exprTrim { + value = exprLit(stringValue("xyz")) + } + }, + expect("""(call 'trim' (lit "xyz"))""") { + exprTrim { + value = exprLit(stringValue("xyz")) + } + }, + expect("""(call 'trim' (lit "xyz"))""") { + exprTrim { + value = exprLit(stringValue("xyz")) + } + }, + // TODO overlay + // TODO extract + // TODO cast + // TODO can_cast + // TODO can_lossless_cast + // TODO date_add + // TODO date_diff + ) + + @JvmStatic + fun sfw() = listOf( + // PROJECT Variants + expect("(project_star)") { + selectStar() + }, + expect( + """ + (project_list + (project_all (id 'a' (case_sensitive) (unqualified))) + (project_expr (lit 1) 'x') + ) + """ + ) { + selectProject { + items += selectProjectItemAll { + expr = + exprVar(identifierSymbol("a", Identifier.CaseSensitivity.SENSITIVE), Expr.Var.Scope.DEFAULT) + } + items += selectProjectItemExpression { + expr = exprLit(int32Value(1)) + asAlias = "x" + } + } + }, + expect("(project_pivot (lit 1) (lit 2))") { + selectPivot { + value = exprLit(int32Value(1)) + key = exprLit(int32Value(2)) + } + }, + expect("(project_value (lit null))") { + selectValue { + constructor = NULL + } + }, + // FROM_SOURCE Variants + expect("(scan (lit null) null null null)") { + fromValue { + expr = NULL + type = From.Value.Type.SCAN + } + }, + expect("(scan (lit null) 'a' 'b' 'c')") { + fromValue { + expr = NULL + type = From.Value.Type.SCAN + asAlias = "a" + atAlias = "b" + byAlias = "c" + } + }, + expect("(unpivot (lit null) null null null)") { + fromValue { + expr = NULL + type = From.Value.Type.UNPIVOT + } + }, + expect("(unpivot (lit null) 'a' 'b' 'c')") { + fromValue { + expr = NULL + type = From.Value.Type.UNPIVOT + asAlias = "a" + atAlias = "b" + byAlias = "c" + } + }, + expect( + """ + (join (inner) + (scan (lit "lhs") null null null) + (scan (lit "rhs") null null null) + null + ) + """ + ) { + fromJoin { + type = From.Join.Type.INNER + lhs = fromValue { + expr = exprLit(stringValue("lhs")) + type = From.Value.Type.SCAN + } + rhs = fromValue { + expr = exprLit(stringValue("rhs")) + type = From.Value.Type.SCAN + } + } + }, + expect( + """ + (join (inner) + (scan (lit "lhs") null null null) + (scan (lit "rhs") null null null) + (lit true) + ) + """ + ) { + fromJoin { + // DEFAULT + // type = From.Join.Type.INNER + lhs = fromValue { + expr = exprLit(stringValue("lhs")) + type = From.Value.Type.SCAN + } + rhs = fromValue { + expr = exprLit(stringValue("rhs")) + type = From.Value.Type.SCAN + } + condition = exprLit(boolValue(true)) + } + }, + expect("(let (let_binding (lit null) 'x'))") { + let { + bindings += letBinding { + expr = NULL + asAlias = "x" + } + } + }, + expect( + """ + (group_by (group_full) + (group_key_list + (group_key (lit "a") null) + (group_key (lit "b") 'x') + ) + null + ) + """ + ) { + groupBy { + strategy = GroupBy.Strategy.FULL + keys += groupByKey(exprLit(stringValue("a")), null) + keys += groupByKey(exprLit(stringValue("b")), "x") + } + }, + expect( + """ + (group_by (group_partial) + (group_key_list + (group_key (lit "a") null) + (group_key (lit "b") 'x') + ) + 'as' + ) + """ + ) { + groupBy { + strategy = GroupBy.Strategy.PARTIAL + keys += groupByKey(exprLit(stringValue("a")), null) + keys += groupByKey(exprLit(stringValue("b")), "x") + asAlias = "as" + } + }, + expect( + """ + (order_by + (sort_spec (lit "a") null null) + (sort_spec (lit "b") (asc) (nulls_first)) + (sort_spec (lit "c") (asc) (nulls_last)) + (sort_spec (lit "d") (desc) (nulls_first)) + (sort_spec (lit "e") (desc) (nulls_last)) + ) + """ + ) { + orderBy { + sorts += sort(exprLit(stringValue("a"))) + sorts += sort(exprLit(stringValue("b")), Sort.Dir.ASC, Sort.Nulls.FIRST) + sorts += sort(exprLit(stringValue("c")), Sort.Dir.ASC, Sort.Nulls.LAST) + sorts += sort(exprLit(stringValue("d")), Sort.Dir.DESC, Sort.Nulls.FIRST) + sorts += sort(exprLit(stringValue("e")), Sort.Dir.DESC, Sort.Nulls.LAST) + } + }, + ) + } + + sealed class Case { + + abstract fun assert() + + class Translate( + private val input: AstNode, + private val expected: PartiqlAst.PartiqlAstNode, + private val metas: Map = emptyMap(), + ) : Case() { + + override fun assert() { + val actual = input.toLegacyAst(metas) + val aIon = actual.toIonElement() + val eIon = expected.toIonElement() + assertEquals(eIon, aIon) + } + } + + class Fail( + private val input: AstNode, + private val message: String, + private val metas: Map = emptyMap(), + ) : Case() { + + override fun assert() { + assertFails(message) { + input.toLegacyAst(metas) + } + } + } + } +} diff --git a/partiql-lang/build.gradle.kts b/partiql-lang/build.gradle.kts index 2168eb5fd3..43642bd51d 100644 --- a/partiql-lang/build.gradle.kts +++ b/partiql-lang/build.gradle.kts @@ -33,13 +33,13 @@ kotlin { dependencies { api(project(":partiql-ast")) + api(project(":partiql-parser")) api(project(":partiql-spi")) api(project(":partiql-types")) api(Deps.ionElement) api(Deps.ionJava) api(Deps.ionSchema) // libs are included in partiql-lang-kotlin JAR, but are not published independently yet. - libs(project(":partiql-parser")) libs(project(":partiql-plan")) implementation(Deps.antlrRuntime) implementation(Deps.csv) diff --git a/partiql-lang/src/main/kotlin/org/partiql/lang/prettyprint/QueryPrettyPrinter.kt b/partiql-lang/src/main/kotlin/org/partiql/lang/prettyprint/QueryPrettyPrinter.kt index 788954e78a..1fa489de4a 100644 --- a/partiql-lang/src/main/kotlin/org/partiql/lang/prettyprint/QueryPrettyPrinter.kt +++ b/partiql-lang/src/main/kotlin/org/partiql/lang/prettyprint/QueryPrettyPrinter.kt @@ -333,7 +333,7 @@ class QueryPrettyPrinter { } private fun writeSessionAttribute(node: PartiqlAst.Expr.SessionAttribute, sb: StringBuilder) { - sb.append(node.value.text) + sb.append(node.value.text.uppercase()) } /** diff --git a/partiql-lang/src/main/kotlin/org/partiql/lang/syntax/PartiQLParserBuilder.kt b/partiql-lang/src/main/kotlin/org/partiql/lang/syntax/PartiQLParserBuilder.kt index 330f4c4be5..b1537c2ece 100644 --- a/partiql-lang/src/main/kotlin/org/partiql/lang/syntax/PartiQLParserBuilder.kt +++ b/partiql-lang/src/main/kotlin/org/partiql/lang/syntax/PartiQLParserBuilder.kt @@ -15,6 +15,7 @@ package org.partiql.lang.syntax import org.partiql.lang.syntax.impl.PartiQLPigParser +import org.partiql.lang.syntax.impl.PartiQLShimParser import org.partiql.lang.types.CustomType /** @@ -29,12 +30,25 @@ import org.partiql.lang.types.CustomType */ class PartiQLParserBuilder { + private var constructor: (customTypes: List) -> Parser = ::PartiQLPigParser + companion object { @JvmStatic fun standard(): PartiQLParserBuilder { return PartiQLParserBuilder() } + + @JvmStatic + fun experimental(): PartiQLParserBuilder { + val builder = PartiQLParserBuilder() + builder.constructor = { _ -> + // currently don't pass custom types + val delegate = org.partiql.parser.PartiQLParserBuilder.standard().build() + PartiQLShimParser(delegate) + } + return builder + } } private var customTypes: List = emptyList() @@ -44,6 +58,6 @@ class PartiQLParserBuilder { } fun build(): Parser { - return PartiQLPigParser(this.customTypes) + return constructor(this.customTypes) } } diff --git a/partiql-lang/src/main/kotlin/org/partiql/lang/syntax/impl/PartiQLPigVisitor.kt b/partiql-lang/src/main/kotlin/org/partiql/lang/syntax/impl/PartiQLPigVisitor.kt index ac74588c37..c900b52157 100644 --- a/partiql-lang/src/main/kotlin/org/partiql/lang/syntax/impl/PartiQLPigVisitor.kt +++ b/partiql-lang/src/main/kotlin/org/partiql/lang/syntax/impl/PartiQLPigVisitor.kt @@ -112,7 +112,7 @@ internal class PartiQLPigVisitor( override fun visitExprTermCurrentUser(ctx: PartiQLParser.ExprTermCurrentUserContext): PartiqlAst.Expr.SessionAttribute { val metas = ctx.CURRENT_USER().getSourceMetaContainer() return PartiqlAst.Expr.SessionAttribute( - value = SymbolPrimitive(ctx.CURRENT_USER().text, metas), + value = SymbolPrimitive(ctx.CURRENT_USER().text.toLowerCase(), metas), metas = metas ) } diff --git a/partiql-lang/src/main/kotlin/org/partiql/lang/syntax/impl/PartiQLShimParser.kt b/partiql-lang/src/main/kotlin/org/partiql/lang/syntax/impl/PartiQLShimParser.kt new file mode 100644 index 0000000000..6c5c63ae56 --- /dev/null +++ b/partiql-lang/src/main/kotlin/org/partiql/lang/syntax/impl/PartiQLShimParser.kt @@ -0,0 +1,89 @@ +package org.partiql.lang.syntax.impl + +import com.amazon.ion.system.IonSystemBuilder +import com.amazon.ionelement.api.MetaContainer +import org.partiql.ast.helpers.toLegacyAst +import org.partiql.lang.ast.SourceLocationMeta +import org.partiql.lang.domains.PartiqlAst +import org.partiql.lang.domains.metaContainerOf +import org.partiql.lang.errors.ErrorCode +import org.partiql.lang.errors.Property +import org.partiql.lang.errors.PropertyValueMap +import org.partiql.lang.syntax.Parser +import org.partiql.lang.syntax.ParserException +import org.partiql.parser.PartiQLLexerException +import org.partiql.parser.PartiQLParser +import org.partiql.parser.PartiQLParserException +import org.partiql.parser.SourceLocations + +/** + * Implementation of [Parser] which uses a [org.partiql.ast.AstNode] tree, then translates to the legacy interface. + * + * @property delegate + */ +internal class PartiQLShimParser( + private val delegate: PartiQLParser, +) : Parser { + + // required for PropertyValueMap debug information + private val ion = IonSystemBuilder.standard().build() + + override fun parseAstStatement(source: String): PartiqlAst.Statement { + val result = try { + delegate.parse(source) + } catch (ex: PartiQLLexerException) { + throw ex.shim() + } catch (ex: PartiQLParserException) { + throw ex.shim() + } + val statement = try { + val metas = result.locations.toMetas() + result.root.toLegacyAst(metas) + } catch (ex: Exception) { + throw ParserException( + message = ex.message ?: "", + errorCode = ErrorCode.PARSE_INVALID_QUERY, + cause = ex, + ) + } + if (statement !is PartiqlAst.Statement) { + throw ParserException( + message = "Expected statement, got ${statement::class.qualifiedName}", + errorCode = ErrorCode.PARSE_INVALID_QUERY, + ) + } + return statement + } + + /** + * The legacy parser tests assert on ParserExcept, not LexerException. + */ + private fun PartiQLLexerException.shim(): ParserException { + val ctx = PropertyValueMap() + ctx[Property.LINE_NUMBER] = location.line.toLong() + ctx[Property.COLUMN_NUMBER] = location.offset.toLong() + ctx[Property.TOKEN_STRING] = token + ctx[Property.TOKEN_DESCRIPTION] = tokenType + ctx[Property.TOKEN_VALUE] = ion.newSymbol(token) + return ParserException(message, ErrorCode.PARSE_UNEXPECTED_TOKEN, ctx, cause) + } + + private fun PartiQLParserException.shim(): ParserException { + val ctx = PropertyValueMap() + ctx[Property.LINE_NUMBER] = location.line.toLong() + ctx[Property.COLUMN_NUMBER] = location.offset.toLong() + ctx[Property.TOKEN_DESCRIPTION] = tokenType + ctx[Property.TOKEN_VALUE] = ion.newSymbol(token) + return ParserException(message, ErrorCode.PARSE_UNEXPECTED_TOKEN, ctx, cause) + } + + private fun SourceLocations.toMetas(): Map = mapValues { + metaContainerOf( + SourceLocationMeta( + lineNum = it.value.line.toLong(), + charOffset = it.value.offset.toLong(), + length = it.value.lengthLegacy.toLong(), + ) + ) + } +} diff --git a/partiql-lang/src/test/kotlin/org/partiql/lang/ast/passes/StatementRedactorTest.kt b/partiql-lang/src/test/kotlin/org/partiql/lang/ast/passes/StatementRedactorTest.kt index 0275ff08f4..5b315b3ffe 100644 --- a/partiql-lang/src/test/kotlin/org/partiql/lang/ast/passes/StatementRedactorTest.kt +++ b/partiql-lang/src/test/kotlin/org/partiql/lang/ast/passes/StatementRedactorTest.kt @@ -43,7 +43,12 @@ class StatementRedactorTest : PartiQLParserTestBase() { * Return true if the parsed results of input [statement] is the same as input [ast] */ private fun validateInputAstParsedFromInputStatement(statement: String, ast: PartiqlAst.Statement): Boolean { - return parser.parseAstStatement(statement) == ast + targets.forEach { target -> + if (target.parser.parseAstStatement(statement) != ast) { + return false + } + } + return true } @Test @@ -404,23 +409,23 @@ class StatementRedactorTest : PartiQLParserTestBase() { ) @Test - fun testDefaultArguments() { + fun testDefaultArguments(): Unit = forEachTarget { val originalStatement = "SELECT * FROM tb WHERE hk = 1 AND begins_with(Attr, 'foo', bar)" val expectedRedactedStatement = "SELECT * FROM tb WHERE hk = ***(Redacted) AND begins_with(Attr, ***(Redacted), bar)" - val redactedStatement1 = redact(originalStatement, super.parser.parseAstStatement(originalStatement)) + val redactedStatement1 = redact(originalStatement, parser.parseAstStatement(originalStatement)) assertEquals(expectedRedactedStatement, redactedStatement1) - val redactedStatement2 = redact(originalStatement, super.parser.parseAstStatement(originalStatement), providedSafeFieldNames = emptySet()) + val redactedStatement2 = redact(originalStatement, parser.parseAstStatement(originalStatement), providedSafeFieldNames = emptySet()) assertEquals(expectedRedactedStatement, redactedStatement2) - val redactedStatement3 = redact(originalStatement, super.parser.parseAstStatement(originalStatement), userDefinedFunctionRedactionConfig = emptyMap()) + val redactedStatement3 = redact(originalStatement, parser.parseAstStatement(originalStatement), userDefinedFunctionRedactionConfig = emptyMap()) assertEquals(expectedRedactedStatement, redactedStatement3) } @Test - fun testInputStatementAstMismatch() { + fun testInputStatementAstMismatch() = forEachTarget { val inputStatement = "SELECT * FROM tb WHERE nonKey = 'a'" - val inputAst = super.parser.parseAstStatement("SELECT * FROM tb WHERE hk = 1 AND nonKey = 'a'") + val inputAst = parser.parseAstStatement("SELECT * FROM tb WHERE hk = 1 AND nonKey = 'a'") assertFalse(validateInputAstParsedFromInputStatement(inputStatement, inputAst)) try { diff --git a/partiql-lang/src/test/kotlin/org/partiql/lang/errors/WindowRelatedParserErrorsTest.kt b/partiql-lang/src/test/kotlin/org/partiql/lang/errors/WindowRelatedParserErrorsTest.kt deleted file mode 100644 index 9ffd017c9a..0000000000 --- a/partiql-lang/src/test/kotlin/org/partiql/lang/errors/WindowRelatedParserErrorsTest.kt +++ /dev/null @@ -1,65 +0,0 @@ -package org.partiql.lang.errors - -import org.junit.Test -import org.partiql.lang.syntax.PartiQLParserTestBase -import org.partiql.lang.util.getAntlrDisplayString -import org.partiql.parser.antlr.PartiQLParser - -// TODO: move this to [ParserErrorsTest] once https://github.com/partiql/partiql-docs/issues/31 is resolved and a RFC is approved - -class WindowRelatedParserErrorsTest : PartiQLParserTestBase() { - @Test - fun lagWithoutOrderBy() { - checkInputThrowingParserException( - "SELECT lag(a) OVER () FROM b", - ErrorCode.PARSE_EXPECTED_WINDOW_ORDER_BY, - mapOf( - Property.LINE_NUMBER to 1L, - Property.COLUMN_NUMBER to 8L, - Property.TOKEN_DESCRIPTION to PartiQLParser.LAG.getAntlrDisplayString(), - Property.TOKEN_VALUE to ion.newSymbol("lag") - ) - ) - } - - @Test - fun lagWrongNumberOfParameter() { - checkInputThrowingParserException( - "SELECT lag(a,b,c,d) OVER (ORDER BY e) FROM f", - ErrorCode.PARSE_UNEXPECTED_TOKEN, - mapOf( - Property.LINE_NUMBER to 1L, - Property.COLUMN_NUMBER to 17L, - Property.TOKEN_DESCRIPTION to PartiQLParser.COMMA.getAntlrDisplayString(), - Property.TOKEN_VALUE to ion.newSymbol(",") - ) - ) - } - - fun leadWithoutOrderBy() { - checkInputThrowingParserException( - "SELECT lead(a) OVER () FROM b", - ErrorCode.PARSE_EXPECTED_WINDOW_ORDER_BY, - mapOf( - Property.LINE_NUMBER to 1L, - Property.COLUMN_NUMBER to 8L, - Property.TOKEN_DESCRIPTION to PartiQLParser.LAG.getAntlrDisplayString(), - Property.TOKEN_VALUE to ion.newSymbol("lag") - ) - ) - } - - @Test - fun leadWrongNumberOfParameter() { - checkInputThrowingParserException( - "SELECT lead(a,b,c,d) OVER (ORDER BY e) FROM f", - ErrorCode.PARSE_UNEXPECTED_TOKEN, - mapOf( - Property.LINE_NUMBER to 1L, - Property.COLUMN_NUMBER to 18L, - Property.TOKEN_DESCRIPTION to PartiQLParser.COMMA.getAntlrDisplayString(), - Property.TOKEN_VALUE to ion.newSymbol(",") - ) - ) - } -} diff --git a/partiql-lang/src/test/kotlin/org/partiql/lang/eval/visitors/AggregateSupportVisitorTransformTests.kt b/partiql-lang/src/test/kotlin/org/partiql/lang/eval/visitors/AggregateSupportVisitorTransformTests.kt index 0c95529723..b49114c42b 100644 --- a/partiql-lang/src/test/kotlin/org/partiql/lang/eval/visitors/AggregateSupportVisitorTransformTests.kt +++ b/partiql-lang/src/test/kotlin/org/partiql/lang/eval/visitors/AggregateSupportVisitorTransformTests.kt @@ -24,9 +24,13 @@ import org.partiql.lang.ast.AggregateCallSiteListMeta import org.partiql.lang.ast.AggregateRegisterIdMeta import org.partiql.lang.ast.SourceLocationMeta import org.partiql.lang.domains.PartiqlAst +import org.partiql.lang.syntax.PartiQLParserBuilder import org.partiql.lang.util.ArgumentsProviderBase class AggregateSupportVisitorTransformTests : VisitorTransformTestBase() { + + private val parser = PartiQLParserBuilder.standard().build() + private val transformer = AggregateSupportVisitorTransform() data class AggSupportTestCase(val query: String, val expectedCallAggs: List>) @@ -37,7 +41,7 @@ class AggregateSupportVisitorTransformTests : VisitorTransformTestBase() { */ private fun String.parseAndTransformQuery(): PartiqlAst.Expr.Select { val query = this - val statement = super.parser.parseAstStatement(query) + val statement = parser.parseAstStatement(query) val transformedNode = (transformer).transformStatement(statement) as PartiqlAst.Statement.Query return (transformedNode.expr) as PartiqlAst.Expr.Select } diff --git a/partiql-lang/src/test/kotlin/org/partiql/lang/eval/visitors/StaticTypeInferenceVisitorTransformTest.kt b/partiql-lang/src/test/kotlin/org/partiql/lang/eval/visitors/StaticTypeInferenceVisitorTransformTest.kt index e63ef75758..9acb7c905f 100644 --- a/partiql-lang/src/test/kotlin/org/partiql/lang/eval/visitors/StaticTypeInferenceVisitorTransformTest.kt +++ b/partiql-lang/src/test/kotlin/org/partiql/lang/eval/visitors/StaticTypeInferenceVisitorTransformTest.kt @@ -165,7 +165,7 @@ class StaticTypeInferenceVisitorTransformTest : VisitorTransformTestBase() { @MethodSource("parametersForAggFunctionTests") fun aggFunctionTests(tc: TestCase) = runTest(tc) - private fun runTest(tc: TestCase) { + private fun runTest(tc: TestCase) = forEachTarget { val globalBindings = Bindings.ofMap(tc.globals) val ion = IonSystemBuilder.standard().build() val inferencer = StaticTypeInferencer( @@ -176,7 +176,7 @@ class StaticTypeInferenceVisitorTransformTest : VisitorTransformTestBase() { val defaultVisitorTransforms = basicVisitorTransforms() val staticTypeVisitorTransform = StaticTypeVisitorTransform(ion, globalBindings) - val originalStatement = parse(tc.originalSql).let { + val originalStatement = parser.parseAstStatement(tc.originalSql).let { // We always pass the query under test through all of the basic VisitorTransforms primarily because we need // FromSourceAliasVisitorTransform to execute first but also to help ensure the queries we're testing // make sense when they're all run. diff --git a/partiql-lang/src/test/kotlin/org/partiql/lang/eval/visitors/StaticTypeVisitorTransformTests.kt b/partiql-lang/src/test/kotlin/org/partiql/lang/eval/visitors/StaticTypeVisitorTransformTests.kt index bb829bad79..ddd9e0cfa0 100644 --- a/partiql-lang/src/test/kotlin/org/partiql/lang/eval/visitors/StaticTypeVisitorTransformTests.kt +++ b/partiql-lang/src/test/kotlin/org/partiql/lang/eval/visitors/StaticTypeVisitorTransformTests.kt @@ -975,7 +975,7 @@ class StaticTypeVisitorTransformTests : VisitorTransformTestBase() { private fun runSTRTest( tc: STRTestCase - ) { + ): Unit = forEachTarget { val globalBindings = Bindings.ofMap(tc.globals) val transformer = StaticTypeVisitorTransform(ion, globalBindings, tc.constraints) @@ -983,7 +983,7 @@ class StaticTypeVisitorTransformTests : VisitorTransformTestBase() { // FromSourceAliasVisitorTransform to execute first but also to help ensure the queries we're testing // make sense when they're all run. val defaultTransforms = basicVisitorTransforms() - val originalAst = defaultTransforms.transformStatement(parse(tc.originalSql)) + val originalAst = defaultTransforms.transformStatement(parser.parseAstStatement(tc.originalSql)) val transformedAst = try { transformer.transformStatement(originalAst) diff --git a/partiql-lang/src/test/kotlin/org/partiql/lang/eval/visitors/VisitorTransformTestBase.kt b/partiql-lang/src/test/kotlin/org/partiql/lang/eval/visitors/VisitorTransformTestBase.kt index 8dd19ee862..2cd73ff1bd 100644 --- a/partiql-lang/src/test/kotlin/org/partiql/lang/eval/visitors/VisitorTransformTestBase.kt +++ b/partiql-lang/src/test/kotlin/org/partiql/lang/eval/visitors/VisitorTransformTestBase.kt @@ -100,14 +100,14 @@ abstract class VisitorTransformTestBase : PartiQLParserTestBase() { assertEquals("The expected AST must match the transformed AST", tc.expected, actualStatement) } - protected fun runErrorTest(tc: TransformErrorTestCase, transform: PartiqlAst.VisitorTransform) { + protected fun runErrorTest(tc: TransformErrorTestCase, transform: PartiqlAst.VisitorTransform): Unit = forEachTarget { val ex = assertThrowsSqlException( EvaluatorTestFailureReason.EXPECTED_SQL_EXCEPTION_BUT_THERE_WAS_NONE, { tc.testDetails() } ) { val ast = org.junit.jupiter.api.assertDoesNotThrow("Parsing Original SQL") { - this.parser.parseAstStatement(tc.query) + parser.parseAstStatement(tc.query) } transform.transformStatement(ast) } diff --git a/partiql-lang/src/test/kotlin/org/partiql/lang/prettyprint/QueryPrettyPrinterTest.kt b/partiql-lang/src/test/kotlin/org/partiql/lang/prettyprint/QueryPrettyPrinterTest.kt index cb16bc0abf..e5e6e9ee35 100644 --- a/partiql-lang/src/test/kotlin/org/partiql/lang/prettyprint/QueryPrettyPrinterTest.kt +++ b/partiql-lang/src/test/kotlin/org/partiql/lang/prettyprint/QueryPrettyPrinterTest.kt @@ -915,7 +915,7 @@ class QueryPrettyPrinterTest { fun checkCurrentUserMixedCase() { checkPrettyPrintQuery( query = "CURRENT_user", - expected = "CURRENT_user" + expected = "CURRENT_USER" ) } @@ -925,7 +925,7 @@ class QueryPrettyPrinterTest { query = "SELECT * FROM [ CURRENT_user ]", expected = """ SELECT * - FROM [ CURRENT_user ] + FROM [ CURRENT_USER ] """.trimIndent() ) } diff --git a/partiql-lang/src/test/kotlin/org/partiql/lang/syntax/PartiQLParserCastTests.kt b/partiql-lang/src/test/kotlin/org/partiql/lang/syntax/PartiQLParserCastTests.kt index 489ba962b7..442b81f646 100644 --- a/partiql-lang/src/test/kotlin/org/partiql/lang/syntax/PartiQLParserCastTests.kt +++ b/partiql-lang/src/test/kotlin/org/partiql/lang/syntax/PartiQLParserCastTests.kt @@ -12,6 +12,8 @@ import org.partiql.lang.util.ArgumentsProviderBase class PartiQLParserCastTests : PartiQLParserTestBase() { + override val targets: Array = arrayOf(ParserTarget.DEFAULT, ParserTarget.EXPERIMENTAL) + @ParameterizedTest @ArgumentsSource(ConfiguredCastArguments::class) fun configuredCast(configuredCastCase: ConfiguredCastParseTest) = configuredCastCase.assertCase() diff --git a/partiql-lang/src/test/kotlin/org/partiql/lang/syntax/PartiQLParserCorrelatedJoinTests.kt b/partiql-lang/src/test/kotlin/org/partiql/lang/syntax/PartiQLParserCorrelatedJoinTests.kt index ff9ccf1e54..86b0ec3be7 100644 --- a/partiql-lang/src/test/kotlin/org/partiql/lang/syntax/PartiQLParserCorrelatedJoinTests.kt +++ b/partiql-lang/src/test/kotlin/org/partiql/lang/syntax/PartiQLParserCorrelatedJoinTests.kt @@ -5,6 +5,9 @@ import org.partiql.lang.domains.PartiqlAst import org.partiql.lang.domains.id class PartiQLParserCorrelatedJoinTests : PartiQLParserTestBase() { + + override val targets: Array = arrayOf(ParserTarget.DEFAULT, ParserTarget.EXPERIMENTAL) + private fun PartiqlAst.Builder.callFWithS() = call("f", id("s", caseInsensitive(), unqualified())) diff --git a/partiql-lang/src/test/kotlin/org/partiql/lang/syntax/PartiQLParserDateTimeTests.kt b/partiql-lang/src/test/kotlin/org/partiql/lang/syntax/PartiQLParserDateTimeTests.kt index 8475ee32df..404658a73b 100644 --- a/partiql-lang/src/test/kotlin/org/partiql/lang/syntax/PartiQLParserDateTimeTests.kt +++ b/partiql-lang/src/test/kotlin/org/partiql/lang/syntax/PartiQLParserDateTimeTests.kt @@ -16,17 +16,24 @@ import org.partiql.parser.antlr.PartiQLParser class PartiQLParserDateTimeTests : PartiQLParserTestBase() { + // TODO we do not model precision within the expression node + // For example, TIME (0) WITH TIME ZONE '23:59:59.123456789'` will have precision 0 which means + // the underlying literal value does not preserve the extraneous places. + // We should consider preserving the AST exactly as is text. + // override val targets: Array = arrayOf(ParserTarget.DEFAULT, ParserTarget.EXPERIMENTAL) + override val targets: Array = arrayOf(ParserTarget.DEFAULT) + data class DateTimeTestCase( val source: String, val skipTest: Boolean = false, - val block: PartiqlAst.Builder.() -> PartiqlAst.PartiqlAstNode + val block: PartiqlAst.Builder.() -> PartiqlAst.PartiqlAstNode, ) data class ErrorDateTimeTestCase( val source: String, val errorCode: ErrorCode, val ctx: Map, - val skipTest: Boolean = false + val skipTest: Boolean = false, ) companion object { @@ -1217,7 +1224,7 @@ class PartiQLParserDateTimeTests : PartiQLParserTestBase() { col: Long, tokenType: Int, tokenValue: IonValue, - skipTest: Boolean = false + skipTest: Boolean = false, ): ErrorDateTimeTestCase { val displayTokenType = tokenType.getAntlrDisplayString() val ctx = mapOf( @@ -1256,7 +1263,8 @@ class PartiQLParserDateTimeTests : PartiQLParserTestBase() { checkInputThrowingParserException( tc.source, tc.errorCode, - tc.ctx + tc.ctx, + assertContext = false, ) } } diff --git a/partiql-lang/src/test/kotlin/org/partiql/lang/syntax/PartiQLParserExplainTest.kt b/partiql-lang/src/test/kotlin/org/partiql/lang/syntax/PartiQLParserExplainTest.kt index ace137b975..86e1c03b77 100644 --- a/partiql-lang/src/test/kotlin/org/partiql/lang/syntax/PartiQLParserExplainTest.kt +++ b/partiql-lang/src/test/kotlin/org/partiql/lang/syntax/PartiQLParserExplainTest.kt @@ -26,6 +26,8 @@ import org.partiql.parser.antlr.PartiQLParser class PartiQLParserExplainTest : PartiQLParserTestBase() { + override val targets: Array = arrayOf(ParserTarget.DEFAULT, ParserTarget.EXPERIMENTAL) + data class ParserTestCase( val description: String? = null, val query: String, @@ -45,7 +47,7 @@ class PartiQLParserExplainTest : PartiQLParserTestBase() { @ArgumentsSource(ErrorTestProvider::class) @ParameterizedTest - fun errorTests(tc: ParserErrorTestCase) = checkInputThrowingParserException(tc.query, tc.code, tc.context) + fun errorTests(tc: ParserErrorTestCase) = checkInputThrowingParserException(tc.query, tc.code, tc.context, assertContext = false) class SuccessTestProvider : ArgumentsProviderBase() { override fun getParameters() = listOf( @@ -145,7 +147,7 @@ class PartiQLParserExplainTest : PartiQLParserTestBase() { Property.COLUMN_NUMBER to 36L, Property.TOKEN_DESCRIPTION to PartiQLParser.EOF.getAntlrDisplayString(), Property.TOKEN_VALUE to ION.newSymbol("EOF") - ) + ), ), ParserErrorTestCase( description = "Setting option twice", diff --git a/partiql-lang/src/test/kotlin/org/partiql/lang/syntax/PartiQLParserJoinTest.kt b/partiql-lang/src/test/kotlin/org/partiql/lang/syntax/PartiQLParserJoinTest.kt index 4154e6c983..7e198be3cf 100644 --- a/partiql-lang/src/test/kotlin/org/partiql/lang/syntax/PartiQLParserJoinTest.kt +++ b/partiql-lang/src/test/kotlin/org/partiql/lang/syntax/PartiQLParserJoinTest.kt @@ -7,6 +7,9 @@ import org.partiql.lang.domains.PartiqlAst import org.partiql.lang.domains.id class PartiQLParserJoinTest : PartiQLParserTestBase() { + + override val targets: Array = arrayOf(ParserTarget.DEFAULT, ParserTarget.EXPERIMENTAL) + private val projectX = PartiqlAst.build { projectList(projectExpr(id("x"))) } private fun PartiqlAst.Builder.selectWithOneJoin( diff --git a/partiql-lang/src/test/kotlin/org/partiql/lang/syntax/PartiQLParserMatchTest.kt b/partiql-lang/src/test/kotlin/org/partiql/lang/syntax/PartiQLParserMatchTest.kt index 252683d54f..f80aae127b 100644 --- a/partiql-lang/src/test/kotlin/org/partiql/lang/syntax/PartiQLParserMatchTest.kt +++ b/partiql-lang/src/test/kotlin/org/partiql/lang/syntax/PartiQLParserMatchTest.kt @@ -11,6 +11,8 @@ import kotlin.test.assertFailsWith class PartiQLParserMatchTest : PartiQLParserTestBase() { + override val targets: Array = arrayOf(ParserTarget.DEFAULT, ParserTarget.EXPERIMENTAL) + @Test fun loneMatchExpr1path() = assertExpression( "(MyGraph MATCH (x))" diff --git a/partiql-lang/src/test/kotlin/org/partiql/lang/syntax/PartiQLParserMetaTests.kt b/partiql-lang/src/test/kotlin/org/partiql/lang/syntax/PartiQLParserMetaTests.kt index b7a71a06bc..e3f7105011 100644 --- a/partiql-lang/src/test/kotlin/org/partiql/lang/syntax/PartiQLParserMetaTests.kt +++ b/partiql-lang/src/test/kotlin/org/partiql/lang/syntax/PartiQLParserMetaTests.kt @@ -6,28 +6,30 @@ import org.partiql.lang.domains.PartiqlAst internal class PartiQLParserMetaTests : PartiQLParserTestBase() { + override val targets: Array = arrayOf(ParserTarget.DEFAULT, ParserTarget.EXPERIMENTAL) + @Test - fun listParenthesized() { + fun listParenthesized(): Unit = forEachTarget { val query = "(0, 1, 2)" - val ast = parse(query) as PartiqlAst.Statement.Query + val ast = parser.parseAstStatement(query) as PartiqlAst.Statement.Query val list = ast.expr as PartiqlAst.Expr.List assert(list.metas.containsKey(IsListParenthesizedMeta.tag)) } @Test - fun listParenthesizedNot() { + fun listParenthesizedNot(): Unit = forEachTarget { val query = "[0, 1, 2]" - val ast = parse(query) as PartiqlAst.Statement.Query + val ast = parser.parseAstStatement(query) as PartiqlAst.Statement.Query val list = ast.expr as PartiqlAst.Expr.List assert(list.metas.containsKey(IsListParenthesizedMeta.tag).not()) } @Test - fun inListParenthesized() { + fun inListParenthesized(): Unit = forEachTarget { val query = "0 IN (0, 1, 2)" - val ast = parse(query) as PartiqlAst.Statement.Query + val ast = parser.parseAstStatement(query) as PartiqlAst.Statement.Query val inCollection = ast.expr as PartiqlAst.Expr.InCollection val list = inCollection.operands[1] as PartiqlAst.Expr.List @@ -35,9 +37,9 @@ internal class PartiQLParserMetaTests : PartiQLParserTestBase() { } @Test - fun inListParenthesizedNot() { + fun inListParenthesizedNot(): Unit = forEachTarget { val query = "0 IN [0, 1, 2]" - val ast = parse(query) as PartiqlAst.Statement.Query + val ast = parser.parseAstStatement(query) as PartiqlAst.Statement.Query val inCollection = ast.expr as PartiqlAst.Expr.InCollection val list = inCollection.operands[1] as PartiqlAst.Expr.List @@ -45,9 +47,9 @@ internal class PartiQLParserMetaTests : PartiQLParserTestBase() { } @Test - fun inListParenthesizedSingleElement() { + fun inListParenthesizedSingleElement(): Unit = forEachTarget { val query = "0 IN (0)" - val ast = parse(query) as PartiqlAst.Statement.Query + val ast = parser.parseAstStatement(query) as PartiqlAst.Statement.Query val inCollection = ast.expr as PartiqlAst.Expr.InCollection val list = inCollection.operands[1] as PartiqlAst.Expr.List diff --git a/partiql-lang/src/test/kotlin/org/partiql/lang/syntax/PartiQLParserPrecedenceTest.kt b/partiql-lang/src/test/kotlin/org/partiql/lang/syntax/PartiQLParserPrecedenceTest.kt index be23232b32..7a0caa51d4 100644 --- a/partiql-lang/src/test/kotlin/org/partiql/lang/syntax/PartiQLParserPrecedenceTest.kt +++ b/partiql-lang/src/test/kotlin/org/partiql/lang/syntax/PartiQLParserPrecedenceTest.kt @@ -22,6 +22,8 @@ import org.partiql.lang.domains.PartiqlAst class PartiQLParserPrecedenceTest : PartiQLParserTestBase() { + override val targets: Array = arrayOf(ParserTarget.DEFAULT, ParserTarget.EXPERIMENTAL) + @Test @Parameters @TestCaseName("{0}") @@ -1083,18 +1085,22 @@ class PartiQLParserPrecedenceTest : PartiQLParserTestBase() { ) private fun runTest(pair: Pair) { - val (source, expectedAst) = pair + targets.forEach { target -> + val (source, expectedAst) = pair - val expectedExpr = PartiqlAst.transform(ion.singleValue(expectedAst).toIonElement()) as PartiqlAst.Expr - val expectedStatement = PartiqlAst.build { query(expectedExpr) } + val expectedExpr = PartiqlAst.transform(ion.singleValue(expectedAst).toIonElement()) as PartiqlAst.Expr + val expectedStatement = PartiqlAst.build { query(expectedExpr) } - val actualStatement = parser.parseAstStatement(source) - assertEquals(expectedStatement, actualStatement) + val actualStatement = target.parser.parseAstStatement(source) + assertEquals(expectedStatement, actualStatement) + } } private fun assertFailure(case: String) { - assertThrows(ParserException::class.java) { - parser.parseAstStatement(case) + targets.forEach { target -> + assertThrows(ParserException::class.java) { + target.parser.parseAstStatement(case) + } } } } diff --git a/partiql-lang/src/test/kotlin/org/partiql/lang/syntax/PartiQLParserTest.kt b/partiql-lang/src/test/kotlin/org/partiql/lang/syntax/PartiQLParserTest.kt index 07bb1c38a6..717d8fd820 100644 --- a/partiql-lang/src/test/kotlin/org/partiql/lang/syntax/PartiQLParserTest.kt +++ b/partiql-lang/src/test/kotlin/org/partiql/lang/syntax/PartiQLParserTest.kt @@ -33,15 +33,12 @@ import org.partiql.parser.antlr.PartiQLParser import kotlin.concurrent.thread /** - * Originally just meant to test the parser, this class now tests several different things because - * the same test cases can be used for all three: - * - * - Parsing of query to PIG-generated ast - * - Conversion of PIG-generated ast to [ExprNode]. - * - Conversion of [ExprNode] to legacy and new s-exp ASTs. + * Test parsing of query to PIG-generated AST */ class PartiQLParserTest : PartiQLParserTestBase() { + override val targets: Array = arrayOf(ParserTarget.DEFAULT, ParserTarget.EXPERIMENTAL) + // **************************************** // literals // **************************************** @@ -438,6 +435,7 @@ class PartiQLParserTest : PartiQLParserTestBase() { (path_expr (lit "b") (case_insensitive)) (path_expr (lit "c") (case_insensitive)))""".trimMargin() ) + @Test fun dot_case_3_insensitive_components() = assertExpression( "a.b.c.d", @@ -474,6 +472,7 @@ class PartiQLParserTest : PartiQLParserTestBase() { """(path (id a (case_insensitive) (unqualified)) (path_expr (lit 5) (case_sensitive)))""".trimMargin() ) + @Test fun pathWith3SquareBrackets() = assertExpression( """a[5]['b'][(a + 3)]""", @@ -940,13 +939,13 @@ class PartiQLParserTest : PartiQLParserTestBase() { @Test fun currentUserUpperCase() = assertExpression( "CURRENT_USER", - "(session_attribute CURRENT_USER)" + "(session_attribute current_user)" ) @Test fun currentUserMixedCase() = assertExpression( "CURRENT_user", - "(session_attribute CURRENT_user)" + "(session_attribute current_user)" ) @Test @@ -1580,48 +1579,52 @@ class PartiQLParserTest : PartiQLParserTestBase() { } @Test - fun orderByAscWithNullsSpec() = assertExpression("SELECT x FROM tb ORDER BY rk1 asc NULLS FIRST, rk2 asc NULLS LAST") { - select( - project = projectX, - from = scan(id("tb")), - order = orderBy( - listOf( - sortSpec(id("rk1"), asc(), nullsFirst()), - sortSpec(id("rk2"), asc(), nullsLast()) + fun orderByAscWithNullsSpec() = + assertExpression("SELECT x FROM tb ORDER BY rk1 asc NULLS FIRST, rk2 asc NULLS LAST") { + select( + project = projectX, + from = scan(id("tb")), + order = orderBy( + listOf( + sortSpec(id("rk1"), asc(), nullsFirst()), + sortSpec(id("rk2"), asc(), nullsLast()) + ) ) ) - ) - } + } @Test - fun orderByDescWithNullsSpec() = assertExpression("SELECT x FROM tb ORDER BY rk1 desc NULLS FIRST, rk2 desc NULLS LAST") { - select( - project = projectX, - from = scan(id("tb")), - order = orderBy( - listOf( - sortSpec(id("rk1"), desc(), nullsFirst()), - sortSpec(id("rk2"), desc(), nullsLast()) + fun orderByDescWithNullsSpec() = + assertExpression("SELECT x FROM tb ORDER BY rk1 desc NULLS FIRST, rk2 desc NULLS LAST") { + select( + project = projectX, + from = scan(id("tb")), + order = orderBy( + listOf( + sortSpec(id("rk1"), desc(), nullsFirst()), + sortSpec(id("rk2"), desc(), nullsLast()) + ) ) ) - ) - } + } @Test - fun orderByWithOrderingAndNullsSpec() = assertExpression("SELECT x FROM tb ORDER BY rk1 desc NULLS FIRST, rk2 asc NULLS LAST, rk3 desc NULLS LAST, rk4 asc NULLS FIRST") { - select( - project = projectX, - from = scan(id("tb")), - order = orderBy( - listOf( - sortSpec(id("rk1"), desc(), nullsFirst()), - sortSpec(id("rk2"), asc(), nullsLast()), - sortSpec(id("rk3"), desc(), nullsLast()), - sortSpec(id("rk4"), asc(), nullsFirst()) + fun orderByWithOrderingAndNullsSpec() = + assertExpression("SELECT x FROM tb ORDER BY rk1 desc NULLS FIRST, rk2 asc NULLS LAST, rk3 desc NULLS LAST, rk4 asc NULLS FIRST") { + select( + project = projectX, + from = scan(id("tb")), + order = orderBy( + listOf( + sortSpec(id("rk1"), desc(), nullsFirst()), + sortSpec(id("rk2"), asc(), nullsLast()), + sortSpec(id("rk3"), desc(), nullsLast()), + sortSpec(id("rk4"), asc(), nullsFirst()) + ) ) ) - ) - } + } + // **************************************** // GROUP BY and GROUP PARTIAL BY // **************************************** @@ -4308,33 +4311,38 @@ class PartiQLParserTest : PartiQLParserTestBase() { @Test fun rootSelectNodeHasSourceLocation() { - val ast = parse("select 1 from dogs") - assertEquals(SourceLocationMeta(1L, 1L, 6L), ast.metas.sourceLocation) + targets.forEach { target -> + val ast = target.parser.parseAstStatement("select 1 from dogs") + assertEquals(SourceLocationMeta(1L, 1L, 6L), ast.metas.sourceLocation) + } } @Test fun semicolonAtEndOfQueryHasNoEffect() { - val query = "SELECT * FROM <<1>>" - val withSemicolon = parse("$query;") - val withoutSemicolon = parse(query) - - assertEquals(withoutSemicolon, withSemicolon) + targets.forEach { target -> + val query = "SELECT * FROM <<1>>" + val withSemicolon = target.parser.parseAstStatement("$query;") + val withoutSemicolon = target.parser.parseAstStatement(query) + assertEquals(withoutSemicolon, withSemicolon) + } } @Test fun semicolonAtEndOfLiteralHasNoEffect() { - val withSemicolon = parse("1;") - val withoutSemicolon = parse("1") - - assertEquals(withoutSemicolon, withSemicolon) + targets.forEach { target -> + val withSemicolon = target.parser.parseAstStatement("1;") + val withoutSemicolon = target.parser.parseAstStatement("1") + assertEquals(withoutSemicolon, withSemicolon) + } } @Test fun semicolonAtEndOfExpressionHasNoEffect() { - val withSemicolon = parse("(1+1);") - val withoutSemicolon = parse("(1+1)") - - assertEquals(withoutSemicolon, withSemicolon) + targets.forEach { target -> + val withSemicolon = target.parser.parseAstStatement("(1+1);") + val withoutSemicolon = target.parser.parseAstStatement("(1+1)") + assertEquals(withoutSemicolon, withSemicolon) + } } // **************************************** @@ -4478,7 +4486,10 @@ class PartiQLParserTest : PartiQLParserTestBase() { fun execMultipleArg() = assertExpression( "EXEC foo 'bar0', `1d0`, 2, [3]" ) { - exec("foo", listOf(lit(ionString("bar0")), lit(ionDecimal(Decimal.valueOf(1))), lit(ionInt(2)), list(lit(ionInt(3))))) + exec( + "foo", + listOf(lit(ionString("bar0")), lit(ionDecimal(Decimal.valueOf(1))), lit(ionInt(2)), list(lit(ionInt(3)))) + ) } @Test @@ -4511,10 +4522,10 @@ class PartiQLParserTest : PartiQLParserTestBase() { } @Test - fun manyNestedNotPerformanceRegressionTest() { + fun manyNestedNotPerformanceRegressionTest(): Unit = forEachTarget { val startTime = System.currentTimeMillis() val t = thread { - parse( + parser.parseAstStatement( """ not not not not not not not not not not not not not not not not not not not not not not not not not not not not not not not not not not not not not not not not not not not not not not not not @@ -4544,7 +4555,7 @@ class PartiQLParserTest : PartiQLParserTestBase() { } @Test - fun testOrderByMetas() { + fun testOrderByMetas(): Unit = forEachTarget { // Arrange val query = "SELECT * FROM << { 'x': 2 } >> ORDER BY x" val expected = SourceLocationMeta(1, 32, 5) diff --git a/partiql-lang/src/test/kotlin/org/partiql/lang/syntax/PartiQLParserTestBase.kt b/partiql-lang/src/test/kotlin/org/partiql/lang/syntax/PartiQLParserTestBase.kt index 78abf0d5ae..2887e75301 100644 --- a/partiql-lang/src/test/kotlin/org/partiql/lang/syntax/PartiQLParserTestBase.kt +++ b/partiql-lang/src/test/kotlin/org/partiql/lang/syntax/PartiQLParserTestBase.kt @@ -34,9 +34,20 @@ import org.partiql.pig.runtime.toIonElement abstract class PartiQLParserTestBase : TestBase() { - val parser = PartiQLParserBuilder().customTypes(CUSTOM_TEST_TYPES).build() + /** + * We can change the parser target for an entire test suite by overriding this list. + */ + open val targets = arrayOf(ParserTarget.DEFAULT) + + /** + * Executes a test block for each target. + */ + inline fun forEachTarget(block: ParserTarget.() -> Unit) = targets.forEach { it.block() } - protected fun parse(source: String): PartiqlAst.Statement = parser.parseAstStatement(source) + enum class ParserTarget(val parser: Parser) { + DEFAULT(PartiQLParserBuilder().customTypes(CUSTOM_TEST_TYPES).build()), + EXPERIMENTAL(PartiQLParserBuilder.experimental().customTypes(CUSTOM_TEST_TYPES).build()), + } private fun assertSexpEquals( expectedValue: IonValue, @@ -60,7 +71,7 @@ abstract class PartiQLParserTestBase : TestBase() { protected fun assertExpression( source: String, expectedPigAst: String, - ) { + ): Unit = forEachTarget { val actualStatement = parser.parseAstStatement(source) val expectedIonSexp = loadIonSexp(expectedPigAst) @@ -141,14 +152,19 @@ abstract class PartiQLParserTestBase : TestBase() { protected fun checkInputThrowingParserException( input: String, errorCode: ErrorCode, - expectErrorContextValues: Map - ) { + expectErrorContextValues: Map, + targets: Array = arrayOf(ParserTarget.DEFAULT), + assertContext: Boolean = true, + ): Unit = forEachTarget { softAssert { try { parser.parseAstStatement(input) fail("Expected ParserException but there was no Exception") - } catch (pex: ParserException) { - checkErrorAndErrorContext(errorCode, pex, expectErrorContextValues) + } catch (ex: ParserException) { + // split parser target does not use ErrorCode + if (assertContext && (this@forEachTarget == ParserTarget.EXPERIMENTAL)) { + checkErrorAndErrorContext(errorCode, ex, expectErrorContextValues) + } } catch (ex: Exception) { fail("Expected ParserException but a different exception was thrown \n\t $ex") } diff --git a/partiql-lang/src/test/kotlin/org/partiql/lang/syntax/PartiQLParserWindowTests.kt b/partiql-lang/src/test/kotlin/org/partiql/lang/syntax/PartiQLParserWindowTests.kt index ce6dca7e02..1fa0d05adf 100644 --- a/partiql-lang/src/test/kotlin/org/partiql/lang/syntax/PartiQLParserWindowTests.kt +++ b/partiql-lang/src/test/kotlin/org/partiql/lang/syntax/PartiQLParserWindowTests.kt @@ -1,8 +1,15 @@ package org.partiql.lang.syntax import org.junit.Test +import org.partiql.lang.errors.ErrorCode +import org.partiql.lang.errors.Property +import org.partiql.lang.util.getAntlrDisplayString +import org.partiql.parser.antlr.PartiQLParser class PartiQLParserWindowTests : PartiQLParserTestBase() { + + override val targets: Array = arrayOf(ParserTarget.DEFAULT, ParserTarget.EXPERIMENTAL) + // TODO: In the future when we support custom-defined window frame, we will need to change this @Test fun lagWithInlinePartitionBYOrderBy() = assertExpression( @@ -231,4 +238,62 @@ class PartiQLParserWindowTests : PartiQLParserTestBase() { null))) """ ) + + @Test + fun lagWithoutOrderBy() { + checkInputThrowingParserException( + "SELECT lag(a) OVER () FROM b", + ErrorCode.PARSE_EXPECTED_WINDOW_ORDER_BY, + mapOf( + Property.LINE_NUMBER to 1L, + Property.COLUMN_NUMBER to 8L, + Property.TOKEN_DESCRIPTION to PartiQLParser.LAG.getAntlrDisplayString(), + Property.TOKEN_VALUE to ion.newSymbol("lag") + ), + assertContext = false, + ) + } + + @Test + fun lagWrongNumberOfParameter() { + checkInputThrowingParserException( + "SELECT lag(a,b,c,d) OVER (ORDER BY e) FROM f", + ErrorCode.PARSE_UNEXPECTED_TOKEN, + mapOf( + Property.LINE_NUMBER to 1L, + Property.COLUMN_NUMBER to 17L, + Property.TOKEN_DESCRIPTION to PartiQLParser.COMMA.getAntlrDisplayString(), + Property.TOKEN_VALUE to ion.newSymbol(",") + ) + ) + } + + @Test + fun leadWithoutOrderBy() { + checkInputThrowingParserException( + "SELECT lead(a) OVER () FROM b", + ErrorCode.PARSE_EXPECTED_WINDOW_ORDER_BY, + mapOf( + Property.LINE_NUMBER to 1L, + Property.COLUMN_NUMBER to 8L, + Property.TOKEN_DESCRIPTION to PartiQLParser.LAG.getAntlrDisplayString(), + Property.TOKEN_VALUE to ion.newSymbol("lag") + ), + assertContext = false, + ) + } + + @Test + fun leadWrongNumberOfParameter() { + checkInputThrowingParserException( + "SELECT lead(a,b,c,d) OVER (ORDER BY e) FROM f", + ErrorCode.PARSE_UNEXPECTED_TOKEN, + mapOf( + Property.LINE_NUMBER to 1L, + Property.COLUMN_NUMBER to 18L, + Property.TOKEN_DESCRIPTION to PartiQLParser.COMMA.getAntlrDisplayString(), + Property.TOKEN_VALUE to ion.newSymbol(",") + ) + ) + } } diff --git a/partiql-parser/README.adoc b/partiql-parser/README.adoc new file mode 100644 index 0000000000..b416c6a0b3 --- /dev/null +++ b/partiql-parser/README.adoc @@ -0,0 +1,66 @@ += PartiQL Parser + +The PartiQL Parser can be used to parse PartiQL queries into an AST (`org.partiql.ast`). +This interface expands the legacy `org.partiql.syntax.Parser` interface allowing for a richer return type as well as the latest AST. + +== Interfaces + +[source,kotlin] +---- +// PartiQLParser.kt + +public interface PartiQLParser { + + @Throws(PartiQLSyntaxException::class, InterruptedException::class) + public fun parse(source: String): Result + + public data class Result( + val source: String, + val root: AstNode, + val locations: SourceLocations, + ) +} + +// Exceptions.kt + +/** + * Generic PartiQLParser Syntax Exception + */ +public open class PartiQLSyntaxException( + override val message: String, + override val cause: Throwable? = null, + public val location: SourceLocation = SourceLocation.UNKNOWN, +) : Exception() + +/** + * PartiQLParser Exception upon lexing. + */ +public class PartiQLLexerException( + public val token: String, + public val tokenType: String, + message: String = "", + cause: Throwable? = null, + location: SourceLocation = SourceLocation.UNKNOWN, +) : PartiQLSyntaxException(message, cause, location) + +/** + * PartiQLParser Exception upon parsing. + */ +public class PartiQLParserException( + public val rule: String, + public val token: String, + public val tokenType: String, + message: String = "", + cause: Throwable? = null, + location: SourceLocation = SourceLocation.UNKNOWN, +) : PartiQLSyntaxException(message, cause, location) + +---- + +== Usage + +[source,kotlin] +---- +val parser = PartiQLParserBuilder.standard().build() +val ast = parser.parse("SELECT a FROM T") +---- diff --git a/partiql-parser/build.gradle.kts b/partiql-parser/build.gradle.kts index 51cb9e9c37..7c72a2f9df 100644 --- a/partiql-parser/build.gradle.kts +++ b/partiql-parser/build.gradle.kts @@ -15,12 +15,13 @@ plugins { id(Plugins.antlr) id(Plugins.conventions) - // id(Plugins.publish) + id(Plugins.publish) } dependencies { antlr(Deps.antlr) - implementation(project(":partiql-ast")) + api(project(":partiql-ast")) + api(project(":partiql-types")) implementation(Deps.ionElement) implementation(Deps.antlrRuntime) } @@ -51,8 +52,8 @@ tasks.processResources { } } -// publish { -// artifactId = "partiql-parser" -// name = "PartiQL Parser" -// description = "PartiQL's Parser" -// } +publish { + artifactId = "partiql-parser" + name = "PartiQL Parser" + description = "PartiQL's Parser" +} diff --git a/partiql-parser/src/main/kotlin/org/partiql/parser/Exceptions.kt b/partiql-parser/src/main/kotlin/org/partiql/parser/Exceptions.kt new file mode 100644 index 0000000000..67edf1d4f6 --- /dev/null +++ b/partiql-parser/src/main/kotlin/org/partiql/parser/Exceptions.kt @@ -0,0 +1,83 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). + * You may not use this file except in compliance with the License. + * A copy of the License is located at: + * + * http://aws.amazon.com/apache2.0/ + * + * or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific + * language governing permissions and limitations under the License. + */ + +package org.partiql.parser + +/** + * PartiQLParser Syntax Exception + * + * @property message + * @property cause + * @property location + */ +public open class PartiQLSyntaxException( + override val message: String, + override val cause: Throwable? = null, + public val location: SourceLocation = SourceLocation.UNKNOWN, +) : Exception() { + + internal companion object { + + internal fun wrap(cause: Throwable) = when (cause) { + is PartiQLSyntaxException -> cause + is StackOverflowError -> PartiQLSyntaxException( + message = """ + |Input query too large. This error typically occurs when there are several nested + |expressions/predicates and can usually be fixed by simplifying expressions. + """.trimMargin(), + cause = cause, + ) + is InterruptedException -> cause + else -> PartiQLSyntaxException("Unhandled exception.", cause) + } + } +} + +/** + * PartiQLParser Exception upon lexing. + * + * @property token — Debug token where the Exception occurred + * @constructor + * + * @param message + * @param cause + * @param location + */ +public class PartiQLLexerException( + public val token: String, + public val tokenType: String, + message: String = "", + cause: Throwable? = null, + location: SourceLocation = SourceLocation.UNKNOWN, +) : PartiQLSyntaxException(message, cause, location) + +/** + * PartiQLParser Exception upon parsing. + * + * @property rule Debug rule where the Exception occurred + * @property token Debug token where the Exception occurred + * @constructor + * + * @param message + * @param cause + * @param location + */ +public class PartiQLParserException( + public val rule: String, + public val token: String, + public val tokenType: String, + message: String = "", + cause: Throwable? = null, + location: SourceLocation = SourceLocation.UNKNOWN, +) : PartiQLSyntaxException(message, cause, location) diff --git a/partiql-parser/src/main/kotlin/org/partiql/parser/PartiQLParser.kt b/partiql-parser/src/main/kotlin/org/partiql/parser/PartiQLParser.kt new file mode 100644 index 0000000000..9274e8c66e --- /dev/null +++ b/partiql-parser/src/main/kotlin/org/partiql/parser/PartiQLParser.kt @@ -0,0 +1,29 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). + * You may not use this file except in compliance with the License. + * A copy of the License is located at: + * + * http://aws.amazon.com/apache2.0/ + * + * or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific + * language governing permissions and limitations under the License. + */ + +package org.partiql.parser + +import org.partiql.ast.AstNode + +public interface PartiQLParser { + + @Throws(PartiQLSyntaxException::class, InterruptedException::class) + public fun parse(source: String): Result + + public data class Result( + val source: String, + val root: AstNode, + val locations: SourceLocations, + ) +} diff --git a/partiql-parser/src/main/kotlin/org/partiql/parser/PartiQLParserBuilder.kt b/partiql-parser/src/main/kotlin/org/partiql/parser/PartiQLParserBuilder.kt new file mode 100644 index 0000000000..56318f47f6 --- /dev/null +++ b/partiql-parser/src/main/kotlin/org/partiql/parser/PartiQLParserBuilder.kt @@ -0,0 +1,35 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). + * You may not use this file except in compliance with the License. + * A copy of the License is located at: + * + * http://aws.amazon.com/apache2.0/ + * + * or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific + * language governing permissions and limitations under the License. + */ + +package org.partiql.parser + +import org.partiql.parser.impl.PartiQLParserDefault + +/** + * A builder class to instantiate a [PartiQLParser]. + */ +public class PartiQLParserBuilder { + + public companion object { + + @JvmStatic + public fun standard(): PartiQLParserBuilder { + return PartiQLParserBuilder() + } + } + + public fun build(): PartiQLParser { + return PartiQLParserDefault() + } +} diff --git a/partiql-parser/src/main/kotlin/org/partiql/parser/SourceLocation.kt b/partiql-parser/src/main/kotlin/org/partiql/parser/SourceLocation.kt new file mode 100644 index 0000000000..521934f30c --- /dev/null +++ b/partiql-parser/src/main/kotlin/org/partiql/parser/SourceLocation.kt @@ -0,0 +1,44 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). + * You may not use this file except in compliance with the License. + * A copy of the License is located at: + * + * http://aws.amazon.com/apache2.0/ + * + * or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific + * language governing permissions and limitations under the License. + */ + +package org.partiql.parser + +/** + * SourceLocation represents the span of a given grammar rule; which corresponds to an AST subtree. + * + * TODO Fix Source Location Tests https://github.com/partiql/partiql-lang-kotlin/issues/1114 + * Unfortunately several mistakes were made that are hard to undo altogether. The legacy parser incorrectly + * used the first token length rather than rule span for source location length. Then we have asserted on these + * incorrect SourceLocations in many unit tests unrelated to SourceLocations. + * + * @property line + * @property offset + * @property length + * @property lengthLegacy + */ +public data class SourceLocation( + public val line: Int, + public val offset: Int, + public val length: Int, + public val lengthLegacy: Int = 0, +) { + + public companion object { + + /** + * This is a flag for backwards compatibility when converting to the legacy AST. + */ + public val UNKNOWN: SourceLocation = SourceLocation(-1, -1, -1) + } +} diff --git a/partiql-parser/src/main/kotlin/org/partiql/parser/SourceLocations.kt b/partiql-parser/src/main/kotlin/org/partiql/parser/SourceLocations.kt new file mode 100644 index 0000000000..5d7c2ce4ae --- /dev/null +++ b/partiql-parser/src/main/kotlin/org/partiql/parser/SourceLocations.kt @@ -0,0 +1,49 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). + * You may not use this file except in compliance with the License. + * A copy of the License is located at: + * + * http://aws.amazon.com/apache2.0/ + * + * or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific + * language governing permissions and limitations under the License. + */ + +package org.partiql.parser + +/** + * Each node is hashable and has a unique identifier. Metadata is kept externally. + * Delegate once we are on Kotlin 1.7 + */ +public class SourceLocations private constructor( + private val delegate: Map +) : Map { + + override val entries: Set> = delegate.entries + + override val keys: Set = delegate.keys + + override val size: Int = delegate.size + + override val values: Collection = delegate.values + + override fun containsKey(key: String): Boolean = delegate.containsKey(key) + + override fun containsValue(value: SourceLocation): Boolean = delegate.containsValue(value) + + override fun get(key: String): SourceLocation? = delegate[key] + + override fun isEmpty(): Boolean = delegate.isEmpty() + + internal class Mutable { + + private val delegate = mutableMapOf() + + operator fun set(id: String, value: SourceLocation) = delegate.put(id, value) + + fun toMap() = SourceLocations(delegate) + } +} diff --git a/partiql-parser/src/main/kotlin/org/partiql/parser/impl/PartiQLParserDefault.kt b/partiql-parser/src/main/kotlin/org/partiql/parser/impl/PartiQLParserDefault.kt new file mode 100644 index 0000000000..a5b3c9431a --- /dev/null +++ b/partiql-parser/src/main/kotlin/org/partiql/parser/impl/PartiQLParserDefault.kt @@ -0,0 +1,2005 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). + * You may not use this file except in compliance with the License. + * A copy of the License is located at: + * + * http://aws.amazon.com/apache2.0/ + * + * or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific + * language governing permissions and limitations under the License. + */ + +package org.partiql.parser.impl + +import com.amazon.ionelement.api.IntElement +import com.amazon.ionelement.api.IntElementSize +import com.amazon.ionelement.api.IonElement +import com.amazon.ionelement.api.IonElementException +import com.amazon.ionelement.api.loadSingleElement +import org.antlr.v4.runtime.BailErrorStrategy +import org.antlr.v4.runtime.BaseErrorListener +import org.antlr.v4.runtime.CharStreams +import org.antlr.v4.runtime.CommonTokenStream +import org.antlr.v4.runtime.ParserRuleContext +import org.antlr.v4.runtime.RecognitionException +import org.antlr.v4.runtime.Recognizer +import org.antlr.v4.runtime.Token +import org.antlr.v4.runtime.TokenSource +import org.antlr.v4.runtime.TokenStream +import org.antlr.v4.runtime.atn.PredictionMode +import org.antlr.v4.runtime.misc.ParseCancellationException +import org.antlr.v4.runtime.tree.TerminalNode +import org.partiql.ast.Ast +import org.partiql.ast.AstNode +import org.partiql.ast.DatetimeField +import org.partiql.ast.Expr +import org.partiql.ast.From +import org.partiql.ast.GraphMatch +import org.partiql.ast.GroupBy +import org.partiql.ast.Identifier +import org.partiql.ast.Let +import org.partiql.ast.OnConflict +import org.partiql.ast.Path +import org.partiql.ast.Returning +import org.partiql.ast.Select +import org.partiql.ast.SetOp +import org.partiql.ast.SetQuantifier +import org.partiql.ast.Sort +import org.partiql.ast.Statement +import org.partiql.ast.TableDefinition +import org.partiql.ast.Type +import org.partiql.ast.builder.AstFactory +import org.partiql.parser.PartiQLLexerException +import org.partiql.parser.PartiQLParser +import org.partiql.parser.PartiQLParserException +import org.partiql.parser.PartiQLSyntaxException +import org.partiql.parser.SourceLocation +import org.partiql.parser.SourceLocations +import org.partiql.parser.antlr.PartiQLBaseVisitor +import org.partiql.parser.impl.util.DateTimeUtils +import org.partiql.value.NumericValue +import org.partiql.value.PartiQLValueExperimental +import org.partiql.value.StringValue +import org.partiql.value.boolValue +import org.partiql.value.dateValue +import org.partiql.value.datetime.DateTimeException +import org.partiql.value.datetime.DateTimeValue +import org.partiql.value.decimalValue +import org.partiql.value.int64Value +import org.partiql.value.intValue +import org.partiql.value.missingValue +import org.partiql.value.nullValue +import org.partiql.value.stringValue +import org.partiql.value.timeValue +import org.partiql.value.timestampValue +import java.math.BigDecimal +import java.math.BigInteger +import java.math.MathContext +import java.math.RoundingMode +import java.nio.channels.ClosedByInterruptException +import java.nio.charset.StandardCharsets +import java.time.LocalDate +import java.time.format.DateTimeFormatter +import java.time.format.DateTimeParseException +import org.partiql.parser.antlr.PartiQLParser as GeneratedParser +import org.partiql.parser.antlr.PartiQLTokens as GeneratedLexer + +/** + * ANTLR Based Implementation of a PartiQLParser + * + * SLL Prediction Mode + * ------------------- + * The [PredictionMode.SLL] mode uses the [BailErrorStrategy]. The [GeneratedParser], upon seeing a syntax error, + * will throw a [ParseCancellationException] due to the [GeneratedParser.getErrorHandler] + * being a [BailErrorStrategy]. The purpose of this is to throw syntax errors as quickly as possible once encountered. + * As noted by the [PredictionMode.SLL] documentation, to guarantee results, it is useful to follow up a failed parse + * by parsing with [PredictionMode.LL]. See the JavaDocs for [PredictionMode.SLL] and [BailErrorStrategy] for more. + * + * LL Prediction Mode + * ------------------ + * The [PredictionMode.LL] mode is capable of parsing all valid inputs for a grammar, + * but is slower than [PredictionMode.SLL]. Upon seeing a syntax error, this parser throws a [PartiQLParserException]. + */ +internal class PartiQLParserDefault : PartiQLParser { + + @Throws(PartiQLSyntaxException::class, InterruptedException::class) + override fun parse(source: String): PartiQLParser.Result { + try { + return PartiQLParserDefault.parse(source) + } catch (throwable: Throwable) { + throw PartiQLSyntaxException.wrap(throwable) + } + } + + companion object { + + /** + * To reduce latency costs, the [PartiQLParserDefault] attempts to use [PredictionMode.SLL] and falls back to + * [PredictionMode.LL] if a [ParseCancellationException] is thrown by the [BailErrorStrategy]. + */ + private fun parse(source: String): PartiQLParser.Result = try { + parse(source, PredictionMode.SLL) + } catch (ex: ParseCancellationException) { + parse(source, PredictionMode.LL) + } + + /** + * Parses an input string [source] using the given prediction mode. + */ + private fun parse(source: String, mode: PredictionMode): PartiQLParser.Result { + val tokens = createTokenStream(source) + val parser = InterruptibleParser(tokens) + parser.reset() + parser.removeErrorListeners() + parser.interpreter.predictionMode = mode + when (mode) { + PredictionMode.SLL -> parser.errorHandler = BailErrorStrategy() + PredictionMode.LL -> parser.addErrorListener(ParseErrorListener()) + else -> throw IllegalArgumentException("Unsupported parser mode: $mode") + } + val tree = parser.root() + return Visitor.translate(source, tokens, tree) + } + + private fun createTokenStream(source: String): CountingTokenStream { + val queryStream = source.byteInputStream(StandardCharsets.UTF_8) + val inputStream = try { + CharStreams.fromStream(queryStream) + } catch (ex: ClosedByInterruptException) { + throw InterruptedException() + } + val handler = TokenizeErrorListener() + val lexer = GeneratedLexer(inputStream) + lexer.removeErrorListeners() + lexer.addErrorListener(handler) + return CountingTokenStream(lexer) + } + } + + /** + * Catches Lexical errors (unidentified tokens) and throws a [PartiQLParserException] + */ + private class TokenizeErrorListener : BaseErrorListener() { + @Throws(PartiQLParserException::class) + override fun syntaxError( + recognizer: Recognizer<*, *>?, + offendingSymbol: Any?, + line: Int, + charPositionInLine: Int, + msg: String, + e: RecognitionException? + ) { + if (offendingSymbol is Token) { + val token = offendingSymbol.text + val tokenType = GeneratedParser.VOCABULARY.getSymbolicName(offendingSymbol.type) + throw PartiQLLexerException( + token = token, + tokenType = tokenType, + message = msg, + cause = e, + location = SourceLocation( + line = line, + offset = charPositionInLine + 1, + length = token.length, + lengthLegacy = token.length, + ), + ) + } else { + throw IllegalArgumentException("Offending symbol is not a Token.") + } + } + } + + /** + * Catches Parser errors (malformed syntax) and throws a [PartiQLParserException] + */ + private class ParseErrorListener : BaseErrorListener() { + + private val rules = GeneratedParser.ruleNames.asList() + + @Throws(PartiQLParserException::class) + override fun syntaxError( + recognizer: Recognizer<*, *>?, + offendingSymbol: Any, + line: Int, + charPositionInLine: Int, + msg: String, + e: RecognitionException? + ) { + if (offendingSymbol is Token) { + val rule = e?.ctx?.toString(rules) ?: "UNKNOWN" + val token = offendingSymbol.text + val tokenType = GeneratedParser.VOCABULARY.getSymbolicName(offendingSymbol.type) + throw PartiQLParserException( + rule = rule, + token = token, + tokenType = tokenType, + message = msg, + cause = e, + location = SourceLocation( + line = line, + offset = charPositionInLine + 1, + length = msg.length, + lengthLegacy = offendingSymbol.text.length, + ), + ) + } else { + throw IllegalArgumentException("Offending symbol is not a Token.") + } + } + } + + /** + * A wrapped [GeneratedParser] to allow thread interruption during parse. + */ + internal class InterruptibleParser(input: TokenStream) : GeneratedParser(input) { + override fun enterRule(localctx: ParserRuleContext?, state: Int, ruleIndex: Int) { + if (Thread.interrupted()) { + throw InterruptedException() + } + super.enterRule(localctx, state, ruleIndex) + } + } + + /** + * This token stream creates [parameterIndexes], which is a map, where the keys represent the + * indexes of all [GeneratedLexer.QUESTION_MARK]'s and the values represent their relative index amongst all other + * [GeneratedLexer.QUESTION_MARK]'s. + */ + internal open class CountingTokenStream(tokenSource: TokenSource) : CommonTokenStream(tokenSource) { + // TODO: Research use-case of parameters and implementation -- see https://github.com/partiql/partiql-docs/issues/23 + val parameterIndexes = mutableMapOf() + private var parametersFound = 0 + override fun LT(k: Int): Token? { + val token = super.LT(k) + token?.let { + if (it.type == GeneratedLexer.QUESTION_MARK && parameterIndexes.containsKey(token.tokenIndex).not()) { + parameterIndexes[token.tokenIndex] = ++parametersFound + } + } + return token + } + } + + /** + * Translate an ANTLR ParseTree to a PartiQL + */ + @OptIn(PartiQLValueExperimental::class) + private class Visitor( + private val locations: SourceLocations.Mutable, + private val parameters: Map = mapOf(), + ) : PartiQLBaseVisitor() { + + // Use default factory + private val factory = Ast + + companion object { + + private val rules = GeneratedParser.ruleNames.asList() + + /** + * Expose an (internal) friendly entry point into the traversal; mostly for keeping mutable state contained. + */ + fun translate( + source: String, + tokens: CountingTokenStream, + tree: GeneratedParser.RootContext + ): PartiQLParser.Result { + val locations = SourceLocations.Mutable() + val visitor = Visitor(locations, tokens.parameterIndexes) + val root = visitor.visitAs(tree) + return PartiQLParser.Result( + source = source, + root = root, + locations = locations.toMap(), + ) + } + + fun error( + ctx: ParserRuleContext, + message: String, + cause: Throwable? = null, + ) = PartiQLParserException( + rule = ctx.toStringTree(rules), + token = ctx.start.text, + tokenType = GeneratedParser.VOCABULARY.getSymbolicName(ctx.start.type), + message = message, + cause = cause, + location = SourceLocation( + line = ctx.start.line, + offset = ctx.start.charPositionInLine + 1, + length = ctx.stop.stopIndex - ctx.start.startIndex, + lengthLegacy = ctx.start.text.length, + ), + ) + + fun error( + token: Token, + message: String, + cause: Throwable? = null, + ) = PartiQLLexerException( + token = token.text, + tokenType = GeneratedParser.VOCABULARY.getSymbolicName(token.type), + message = message, + cause = cause, + location = SourceLocation( + line = token.line, + offset = token.charPositionInLine + 1, + length = token.stopIndex - token.startIndex, + lengthLegacy = token.text.length, + ), + ) + + internal val DATE_PATTERN_REGEX = Regex("\\d\\d\\d\\d-\\d\\d-\\d\\d") + + internal val GENERIC_TIME_REGEX = Regex("\\d\\d:\\d\\d:\\d\\d(\\.\\d*)?([+|-]\\d\\d:\\d\\d)?") + } + + /** + * Each visit attaches source locations from the given parse tree node; constructs nodes via the factory. + */ + private inline fun translate(ctx: ParserRuleContext, block: AstFactory.() -> T): T { + val node = factory.block() + if (ctx.start != null) { + locations[node._id] = SourceLocation( + line = ctx.start.line, + offset = ctx.start.charPositionInLine + 1, + length = (ctx.stop?.stopIndex ?: ctx.start.stopIndex) - ctx.start.startIndex + 1, + lengthLegacy = ctx.start.text.length, // LEGACY LENGTH + ) + } + return node + } + + /** + * + * TOP LEVEL + * + */ + + override fun visitQueryDql(ctx: GeneratedParser.QueryDqlContext): AstNode = visitDql(ctx.dql()) + + override fun visitQueryDml(ctx: GeneratedParser.QueryDmlContext): AstNode = visit(ctx.dml()) + + override fun visitRoot(ctx: GeneratedParser.RootContext) = translate(ctx) { + when (ctx.EXPLAIN()) { + null -> visit(ctx.statement()) as Statement + else -> { + var type: String? = null + var format: String? = null + ctx.explainOption().forEach { option -> + val parameter = try { + ExplainParameters.valueOf(option.param.text.uppercase()) + } catch (ex: java.lang.IllegalArgumentException) { + throw error(option.param, "Unknown EXPLAIN parameter.", ex) + } + when (parameter) { + ExplainParameters.TYPE -> { + type = parameter.getCompliantString(type, option.value) + } + ExplainParameters.FORMAT -> { + format = parameter.getCompliantString(format, option.value) + } + } + } + statementExplain( + target = statementExplainTargetDomain( + statement = visit(ctx.statement()) as Statement, + type = type, + format = format, + ), + ) + } + } + } + + /** + * + * COMMON USAGES + * + */ + + override fun visitAsIdent(ctx: GeneratedParser.AsIdentContext) = visitSymbolPrimitive(ctx.symbolPrimitive()) + + override fun visitAtIdent(ctx: GeneratedParser.AtIdentContext) = visitSymbolPrimitive(ctx.symbolPrimitive()) + + override fun visitByIdent(ctx: GeneratedParser.ByIdentContext) = visitSymbolPrimitive(ctx.symbolPrimitive()) + + override fun visitSymbolPrimitive(ctx: GeneratedParser.SymbolPrimitiveContext) = translate(ctx) { + when (ctx.ident.type) { + GeneratedParser.IDENTIFIER_QUOTED -> identifierSymbol( + ctx.IDENTIFIER_QUOTED().getStringValue(), + Identifier.CaseSensitivity.SENSITIVE, + ) + GeneratedParser.IDENTIFIER -> identifierSymbol( + ctx.IDENTIFIER().getStringValue(), + Identifier.CaseSensitivity.INSENSITIVE, + ) + else -> throw error(ctx, "Invalid symbol reference.") + } + } + + /** + * + * DATA DEFINITION LANGUAGE (DDL) + * + */ + + override fun visitQueryDdl(ctx: GeneratedParser.QueryDdlContext): AstNode = visitDdl(ctx.ddl()) + + override fun visitDropTable(ctx: GeneratedParser.DropTableContext) = translate(ctx) { + val table = visitSymbolPrimitive(ctx.tableName().symbolPrimitive()) + statementDDLDropTable(table) + } + + override fun visitDropIndex(ctx: GeneratedParser.DropIndexContext) = translate(ctx) { + val table = visitSymbolPrimitive(ctx.on) + val index = visitSymbolPrimitive(ctx.target) + statementDDLDropIndex(index, table) + } + + override fun visitCreateTable(ctx: GeneratedParser.CreateTableContext) = translate(ctx) { + val table = visitSymbolPrimitive(ctx.tableName().symbolPrimitive()) + val definition = ctx.tableDef()?.let { visitTableDef(it) } + statementDDLCreateTable(table, definition) + } + + override fun visitCreateIndex(ctx: GeneratedParser.CreateIndexContext) = translate(ctx) { + // TODO add index name to ANTLR grammar + val name: Identifier? = null + val table = visitSymbolPrimitive(ctx.symbolPrimitive()) + val fields = ctx.pathSimple().map { path -> visitPathSimple(path) } + statementDDLCreateIndex(name, table, fields) + } + + override fun visitTableDef(ctx: GeneratedParser.TableDefContext) = translate(ctx) { + // Column Definitions are the only thing we currently allow as table definition parts + val columns = ctx.tableDefPart().filterIsInstance().map { + visitColumnDeclaration(it) + } + tableDefinition(columns) + } + + override fun visitColumnDeclaration(ctx: GeneratedParser.ColumnDeclarationContext) = translate(ctx) { + val name = symbol(ctx.columnName().symbolPrimitive()) + val type = visit(ctx.type()) as Type + val constraints = ctx.columnConstraint().map { + visitColumnConstraint(it) + } + tableDefinitionColumn(name, type, constraints) + } + + override fun visitColumnConstraint(ctx: GeneratedParser.ColumnConstraintContext) = translate(ctx) { + val identifier = ctx.columnConstraintName()?.let { symbol(it.symbolPrimitive()) } + val body = visit(ctx.columnConstraintDef()) as TableDefinition.Column.Constraint.Body + tableDefinitionColumnConstraint(identifier, body) + } + + override fun visitColConstrNotNull(ctx: GeneratedParser.ColConstrNotNullContext) = translate(ctx) { + tableDefinitionColumnConstraintBodyNotNull() + } + + override fun visitColConstrNull(ctx: GeneratedParser.ColConstrNullContext) = translate(ctx) { + tableDefinitionColumnConstraintBodyNullable() + } + + /** + * + * EXECUTE + * + */ + + override fun visitQueryExec(ctx: GeneratedParser.QueryExecContext) = visitExecCommand(ctx.execCommand()) + + /** + * TODO EXEC accepts an `expr` as the procedure name so we have to unpack the string. + * - https://github.com/partiql/partiql-lang-kotlin/issues/707 + */ + override fun visitExecCommand(ctx: GeneratedParser.ExecCommandContext) = translate(ctx) { + val expr = visitExpr(ctx.name) + if (expr !is Expr.Var || expr.identifier !is Identifier.Symbol) { + throw error(ctx, "EXEC procedure must be a symbol identifier") + } + val procedure = (expr.identifier as Identifier.Symbol).symbol + val args = visitOrEmpty(ctx.args) + statementExec(procedure, args) + } + + /** + * + * DATA MANIPULATION LANGUAGE (DML) + * + */ + + /** + * The PartiQL grammars allows for multiple DML commands in one UPDATE statement. + * This function unwraps DML commands to the more limited DML.BatchLegacy.Op commands. + */ + override fun visitDmlBaseWrapper(ctx: GeneratedParser.DmlBaseWrapperContext) = translate(ctx) { + val table = when { + ctx.updateClause() != null -> ctx.updateClause().tableBaseReference() + ctx.fromClause() != null -> ctx.fromClause().tableReference() + else -> throw error(ctx, "Expected UPDATE or FROM
") + } + val from = visitOrNull(table) + var returning: Returning? = null + val ops = ctx.dmlBaseCommand().map { + val op = visitDmlBaseCommand(it) + when (op) { + is Statement.DML.Update -> statementDMLBatchLegacyOpSet(op.assignments) + is Statement.DML.Remove -> statementDMLBatchLegacyOpRemove(op.target) + is Statement.DML.Delete -> statementDMLBatchLegacyOpDelete() + is Statement.DML.Insert -> statementDMLBatchLegacyOpInsert( + op.target, op.values, op.asAlias, op.onConflict + ) + is Statement.DML.InsertLegacy -> statementDMLBatchLegacyOpInsertLegacy( + op.target, op.value, op.index, op.conflictCondition + ) + is Statement.DML.BatchLegacy -> { + // UNPACK InsertLegacy with returning + assert(op.ops.size == 1) { "wrapped batch op can only have one item" } + returning = op.returning + op.ops[0] + } + else -> throw error(ctx, "Invalid DML operator in BatchLegacy update") + } + } + val where = ctx.whereClause()?.let { visitExpr(it.expr()) } + // outer returning + if (ctx.returningClause() != null) { + returning = visitReturningClause(ctx.returningClause()!!) + } + statementDMLBatchLegacy(ops, from, where, returning) + } + + override fun visitDmlDelete(ctx: GeneratedParser.DmlDeleteContext) = visitDeleteCommand(ctx.deleteCommand()) + + override fun visitDmlInsertReturning(ctx: GeneratedParser.DmlInsertReturningContext): Statement.DML = + super.visit(ctx.insertCommandReturning()) as Statement.DML + + override fun visitDmlBase(ctx: GeneratedParser.DmlBaseContext) = + super.visitDmlBaseCommand(ctx.dmlBaseCommand()) as Statement.DML + + override fun visitDmlBaseCommand(ctx: GeneratedParser.DmlBaseCommandContext) = + super.visitDmlBaseCommand(ctx) as Statement.DML + + override fun visitRemoveCommand(ctx: GeneratedParser.RemoveCommandContext) = translate(ctx) { + val target = visitPathSimple(ctx.pathSimple()) + statementDMLRemove(target) + } + + override fun visitDeleteCommand(ctx: GeneratedParser.DeleteCommandContext) = translate(ctx) { + val from = visitAs(ctx.fromClauseSimple()) + val where = ctx.whereClause()?.let { visitExpr(it.arg) } + val returning = ctx.returningClause()?.let { visitReturningClause(it) } + statementDMLDelete(from, where, returning) + } + + /** + * Legacy INSERT with RETURNING clause is not represented in the AST as this grammar .. + * .. only exists for backwards compatibility. The RETURNING clause is ignored. + * + * TODO remove insertCommandReturning grammar rule + * - https://github.com/partiql/partiql-lang-kotlin/issues/698 + * - https://github.com/partiql/partiql-lang-kotlin/issues/708 + */ + override fun visitInsertCommandReturning(ctx: GeneratedParser.InsertCommandReturningContext) = translate(ctx) { + val target = visitPathSimple(ctx.pathSimple()) + val value = visitExpr(ctx.value) + val index = visitOrNull(ctx.pos) + val conflictCondition = ctx.onConflictLegacy()?.let { visitOnConflictLegacy(it) } + if (ctx.returningClause() != null) { + val returning = visitReturningClause(ctx.returningClause()!!) + val insert = statementDMLBatchLegacyOpInsertLegacy(target, value, index, conflictCondition) + statementDMLBatchLegacy(listOf(insert), null, null, returning) + } else { + statementDMLInsertLegacy(target, value, index, conflictCondition) + } + } + + override fun visitInsertStatementLegacy(ctx: GeneratedParser.InsertStatementLegacyContext) = translate(ctx) { + val target = visitPathSimple(ctx.pathSimple()) + val value = visitExpr(ctx.value) + val index = visitOrNull(ctx.pos) + val conflictCondition = ctx.onConflictLegacy()?.let { visitOnConflictLegacy(it) } + statementDMLInsertLegacy(target, value, index, conflictCondition) + } + + override fun visitInsertStatement(ctx: GeneratedParser.InsertStatementContext) = translate(ctx) { + val target = visitSymbolPrimitive(ctx.symbolPrimitive()) + val values = visitExpr(ctx.value) + val asAlias = visitOrNull(ctx.asIdent())?.symbol + val onConflict = ctx.onConflict()?.let { visitOnConflictClause(it) } + statementDMLInsert(target, values, asAlias, onConflict) + } + + override fun visitReplaceCommand(ctx: GeneratedParser.ReplaceCommandContext) = translate(ctx) { + val target = visitSymbolPrimitive(ctx.symbolPrimitive()) + val values = visitExpr(ctx.value) + val asAlias = visitOrNull(ctx.asIdent())?.symbol + statementDMLReplace(target, values, asAlias) + } + + override fun visitUpsertCommand(ctx: GeneratedParser.UpsertCommandContext) = translate(ctx) { + val target = visitSymbolPrimitive(ctx.symbolPrimitive()) + val values = visitExpr(ctx.value) + val asAlias = visitOrNull(ctx.asIdent())?.symbol + statementDMLUpsert(target, values, asAlias) + } + + override fun visitReturningClause(ctx: GeneratedParser.ReturningClauseContext) = translate(ctx) { + val columns = visitOrEmpty(ctx.returningColumn()) + returning(columns) + } + + override fun visitReturningColumn(ctx: GeneratedParser.ReturningColumnContext) = translate(ctx) { + val status = when (ctx.status.type) { + GeneratedParser.MODIFIED -> Returning.Column.Status.MODIFIED + GeneratedParser.ALL -> Returning.Column.Status.ALL + else -> throw error(ctx.status, "Expected MODIFIED or ALL") + } + val age = when (ctx.age.type) { + GeneratedParser.OLD -> Returning.Column.Age.OLD + GeneratedParser.NEW -> Returning.Column.Age.NEW + else -> throw error(ctx.status, "Expected OLD or NEW") + } + val value = when (ctx.ASTERISK()) { + null -> returningColumnValueExpression(visitExpr(ctx.expr())) + else -> returningColumnValueWildcard() + } + returningColumn(status, age, value) + } + + private fun visitOnConflictClause(ctx: GeneratedParser.OnConflictContext) = ctx.accept(this) as OnConflict + + override fun visitOnConflict(ctx: GeneratedParser.OnConflictContext) = translate(ctx) { + val target = ctx.conflictTarget()?.let { visitConflictTarget(it) } + val action = visitConflictAction(ctx.conflictAction()) + onConflict(target, action) + } + + /** + * TODO Remove this when we remove INSERT LEGACY as no other conflict actions are allowed in PartiQL.g4. + */ + override fun visitOnConflictLegacy(ctx: GeneratedParser.OnConflictLegacyContext) = translate(ctx) { + visitExpr(ctx.expr()) + } + + override fun visitConflictTarget(ctx: GeneratedParser.ConflictTargetContext) = translate(ctx) { + if (ctx.constraintName() != null) { + onConflictTargetConstraint(visitSymbolPrimitive(ctx.constraintName().symbolPrimitive())) + } else { + val symbols = ctx.symbolPrimitive().map { visitSymbolPrimitive(it) } + onConflictTargetSymbols(symbols) + } + } + + override fun visitConflictAction(ctx: GeneratedParser.ConflictActionContext) = when { + ctx.NOTHING() != null -> translate(ctx) { onConflictActionDoNothing() } + ctx.REPLACE() != null -> visitDoReplace(ctx.doReplace()) + ctx.UPDATE() != null -> visitDoUpdate(ctx.doUpdate()) + else -> throw error(ctx, "ON CONFLICT only supports `DO REPLACE` and `DO NOTHING` actions at the moment.") + } + + override fun visitDoReplace(ctx: GeneratedParser.DoReplaceContext) = translate(ctx) { + val condition = ctx.condition?.let { visitExpr(it) } + onConflictActionDoReplace(condition) + } + + override fun visitDoUpdate(ctx: GeneratedParser.DoUpdateContext) = translate(ctx) { + val condition = ctx.condition?.let { visitExpr(it) } + onConflictActionDoUpdate(condition) + } + + override fun visitPathSimple(ctx: GeneratedParser.PathSimpleContext) = translate(ctx) { + val root = visitSymbolPrimitive(ctx.symbolPrimitive()) + val steps = visitOrEmpty(ctx.pathSimpleSteps()) + path(root, steps) + } + + override fun visitPathSimpleLiteral(ctx: GeneratedParser.PathSimpleLiteralContext) = translate(ctx) { + val v = visit(ctx.literal()) + if (v !is Expr.Lit) { + throw error(ctx, "Expected a path element literal") + } + when (val i = v.value) { + is NumericValue<*> -> pathStepIndex(i.int) + is StringValue -> pathStepSymbol( + identifierSymbol( + i.value, Identifier.CaseSensitivity.SENSITIVE + ) + ) + else -> throw error(ctx, "Expected an integer or string literal, found literal ${i.type}") + } + } + + override fun visitPathSimpleSymbol(ctx: GeneratedParser.PathSimpleSymbolContext) = translate(ctx) { + val identifier = visitSymbolPrimitive(ctx.symbolPrimitive()) + pathStepSymbol(identifier) + } + + override fun visitPathSimpleDotSymbol(ctx: GeneratedParser.PathSimpleDotSymbolContext) = translate(ctx) { + val identifier = visitSymbolPrimitive(ctx.symbolPrimitive()) + pathStepSymbol(identifier) + } + + /** + * TODO current PartiQL.g4 grammar models a SET with no UPDATE target as valid DML command. + */ + override fun visitSetCommand(ctx: GeneratedParser.SetCommandContext) = translate(ctx) { + // We put a blank target, because we'll have to unpack this. + val target = path( + root = identifierSymbol("_blank", Identifier.CaseSensitivity.INSENSITIVE), + steps = emptyList(), + ) + val assignments = visitOrEmpty(ctx.setAssignment()) + statementDMLUpdate(target, assignments) + } + + override fun visitSetAssignment(ctx: GeneratedParser.SetAssignmentContext) = translate(ctx) { + val target = visitPathSimple(ctx.pathSimple()) + val value = visitExpr(ctx.expr()) + statementDMLUpdateAssignment(target, value) + } + + /** + * + * DATA QUERY LANGUAGE (DQL) + * + */ + + override fun visitDql(ctx: GeneratedParser.DqlContext) = translate(ctx) { + val expr = visitAs(ctx.expr()) + statementQuery(expr) + } + + override fun visitQueryBase(ctx: GeneratedParser.QueryBaseContext): AstNode = visit(ctx.exprSelect()) + + override fun visitSfwQuery(ctx: GeneratedParser.SfwQueryContext) = translate(ctx) { + val select = visit(ctx.select) as Select + val from = visitFromClause(ctx.from) + val let = visitOrNull(ctx.let) + val where = visitOrNull(ctx.where) + val groupBy = ctx.group?.let { visitGroupClause(it) } + val having = visitOrNull(ctx.having?.arg) + // TODO Add SQL UNION, INTERSECT, EXCEPT to PartiQL.g4 + val setOp: Expr.SFW.SetOp? = null + val orderBy = ctx.order?.let { visitOrderByClause(it) } + val limit = visitOrNull(ctx.limit?.arg) + val offset = visitOrNull(ctx.offset?.arg) + exprSFW(select, from, let, where, groupBy, having, setOp, orderBy, limit, offset) + } + + /** + * + * SELECT & PROJECTIONS + * + */ + + override fun visitSelectAll(ctx: GeneratedParser.SelectAllContext) = translate(ctx) { + val quantifier = convertSetQuantifier(ctx.setQuantifierStrategy()) + selectStar(quantifier) + } + + override fun visitSelectItems(ctx: GeneratedParser.SelectItemsContext) = translate(ctx) { + val items = visitOrEmpty(ctx.projectionItems().projectionItem()) + val setq = convertSetQuantifier(ctx.setQuantifierStrategy()) + selectProject(items, setq) + } + + override fun visitSelectPivot(ctx: GeneratedParser.SelectPivotContext) = translate(ctx) { + val key = visitExpr(ctx.pivot) + val value = visitExpr(ctx.at) + selectPivot(key, value) + } + + override fun visitSelectValue(ctx: GeneratedParser.SelectValueContext) = translate(ctx) { + val constructor = visitExpr(ctx.expr()) + val setq = convertSetQuantifier(ctx.setQuantifierStrategy()) + selectValue(constructor, setq) + } + + override fun visitProjectionItem(ctx: GeneratedParser.ProjectionItemContext) = translate(ctx) { + val expr = visitExpr(ctx.expr()) + val alias = ctx.symbolPrimitive()?.let { symbol(it) } + if (expr is Expr.Path) { + convertPathToProjectionItem(ctx, expr, alias) + } else { + selectProjectItemExpression(expr, alias) + } + } + + /** + * + * SIMPLE CLAUSES + * + */ + + override fun visitLimitClause(ctx: GeneratedParser.LimitClauseContext): Expr = visitAs(ctx.arg) + + override fun visitExpr(ctx: GeneratedParser.ExprContext): Expr { + if (Thread.interrupted()) { + throw InterruptedException() + } + return visitAs(ctx.exprBagOp()) + } + + override fun visitOffsetByClause(ctx: GeneratedParser.OffsetByClauseContext) = visitAs(ctx.arg) + + override fun visitWhereClause(ctx: GeneratedParser.WhereClauseContext) = visitExpr(ctx.arg) + + override fun visitWhereClauseSelect(ctx: GeneratedParser.WhereClauseSelectContext) = visitAs(ctx.arg) + + override fun visitHavingClause(ctx: GeneratedParser.HavingClauseContext) = visitAs(ctx.arg) + + /** + * + * LET CLAUSE + * + */ + + override fun visitLetClause(ctx: GeneratedParser.LetClauseContext) = translate(ctx) { + val bindings = visitOrEmpty(ctx.letBinding()) + let(bindings) + } + + override fun visitLetBinding(ctx: GeneratedParser.LetBindingContext) = translate(ctx) { + val expr = visitAs(ctx.expr()) + val alias = symbol(ctx.symbolPrimitive()) + letBinding(expr, alias) + } + + /** + * + * ORDER BY CLAUSE + * + */ + + override fun visitOrderByClause(ctx: GeneratedParser.OrderByClauseContext) = translate(ctx) { + val sorts = visitOrEmpty(ctx.orderSortSpec()) + orderBy(sorts) + } + + override fun visitOrderSortSpec(ctx: GeneratedParser.OrderSortSpecContext) = translate(ctx) { + val expr = visitAs(ctx.expr()) + val dir = when { + ctx.dir == null -> null + ctx.dir.type == GeneratedParser.ASC -> Sort.Dir.ASC + ctx.dir.type == GeneratedParser.DESC -> Sort.Dir.DESC + else -> throw error(ctx.dir, "Invalid ORDER BY direction; expected ASC or DESC") + } + val nulls = when { + ctx.nulls == null -> null + ctx.nulls.type == GeneratedParser.FIRST -> Sort.Nulls.FIRST + ctx.nulls.type == GeneratedParser.LAST -> Sort.Nulls.LAST + else -> throw error(ctx.nulls, "Invalid ORDER BY null ordering; expected FIRST or LAST") + } + sort(expr, dir, nulls) + } + + /** + * + * GROUP BY CLAUSE + * + */ + + override fun visitGroupClause(ctx: GeneratedParser.GroupClauseContext) = translate(ctx) { + val strategy = if (ctx.PARTIAL() != null) GroupBy.Strategy.PARTIAL else GroupBy.Strategy.FULL + val keys = visitOrEmpty(ctx.groupKey()) + val alias = ctx.groupAlias()?.let { symbol(ctx.groupAlias().symbolPrimitive()) } + groupBy(strategy, keys, alias) + } + + override fun visitGroupKey(ctx: GeneratedParser.GroupKeyContext) = translate(ctx) { + val expr = visitAs(ctx.key) + val alias = ctx.symbolPrimitive()?.let { symbol(it) } + groupByKey(expr, alias) + } + + /** + * + * BAG OPERATIONS + * + */ + + override fun visitIntersect(ctx: GeneratedParser.IntersectContext) = translate(ctx) { + val setq = when { + ctx.ALL() != null -> SetQuantifier.ALL + ctx.DISTINCT() != null -> SetQuantifier.DISTINCT + else -> null + } + val op = setOp(SetOp.Type.INTERSECT, setq) + val lhs = visitAs(ctx.lhs) + val rhs = visitAs(ctx.rhs) + val outer = ctx.OUTER() != null + exprBagOp(op, lhs, rhs, outer) + } + + override fun visitExcept(ctx: GeneratedParser.ExceptContext) = translate(ctx) { + val setq = when { + ctx.ALL() != null -> SetQuantifier.ALL + ctx.DISTINCT() != null -> SetQuantifier.DISTINCT + else -> null + } + val op = setOp(SetOp.Type.EXCEPT, setq) + val lhs = visitAs(ctx.lhs) + val rhs = visitAs(ctx.rhs) + val outer = ctx.OUTER() != null + exprBagOp(op, lhs, rhs, outer) + } + + override fun visitUnion(ctx: GeneratedParser.UnionContext) = translate(ctx) { + val setq = when { + ctx.ALL() != null -> SetQuantifier.ALL + ctx.DISTINCT() != null -> SetQuantifier.DISTINCT + else -> null + } + val op = setOp(SetOp.Type.UNION, setq) + val lhs = visitAs(ctx.lhs) + val rhs = visitAs(ctx.rhs) + val outer = ctx.OUTER() != null + exprBagOp(op, lhs, rhs, outer) + } + + /** + * + * GRAPH PATTERN MANIPULATION LANGUAGE (GPML) + * + */ + + override fun visitGpmlPattern(ctx: GeneratedParser.GpmlPatternContext) = translate(ctx) { + val pattern = visitMatchPattern(ctx.matchPattern()) + val selector = visitOrNull(ctx.matchSelector()) + graphMatch(listOf(pattern), selector) + } + + override fun visitGpmlPatternList(ctx: GeneratedParser.GpmlPatternListContext) = translate(ctx) { + val patterns = ctx.matchPattern().map { pattern -> visitMatchPattern(pattern) } + val selector = visitOrNull(ctx.matchSelector()) + graphMatch(patterns, selector) + } + + override fun visitMatchPattern(ctx: GeneratedParser.MatchPatternContext) = translate(ctx) { + val parts = visitOrEmpty(ctx.graphPart()) + val restrictor = ctx.restrictor?.let { + when (ctx.restrictor.text.lowercase()) { + "trail" -> GraphMatch.Restrictor.TRAIL + "acyclic" -> GraphMatch.Restrictor.ACYCLIC + "simple" -> GraphMatch.Restrictor.SIMPLE + else -> throw error(ctx.restrictor, "Unrecognized pattern restrictor") + } + } + val variable = visitOrNull(ctx.variable)?.symbol + graphMatchPattern(restrictor, null, variable, null, parts) + } + + override fun visitPatternPathVariable(ctx: GeneratedParser.PatternPathVariableContext) = + visitSymbolPrimitive(ctx.symbolPrimitive()) + + override fun visitSelectorBasic(ctx: GeneratedParser.SelectorBasicContext) = translate(ctx) { + when (ctx.mod.type) { + GeneratedParser.ANY -> graphMatchSelectorAnyShortest() + GeneratedParser.ALL -> graphMatchSelectorAllShortest() + else -> throw error(ctx, "Unsupported match selector.") + } + } + + override fun visitSelectorAny(ctx: GeneratedParser.SelectorAnyContext) = translate(ctx) { + when (ctx.k) { + null -> graphMatchSelectorAny() + else -> graphMatchSelectorAnyK(ctx.k.text.toLong()) + } + } + + override fun visitSelectorShortest(ctx: GeneratedParser.SelectorShortestContext) = translate(ctx) { + val k = ctx.k.text.toLong() + when (ctx.GROUP()) { + null -> graphMatchSelectorShortestK(k) + else -> graphMatchSelectorShortestKGroup(k) + } + } + + override fun visitPatternPartLabel(ctx: GeneratedParser.PatternPartLabelContext) = + visitSymbolPrimitive(ctx.symbolPrimitive()) + + override fun visitPattern(ctx: GeneratedParser.PatternContext) = translate(ctx) { + val restrictor = visitRestrictor(ctx.restrictor) + val variable = visitOrNull(ctx.variable)?.symbol + val prefilter = ctx.where?.let { visitExpr(it.expr()) } + val quantifier = ctx.quantifier?.let { visitPatternQuantifier(it) } + val parts = visitOrEmpty(ctx.graphPart()) + graphMatchPattern(restrictor, prefilter, variable, quantifier, parts) + } + + override fun visitEdgeAbbreviated(ctx: GeneratedParser.EdgeAbbreviatedContext) = translate(ctx) { + val direction = visitEdge(ctx.edgeAbbrev()) + val quantifier = visitOrNull(ctx.quantifier) + graphMatchPatternPartEdge(direction, quantifier, null, null, emptyList()) + } + + override fun visitEdgeWithSpec(ctx: GeneratedParser.EdgeWithSpecContext) = translate(ctx) { + val quantifier = visitOrNull(ctx.quantifier) + val edge = visitOrNull(ctx.edgeWSpec()) + edge!!.copy(quantifier = quantifier) + } + + override fun visitEdgeSpec(ctx: GeneratedParser.EdgeSpecContext) = translate(ctx) { + val placeholderDirection = GraphMatch.Direction.RIGHT + val variable = visitOrNull(ctx.symbolPrimitive())?.symbol + val prefilter = ctx.whereClause()?.let { visitExpr(it.expr()) } + val label = visitOrNull(ctx.patternPartLabel())?.symbol + graphMatchPatternPartEdge(placeholderDirection, null, prefilter, variable, listOfNotNull(label)) + } + + override fun visitEdgeSpecLeft(ctx: GeneratedParser.EdgeSpecLeftContext): AstNode { + val edge = visitEdgeSpec(ctx.edgeSpec()) + return edge.copy(direction = GraphMatch.Direction.LEFT) + } + + override fun visitEdgeSpecRight(ctx: GeneratedParser.EdgeSpecRightContext): AstNode { + val edge = visitEdgeSpec(ctx.edgeSpec()) + return edge.copy(direction = GraphMatch.Direction.RIGHT) + } + + override fun visitEdgeSpecBidirectional(ctx: GeneratedParser.EdgeSpecBidirectionalContext): AstNode { + val edge = visitEdgeSpec(ctx.edgeSpec()) + return edge.copy(direction = GraphMatch.Direction.LEFT_OR_RIGHT) + } + + override fun visitEdgeSpecUndirectedBidirectional(ctx: GeneratedParser.EdgeSpecUndirectedBidirectionalContext): AstNode { + val edge = visitEdgeSpec(ctx.edgeSpec()) + return edge.copy(direction = GraphMatch.Direction.LEFT_UNDIRECTED_OR_RIGHT) + } + + override fun visitEdgeSpecUndirected(ctx: GeneratedParser.EdgeSpecUndirectedContext): AstNode { + val edge = visitEdgeSpec(ctx.edgeSpec()) + return edge.copy(direction = GraphMatch.Direction.UNDIRECTED) + } + + override fun visitEdgeSpecUndirectedLeft(ctx: GeneratedParser.EdgeSpecUndirectedLeftContext): AstNode { + val edge = visitEdgeSpec(ctx.edgeSpec()) + return edge.copy(direction = GraphMatch.Direction.LEFT_OR_UNDIRECTED) + } + + override fun visitEdgeSpecUndirectedRight(ctx: GeneratedParser.EdgeSpecUndirectedRightContext): AstNode { + val edge = visitEdgeSpec(ctx.edgeSpec()) + return edge.copy(direction = GraphMatch.Direction.UNDIRECTED_OR_RIGHT) + } + + private fun visitEdge(ctx: GeneratedParser.EdgeAbbrevContext): GraphMatch.Direction = when { + ctx.TILDE() != null && ctx.ANGLE_RIGHT() != null -> GraphMatch.Direction.UNDIRECTED_OR_RIGHT + ctx.TILDE() != null && ctx.ANGLE_LEFT() != null -> GraphMatch.Direction.LEFT_OR_UNDIRECTED + ctx.TILDE() != null -> GraphMatch.Direction.UNDIRECTED + ctx.MINUS() != null && ctx.ANGLE_LEFT() != null && ctx.ANGLE_RIGHT() != null -> GraphMatch.Direction.LEFT_OR_RIGHT + ctx.MINUS() != null && ctx.ANGLE_LEFT() != null -> GraphMatch.Direction.LEFT + ctx.MINUS() != null && ctx.ANGLE_RIGHT() != null -> GraphMatch.Direction.RIGHT + ctx.MINUS() != null -> GraphMatch.Direction.LEFT_UNDIRECTED_OR_RIGHT + else -> throw error(ctx, "Unsupported edge type") + } + + override fun visitGraphPart(ctx: GeneratedParser.GraphPartContext): GraphMatch.Pattern.Part { + val part = super.visitGraphPart(ctx) + if (part is GraphMatch.Pattern) { + return translate(ctx) { graphMatchPatternPartPattern(part) } + } + return part as GraphMatch.Pattern.Part + } + + override fun visitPatternQuantifier(ctx: GeneratedParser.PatternQuantifierContext) = translate(ctx) { + when { + ctx.quant == null -> graphMatchQuantifier(ctx.lower.text.toLong(), ctx.upper?.text?.toLong()) + ctx.quant.type == GeneratedParser.PLUS -> graphMatchQuantifier(1L, null) + ctx.quant.type == GeneratedParser.ASTERISK -> graphMatchQuantifier(0L, null) + else -> throw error(ctx, "Unsupported quantifier") + } + } + + override fun visitNode(ctx: GeneratedParser.NodeContext) = translate(ctx) { + val variable = visitOrNull(ctx.symbolPrimitive())?.symbol + val prefilter = ctx.whereClause()?.let { visitExpr(it.expr()) } + val label = visitOrNull(ctx.patternPartLabel())?.symbol + graphMatchPatternPartNode(prefilter, variable, listOfNotNull(label)) + } + + private fun visitRestrictor(ctx: GeneratedParser.PatternRestrictorContext?): GraphMatch.Restrictor? { + if (ctx == null) return null + return when (ctx.restrictor.text.lowercase()) { + "trail" -> GraphMatch.Restrictor.TRAIL + "acyclic" -> GraphMatch.Restrictor.ACYCLIC + "simple" -> GraphMatch.Restrictor.SIMPLE + else -> throw error(ctx, "Unrecognized pattern restrictor") + } + } + + /** + * + * TABLE REFERENCES & JOINS & FROM CLAUSE + * + */ + + override fun visitFromClause(ctx: GeneratedParser.FromClauseContext) = visitAs(ctx.tableReference()) + + override fun visitTableBaseRefClauses(ctx: GeneratedParser.TableBaseRefClausesContext) = translate(ctx) { + val expr = visitAs(ctx.source) + val asAlias = ctx.asIdent()?.let { symbol(it.symbolPrimitive()) } + val atAlias = ctx.atIdent()?.let { symbol(it.symbolPrimitive()) } + val byAlias = ctx.byIdent()?.let { symbol(it.symbolPrimitive()) } + fromValue(expr, From.Value.Type.SCAN, asAlias, atAlias, byAlias) + } + + override fun visitTableBaseRefMatch(ctx: GeneratedParser.TableBaseRefMatchContext) = translate(ctx) { + val expr = visitAs(ctx.source) + val asAlias = ctx.asIdent()?.let { symbol(it.symbolPrimitive()) } + val atAlias = ctx.atIdent()?.let { symbol(it.symbolPrimitive()) } + val byAlias = ctx.byIdent()?.let { symbol(it.symbolPrimitive()) } + fromValue(expr, From.Value.Type.SCAN, asAlias, atAlias, byAlias) + } + + /** + * TODO Remove as/at/by aliases from DELETE command grammar in PartiQL.g4 + */ + override fun visitFromClauseSimpleExplicit(ctx: GeneratedParser.FromClauseSimpleExplicitContext) = + translate(ctx) { + val path = visitPathSimple(ctx.pathSimple()) + val asAlias = ctx.asIdent()?.let { visitAsIdent(it) }?.symbol + val atAlias = ctx.atIdent()?.let { visitAtIdent(it) }?.symbol + val byAlias = ctx.byIdent()?.let { visitByIdent(it) }?.symbol + statementDMLDeleteTarget(path, asAlias, atAlias, byAlias) + } + + /** + * TODO Remove fromClauseSimple rule from DELETE command grammar in PartiQL.g4 + */ + override fun visitFromClauseSimpleImplicit(ctx: GeneratedParser.FromClauseSimpleImplicitContext) = + translate(ctx) { + val path = visitPathSimple(ctx.pathSimple()) + val asAlias = visitSymbolPrimitive(ctx.symbolPrimitive()).symbol + statementDMLDeleteTarget(path, asAlias, null, null) + } + + override fun visitTableUnpivot(ctx: GeneratedParser.TableUnpivotContext) = translate(ctx) { + val expr = visitAs(ctx.expr()) + val asAlias = ctx.asIdent()?.let { symbol(it.symbolPrimitive()) } + val atAlias = ctx.atIdent()?.let { symbol(it.symbolPrimitive()) } + val byAlias = ctx.byIdent()?.let { symbol(it.symbolPrimitive()) } + fromValue(expr, From.Value.Type.UNPIVOT, asAlias, atAlias, byAlias) + } + + override fun visitTableCrossJoin(ctx: GeneratedParser.TableCrossJoinContext) = translate(ctx) { + val lhs = visitAs(ctx.lhs) + val rhs = visitAs(ctx.rhs) + val type = convertJoinType(ctx.joinType()) + fromJoin(lhs, rhs, type, null) + } + + private fun convertJoinType(ctx: GeneratedParser.JoinTypeContext?): From.Join.Type? { + if (ctx == null) return null + return when (ctx.mod.type) { + GeneratedParser.INNER -> From.Join.Type.INNER + GeneratedParser.LEFT -> when (ctx.OUTER()) { + null -> From.Join.Type.LEFT + else -> From.Join.Type.LEFT_OUTER + } + GeneratedParser.RIGHT -> when (ctx.OUTER()) { + null -> From.Join.Type.RIGHT + else -> From.Join.Type.RIGHT_OUTER + } + GeneratedParser.FULL -> when (ctx.OUTER()) { + null -> From.Join.Type.FULL + else -> From.Join.Type.FULL_OUTER + } + GeneratedParser.OUTER -> { + // TODO https://github.com/partiql/partiql-spec/issues/41 + // TODO https://github.com/partiql/partiql-lang-kotlin/issues/1013 + From.Join.Type.FULL_OUTER + } + else -> null + } + } + + override fun visitTableQualifiedJoin(ctx: GeneratedParser.TableQualifiedJoinContext) = translate(ctx) { + val lhs = visitAs(ctx.lhs) + val rhs = visitAs(ctx.rhs) + val type = convertJoinType(ctx.joinType()) + val condition = ctx.joinSpec()?.let { visitExpr(it.expr()) } + fromJoin(lhs, rhs, type, condition) + } + + override fun visitTableBaseRefSymbol(ctx: GeneratedParser.TableBaseRefSymbolContext) = translate(ctx) { + val expr = visitAs(ctx.source) + val asAlias = symbol(ctx.symbolPrimitive()) + fromValue(expr, From.Value.Type.SCAN, asAlias, null, null) + } + + override fun visitTableWrapped(ctx: GeneratedParser.TableWrappedContext): AstNode = visit(ctx.tableReference()) + + override fun visitJoinSpec(ctx: GeneratedParser.JoinSpecContext) = visitExpr(ctx.expr()) + + override fun visitJoinRhsTableJoined(ctx: GeneratedParser.JoinRhsTableJoinedContext) = + visitAs(ctx.tableReference()) + + /** + * SIMPLE EXPRESSIONS + */ + + override fun visitOr(ctx: GeneratedParser.OrContext) = translate(ctx) { + convertBinaryExpr(ctx.lhs, ctx.rhs, Expr.Binary.Op.OR) + } + + override fun visitAnd(ctx: GeneratedParser.AndContext) = translate(ctx) { + convertBinaryExpr(ctx.lhs, ctx.rhs, Expr.Binary.Op.AND) + } + + override fun visitNot(ctx: GeneratedParser.NotContext) = translate(ctx) { + val expr = visit(ctx.exprNot()) as Expr + exprUnary(Expr.Unary.Op.NOT, expr) + } + + override fun visitMathOp00(ctx: GeneratedParser.MathOp00Context) = translate(ctx) { + if (ctx.parent != null) return@translate visit(ctx.parent) + convertBinaryExpr(ctx.lhs, ctx.rhs, convertBinaryOp(ctx.op)) + } + + override fun visitMathOp01(ctx: GeneratedParser.MathOp01Context) = translate(ctx) { + if (ctx.parent != null) return@translate visit(ctx.parent) + convertBinaryExpr(ctx.lhs, ctx.rhs, convertBinaryOp(ctx.op)) + } + + override fun visitMathOp02(ctx: GeneratedParser.MathOp02Context) = translate(ctx) { + if (ctx.parent != null) return@translate visit(ctx.parent) + convertBinaryExpr(ctx.lhs, ctx.rhs, convertBinaryOp(ctx.op)) + } + + override fun visitValueExpr(ctx: GeneratedParser.ValueExprContext) = translate(ctx) { + if (ctx.parent != null) return@translate visit(ctx.parent) + val expr = visit(ctx.rhs) as Expr + exprUnary(convertUnaryOp(ctx.sign), expr) + } + + private fun convertBinaryExpr(lhs: ParserRuleContext, rhs: ParserRuleContext, op: Expr.Binary.Op): Expr { + val l = visit(lhs) as Expr + val r = visit(rhs) as Expr + return factory.exprBinary(op, l, r) + } + + private fun convertBinaryOp(token: Token) = when (token.type) { + GeneratedParser.AND -> Expr.Binary.Op.AND + GeneratedParser.OR -> Expr.Binary.Op.OR + GeneratedParser.ASTERISK -> Expr.Binary.Op.TIMES + GeneratedParser.SLASH_FORWARD -> Expr.Binary.Op.DIVIDE + GeneratedParser.PLUS -> Expr.Binary.Op.PLUS + GeneratedParser.MINUS -> Expr.Binary.Op.MINUS + GeneratedParser.PERCENT -> Expr.Binary.Op.MODULO + GeneratedParser.CONCAT -> Expr.Binary.Op.CONCAT + GeneratedParser.ANGLE_LEFT -> Expr.Binary.Op.LT + GeneratedParser.LT_EQ -> Expr.Binary.Op.LTE + GeneratedParser.ANGLE_RIGHT -> Expr.Binary.Op.GT + GeneratedParser.GT_EQ -> Expr.Binary.Op.GTE + GeneratedParser.NEQ -> Expr.Binary.Op.NE + GeneratedParser.EQ -> Expr.Binary.Op.EQ + else -> throw error(token, "Invalid binary operator") + } + + private fun convertUnaryOp(token: Token) = when (token.type) { + GeneratedParser.PLUS -> Expr.Unary.Op.POS + GeneratedParser.MINUS -> Expr.Unary.Op.NEG + GeneratedParser.NOT -> Expr.Unary.Op.NOT + else -> throw error(token, "Invalid unary operator") + } + + /** + * + * PREDICATES + * + */ + + override fun visitPredicateComparison(ctx: GeneratedParser.PredicateComparisonContext) = translate(ctx) { + val op = convertBinaryOp(ctx.op) + convertBinaryExpr(ctx.lhs, ctx.rhs, op) + } + + /** + * TODO Fix the IN collection grammar, also label alternative forms + * - https://github.com/partiql/partiql-lang-kotlin/issues/1115 + * - https://github.com/partiql/partiql-lang-kotlin/issues/1113 + */ + override fun visitPredicateIn(ctx: GeneratedParser.PredicateInContext) = translate(ctx) { + val lhs = visitAs(ctx.lhs) + val rhs = visitAs(ctx.rhs ?: ctx.expr()).let { + // Wrap rhs in an array unless it's a query or already a collection + if (it is Expr.SFW || it is Expr.Collection || ctx.PAREN_LEFT() == null) { + it + } else { + // IN ( expr ) + exprCollection(Expr.Collection.Type.LIST, listOf(it)) + } + } + val not = ctx.NOT() != null + exprInCollection(lhs, rhs, not) + } + + override fun visitPredicateIs(ctx: GeneratedParser.PredicateIsContext) = translate(ctx) { + val value = visitAs(ctx.lhs) + val type = visitAs(ctx.type()) + val not = ctx.NOT() != null + exprIsType(value, type, not) + } + + override fun visitPredicateBetween(ctx: GeneratedParser.PredicateBetweenContext) = translate(ctx) { + val value = visitAs(ctx.lhs) + val lower = visitAs(ctx.lower) + val upper = visitAs(ctx.upper) + val not = ctx.NOT() != null + exprBetween(value, lower, upper, not) + } + + override fun visitPredicateLike(ctx: GeneratedParser.PredicateLikeContext) = translate(ctx) { + val value = visitAs(ctx.lhs) + val pattern = visitAs(ctx.rhs) + val escape = visitOrNull(ctx.escape) + val not = ctx.NOT() != null + exprLike(value, pattern, escape, not) + } + + /** + * + * PRIMARY EXPRESSIONS + * + */ + + override fun visitExprTermWrappedQuery(ctx: GeneratedParser.ExprTermWrappedQueryContext): AstNode = + visit(ctx.expr()) + + override fun visitVariableIdentifier(ctx: GeneratedParser.VariableIdentifierContext) = translate(ctx) { + val symbol = ctx.ident.getStringValue() + val case = when (ctx.ident.type) { + GeneratedParser.IDENTIFIER -> Identifier.CaseSensitivity.INSENSITIVE + else -> Identifier.CaseSensitivity.SENSITIVE + } + val scope = when (ctx.qualifier) { + null -> Expr.Var.Scope.DEFAULT + else -> Expr.Var.Scope.LOCAL + } + exprVar(identifierSymbol(symbol, case), scope) + } + + override fun visitVariableKeyword(ctx: GeneratedParser.VariableKeywordContext) = translate(ctx) { + val symbol = ctx.key.text + val case = Identifier.CaseSensitivity.INSENSITIVE + val scope = when (ctx.qualifier) { + null -> Expr.Var.Scope.DEFAULT + else -> Expr.Var.Scope.LOCAL + } + exprVar(identifierSymbol(symbol, case), scope) + } + + override fun visitParameter(ctx: GeneratedParser.ParameterContext) = translate(ctx) { + val index = parameters[ctx.QUESTION_MARK().symbol.tokenIndex] ?: throw error( + ctx, + "Unable to find index of parameter." + ) + exprParameter(index) + } + + override fun visitSequenceConstructor(ctx: GeneratedParser.SequenceConstructorContext) = translate(ctx) { + val expressions = visitOrEmpty(ctx.expr()) + val type = when (ctx.datatype.type) { + GeneratedParser.LIST -> Expr.Collection.Type.LIST + GeneratedParser.SEXP -> Expr.Collection.Type.SEXP + else -> throw error(ctx.datatype, "Invalid sequence type") + } + exprCollection(type, expressions) + } + + override fun visitExprPrimaryPath(ctx: GeneratedParser.ExprPrimaryPathContext) = translate(ctx) { + val base = visitAs(ctx.exprPrimary()) + val steps = ctx.pathStep().map { visit(it) as Expr.Path.Step } + exprPath(base, steps) + } + + override fun visitPathStepIndexExpr(ctx: GeneratedParser.PathStepIndexExprContext) = translate(ctx) { + val key = visitAs(ctx.key) + exprPathStepIndex(key) + } + + override fun visitPathStepDotExpr(ctx: GeneratedParser.PathStepDotExprContext) = translate(ctx) { + val symbol = visitSymbolPrimitive(ctx.symbolPrimitive()) + exprPathStepSymbol(symbol) + } + + override fun visitPathStepIndexAll(ctx: GeneratedParser.PathStepIndexAllContext) = translate(ctx) { + exprPathStepWildcard() + } + + override fun visitPathStepDotAll(ctx: GeneratedParser.PathStepDotAllContext) = translate(ctx) { + exprPathStepUnpivot() + } + + override fun visitValues(ctx: GeneratedParser.ValuesContext) = translate(ctx) { + val rows = visitOrEmpty(ctx.valueRow()) + exprCollection(Expr.Collection.Type.BAG, rows) + } + + override fun visitValueRow(ctx: GeneratedParser.ValueRowContext) = translate(ctx) { + val expressions = visitOrEmpty(ctx.expr()) + exprCollection(Expr.Collection.Type.LIST, expressions) + } + + override fun visitValueList(ctx: GeneratedParser.ValueListContext) = translate(ctx) { + val expressions = visitOrEmpty(ctx.expr()) + exprCollection(Expr.Collection.Type.LIST, expressions) + } + + override fun visitExprGraphMatchMany(ctx: GeneratedParser.ExprGraphMatchManyContext) = translate(ctx) { + val graph = visit(ctx.exprPrimary()) as Expr + val pattern = visitGpmlPatternList(ctx.gpmlPatternList()) + exprMatch(graph, pattern) + } + + override fun visitExprGraphMatchOne(ctx: GeneratedParser.ExprGraphMatchOneContext) = translate(ctx) { + val graph = visit(ctx.exprPrimary()) as Expr + val pattern = visitGpmlPattern(ctx.gpmlPattern()) + exprMatch(graph, pattern) + } + + override fun visitExprTermCurrentUser(ctx: GeneratedParser.ExprTermCurrentUserContext) = translate(ctx) { + exprSessionAttribute(Expr.SessionAttribute.Attribute.CURRENT_USER) + } + + /** + * + * FUNCTIONS + * + */ + + override fun visitNullIf(ctx: GeneratedParser.NullIfContext) = translate(ctx) { + val value = visitExpr(ctx.expr(0)) + val nullifier = visitExpr(ctx.expr(1)) + exprNullIf(value, nullifier) + } + + override fun visitCoalesce(ctx: GeneratedParser.CoalesceContext) = translate(ctx) { + val expressions = visitOrEmpty(ctx.expr()) + exprCoalesce(expressions) + } + + override fun visitCaseExpr(ctx: GeneratedParser.CaseExprContext) = translate(ctx) { + val expr = ctx.case_?.let { visitExpr(it) } + val branches = ctx.whens.indices.map { i -> + // consider adding locations + val w = visitExpr(ctx.whens[i]) + val t = visitExpr(ctx.thens[i]) + exprCaseBranch(w, t) + } + val default = ctx.else_?.let { visitExpr(it) } + exprCase(expr, branches, default) + } + + override fun visitCast(ctx: GeneratedParser.CastContext) = translate(ctx) { + val expr = visitExpr(ctx.expr()) + val type = visitAs(ctx.type()) + exprCast(expr, type) + } + + override fun visitCanCast(ctx: GeneratedParser.CanCastContext) = translate(ctx) { + val expr = visitExpr(ctx.expr()) + val type = visitAs(ctx.type()) + exprCanCast(expr, type) + } + + override fun visitCanLosslessCast(ctx: GeneratedParser.CanLosslessCastContext) = translate(ctx) { + val expr = visitExpr(ctx.expr()) + val type = visitAs(ctx.type()) + exprCanLosslessCast(expr, type) + } + + override fun visitFunctionCallIdent(ctx: GeneratedParser.FunctionCallIdentContext) = translate(ctx) { + val function = visitSymbolPrimitive(ctx.name) + val args = visitOrEmpty(ctx.expr()) + exprCall(function, args) + } + + override fun visitFunctionCallReserved(ctx: GeneratedParser.FunctionCallReservedContext) = translate(ctx) { + val function = ctx.name.text.toIdentifier() + val args = visitOrEmpty(ctx.expr()) + exprCall(function, args) + } + + /** + * + * FUNCTIONS WITH SPECIAL FORMS + * + */ + + override fun visitDateFunction(ctx: GeneratedParser.DateFunctionContext) = translate(ctx) { + val field = try { + DatetimeField.valueOf(ctx.dt.text.uppercase()) + } catch (ex: IllegalArgumentException) { + throw error(ctx.dt, "Expected one of: ${DatetimeField.values().joinToString()}", ex) + } + val lhs = visitExpr(ctx.expr(0)) + val rhs = visitExpr(ctx.expr(1)) + when { + ctx.DATE_ADD() != null -> exprDateAdd(field, lhs, rhs) + ctx.DATE_DIFF() != null -> exprDateDiff(field, lhs, rhs) + else -> throw error(ctx, "Expected DATE_ADD or DATE_DIFF") + } + } + + /** + * TODO Add labels to each alternative, https://github.com/partiql/partiql-lang-kotlin/issues/1113 + */ + override fun visitSubstring(ctx: GeneratedParser.SubstringContext) = translate(ctx) { + if (ctx.FROM() == null) { + // normal form + val function = "SUBSTRING".toIdentifier() + val args = visitOrEmpty(ctx.expr()) + exprCall(function, args) + } else { + // special form + val value = visitExpr(ctx.expr(0)) + val start = visitOrNull(ctx.expr(1)) + val length = visitOrNull(ctx.expr(2)) + exprSubstring(value, start, length) + } + } + + /** + * TODO Add labels to each alternative, https://github.com/partiql/partiql-lang-kotlin/issues/1113 + */ + override fun visitPosition(ctx: GeneratedParser.PositionContext) = translate(ctx) { + if (ctx.IN() == null) { + // normal form + val function = "POSITION".toIdentifier() + val args = visitOrEmpty(ctx.expr()) + exprCall(function, args) + } else { + // special form + val lhs = visitExpr(ctx.expr(0)) + val rhs = visitExpr(ctx.expr(1)) + exprPosition(lhs, rhs) + } + } + + /** + * TODO Add labels to each alternative, https://github.com/partiql/partiql-lang-kotlin/issues/1113 + */ + override fun visitOverlay(ctx: GeneratedParser.OverlayContext) = translate(ctx) { + if (ctx.PLACING() == null) { + // normal form + val function = "OVERLAY".toIdentifier() + val args = visitOrEmpty(ctx.expr()) + exprCall(function, args) + } else { + // special form + val value = visitExpr(ctx.expr(0)) + val overlay = visitExpr(ctx.expr(1)) + val start = visitExpr(ctx.expr(2)) + val length = visitOrNull(ctx.expr(3)) + exprOverlay(value, overlay, start, length) + } + } + + /** + * COUNT(*) + */ + override fun visitCountAll(ctx: GeneratedParser.CountAllContext) = translate(ctx) { + val function = "COUNT_STAR".toIdentifier() + exprAgg(function, emptyList(), SetQuantifier.ALL) + } + + override fun visitExtract(ctx: GeneratedParser.ExtractContext) = translate(ctx) { + val field = try { + DatetimeField.valueOf(ctx.IDENTIFIER().text.uppercase()) + } catch (ex: IllegalArgumentException) { + throw error(ctx.IDENTIFIER().symbol, "Expected one of: ${DatetimeField.values().joinToString()}", ex) + } + val source = visitExpr(ctx.expr()) + exprExtract(field, source) + } + + override fun visitTrimFunction(ctx: GeneratedParser.TrimFunctionContext) = translate(ctx) { + val spec = ctx.mod?.let { + try { + Expr.Trim.Spec.valueOf(it.text.uppercase()) + } catch (ex: IllegalArgumentException) { + throw error(it, "Expected on of: ${Expr.Trim.Spec.values().joinToString()}", ex) + } + } + val (chars, value) = when (ctx.expr().size) { + 1 -> null to visitExpr(ctx.expr(0)) + 2 -> visitExpr(ctx.expr(0)) to visitExpr(ctx.expr(1)) + else -> throw error(ctx, "Expected one or two TRIM expression arguments") + } + exprTrim(value, chars, spec) + } + + override fun visitAggregateBase(ctx: GeneratedParser.AggregateBaseContext) = translate(ctx) { + val function = ctx.func.text.toIdentifier() + val args = listOf(visitExpr(ctx.expr())) + val setq = convertSetQuantifier(ctx.setQuantifierStrategy()) + exprAgg(function, args, setq) + } + + /** + * Window Functions + */ + + override fun visitLagLeadFunction(ctx: GeneratedParser.LagLeadFunctionContext) = translate(ctx) { + val function = when { + ctx.LAG() != null -> Expr.Window.Function.LAG + ctx.LEAD() != null -> Expr.Window.Function.LEAD + else -> throw error(ctx, "Expected LAG or LEAD") + } + val expression = visitExpr(ctx.expr(0)) + val offset = visitOrNull(ctx.expr(1)) + val default = visitOrNull(ctx.expr(2)) + val over = visitOver(ctx.over()) + if (over.sorts == null) { + throw error(ctx.over(), "$function requires Window ORDER BY") + } + exprWindow(function, expression, offset, default, over) + } + + override fun visitOver(ctx: GeneratedParser.OverContext) = translate(ctx) { + val partitions = ctx.windowPartitionList()?.let { visitOrEmpty(it.expr()) } + val sorts = ctx.windowSortSpecList()?.let { visitOrEmpty(it.orderSortSpec()) } + exprWindowOver(partitions, sorts) + } + + /** + * + * LITERALS + * + */ + + override fun visitBag(ctx: GeneratedParser.BagContext) = translate(ctx) { + val expressions = visitOrEmpty(ctx.expr()) + exprCollection(Expr.Collection.Type.BAG, expressions) + } + + override fun visitLiteralDecimal(ctx: GeneratedParser.LiteralDecimalContext) = translate(ctx) { + val decimal = try { + val v = ctx.LITERAL_DECIMAL().text.trim() + BigDecimal(v, MathContext(38, RoundingMode.HALF_EVEN)) + } catch (e: NumberFormatException) { + throw error(ctx, "Invalid decimal literal", e) + } + exprLit(decimalValue(decimal)) + } + + override fun visitArray(ctx: GeneratedParser.ArrayContext) = translate(ctx) { + val expressions = visitOrEmpty(ctx.expr()) + exprCollection(Expr.Collection.Type.ARRAY, expressions) + } + + override fun visitLiteralNull(ctx: GeneratedParser.LiteralNullContext) = translate(ctx) { + exprLit(nullValue()) + } + + override fun visitLiteralMissing(ctx: GeneratedParser.LiteralMissingContext) = translate(ctx) { + exprLit(missingValue()) + } + + override fun visitLiteralTrue(ctx: GeneratedParser.LiteralTrueContext) = translate(ctx) { + exprLit(boolValue(true)) + } + + override fun visitLiteralFalse(ctx: GeneratedParser.LiteralFalseContext) = translate(ctx) { + exprLit(boolValue(false)) + } + + override fun visitLiteralIon(ctx: GeneratedParser.LiteralIonContext) = translate(ctx) { + val value = try { + loadSingleElement(ctx.ION_CLOSURE().getStringValue()) + } catch (e: IonElementException) { + throw error(ctx, "Unable to parse Ion value.", e) + } + exprIon(value) + } + + override fun visitLiteralString(ctx: GeneratedParser.LiteralStringContext) = translate(ctx) { + val value = ctx.LITERAL_STRING().getStringValue() + exprLit(stringValue(value)) + } + + override fun visitLiteralInteger(ctx: GeneratedParser.LiteralIntegerContext) = translate(ctx) { + val n = ctx.LITERAL_INTEGER().text.toInt() + val v = try { + int64Value(n.toLong()) + } catch (_: java.lang.NumberFormatException) { + intValue(n.toBigInteger()) + } + exprLit(v) + } + + override fun visitLiteralDate(ctx: GeneratedParser.LiteralDateContext) = translate(ctx) { + val pattern = ctx.LITERAL_STRING().symbol + val dateString = ctx.LITERAL_STRING().getStringValue() + if (DATE_PATTERN_REGEX.matches(dateString).not()) { + throw error(pattern, "Expected DATE string to be of the format yyyy-MM-dd") + } + val value = try { + LocalDate.parse(dateString, DateTimeFormatter.ISO_LOCAL_DATE) + } catch (e: DateTimeParseException) { + throw error(pattern, e.localizedMessage, e) + } catch (e: IndexOutOfBoundsException) { + throw error(pattern, e.localizedMessage, e) + } + val date = DateTimeValue.date(value.year, value.monthValue, value.dayOfMonth) + exprLit(dateValue(date)) + } + + override fun visitLiteralTime(ctx: GeneratedParser.LiteralTimeContext) = translate(ctx) { + val (timeString, precision) = getTimeStringAndPrecision(ctx.LITERAL_STRING(), ctx.LITERAL_INTEGER()) + val time = try { + DateTimeUtils.parseTimeLiteral(timeString) + } catch (e: DateTimeException) { + throw error(ctx, "Invalid Date Time Literal", e) + } + val value = time.toPrecision(precision) + exprLit(timeValue(value)) + } + + override fun visitLiteralTimestamp(ctx: GeneratedParser.LiteralTimestampContext) = translate(ctx) { + val (timeString, precision) = getTimeStringAndPrecision(ctx.LITERAL_STRING(), ctx.LITERAL_INTEGER()) + val timestamp = try { + DateTimeUtils.parseTimestamp(timeString) + } catch (e: DateTimeException) { + throw error(ctx, "Invalid Date Time Literal", e) + } + val value = timestamp.toPrecision(precision) + exprLit(timestampValue(value)) + } + + override fun visitTuple(ctx: GeneratedParser.TupleContext) = translate(ctx) { + val fields = ctx.pair().map { + val k = visitExpr(it.lhs) + val v = visitExpr(it.rhs) + exprStructField(k, v) + } + exprStruct(fields) + } + + /** + * + * TYPES + * + */ + + override fun visitTypeAtomic(ctx: GeneratedParser.TypeAtomicContext) = translate(ctx) { + when (ctx.datatype.type) { + GeneratedParser.NULL -> typeNullType() + GeneratedParser.BOOL, GeneratedParser.BOOLEAN -> typeBool() + GeneratedParser.SMALLINT, GeneratedParser.INT2, GeneratedParser.INTEGER2 -> typeInt2() + GeneratedParser.INT4, GeneratedParser.INTEGER4 -> typeInt4() + GeneratedParser.BIGINT, GeneratedParser.INT8, GeneratedParser.INTEGER8 -> typeInt8() + GeneratedParser.INT, GeneratedParser.INTEGER -> typeInt() + GeneratedParser.FLOAT -> typeFloat32() + GeneratedParser.DOUBLE -> typeFloat64() + GeneratedParser.REAL -> typeReal() + GeneratedParser.TIMESTAMP -> typeTimestamp(null) + GeneratedParser.CHAR, GeneratedParser.CHARACTER -> typeChar(null) + GeneratedParser.MISSING -> typeMissing() + GeneratedParser.STRING -> typeString(null) + GeneratedParser.SYMBOL -> typeSymbol() + // TODO https://github.com/partiql/partiql-lang-kotlin/issues/1125 + GeneratedParser.BLOB -> typeBlob(null) + GeneratedParser.CLOB -> typeClob(null) + GeneratedParser.DATE -> typeDate() + GeneratedParser.STRUCT -> typeStruct() + GeneratedParser.TUPLE -> typeTuple() + GeneratedParser.LIST -> typeList() + GeneratedParser.SEXP -> typeSexp() + GeneratedParser.BAG -> typeBag() + GeneratedParser.ANY -> typeAny() + else -> throw error(ctx, "Unknown atomic type.") + } + } + + override fun visitTypeVarChar(ctx: GeneratedParser.TypeVarCharContext) = translate(ctx) { + val n = ctx.arg0?.text?.toInt() + typeVarchar(n) + } + + override fun visitTypeArgSingle(ctx: GeneratedParser.TypeArgSingleContext) = translate(ctx) { + val n = ctx.arg0?.text?.toInt() + when (ctx.datatype.type) { + GeneratedParser.FLOAT -> when (n) { + 32 -> typeFloat32() + 64 -> typeFloat64() + else -> throw error(ctx.datatype, "Invalid FLOAT precision. Expected 32 or 64") + } + GeneratedParser.CHAR, GeneratedParser.CHARACTER -> typeChar(n) + GeneratedParser.VARCHAR -> typeVarchar(n) + else -> throw error(ctx.datatype, "Invalid datatype") + } + } + + override fun visitTypeArgDouble(ctx: GeneratedParser.TypeArgDoubleContext) = translate(ctx) { + val arg0 = ctx.arg0?.text?.toInt() + val arg1 = ctx.arg1?.text?.toInt() + when (ctx.datatype.type) { + GeneratedParser.DECIMAL, GeneratedParser.DEC -> typeDecimal(arg0, arg1) + GeneratedParser.NUMERIC -> typeNumeric(arg0, arg1) + else -> throw error(ctx.datatype, "Invalid datatype") + } + } + + override fun visitTypeTimeZone(ctx: GeneratedParser.TypeTimeZoneContext) = translate(ctx) { + val precision = ctx.precision?.let { + val p = ctx.precision.text.toInt() + if (p < 0 || 9 < p) throw error(ctx.precision, "Unsupported time precision") + p + } + when (ctx.ZONE()) { + null -> typeTime(precision) + else -> typeTimeWithTz(precision) + } + } + + override fun visitTypeCustom(ctx: GeneratedParser.TypeCustomContext) = translate(ctx) { + typeCustom(ctx.text.uppercase()) + } + + private inline fun visitOrEmpty(ctx: List?): List = when { + ctx.isNullOrEmpty() -> emptyList() + else -> ctx.map { visit(it) as T } + } + + private inline fun visitOrNull(ctx: ParserRuleContext?): T? = ctx?.let { it.accept(this) as T } + + private inline fun visitAs(ctx: ParserRuleContext): T = visit(ctx) as T + + /** + * Visiting a symbol to get a string, skip the wrapping, unwrapping, and location tracking. + */ + private fun symbol(ctx: GeneratedParser.SymbolPrimitiveContext) = when (ctx.ident.type) { + GeneratedParser.IDENTIFIER_QUOTED -> ctx.IDENTIFIER_QUOTED().getStringValue() + GeneratedParser.IDENTIFIER -> ctx.IDENTIFIER().getStringValue() + else -> throw error(ctx, "Invalid symbol reference.") + } + + /** + * Convert [ALL|DISTINCT] to SetQuantifier Enum + */ + private fun convertSetQuantifier(ctx: GeneratedParser.SetQuantifierStrategyContext?): SetQuantifier? = when { + ctx == null -> null + ctx.ALL() != null -> SetQuantifier.ALL + ctx.DISTINCT() != null -> SetQuantifier.DISTINCT + else -> throw error(ctx, "Expected set quantifier ALL or DISTINCT") + } + + /** + * With the and nodes of a literal time expression, returns the parsed string and precision. + * TIME ()? (WITH TIME ZONE)? + */ + private fun getTimeStringAndPrecision( + stringNode: TerminalNode, + integerNode: TerminalNode? + ): Pair { + val timeString = stringNode.getStringValue() + val precision = when (integerNode) { + null -> { + try { + getPrecisionFromTimeString(timeString) + } catch (e: Exception) { + throw error(stringNode.symbol, "Unable to parse precision.", e) + } + } + else -> { + val p = integerNode.text.toBigInteger().toInt() + if (p < 0 || 9 < p) throw error(integerNode.symbol, "Precision out of bounds") + p + } + } + return timeString to precision + } + + private fun getPrecisionFromTimeString(timeString: String): Int { + val matcher = GENERIC_TIME_REGEX.toPattern().matcher(timeString) + if (!matcher.find()) { + throw IllegalArgumentException("Time string does not match the format 'HH:MM:SS[.ddd....][+|-HH:MM]'") + } + val fraction = matcher.group(1)?.removePrefix(".") + return fraction?.length ?: 0 + } + + /** + * Converts a Path expression into a Projection Item (either ALL or EXPR). Note: A Projection Item only allows a + * subset of a typical Path expressions. See the following examples. + * + * Examples of valid projections are: + * + * ```partiql + * SELECT * FROM foo + * SELECT foo.* FROM foo + * SELECT f.* FROM foo as f + * SELECT foo.bar.* FROM foo + * SELECT f.bar.* FROM foo as f + * ``` + * Also validates that the expression is valid for select list context. It does this by making + * sure that expressions looking like the following do not appear: + * + * ```partiql + * SELECT foo[*] FROM foo + * SELECT f.*.bar FROM foo as f + * SELECT foo[1].* FROM foo + * SELECT foo.*.bar FROM foo + * ``` + */ + protected fun convertPathToProjectionItem(ctx: ParserRuleContext, path: Expr.Path, alias: String?) = + translate(ctx) { + val steps = mutableListOf() + var containsIndex = false + path.steps.forEachIndexed { index, step -> + // Only last step can have a '.*' + if (step is Expr.Path.Step.Unpivot && index != path.steps.lastIndex) { + throw error(ctx, "Projection item cannot unpivot unless at end.") + } + // No step can have an indexed wildcard: '[*]' + if (step is Expr.Path.Step.Wildcard) { + throw error(ctx, "Projection item cannot index using wildcard.") + } + // TODO If the last step is '.*', no indexing is allowed + // if (step.metas.containsKey(IsPathIndexMeta.TAG)) { + // containsIndex = true + // } + if (step !is Expr.Path.Step.Unpivot) { + steps.add(step) + } + } + if (path.steps.last() is Expr.Path.Step.Unpivot && containsIndex) { + throw error(ctx, "Projection item use wildcard with any indexing.") + } + when { + path.steps.last() is Expr.Path.Step.Unpivot && steps.isEmpty() -> { + selectProjectItemAll(path.root) + } + path.steps.last() is Expr.Path.Step.Unpivot -> { + selectProjectItemAll(factory.exprPath(path.root, steps)) + } + else -> { + selectProjectItemExpression(path, alias) + } + } + } + + private fun TerminalNode.getStringValue(): String = this.symbol.getStringValue() + + private fun Token.getStringValue(): String = when (this.type) { + GeneratedParser.IDENTIFIER -> this.text + GeneratedParser.IDENTIFIER_QUOTED -> this.text.removePrefix("\"").removeSuffix("\"").replace("\"\"", "\"") + GeneratedParser.LITERAL_STRING -> this.text.removePrefix("'").removeSuffix("'").replace("''", "'") + GeneratedParser.ION_CLOSURE -> this.text.removePrefix("`").removeSuffix("`") + else -> throw error(this, "Unsupported token for grabbing string value.") + } + + private fun String.toIdentifier(): Identifier.Symbol = factory.identifierSymbol( + symbol = this, + caseSensitivity = Identifier.CaseSensitivity.INSENSITIVE, + ) + + private fun String.toBigInteger() = BigInteger(this, 10) + + private fun assertIntegerElement(token: Token?, value: IonElement?) { + if (value == null || token == null) return + if (value !is IntElement) throw error(token, "Expected an integer value.") + if (value.integerSize == IntElementSize.BIG_INTEGER || value.longValue > Int.MAX_VALUE || value.longValue < Int.MIN_VALUE) throw error( + token, "Type parameter exceeded maximum value" + ) + } + + private enum class ExplainParameters { + TYPE, FORMAT; + + fun getCompliantString(target: String?, input: Token): String = when (target) { + null -> input.text!! + else -> throw error(input, "Cannot set EXPLAIN parameter ${this.name} multiple times.") + } + } + } +} diff --git a/partiql-parser/src/main/kotlin/org/partiql/parser/impl/util/DateTimeUtils.kt b/partiql-parser/src/main/kotlin/org/partiql/parser/impl/util/DateTimeUtils.kt new file mode 100644 index 0000000000..ea4b9a1a70 --- /dev/null +++ b/partiql-parser/src/main/kotlin/org/partiql/parser/impl/util/DateTimeUtils.kt @@ -0,0 +1,76 @@ +package org.partiql.parser.impl.util + +import org.partiql.value.datetime.Date +import org.partiql.value.datetime.DateTimeException +import org.partiql.value.datetime.DateTimeValue +import org.partiql.value.datetime.Time +import org.partiql.value.datetime.TimeZone +import org.partiql.value.datetime.Timestamp +import java.math.BigDecimal +import java.util.regex.Matcher +import java.util.regex.Pattern + +internal object DateTimeUtils { + private val DATE_PATTERN: Pattern = Pattern.compile("(?\\d{4,})-(?\\d{2,})-(?\\d{2,})") + private val TIME_PATTERN: Pattern = Pattern.compile("(?\\d{2,}):(?\\d{2,}):(?\\d{2,})(?:\\.(?\\d+))?\\s*(?([+-]\\d\\d:\\d\\d)|(?[Zz]))?") + private val SQL_TIMESTAMP_DATE_TIME_DELIMITER = "\\s+".toRegex() + private val RFC8889_TIMESTAMP_DATE_TIME_DELIMITER = "[Tt]".toRegex() + private val TIMESTAMP_PATTERN = "(?$DATE_PATTERN)($SQL_TIMESTAMP_DATE_TIME_DELIMITER|$RFC8889_TIMESTAMP_DATE_TIME_DELIMITER)(?