diff --git a/partiql-cli/src/main/kotlin/org/partiql/cli/utils/ServiceLoaderUtil.kt b/partiql-cli/src/main/kotlin/org/partiql/cli/utils/ServiceLoaderUtil.kt index c850e7fb6d..1b649ce707 100644 --- a/partiql-cli/src/main/kotlin/org/partiql/cli/utils/ServiceLoaderUtil.kt +++ b/partiql-cli/src/main/kotlin/org/partiql/cli/utils/ServiceLoaderUtil.kt @@ -263,7 +263,7 @@ class ServiceLoaderUtil { val fraction = partiqlTime.decimalSecond.remainder(BigDecimal.ONE) val precision = when { - fraction.scale() > 9 -> (partiqlTime.decimalSecond.remainder(BigDecimal.ONE) * BigDecimal.valueOf(1_000_000_000)).toInt() + fraction.scale() > 9 -> throw DateTimeException("Precision greater than nano seconds not supported") else -> fraction.scale() } @@ -387,7 +387,7 @@ class ServiceLoaderUtil { val fraction = partiqlTime.decimalSecond.remainder(BigDecimal.ONE) val precision = when { - fraction.scale() > 9 -> (partiqlTime.decimalSecond.remainder(BigDecimal.ONE) * BigDecimal.valueOf(1_000_000_000)).toInt() + fraction.scale() > 9 -> throw DateTimeException("Precision greater than nano seconds not supported") else -> fraction.scale() } diff --git a/partiql-lang/src/main/kotlin/org/partiql/lang/syntax/impl/PartiQLPigVisitor.kt b/partiql-lang/src/main/kotlin/org/partiql/lang/syntax/impl/PartiQLPigVisitor.kt index c900b52157..5e8043c1ba 100644 --- a/partiql-lang/src/main/kotlin/org/partiql/lang/syntax/impl/PartiQLPigVisitor.kt +++ b/partiql-lang/src/main/kotlin/org/partiql/lang/syntax/impl/PartiQLPigVisitor.kt @@ -35,8 +35,6 @@ import com.amazon.ionelement.api.ionSymbol import com.amazon.ionelement.api.loadSingleElement import org.antlr.v4.runtime.ParserRuleContext import org.antlr.v4.runtime.Token -import org.antlr.v4.runtime.tree.ErrorNode -import org.antlr.v4.runtime.tree.RuleNode import org.antlr.v4.runtime.tree.TerminalNode import org.partiql.lang.ast.IsCountStarMeta import org.partiql.lang.ast.IsImplictJoinMeta @@ -73,12 +71,49 @@ import java.time.LocalTime import java.time.OffsetTime import java.time.format.DateTimeFormatter import java.time.format.DateTimeParseException -import kotlin.reflect.KClass -import kotlin.reflect.cast /** * Extends ANTLR's generated [PartiQLBaseVisitor] to visit an ANTLR ParseTree and convert it into a PartiQL AST. This * class uses the [PartiqlAst.PartiqlAstNode] to represent all nodes within the new AST. + * + * When the grammar in PartiQL.g4 is extended with a new rule, one needs to override corresponding visitor methods + * in this class, in order to extend the transformation from an ANTLR parse tree into a [PartqlAst] tree. + * (Trivial implementations of these methods are generated into [PartiQLBaseVisitor].) + * + * For a rule of the form + * ``` + * Aaa + * : B1 ... Bn ; + * ``` + * it generates the `visitAaa(ctx: PartiQLParser.AaaContext ctx)` method, + * while for a rule of the form + * ``` + * Aaa + * : B1 ... Bn # A1 + * | C1 ... Cm # A2 + * ; + * ``` + * it generates methods `visitA1(ctx: PartiQLParser.A1Context ctx)` and `visitA2(ctx: PartiQLParser.A2Context ctx)`, + * but not `visitAaa`. + * The context objects `ctx` provide access to the terminals and non-terminals (`Bi`, `Cj`) necessary for + * implementing the methods suitably. + * + * Conversely, when implementing the visitor for another rule that *references* `Aaa`, + * - The visitor for a rule of the 1st form can be recursively invoked as `visitAaa(ctx.Aaa)`, + * which usually returns an AST node of the desired type. + * - For the rule of the 2nd form, as there is no `visitAaa`, one has to invoke `AbstractParseTreeVisitor.visit()` + * and then cast the result to the expected AST type, doing something like + * ``` + * visit(ctx.Aaa) as PartiqlAst.Aaa + * ``` + * This delegates to `accept()` which, at run time, invokes the appropriate visitor (`visitA1` or `visitA2`). + * However, any static guarantee is lost (in principle, `accept` can dispatch to any visitor of any rule + * in the grammar), hence the need for the cast. + * + * Note: A rule of an intermediate form between the above two is allowed: when there are multiple alternative clauses, + * but no labels on the clauses. In this case, it generates `visitAaa` whose context object `ctx` provides access + * to the combined set of non-terminals of the rule's clauses -- which are then visible at nullable types. + * There could be clever ways of exploiting this, to avoid the dispatch via `visit()`. */ internal class PartiQLPigVisitor( val customTypes: List = listOf(), @@ -213,20 +248,20 @@ internal class PartiQLPigVisitor( } override fun visitTableDef(ctx: PartiQLParser.TableDefContext) = PartiqlAst.build { - val parts = visitOrEmpty(ctx.tableDefPart(), PartiqlAst.TableDefPart::class) + val parts = ctx.tableDefPart().map { visit(it) as PartiqlAst.TableDefPart } tableDef(parts) } override fun visitColumnDeclaration(ctx: PartiQLParser.ColumnDeclarationContext) = PartiqlAst.build { val name = visitSymbolPrimitive(ctx.columnName().symbolPrimitive()).name.text - val type = visit(ctx.type(), PartiqlAst.Type::class) + val type = visit(ctx.type()) as PartiqlAst.Type val constrs = ctx.columnConstraint().map { visitColumnConstraint(it) } columnDeclaration(name, type, constrs) } override fun visitColumnConstraint(ctx: PartiQLParser.ColumnConstraintContext) = PartiqlAst.build { val name = ctx.columnConstraintName()?.let { visitSymbolPrimitive(it.symbolPrimitive()).name.text } - val def = visit(ctx.columnConstraintDef(), PartiqlAst.ColumnConstraintDef::class) + val def = visit(ctx.columnConstraintDef()) as PartiqlAst.ColumnConstraintDef columnConstraint(name, def) } @@ -248,7 +283,7 @@ internal class PartiQLPigVisitor( override fun visitExecCommand(ctx: PartiQLParser.ExecCommandContext) = PartiqlAst.build { val name = visitExpr(ctx.name).getStringValue(ctx.name.getStart()) - val args = visitOrEmpty(ctx.args, PartiqlAst.Expr::class) + val args = ctx.args.map { visitExpr(it) } exec_( SymbolPrimitive(name.lowercase(), emptyMetaContainer()), args, @@ -268,9 +303,9 @@ internal class PartiQLPigVisitor( ctx.fromClause() != null -> ctx.fromClause() else -> throw ParserException("Unable to deduce from source in DML", ErrorCode.PARSE_INVALID_QUERY) } - val from = visitOrNull(sourceContext, PartiqlAst.FromSource::class) - val where = visitOrNull(ctx.whereClause(), PartiqlAst.Expr::class) - val returning = visitOrNull(ctx.returningClause(), PartiqlAst.ReturningExpr::class) + val from = sourceContext?.let { visit(it) as PartiqlAst.FromSource } + val where = ctx.whereClause()?.let { visitWhereClause(it) } + val returning = ctx.returningClause()?.let { visitReturningClause(it) } val operations = ctx.dmlBaseCommand().map { command -> getCommandList(visit(command)) }.flatten() dml(dmlOpList(operations, operations[0].metas), from, where, returning, metas = operations[0].metas) } @@ -294,9 +329,9 @@ internal class PartiQLPigVisitor( } override fun visitDeleteCommand(ctx: PartiQLParser.DeleteCommandContext) = PartiqlAst.build { - val from = visit(ctx.fromClauseSimple(), PartiqlAst.FromSource::class) - val where = visitOrNull(ctx.whereClause(), PartiqlAst.Expr::class) - val returning = visitOrNull(ctx.returningClause(), PartiqlAst.ReturningExpr::class) + val from = visit(ctx.fromClauseSimple()) as PartiqlAst.FromSource + val where = ctx.whereClause()?.let { visitWhereClause(it) } + val returning = ctx.returningClause()?.let { visitReturningClause(it) } dml( dmlOpList(delete(ctx.DELETE().getSourceMetaContainer()), metas = ctx.DELETE().getSourceMetaContainer()), from, @@ -309,16 +344,16 @@ internal class PartiQLPigVisitor( override fun visitInsertStatementLegacy(ctx: PartiQLParser.InsertStatementLegacyContext) = PartiqlAst.build { val metas = ctx.INSERT().getSourceMetaContainer() val target = visitPathSimple(ctx.pathSimple()) - val index = visitOrNull(ctx.pos, PartiqlAst.Expr::class) + val index = ctx.pos?.let { visitExpr(it) } val onConflict = ctx.onConflictLegacy()?.let { visitOnConflictLegacy(it) } - insertValue(target, visit(ctx.value, PartiqlAst.Expr::class), index, onConflict, metas) + insertValue(target, visitExpr(ctx.value), index, onConflict, metas) } override fun visitInsertStatement(ctx: PartiQLParser.InsertStatementContext) = PartiqlAst.build { insert( target = visitSymbolPrimitive(ctx.symbolPrimitive()), - asAlias = visitOrNull(ctx.asIdent(), PartiqlAst.Expr.Id::class)?.name?.text, - values = visit(ctx.value, PartiqlAst.Expr::class), + asAlias = ctx.asIdent()?.let { visitAsIdent(it).name.text }, + values = visitExpr(ctx.value), conflictAction = ctx.onConflict()?.let { visitOnConflict(it) }, metas = ctx.INSERT().getSourceMetaContainer() ) @@ -327,8 +362,8 @@ internal class PartiQLPigVisitor( override fun visitReplaceCommand(ctx: PartiQLParser.ReplaceCommandContext) = PartiqlAst.build { insert( target = visitSymbolPrimitive(ctx.symbolPrimitive()), - asAlias = visitOrNull(ctx.asIdent(), PartiqlAst.Expr.Id::class)?.name?.text, - values = visit(ctx.value, PartiqlAst.Expr::class), + asAlias = ctx.asIdent()?.let { visitAsIdent(it).name.text }, + values = visitExpr(ctx.value), conflictAction = doReplace(excluded()), metas = ctx.REPLACE().getSourceMetaContainer() ) @@ -338,8 +373,8 @@ internal class PartiQLPigVisitor( override fun visitUpsertCommand(ctx: PartiQLParser.UpsertCommandContext) = PartiqlAst.build { insert( target = visitSymbolPrimitive(ctx.symbolPrimitive()), - asAlias = visitOrNull(ctx.asIdent(), PartiqlAst.Expr.Id::class)?.name?.text, - values = visit(ctx.value, PartiqlAst.Expr::class), + asAlias = ctx.asIdent()?.let { visitAsIdent(it).name.text }, + values = visitExpr(ctx.value), conflictAction = doUpdate(excluded()), metas = ctx.UPSERT().getSourceMetaContainer() ) @@ -349,14 +384,14 @@ internal class PartiQLPigVisitor( override fun visitInsertCommandReturning(ctx: PartiQLParser.InsertCommandReturningContext) = PartiqlAst.build { val metas = ctx.INSERT().getSourceMetaContainer() val target = visitPathSimple(ctx.pathSimple()) - val index = visitOrNull(ctx.pos, PartiqlAst.Expr::class) + val index = ctx.pos?.let { visitExpr(it) } val onConflictLegacy = ctx.onConflictLegacy()?.let { visitOnConflictLegacy(it) } - val returning = visitOrNull(ctx.returningClause(), PartiqlAst.ReturningExpr::class) + val returning = ctx.returningClause()?.let { visitReturningClause(it) } dml( dmlOpList( insertValue( target, - visit(ctx.value, PartiqlAst.Expr::class), + visitExpr(ctx.value), index = index, onConflict = onConflictLegacy, ctx.INSERT().getSourceMetaContainer() @@ -369,7 +404,7 @@ internal class PartiQLPigVisitor( } override fun visitReturningClause(ctx: PartiQLParser.ReturningClauseContext) = PartiqlAst.build { - val elements = visitOrEmpty(ctx.returningColumn(), PartiqlAst.ReturningElem::class) + val elements = ctx.returningColumn().map { visit(it) as PartiqlAst.ReturningElem } returningExpr(elements, ctx.RETURNING().getSourceMetaContainer()) } @@ -392,7 +427,7 @@ internal class PartiQLPigVisitor( } override fun visitOnConflict(ctx: PartiQLParser.OnConflictContext) = PartiqlAst.build { - visit(ctx.conflictAction(), PartiqlAst.ConflictAction::class) + visitConflictAction(ctx.conflictAction()) } override fun visitOnConflictLegacy(ctx: PartiQLParser.OnConflictLegacyContext) = PartiqlAst.build { @@ -417,7 +452,7 @@ internal class PartiQLPigVisitor( ctx.EXCLUDED() != null -> excluded() else -> TODO("DO REPLACE doesn't support values other than `EXCLUDED` yet.") } - val condition = visitOrNull(ctx.condition, PartiqlAst.Expr::class) + val condition = ctx.condition?.let { visitExpr(it) } doReplace(value, condition) } @@ -426,14 +461,14 @@ internal class PartiQLPigVisitor( ctx.EXCLUDED() != null -> excluded() else -> TODO("DO UPDATE doesn't support values other than `EXCLUDED` yet.") } - val condition = visitOrNull(ctx.condition, PartiqlAst.Expr::class) + val condition = ctx.condition?.let { visitExpr(it) } doUpdate(value, condition) } override fun visitPathSimple(ctx: PartiQLParser.PathSimpleContext) = PartiqlAst.build { val root = visitSymbolPrimitive(ctx.symbolPrimitive()) if (ctx.pathSimpleSteps().isEmpty()) return@build root - val steps = visitOrEmpty(ctx.pathSimpleSteps(), PartiqlAst.PathStep::class) + val steps = ctx.pathSimpleSteps().map { visit(it) as PartiqlAst.PathStep } path(root, steps, root.metas) } @@ -449,7 +484,7 @@ internal class PartiQLPigVisitor( getSymbolPathExpr(ctx.symbolPrimitive()) override fun visitSetCommand(ctx: PartiQLParser.SetCommandContext) = PartiqlAst.build { - val assignments = visitOrEmpty(ctx.setAssignment(), PartiqlAst.DmlOp.Set::class) + val assignments = ctx.setAssignment().map { visitSetAssignment(it) } val newSets = assignments.map { assignment -> assignment.copy(metas = ctx.SET().getSourceMetaContainer()) } dmlOpList(newSets, ctx.SET().getSourceMetaContainer()) } @@ -459,7 +494,7 @@ internal class PartiQLPigVisitor( } override fun visitUpdateClause(ctx: PartiQLParser.UpdateClauseContext) = - visit(ctx.tableBaseReference(), PartiqlAst.FromSource::class) + visit(ctx.tableBaseReference()) as PartiqlAst.FromSource /** * @@ -468,23 +503,24 @@ internal class PartiQLPigVisitor( */ override fun visitDql(ctx: PartiQLParser.DqlContext) = PartiqlAst.build { - val query = visit(ctx.expr(), PartiqlAst.Expr::class) + val query = visitExpr(ctx.expr()) query(query, query.metas) } - override fun visitQueryBase(ctx: PartiQLParser.QueryBaseContext) = visit(ctx.exprSelect(), PartiqlAst.Expr::class) + override fun visitQueryBase(ctx: PartiQLParser.QueryBaseContext) = + visit(ctx.exprSelect()) as PartiqlAst.Expr override fun visitSfwQuery(ctx: PartiQLParser.SfwQueryContext) = PartiqlAst.build { - val projection = visit(ctx.select, PartiqlAst.Projection::class) + val projection = visit(ctx.select) as PartiqlAst.Projection val strategy = getSetQuantifierStrategy(ctx.select) val from = visitFromClause(ctx.from) - val order = visitOrNull(ctx.order, PartiqlAst.OrderBy::class) - val group = visitOrNull(ctx.group, PartiqlAst.GroupBy::class) - val limit = visitOrNull(ctx.limit, PartiqlAst.Expr::class) - val offset = visitOrNull(ctx.offset, PartiqlAst.Expr::class) - val where = visitOrNull(ctx.where, PartiqlAst.Expr::class) - val having = visitOrNull(ctx.having, PartiqlAst.Expr::class) - val let = visitOrNull(ctx.let, PartiqlAst.Let::class) + val order = ctx.order?.let { visitOrderByClause(it) } + val group = ctx.group?.let { visitGroupClause(it) } + val limit = ctx.limit?.let { visitLimitClause(it) } + val offset = ctx.offset?.let { visitOffsetByClause(it) } + val where = ctx.where?.let { visitWhereClauseSelect(it) } + val having = ctx.having?.let { visitHavingClause(it) } + val let = ctx.let?.let { visitLetClause(it) } val metas = ctx.selectClause().getMetas() select( project = projection, @@ -520,7 +556,10 @@ internal class PartiQLPigVisitor( } override fun visitSelectItems(ctx: PartiQLParser.SelectItemsContext) = - convertProjectionItems(ctx.projectionItems(), ctx.SELECT().getSourceMetaContainer()) + PartiqlAst.build { + val projections = ctx.projectionItems().projectionItem().map { visitProjectionItem(it) } + projectList(projections, ctx.SELECT().getSourceMetaContainer()) + } override fun visitSelectPivot(ctx: PartiQLParser.SelectPivotContext) = PartiqlAst.build { projectPivot(visitExpr(ctx.at), visitExpr(ctx.pivot)) @@ -531,8 +570,8 @@ internal class PartiQLPigVisitor( } override fun visitProjectionItem(ctx: PartiQLParser.ProjectionItemContext) = PartiqlAst.build { - val expr = visit(ctx.expr(), PartiqlAst.Expr::class) - val alias = visitOrNull(ctx.symbolPrimitive(), PartiqlAst.Expr.Id::class)?.name + val expr = visitExpr(ctx.expr()) + val alias = ctx.symbolPrimitive()?.let { visitSymbolPrimitive(it).name } if (expr is PartiqlAst.Expr.Path) convertPathToProjectionItem(expr, alias) else projectExpr_(expr, asAlias = alias, expr.metas) } @@ -544,21 +583,23 @@ internal class PartiQLPigVisitor( */ override fun visitLimitClause(ctx: PartiQLParser.LimitClauseContext): PartiqlAst.Expr = - visit(ctx.arg, PartiqlAst.Expr::class) + visit(ctx.arg) as PartiqlAst.Expr override fun visitExpr(ctx: PartiQLParser.ExprContext): PartiqlAst.Expr { checkThreadInterrupted() - return visit(ctx.exprBagOp(), PartiqlAst.Expr::class) + return visit(ctx.exprBagOp()) as PartiqlAst.Expr } - override fun visitOffsetByClause(ctx: PartiQLParser.OffsetByClauseContext) = visit(ctx.arg, PartiqlAst.Expr::class) + override fun visitOffsetByClause(ctx: PartiQLParser.OffsetByClauseContext) = + visit(ctx.arg) as PartiqlAst.Expr override fun visitWhereClause(ctx: PartiQLParser.WhereClauseContext) = visitExpr(ctx.arg) override fun visitWhereClauseSelect(ctx: PartiQLParser.WhereClauseSelectContext) = - visit(ctx.arg, PartiqlAst.Expr::class) + visit(ctx.arg) as PartiqlAst.Expr - override fun visitHavingClause(ctx: PartiQLParser.HavingClauseContext) = visit(ctx.arg, PartiqlAst.Expr::class) + override fun visitHavingClause(ctx: PartiQLParser.HavingClauseContext) = + visit(ctx.arg) as PartiqlAst.Expr /** * @@ -567,12 +608,12 @@ internal class PartiQLPigVisitor( */ override fun visitLetClause(ctx: PartiQLParser.LetClauseContext) = PartiqlAst.build { - val letBindings = visitOrEmpty(ctx.letBinding(), PartiqlAst.LetBinding::class) + val letBindings = ctx.letBinding().map { visitLetBinding(it) } let(letBindings) } override fun visitLetBinding(ctx: PartiQLParser.LetBindingContext) = PartiqlAst.build { - val expr = visit(ctx.expr(), PartiqlAst.Expr::class) + val expr = visitExpr(ctx.expr()) val metas = ctx.symbolPrimitive().getSourceMetaContainer() letBinding_(expr, convertSymbolPrimitive(ctx.symbolPrimitive())!!, metas) } @@ -584,13 +625,13 @@ internal class PartiQLPigVisitor( */ override fun visitOrderByClause(ctx: PartiQLParser.OrderByClauseContext) = PartiqlAst.build { - val sortSpecs = visitOrEmpty(ctx.orderSortSpec(), PartiqlAst.SortSpec::class) + val sortSpecs = ctx.orderSortSpec().map { visitOrderSortSpec(it) } val metas = ctx.ORDER().getSourceMetaContainer() orderBy(sortSpecs, metas) } override fun visitOrderSortSpec(ctx: PartiQLParser.OrderSortSpecContext) = PartiqlAst.build { - val expr = visit(ctx.expr(), PartiqlAst.Expr::class) + val expr = visitExpr(ctx.expr()) val orderSpec = when { ctx.dir == null -> null ctx.dir.type == PartiQLParser.ASC -> asc() @@ -614,9 +655,9 @@ internal class PartiQLPigVisitor( override fun visitGroupClause(ctx: PartiQLParser.GroupClauseContext) = PartiqlAst.build { val strategy = if (ctx.PARTIAL() != null) groupPartial() else groupFull() - val keys = visitOrEmpty(ctx.groupKey(), PartiqlAst.GroupKey::class) + val keys = ctx.groupKey().map { visitGroupKey(it) } val keyList = groupKeyList(keys) - val alias = visitOrNull(ctx.groupAlias(), PartiqlAst.Expr.Id::class).toPigSymbolPrimitive() + val alias = ctx.groupAlias()?.let { visitGroupAlias(it).toPigSymbolPrimitive() } groupBy_(strategy, keyList = keyList, groupAsAlias = alias, ctx.GROUP().getSourceMetaContainer()) } @@ -628,7 +669,7 @@ internal class PartiQLPigVisitor( * This is to match the functionality of SqlParser, but this should likely be adjusted. */ override fun visitGroupKey(ctx: PartiQLParser.GroupKeyContext) = PartiqlAst.build { - val expr = visit(ctx.key, PartiqlAst.Expr::class) + val expr = visit(ctx.key) as PartiqlAst.Expr val possibleLiteral = when (expr) { is PartiqlAst.Expr.Pos -> expr.expr is PartiqlAst.Expr.Neg -> expr.expr @@ -643,7 +684,7 @@ internal class PartiQLPigVisitor( ErrorCode.PARSE_UNSUPPORTED_LITERALS_GROUPBY ) } - val alias = visitOrNull(ctx.symbolPrimitive(), PartiqlAst.Expr.Id::class).toPigSymbolPrimitive() + val alias = ctx.symbolPrimitive()?.let { visitSymbolPrimitive(it).toPigSymbolPrimitive() } groupKey_(expr, asAlias = alias, expr.metas) } @@ -654,8 +695,8 @@ internal class PartiQLPigVisitor( */ override fun visitIntersect(ctx: PartiQLParser.IntersectContext) = PartiqlAst.build { - val lhs = visit(ctx.lhs, PartiqlAst.Expr::class) - val rhs = visit(ctx.rhs, PartiqlAst.Expr::class) + val lhs = visit(ctx.lhs) as PartiqlAst.Expr + val rhs = visit(ctx.rhs) as PartiqlAst.Expr val quantifier = if (ctx.ALL() != null) all() else distinct() val (intersect, metas) = when (ctx.OUTER()) { null -> intersect() to ctx.INTERSECT().getSourceMetaContainer() @@ -665,8 +706,8 @@ internal class PartiQLPigVisitor( } override fun visitExcept(ctx: PartiQLParser.ExceptContext) = PartiqlAst.build { - val lhs = visit(ctx.lhs, PartiqlAst.Expr::class) - val rhs = visit(ctx.rhs, PartiqlAst.Expr::class) + val lhs = visit(ctx.lhs) as PartiqlAst.Expr + val rhs = visit(ctx.rhs) as PartiqlAst.Expr val quantifier = if (ctx.ALL() != null) all() else distinct() val (except, metas) = when (ctx.OUTER()) { null -> except() to ctx.EXCEPT().getSourceMetaContainer() @@ -676,8 +717,8 @@ internal class PartiQLPigVisitor( } override fun visitUnion(ctx: PartiQLParser.UnionContext) = PartiqlAst.build { - val lhs = visit(ctx.lhs, PartiqlAst.Expr::class) - val rhs = visit(ctx.rhs, PartiqlAst.Expr::class) + val lhs = visit(ctx.lhs) as PartiqlAst.Expr + val rhs = visit(ctx.rhs) as PartiqlAst.Expr val quantifier = if (ctx.ALL() != null) all() else distinct() val (union, metas) = when (ctx.OUTER()) { null -> union() to ctx.UNION().getSourceMetaContainer() @@ -693,21 +734,21 @@ internal class PartiQLPigVisitor( */ override fun visitGpmlPattern(ctx: PartiQLParser.GpmlPatternContext) = PartiqlAst.build { - val selector = visitOrNull(ctx.matchSelector(), PartiqlAst.GraphMatchSelector::class) + val selector = ctx.matchSelector()?.let { visit(it) as PartiqlAst.GraphMatchSelector } val pattern = visitMatchPattern(ctx.matchPattern()) gpmlPattern(selector, listOf(pattern)) } override fun visitGpmlPatternList(ctx: PartiQLParser.GpmlPatternListContext) = PartiqlAst.build { - val selector = visitOrNull(ctx.matchSelector(), PartiqlAst.GraphMatchSelector::class) + val selector = ctx.matchSelector()?.let { visit(it) as PartiqlAst.GraphMatchSelector } val patterns = ctx.matchPattern().map { pattern -> visitMatchPattern(pattern) } gpmlPattern(selector, patterns) } override fun visitMatchPattern(ctx: PartiQLParser.MatchPatternContext) = PartiqlAst.build { - val parts = visitOrEmpty(ctx.graphPart(), PartiqlAst.GraphMatchPatternPart::class) - val restrictor = visitOrNull(ctx.restrictor, PartiqlAst.GraphMatchRestrictor::class) - val variable = visitOrNull(ctx.variable, PartiqlAst.Expr.Id::class)?.name + val parts = ctx.graphPart().map { visit(it) as PartiqlAst.GraphMatchPatternPart } + val restrictor = ctx.restrictor?.let { visitPatternRestrictor(it) } + val variable = ctx.variable?.let { visitPatternPathVariable(it).name } graphMatchPattern_(parts = parts, restrictor = restrictor, variable = variable) } @@ -744,11 +785,11 @@ internal class PartiQLPigVisitor( visitSymbolPrimitive(ctx.symbolPrimitive()) override fun visitPattern(ctx: PartiQLParser.PatternContext) = PartiqlAst.build { - val restrictor = visitOrNull(ctx.restrictor, PartiqlAst.GraphMatchRestrictor::class) - val variable = visitOrNull(ctx.variable, PartiqlAst.Expr.Id::class)?.name - val prefilter = visitOrNull(ctx.where, PartiqlAst.Expr::class) - val quantifier = visitOrNull(ctx.quantifier, PartiqlAst.GraphMatchQuantifier::class) - val parts = visitOrEmpty(ctx.graphPart(), PartiqlAst.GraphMatchPatternPart::class) + val restrictor = ctx.restrictor?.let { visitPatternRestrictor(it) } + val variable = ctx.variable?.let { visitPatternPathVariable(it).name } + val prefilter = ctx.where?.let { visitWhereClause(it) } + val quantifier = ctx.quantifier?.let { visitPatternQuantifier(it) } + val parts = ctx.graphPart().map { visit(it) as PartiqlAst.GraphMatchPatternPart } pattern( graphMatchPattern_( parts = parts, @@ -762,21 +803,21 @@ internal class PartiQLPigVisitor( override fun visitEdgeAbbreviated(ctx: PartiQLParser.EdgeAbbreviatedContext) = PartiqlAst.build { val direction = visitEdgeAbbrev(ctx.edgeAbbrev()) - val quantifier = visitOrNull(ctx.quantifier, PartiqlAst.GraphMatchQuantifier::class) + val quantifier = ctx.quantifier?.let { visitPatternQuantifier(it) } edge(direction = direction, quantifier = quantifier) } override fun visitEdgeWithSpec(ctx: PartiQLParser.EdgeWithSpecContext) = PartiqlAst.build { - val quantifier = visitOrNull(ctx.quantifier, PartiqlAst.GraphMatchQuantifier::class) - val edge = visitOrNull(ctx.edgeWSpec(), PartiqlAst.GraphMatchPatternPart.Edge::class) + val quantifier = ctx.quantifier?.let { visitPatternQuantifier(it) } + val edge = ctx.edgeWSpec()?.let { visit(it) as PartiqlAst.GraphMatchPatternPart.Edge } edge!!.copy(quantifier = quantifier) } override fun visitEdgeSpec(ctx: PartiQLParser.EdgeSpecContext) = PartiqlAst.build { val placeholderDirection = edgeRight() - val variable = visitOrNull(ctx.symbolPrimitive(), PartiqlAst.Expr.Id::class)?.name - val prefilter = visitOrNull(ctx.whereClause(), PartiqlAst.Expr::class) - val label = visitOrNull(ctx.patternPartLabel(), PartiqlAst.Expr.Id::class)?.name + val variable = ctx.symbolPrimitive()?.let { visitSymbolPrimitive(it).name } + val prefilter = ctx.whereClause()?.let { visitWhereClause(it) } + val label = ctx.patternPartLabel()?.let { visitPatternPartLabel(it).name } edge_( direction = placeholderDirection, variable = variable, @@ -844,9 +885,9 @@ internal class PartiQLPigVisitor( } override fun visitNode(ctx: PartiQLParser.NodeContext) = PartiqlAst.build { - val variable = visitOrNull(ctx.symbolPrimitive(), PartiqlAst.Expr.Id::class)?.name - val prefilter = visitOrNull(ctx.whereClause(), PartiqlAst.Expr::class) - val label = visitOrNull(ctx.patternPartLabel(), PartiqlAst.Expr.Id::class)?.name + val variable = ctx.symbolPrimitive()?.let { visitSymbolPrimitive(it).name } + val prefilter = ctx.whereClause()?.let { visitWhereClause(it) } + val label = ctx.patternPartLabel()?.let { visitPatternPartLabel(it).name } node_(variable = variable, prefilter = prefilter, label = listOfNotNull(label)) } @@ -867,49 +908,37 @@ internal class PartiQLPigVisitor( */ override fun visitFromClause(ctx: PartiQLParser.FromClauseContext) = - visit(ctx.tableReference(), PartiqlAst.FromSource::class) + visit(ctx.tableReference()) as PartiqlAst.FromSource override fun visitTableBaseRefClauses(ctx: PartiQLParser.TableBaseRefClausesContext) = PartiqlAst.build { - val expr = visit(ctx.source, PartiqlAst.Expr::class) - val (asAlias, atAlias, byAlias) = visitNullableItems( - listOf(ctx.asIdent(), ctx.atIdent(), ctx.byIdent()), - PartiqlAst.Expr.Id::class - ) + val expr = visit(ctx.source) as PartiqlAst.Expr scan_( expr, - asAlias = asAlias.toPigSymbolPrimitive(), - byAlias = byAlias.toPigSymbolPrimitive(), - atAlias = atAlias.toPigSymbolPrimitive(), + asAlias = ctx.asIdent()?.let { visitAsIdent(it).toPigSymbolPrimitive() }, + atAlias = ctx.atIdent()?.let { visitAtIdent(it).toPigSymbolPrimitive() }, + byAlias = ctx.byIdent()?.let { visitByIdent(it).toPigSymbolPrimitive() }, metas = expr.metas ) } override fun visitTableBaseRefMatch(ctx: PartiQLParser.TableBaseRefMatchContext) = PartiqlAst.build { - val expr = visit(ctx.source, PartiqlAst.Expr::class) - val (asAlias, atAlias, byAlias) = visitNullableItems( - listOf(ctx.asIdent(), ctx.atIdent(), ctx.byIdent()), - PartiqlAst.Expr.Id::class - ) + val expr = visit(ctx.source) as PartiqlAst.Expr scan_( expr, - asAlias = asAlias.toPigSymbolPrimitive(), - byAlias = byAlias.toPigSymbolPrimitive(), - atAlias = atAlias.toPigSymbolPrimitive(), + asAlias = ctx.asIdent()?.let { visitAsIdent(it).toPigSymbolPrimitive() }, + atAlias = ctx.atIdent()?.let { visitAtIdent(it).toPigSymbolPrimitive() }, + byAlias = ctx.byIdent()?.let { visitByIdent(it).toPigSymbolPrimitive() }, metas = expr.metas ) } override fun visitFromClauseSimpleExplicit(ctx: PartiQLParser.FromClauseSimpleExplicitContext) = PartiqlAst.build { val expr = visitPathSimple(ctx.pathSimple()) - val (asAlias, atAlias, byAlias) = visitNullableItems( - listOf(ctx.asIdent(), ctx.atIdent(), ctx.byIdent()), - PartiqlAst.Expr.Id::class - ) scan_( expr, - asAlias = asAlias.toPigSymbolPrimitive(), - byAlias = byAlias.toPigSymbolPrimitive(), - atAlias = atAlias.toPigSymbolPrimitive(), + asAlias = ctx.asIdent()?.let { visitAsIdent(it).toPigSymbolPrimitive() }, + atAlias = ctx.atIdent()?.let { visitAtIdent(it).toPigSymbolPrimitive() }, + byAlias = ctx.byIdent()?.let { visitByIdent(it).toPigSymbolPrimitive() }, metas = expr.metas ) } @@ -917,44 +946,40 @@ internal class PartiQLPigVisitor( override fun visitTableUnpivot(ctx: PartiQLParser.TableUnpivotContext) = PartiqlAst.build { val expr = visitExpr(ctx.expr()) val metas = ctx.UNPIVOT().getSourceMetaContainer() - val (asAlias, atAlias, byAlias) = visitNullableItems( - listOf(ctx.asIdent(), ctx.atIdent(), ctx.byIdent()), - PartiqlAst.Expr.Id::class - ) unpivot_( expr, - asAlias = asAlias.toPigSymbolPrimitive(), - atAlias = atAlias.toPigSymbolPrimitive(), - byAlias = byAlias.toPigSymbolPrimitive(), + asAlias = ctx.asIdent()?.let { visitAsIdent(it).toPigSymbolPrimitive() }, + atAlias = ctx.atIdent()?.let { visitAtIdent(it).toPigSymbolPrimitive() }, + byAlias = ctx.byIdent()?.let { visitByIdent(it).toPigSymbolPrimitive() }, metas ) } override fun visitTableCrossJoin(ctx: PartiQLParser.TableCrossJoinContext) = PartiqlAst.build { - val lhs = visit(ctx.lhs, PartiqlAst.FromSource::class) + val lhs = visit(ctx.lhs) as PartiqlAst.FromSource val joinType = visitJoinType(ctx.joinType()) - val rhs = visit(ctx.rhs, PartiqlAst.FromSource::class) + val rhs = visit(ctx.rhs) as PartiqlAst.FromSource val metas = metaContainerOf(IsImplictJoinMeta.instance) + joinType.metas join(joinType, lhs, rhs, metas = metas) } override fun visitTableQualifiedJoin(ctx: PartiQLParser.TableQualifiedJoinContext) = PartiqlAst.build { - val lhs = visit(ctx.lhs, PartiqlAst.FromSource::class) + val lhs = visit(ctx.lhs) as PartiqlAst.FromSource val joinType = visitJoinType(ctx.joinType()) - val rhs = visit(ctx.rhs, PartiqlAst.FromSource::class) - val condition = visitOrNull(ctx.joinSpec(), PartiqlAst.Expr::class) + val rhs = visit(ctx.rhs) as PartiqlAst.FromSource + val condition = ctx.joinSpec()?.let { visitJoinSpec(it) } join(joinType, lhs, rhs, condition, metas = joinType.metas) } override fun visitTableBaseRefSymbol(ctx: PartiQLParser.TableBaseRefSymbolContext) = PartiqlAst.build { - val expr = visit(ctx.source, PartiqlAst.Expr::class) - val name = visitOrNull(ctx.symbolPrimitive(), PartiqlAst.Expr.Id::class) - scan_(expr, name.toPigSymbolPrimitive(), metas = expr.metas) + val expr = visit(ctx.source) as PartiqlAst.Expr + val name = ctx.symbolPrimitive()?.let { visitSymbolPrimitive(it).toPigSymbolPrimitive() } + scan_(expr, name, metas = expr.metas) } override fun visitFromClauseSimpleImplicit(ctx: PartiQLParser.FromClauseSimpleImplicitContext) = PartiqlAst.build { val path = visitPathSimple(ctx.pathSimple()) - val name = visitOrNull(ctx.symbolPrimitive(), PartiqlAst.Expr.Id::class)?.name + val name = ctx.symbolPrimitive()?.let { visitSymbolPrimitive(it).name } scan_(path, name, metas = path.metas) } @@ -977,7 +1002,7 @@ internal class PartiQLPigVisitor( } override fun visitJoinRhsTableJoined(ctx: PartiQLParser.JoinRhsTableJoinedContext) = - visit(ctx.tableReference(), PartiqlAst.FromSource::class) + visit(ctx.tableReference()) as PartiqlAst.FromSource /** * SIMPLE EXPRESSIONS @@ -1022,9 +1047,9 @@ internal class PartiQLPigVisitor( possibleRhs else list(possibleRhs, metas = possibleRhs.metas + metaContainerOf(IsListParenthesizedMeta)) } else { - visit(ctx.rhs, PartiqlAst.Expr::class) + visit(ctx.rhs) as PartiqlAst.Expr } - val lhs = visit(ctx.lhs, PartiqlAst.Expr::class) + val lhs = visit(ctx.lhs) as PartiqlAst.Expr val args = listOf(lhs, rhs) val inCollection = inCollection(args, ctx.IN().getSourceMetaContainer()) if (ctx.NOT() == null) return@build inCollection @@ -1032,23 +1057,23 @@ internal class PartiQLPigVisitor( } override fun visitPredicateIs(ctx: PartiQLParser.PredicateIsContext) = PartiqlAst.build { - val lhs = visit(ctx.lhs, PartiqlAst.Expr::class) - val rhs = visit(ctx.type(), PartiqlAst.Type::class) + val lhs = visit(ctx.lhs) as PartiqlAst.Expr + val rhs = visit(ctx.type()) as PartiqlAst.Type val isType = isType(lhs, rhs, ctx.IS().getSourceMetaContainer()) if (ctx.NOT() == null) return@build isType not(isType, ctx.NOT().getSourceMetaContainer() + metaContainerOf(LegacyLogicalNotMeta.instance)) } override fun visitPredicateBetween(ctx: PartiQLParser.PredicateBetweenContext) = PartiqlAst.build { - val args = visitOrEmpty(listOf(ctx.lhs, ctx.lower, ctx.upper), PartiqlAst.Expr::class) + val args = listOf(ctx.lhs, ctx.lower, ctx.upper).map { visit(it) as PartiqlAst.Expr } val between = between(args[0], args[1], args[2], ctx.BETWEEN().getSourceMetaContainer()) if (ctx.NOT() == null) return@build between not(between, ctx.NOT().getSourceMetaContainer() + metaContainerOf(LegacyLogicalNotMeta.instance)) } override fun visitPredicateLike(ctx: PartiQLParser.PredicateLikeContext) = PartiqlAst.build { - val args = visitOrEmpty(listOf(ctx.lhs, ctx.rhs), PartiqlAst.Expr::class) - val escape = visitOrNull(ctx.escape, PartiqlAst.Expr::class) + val args = listOf(ctx.lhs, ctx.rhs).map { visit(it) as PartiqlAst.Expr } + val escape = ctx.escape?.let { visitExpr(it) } val like = like(args[0], args[1], escape, ctx.LIKE().getSourceMetaContainer()) if (ctx.NOT() == null) return@build like not(like, metas = ctx.NOT().getSourceMetaContainer() + metaContainerOf(LegacyLogicalNotMeta.instance)) @@ -1061,7 +1086,7 @@ internal class PartiQLPigVisitor( */ override fun visitExprTermWrappedQuery(ctx: PartiQLParser.ExprTermWrappedQueryContext) = - visit(ctx.expr(), PartiqlAst.Expr::class) + visitExpr(ctx.expr()) override fun visitVariableIdentifier(ctx: PartiQLParser.VariableIdentifierContext): PartiqlAst.PartiqlAstNode = PartiqlAst.build { @@ -1086,7 +1111,7 @@ internal class PartiQLPigVisitor( } override fun visitSequenceConstructor(ctx: PartiQLParser.SequenceConstructorContext) = PartiqlAst.build { - val expressions = visitOrEmpty(ctx.expr(), PartiqlAst.Expr::class) + val expressions = ctx.expr().map { visitExpr(it) } val metas = ctx.datatype.getSourceMetaContainer() when (ctx.datatype.type) { PartiQLParser.LIST -> list(expressions, metas) @@ -1102,7 +1127,7 @@ internal class PartiQLPigVisitor( } override fun visitPathStepIndexExpr(ctx: PartiQLParser.PathStepIndexExprContext) = PartiqlAst.build { - val expr = visit(ctx.key, PartiqlAst.Expr::class) + val expr = visitExpr(ctx.key) val metas = expr.metas + metaContainerOf(IsPathIndexMeta.instance) pathExpr(expr, PartiqlAst.CaseSensitivity.CaseSensitive(), metas) } @@ -1130,17 +1155,17 @@ internal class PartiQLPigVisitor( } override fun visitValues(ctx: PartiQLParser.ValuesContext) = PartiqlAst.build { - val rows = visitOrEmpty(ctx.valueRow(), PartiqlAst.Expr.List::class) + val rows = ctx.valueRow().map { visitValueRow(it) } bag(rows, ctx.VALUES().getSourceMetaContainer() + metaContainerOf(IsValuesExprMeta.instance)) } override fun visitValueRow(ctx: PartiQLParser.ValueRowContext) = PartiqlAst.build { - val expressions = visitOrEmpty(ctx.expr(), PartiqlAst.Expr::class) + val expressions = ctx.expr().map { visitExpr(it) } list(expressions, metas = ctx.PAREN_LEFT().getSourceMetaContainer() + metaContainerOf(IsListParenthesizedMeta)) } override fun visitValueList(ctx: PartiQLParser.ValueListContext) = PartiqlAst.build { - val expressions = visitOrEmpty(ctx.expr(), PartiqlAst.Expr::class) + val expressions = ctx.expr().map { visitExpr(it) } list(expressions, metas = ctx.PAREN_LEFT().getSourceMetaContainer() + metaContainerOf(IsListParenthesizedMeta)) } @@ -1158,7 +1183,7 @@ internal class PartiQLPigVisitor( } override fun visitCoalesce(ctx: PartiQLParser.CoalesceContext) = PartiqlAst.build { - val expressions = visitOrEmpty(ctx.expr(), PartiqlAst.Expr::class) + val expressions = ctx.expr().map { visitExpr(it) } val metas = ctx.COALESCE().getSourceMetaContainer() coalesce(expressions, metas) } @@ -1167,7 +1192,7 @@ internal class PartiQLPigVisitor( val pairs = ctx.whens.indices.map { i -> exprPair(visitExpr(ctx.whens[i]), visitExpr(ctx.thens[i])) } - val elseExpr = visitOrNull(ctx.else_, PartiqlAst.Expr::class) + val elseExpr = ctx.else_?.let { visitExpr(it) } val caseMeta = ctx.CASE().getSourceMetaContainer() when (ctx.case_) { null -> searchedCase(exprPairList(pairs), elseExpr, metas = caseMeta) @@ -1177,35 +1202,35 @@ internal class PartiQLPigVisitor( override fun visitCast(ctx: PartiQLParser.CastContext) = PartiqlAst.build { val expr = visitExpr(ctx.expr()) - val type = visit(ctx.type(), PartiqlAst.Type::class) + val type = visit(ctx.type()) as PartiqlAst.Type val metas = ctx.CAST().getSourceMetaContainer() cast(expr, type, metas) } override fun visitCanCast(ctx: PartiQLParser.CanCastContext) = PartiqlAst.build { val expr = visitExpr(ctx.expr()) - val type = visit(ctx.type(), PartiqlAst.Type::class) + val type = visit(ctx.type()) as PartiqlAst.Type val metas = ctx.CAN_CAST().getSourceMetaContainer() canCast(expr, type, metas) } override fun visitCanLosslessCast(ctx: PartiQLParser.CanLosslessCastContext) = PartiqlAst.build { val expr = visitExpr(ctx.expr()) - val type = visit(ctx.type(), PartiqlAst.Type::class) + val type = visit(ctx.type()) as PartiqlAst.Type val metas = ctx.CAN_LOSSLESS_CAST().getSourceMetaContainer() canLosslessCast(expr, type, metas) } override fun visitFunctionCallIdent(ctx: PartiQLParser.FunctionCallIdentContext) = PartiqlAst.build { val name = ctx.name.getString().lowercase() - val args = visitOrEmpty(ctx.expr(), PartiqlAst.Expr::class) + val args = ctx.expr().map { visitExpr(it) } val metas = ctx.name.getSourceMetaContainer() call(name, args = args, metas = metas) } override fun visitFunctionCallReserved(ctx: PartiQLParser.FunctionCallReservedContext) = PartiqlAst.build { val name = ctx.name.text.lowercase() - val args = visitOrEmpty(ctx.expr(), PartiqlAst.Expr::class) + val args = ctx.expr().map { visitExpr(it) } val metas = ctx.name.getSourceMetaContainer() call(name, args = args, metas = metas) } @@ -1215,26 +1240,26 @@ internal class PartiQLPigVisitor( throw ctx.dt.err("Expected one of: ${DateTimePart.values()}", ErrorCode.PARSE_EXPECTED_DATE_TIME_PART) } val datetimePart = lit(ionSymbol(ctx.dt.text)) - val secondaryArgs = visitOrEmpty(ctx.expr(), PartiqlAst.Expr::class) + val secondaryArgs = ctx.expr().map { visitExpr(it) } val args = listOf(datetimePart) + secondaryArgs val metas = ctx.func.getSourceMetaContainer() call(ctx.func.text.lowercase(), args, metas) } override fun visitSubstring(ctx: PartiQLParser.SubstringContext) = PartiqlAst.build { - val args = visitOrEmpty(ctx.expr(), PartiqlAst.Expr::class) + val args = ctx.expr().map { visitExpr(it) } val metas = ctx.SUBSTRING().getSourceMetaContainer() call(ctx.SUBSTRING().text.lowercase(), args, metas) } override fun visitPosition(ctx: PartiQLParser.PositionContext) = PartiqlAst.build { - val args = visitOrEmpty(ctx.expr(), PartiqlAst.Expr::class) + val args = ctx.expr().map { visitExpr(it) } val metas = ctx.POSITION().getSourceMetaContainer() call(ctx.POSITION().text.lowercase(), args, metas) } override fun visitOverlay(ctx: PartiQLParser.OverlayContext) = PartiqlAst.build { - val args = visitOrEmpty(ctx.expr(), PartiqlAst.Expr::class) + val args = ctx.expr().map { visitExpr(it) } val metas = ctx.OVERLAY().getSourceMetaContainer() call(ctx.OVERLAY().text.lowercase(), args, metas) } @@ -1254,7 +1279,7 @@ internal class PartiQLPigVisitor( .err("Expected one of: ${DateTimePart.values()}", ErrorCode.PARSE_EXPECTED_DATE_TIME_PART) } val datetimePart = lit(ionSymbol(ctx.IDENTIFIER().text)) - val timeExpr = visit(ctx.rhs, PartiqlAst.Expr::class) + val timeExpr = visitExpr(ctx.rhs) val args = listOf(datetimePart, timeExpr) val metas = ctx.EXTRACT().getSourceMetaContainer() call(ctx.EXTRACT().text.lowercase(), args, metas) @@ -1320,7 +1345,7 @@ internal class PartiQLPigVisitor( */ override fun visitLagLeadFunction(ctx: PartiQLParser.LagLeadFunctionContext) = PartiqlAst.build { - val args = visitOrEmpty(ctx.expr(), PartiqlAst.Expr::class) + val args = ctx.expr().map { visitExpr(it) } val over = visitOver(ctx.over()) // LAG and LEAD will require a Window ORDER BY if (over.orderBy == null) { @@ -1346,13 +1371,13 @@ internal class PartiQLPigVisitor( } override fun visitWindowPartitionList(ctx: PartiQLParser.WindowPartitionListContext) = PartiqlAst.build { - val args = visitOrEmpty(ctx.expr(), PartiqlAst.Expr::class) + val args = ctx.expr().map { visitExpr(it) } val metas = ctx.PARTITION().getSourceMetaContainer() windowPartitionList(args, metas) } override fun visitWindowSortSpecList(ctx: PartiQLParser.WindowSortSpecListContext) = PartiqlAst.build { - val sortSpecList = visitOrEmpty(ctx.orderSortSpec(), PartiqlAst.SortSpec::class) + val sortSpecList = ctx.orderSortSpec().map { visitOrderSortSpec(it) } val metas = ctx.ORDER().getSourceMetaContainer() windowSortSpecList(sortSpecList, metas) } @@ -1364,7 +1389,7 @@ internal class PartiQLPigVisitor( */ override fun visitBag(ctx: PartiQLParser.BagContext) = PartiqlAst.build { - val exprList = visitOrEmpty(ctx.expr(), PartiqlAst.Expr::class) + val exprList = ctx.expr().map { visitExpr(it) } bag(exprList, ctx.ANGLE_DOUBLE_LEFT().getSourceMetaContainer()) } @@ -1384,7 +1409,7 @@ internal class PartiQLPigVisitor( override fun visitArray(ctx: PartiQLParser.ArrayContext) = PartiqlAst.build { val metas = ctx.BRACKET_LEFT().getSourceMetaContainer() - list(visitOrEmpty(ctx.expr(), PartiqlAst.Expr::class), metas) + list(ctx.expr().map { visitExpr(it) }, metas) } override fun visitLiteralNull(ctx: PartiQLParser.LiteralNullContext) = PartiqlAst.build { @@ -1457,7 +1482,7 @@ internal class PartiQLPigVisitor( } override fun visitTuple(ctx: PartiQLParser.TupleContext) = PartiqlAst.build { - val pairs = visitOrEmpty(ctx.pair(), PartiqlAst.ExprPair::class) + val pairs = ctx.pair().map { visitPair(it) } val metas = ctx.BRACE_LEFT().getSourceMetaContainer() struct(pairs, metas) } @@ -1568,70 +1593,12 @@ internal class PartiQLPigVisitor( customType_(SymbolPrimitive(customName, metas), metas) } - /** - * NOT OVERRIDDEN - * Explicitly defining the override helps by showing the user (via the IDE) which methods remain to be overridden. - */ - - override fun visitTerminal(node: TerminalNode?): PartiqlAst.PartiqlAstNode = super.visitTerminal(node) - override fun shouldVisitNextChild(node: RuleNode?, currentResult: PartiqlAst.PartiqlAstNode?) = - super.shouldVisitNextChild(node, currentResult) - - override fun visitErrorNode(node: ErrorNode?): PartiqlAst.PartiqlAstNode = super.visitErrorNode(node) - override fun visitChildren(node: RuleNode?): PartiqlAst.PartiqlAstNode = super.visitChildren(node) - override fun visitExprPrimaryBase(ctx: PartiQLParser.ExprPrimaryBaseContext?): PartiqlAst.PartiqlAstNode = - super.visitExprPrimaryBase(ctx) - - override fun visitExprTermBase(ctx: PartiQLParser.ExprTermBaseContext?): PartiqlAst.PartiqlAstNode = - super.visitExprTermBase(ctx) - - override fun visitCollection(ctx: PartiQLParser.CollectionContext?): PartiqlAst.PartiqlAstNode = - super.visitCollection(ctx) - - override fun visitPredicateBase(ctx: PartiQLParser.PredicateBaseContext?): PartiqlAst.PartiqlAstNode = - super.visitPredicateBase(ctx) - - override fun visitTableNonJoin(ctx: PartiQLParser.TableNonJoinContext?): PartiqlAst.PartiqlAstNode = - super.visitTableNonJoin(ctx) - - override fun visitTableRefBase(ctx: PartiQLParser.TableRefBaseContext?): PartiqlAst.PartiqlAstNode = - super.visitTableRefBase(ctx) - - override fun visitJoinRhsBase(ctx: PartiQLParser.JoinRhsBaseContext?): PartiqlAst.PartiqlAstNode = - super.visitJoinRhsBase(ctx) - - override fun visitConflictTarget(ctx: PartiQLParser.ConflictTargetContext?): PartiqlAst.PartiqlAstNode = - super.visitConflictTarget(ctx) - /** * * HELPER METHODS * */ - private fun visitOrEmpty(ctx: List?, clazz: KClass): List = - when { - ctx.isNullOrEmpty() -> emptyList() - else -> ctx.map { clazz.cast(visit(it)) } - } - - private fun visitNullableItems( - ctx: List?, - clazz: KClass - ): List = when { - ctx.isNullOrEmpty() -> emptyList() - else -> ctx.map { visitOrNull(it, clazz) } - } - - private fun visitOrNull(ctx: ParserRuleContext?, clazz: KClass): T? = - when (ctx) { - null -> null - else -> clazz.cast(visit(ctx)) - } - - private fun visit(ctx: ParserRuleContext, clazz: KClass): T = - clazz.cast(visit(ctx)) - private fun TerminalNode?.getSourceMetaContainer(): MetaContainer { if (this == null) return emptyMetaContainer() val metas = this.getSourceMetas() @@ -1657,8 +1624,8 @@ internal class PartiQLPigVisitor( op: Token?, parent: ParserRuleContext? = null ) = PartiqlAst.build { - if (parent != null) return@build visit(parent, PartiqlAst.Expr::class) - val args = visitOrEmpty(listOf(lhs!!, rhs!!), PartiqlAst.Expr::class) + if (parent != null) return@build visit(parent) as PartiqlAst.Expr + val args = listOf(lhs!!, rhs!!).map { visit(it) as PartiqlAst.Expr } val metas = op.getSourceMetaContainer() when (op!!.type) { PartiQLParser.AND -> and(args, metas) @@ -1681,8 +1648,8 @@ internal class PartiQLPigVisitor( private fun visitUnaryOperation(operand: ParserRuleContext?, op: Token?, parent: ParserRuleContext? = null) = PartiqlAst.build { - if (parent != null) return@build visit(parent, PartiqlAst.Expr::class) - val arg = visit(operand!!, PartiqlAst.Expr::class) + if (parent != null) return@build visit(parent) as PartiqlAst.Expr + val arg = visit(operand!!) as PartiqlAst.Expr val metas = op.getSourceMetaContainer() when (op!!.type) { PartiQLParser.PLUS -> { @@ -1746,10 +1713,8 @@ internal class PartiQLPigVisitor( else -> throw token.err("Unable to get value", ErrorCode.PARSE_UNEXPECTED_TOKEN) } - private fun PartiqlAst.Expr.Id?.toPigSymbolPrimitive(): SymbolPrimitive? = when (this) { - null -> null - else -> this.name.copy(metas = this.metas) - } + private fun PartiqlAst.Expr.Id.toPigSymbolPrimitive(): SymbolPrimitive = + this.name.copy(metas = this.metas) private fun PartiqlAst.Expr.Id.toIdentifier(): PartiqlAst.Identifier { val name = this.name.text @@ -1908,12 +1873,6 @@ internal class PartiQLPigVisitor( else -> SymbolPrimitive(sym.getString(), sym.getSourceMetaContainer()) } - private fun convertProjectionItems(ctx: PartiQLParser.ProjectionItemsContext, metas: MetaContainer) = - PartiqlAst.build { - val projections = visitOrEmpty(ctx.projectionItem(), PartiqlAst.ProjectItem::class) - projectList(projections, metas) - } - private fun PartiQLParser.SelectClauseContext.getMetas(): MetaContainer = when (this) { is PartiQLParser.SelectAllContext -> this.SELECT().getSourceMetaContainer() is PartiQLParser.SelectItemsContext -> this.SELECT().getSourceMetaContainer()