Skip to content

Commit

Permalink
Rebase off error-reporting changes
Browse files Browse the repository at this point in the history
  • Loading branch information
alancai98 committed Oct 28, 2024
1 parent 6ae24a7 commit da868e1
Show file tree
Hide file tree
Showing 5 changed files with 83 additions and 56 deletions.
4 changes: 0 additions & 4 deletions partiql-ast/api/partiql-ast.api
Original file line number Diff line number Diff line change
Expand Up @@ -5743,11 +5743,7 @@ public abstract interface class org/partiql/ast/v1/AstVisitor {
public abstract fun visitTableRef (Lorg/partiql/ast/v1/FromTableRef;Ljava/lang/Object;)Ljava/lang/Object;
}

<<<<<<< Updated upstream
public class org/partiql/ast/v1/DataType : org/partiql/ast/v1/AstEnum {
=======
public class org/partiql/ast/v1/DataType : org/partiql/ast/v1/AstNode, org/partiql/ast/v1/Enum {
>>>>>>> Stashed changes
public static final field BAG I
public static final field BIGINT I
public static final field BINARY_LARGE_OBJECT I
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,8 @@

import org.jetbrains.annotations.NotNull;
import org.partiql.eval.Mode;
import org.partiql.eval.internal.compiler.StandardCompiler;
import org.partiql.eval.Statement;
import org.partiql.eval.internal.compiler.StandardCompiler;
import org.partiql.plan.Plan;
import org.partiql.spi.Context;

Expand Down
9 changes: 9 additions & 0 deletions partiql-parser/api/partiql-parser.api
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,7 @@ public abstract interface class org/partiql/parser/V1PartiQLParser {
public static final field Companion Lorg/partiql/parser/V1PartiQLParser$Companion;
public static fun builder ()Lorg/partiql/parser/V1PartiQLParserBuilder;
public abstract fun parse (Ljava/lang/String;)Lorg/partiql/parser/V1PartiQLParser$Result;
public abstract fun parse (Ljava/lang/String;Lorg/partiql/spi/Context;)Lorg/partiql/parser/V1PartiQLParser$Result;
public static fun standard ()Lorg/partiql/parser/V1PartiQLParser;
}

Expand All @@ -116,7 +117,12 @@ public final class org/partiql/parser/V1PartiQLParser$Companion {
public final fun standard ()Lorg/partiql/parser/V1PartiQLParser;
}

public final class org/partiql/parser/V1PartiQLParser$DefaultImpls {
public static fun parse (Lorg/partiql/parser/V1PartiQLParser;Ljava/lang/String;)Lorg/partiql/parser/V1PartiQLParser$Result;
}

public final class org/partiql/parser/V1PartiQLParser$Result {
public static final field Companion Lorg/partiql/parser/V1PartiQLParser$Result$Companion;
public fun <init> (Ljava/lang/String;Lorg/partiql/ast/v1/Statement;Lorg/partiql/parser/SourceLocations;)V
public final fun component1 ()Ljava/lang/String;
public final fun component2 ()Lorg/partiql/ast/v1/Statement;
Expand All @@ -131,6 +137,9 @@ public final class org/partiql/parser/V1PartiQLParser$Result {
public fun toString ()Ljava/lang/String;
}

public final class org/partiql/parser/V1PartiQLParser$Result$Companion {
}

public final class org/partiql/parser/V1PartiQLParserBuilder {
public fun <init> ()V
public final fun build ()Lorg/partiql/parser/V1PartiQLParser;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,19 +14,51 @@

package org.partiql.parser

import org.partiql.ast.v1.Query
import org.partiql.ast.v1.Statement
import org.partiql.ast.v1.expr.ExprLit
import org.partiql.parser.internal.V1PartiQLParserDefault
import org.partiql.spi.Context
import org.partiql.spi.errors.PErrorListenerException
import org.partiql.value.PartiQLValueExperimental
import org.partiql.value.nullValue

public interface V1PartiQLParser {

@Throws(PartiQLSyntaxException::class, InterruptedException::class)
public fun parse(source: String): Result
/**
* Parses the [source] into an AST.
* @param source the user's input
* @param ctx a configuration object for the parser
* @throws PErrorListenerException when the [org.partiql.spi.errors.PErrorListener] defined in the [ctx] throws an
* [PErrorListenerException], this method halts execution and propagates the exception.
*/
@Throws(PErrorListenerException::class)
public fun parse(source: String, ctx: Context): Result

/**
* Parses the [source] into an AST.
* @param source the user's input
* @throws PErrorListenerException when the [org.partiql.spi.errors.PErrorListener] defined in the context throws an
* [PErrorListenerException], this method halts execution and propagates the exception.
*/
@Throws(PErrorListenerException::class)
public fun parse(source: String): Result {
return parse(source, Context.standard())
}

public data class Result(
val source: String,
val root: Statement,
val locations: SourceLocations,
)
) {
public companion object {
@OptIn(PartiQLValueExperimental::class)
internal fun empty(source: String): Result {
val locations = SourceLocations.Mutable().toMap()
return Result(source, Query(ExprLit(nullValue())), locations)
}
}
}

public companion object {

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -154,13 +154,17 @@ import org.partiql.ast.v1.graph.GraphRestrictor
import org.partiql.ast.v1.graph.GraphSelector
import org.partiql.parser.PartiQLLexerException
import org.partiql.parser.PartiQLParserException
import org.partiql.parser.PartiQLSyntaxException
import org.partiql.parser.SourceLocation
import org.partiql.parser.SourceLocations
import org.partiql.parser.V1PartiQLParser
import org.partiql.parser.internal.antlr.PartiQLParser
import org.partiql.parser.internal.antlr.PartiQLParserBaseVisitor
import org.partiql.parser.internal.util.DateTimeUtils
import org.partiql.spi.Context
import org.partiql.spi.errors.PError
import org.partiql.spi.errors.PErrorKind
import org.partiql.spi.errors.PErrorListener
import org.partiql.spi.errors.PErrorListenerException
import org.partiql.value.PartiQLValueExperimental
import org.partiql.value.boolValue
import org.partiql.value.dateValue
Expand Down Expand Up @@ -205,12 +209,16 @@ import org.partiql.parser.internal.antlr.PartiQLTokens as GeneratedLexer
*/
internal class V1PartiQLParserDefault : V1PartiQLParser {

@Throws(PartiQLSyntaxException::class, InterruptedException::class)
override fun parse(source: String): V1PartiQLParser.Result {
@Throws(PErrorListenerException::class)
override fun parse(source: String, ctx: Context): V1PartiQLParser.Result {
try {
return V1PartiQLParserDefault.parse(source)
return parse(source, ctx.errorListener)
} catch (e: PErrorListenerException) {
throw e
} catch (throwable: Throwable) {
throw PartiQLSyntaxException.wrap(throwable)
val error = PError.INTERNAL_ERROR(PErrorKind.SYNTAX(), null, throwable)
ctx.errorListener.report(error)
return V1PartiQLParser.Result.empty(source)
}
}

Expand All @@ -220,38 +228,38 @@ internal class V1PartiQLParserDefault : V1PartiQLParser {
* To reduce latency costs, the [V1PartiQLParserDefault] attempts to use [PredictionMode.SLL] and falls back to
* [PredictionMode.LL] if a [ParseCancellationException] is thrown by the [BailErrorStrategy].
*/
private fun parse(source: String): V1PartiQLParser.Result = try {
parse(source, PredictionMode.SLL)
private fun parse(source: String, listener: PErrorListener): V1PartiQLParser.Result = try {
parse(source, PredictionMode.SLL, listener)
} catch (ex: ParseCancellationException) {
parse(source, PredictionMode.LL)
parse(source, PredictionMode.LL, listener)
}

/**
* Parses an input string [source] using the given prediction mode.
*/
private fun parse(source: String, mode: PredictionMode): V1PartiQLParser.Result {
val tokens = createTokenStream(source)
private fun parse(source: String, mode: PredictionMode, listener: PErrorListener): V1PartiQLParser.Result {
val tokens = createTokenStream(source, listener)
val parser = InterruptibleParser(tokens)
parser.reset()
parser.removeErrorListeners()
parser.interpreter.predictionMode = mode
when (mode) {
PredictionMode.SLL -> parser.errorHandler = BailErrorStrategy()
PredictionMode.LL -> parser.addErrorListener(ParseErrorListener())
PredictionMode.LL -> parser.addErrorListener(ParseErrorListener(listener))
else -> throw IllegalArgumentException("Unsupported parser mode: $mode")
}
val tree = parser.root()
return Visitor.translate(source, tokens, tree)
}

private fun createTokenStream(source: String): CountingTokenStream {
private fun createTokenStream(source: String, listener: PErrorListener): CountingTokenStream {
val queryStream = source.byteInputStream(StandardCharsets.UTF_8)
val inputStream = try {
CharStreams.fromStream(queryStream)
} catch (ex: ClosedByInterruptException) {
throw InterruptedException()
}
val handler = TokenizeErrorListener()
val handler = TokenizeErrorListener(listener)
val lexer = GeneratedLexer(inputStream)
lexer.removeErrorListeners()
lexer.addErrorListener(handler)
Expand All @@ -262,7 +270,7 @@ internal class V1PartiQLParserDefault : V1PartiQLParser {
/**
* Catches Lexical errors (unidentified tokens) and throws a [PartiQLParserException]
*/
private class TokenizeErrorListener : BaseErrorListener() {
private class TokenizeErrorListener(private val listener: PErrorListener) : BaseErrorListener() {
@Throws(PartiQLParserException::class)
override fun syntaxError(
recognizer: Recognizer<*, *>?,
Expand All @@ -274,19 +282,9 @@ internal class V1PartiQLParserDefault : V1PartiQLParser {
) {
if (offendingSymbol is Token) {
val token = offendingSymbol.text
val tokenType = GeneratedParser.VOCABULARY.getSymbolicName(offendingSymbol.type)
throw PartiQLLexerException(
token = token,
tokenType = tokenType,
message = msg,
cause = e,
location = SourceLocation(
line = line,
offset = charPositionInLine + 1,
length = token.length,
lengthLegacy = token.length,
),
)
val location = org.partiql.spi.SourceLocation(line.toLong(), charPositionInLine + 1L, token.length.toLong())
val error = PErrors.unrecognizedToken(location, token)
listener.report(error)
} else {
throw IllegalArgumentException("Offending symbol is not a Token.")
}
Expand All @@ -296,7 +294,7 @@ internal class V1PartiQLParserDefault : V1PartiQLParser {
/**
* Catches Parser errors (malformed syntax) and throws a [PartiQLParserException]
*/
private class ParseErrorListener : BaseErrorListener() {
private class ParseErrorListener(private val listener: PErrorListener) : BaseErrorListener() {

private val rules = GeneratedParser.ruleNames.asList()

Expand All @@ -310,22 +308,12 @@ internal class V1PartiQLParserDefault : V1PartiQLParser {
e: RecognitionException?,
) {
if (offendingSymbol is Token) {
val rule = e?.ctx?.toString(rules) ?: "UNKNOWN"
val rule = e?.ctx?.toString(rules) ?: "UNKNOWN" // TODO: Do we want to display the offending rule?
val token = offendingSymbol.text
val tokenType = GeneratedParser.VOCABULARY.getSymbolicName(offendingSymbol.type)
throw PartiQLParserException(
rule = rule,
token = token,
tokenType = tokenType,
message = msg,
cause = e,
location = SourceLocation(
line = line,
offset = charPositionInLine + 1,
length = msg.length,
lengthLegacy = offendingSymbol.text.length,
),
)
val location = org.partiql.spi.SourceLocation(line.toLong(), charPositionInLine + 1L, token.length.toLong())
val error = PErrors.unexpectedToken(location, tokenType, null)
listener.report(error)
} else {
throw IllegalArgumentException("Offending symbol is not a Token.")
}
Expand Down Expand Up @@ -1631,9 +1619,9 @@ internal class V1PartiQLParserDefault : V1PartiQLParser {

override fun visitDateFunction(ctx: GeneratedParser.DateFunctionContext) = translate(ctx) {
try {
DatetimeField.valueOf(ctx.dt.text)
DatetimeField.parse(ctx.dt.text)
} catch (ex: IllegalArgumentException) {
throw error(ctx.dt, "Expected one of: ${DatetimeField.values().joinToString()}", ex)
throw error(ctx.dt, "Expected one of: ${DatetimeField.codes().joinToString()}", ex)
}
val lhs = visitExpr(ctx.expr(0))
val rhs = visitExpr(ctx.expr(1))
Expand Down Expand Up @@ -1709,9 +1697,11 @@ internal class V1PartiQLParserDefault : V1PartiQLParser {

override fun visitExtract(ctx: GeneratedParser.ExtractContext) = translate(ctx) {
val field = try {
DatetimeField.valueOf(ctx.IDENTIFIER().text.uppercase())
DatetimeField.parse(ctx.IDENTIFIER().text.uppercase())
} catch (ex: IllegalArgumentException) {
throw error(ctx.IDENTIFIER().symbol, "Expected one of: ${DatetimeField.values().joinToString()}", ex)
// TODO decide if we want int codes here or actual text. If we want text here, then there should be a
// method to convert the code into text.
throw error(ctx.IDENTIFIER().symbol, "Expected one of: ${DatetimeField.codes().joinToString()}", ex)
}
val source = visitExpr(ctx.expr())
exprExtract(field, source)
Expand All @@ -1720,9 +1710,9 @@ internal class V1PartiQLParserDefault : V1PartiQLParser {
override fun visitTrimFunction(ctx: GeneratedParser.TrimFunctionContext) = translate(ctx) {
val spec = ctx.mod?.let {
try {
TrimSpec.valueOf(it.text.uppercase())
TrimSpec.parse(it.text.uppercase())
} catch (ex: IllegalArgumentException) {
throw error(it, "Expected on of: ${TrimSpec.values().joinToString()}", ex)
throw error(it, "Expected on of: ${TrimSpec.codes().joinToString()}", ex)
}
}
val (chars, value) = when (ctx.expr().size) {
Expand Down

0 comments on commit da868e1

Please sign in to comment.