Skip to content

Commit

Permalink
Support UnrecoverableException (#898)
Browse files Browse the repository at this point in the history
* Support UnrecoverableException

Signed-off-by: Louis Chu <[email protected]>

* Add UT and IT

Signed-off-by: Louis Chu <[email protected]>

---------

Signed-off-by: Louis Chu <[email protected]>
  • Loading branch information
noCharger authored Nov 15, 2024
1 parent f8a7501 commit 4f58bc8
Show file tree
Hide file tree
Showing 8 changed files with 396 additions and 88 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/

package org.apache.spark.sql.exception

/**
* Represents an unrecoverable exception in session management and statement execution. This
* exception is used for errors that cannot be handled or recovered from.
*/
class UnrecoverableException private (message: String, cause: Throwable)
extends RuntimeException(message, cause) {

def this(cause: Throwable) =
this(cause.getMessage, cause)
}

object UnrecoverableException {
def apply(cause: Throwable): UnrecoverableException =
new UnrecoverableException(cause)

def apply(message: String, cause: Throwable): UnrecoverableException =
new UnrecoverableException(message, cause)
}
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ class InteractiveSession(
val lastUpdateTime: Long,
val jobStartTime: Long = 0,
val excludedJobIds: Seq[String] = Seq.empty[String],
val error: Option[String] = None,
var error: Option[String] = None,
sessionContext: Map[String, Any] = Map.empty[String, Any])
extends ContextualDataStore
with Logging {
Expand All @@ -72,7 +72,7 @@ class InteractiveSession(
val excludedJobIdsStr = excludedJobIds.mkString("[", ", ", "]")
val errorStr = error.getOrElse("None")
// Does not include context, which could contain sensitive information.
s"FlintInstance(applicationId=$applicationId, jobId=$jobId, sessionId=$sessionId, state=$state, " +
s"InteractiveSession(applicationId=$applicationId, jobId=$jobId, sessionId=$sessionId, state=$state, " +
s"lastUpdateTime=$lastUpdateTime, jobStartTime=$jobStartTime, excludedJobIds=$excludedJobIdsStr, error=$errorStr)"
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,15 +17,49 @@ import org.opensearch.OpenSearchStatusException
import org.opensearch.flint.OpenSearchSuite
import org.opensearch.flint.common.model.{FlintStatement, InteractiveSession}
import org.opensearch.flint.core.{FlintClient, FlintOptions}
import org.opensearch.flint.core.storage.{FlintOpenSearchClient, FlintReader, OpenSearchUpdater}
import org.opensearch.search.sort.SortOrder
import org.opensearch.flint.core.storage.{FlintOpenSearchClient, OpenSearchUpdater}

import org.apache.spark.SparkFunSuite
import org.apache.spark.sql.FlintREPLConfConstants.DEFAULT_QUERY_LOOP_EXECUTION_FREQUENCY
import org.apache.spark.sql.flint.config.FlintSparkConf.{DATA_SOURCE_NAME, EXCLUDE_JOB_IDS, HOST_ENDPOINT, HOST_PORT, JOB_TYPE, REFRESH_POLICY, REPL_INACTIVITY_TIMEOUT_MILLIS, REQUEST_INDEX, SESSION_ID}
import org.apache.spark.sql.exception.UnrecoverableException
import org.apache.spark.sql.flint.config.FlintSparkConf.{CUSTOM_STATEMENT_MANAGER, DATA_SOURCE_NAME, EXCLUDE_JOB_IDS, HOST_ENDPOINT, HOST_PORT, JOB_TYPE, REFRESH_POLICY, REPL_INACTIVITY_TIMEOUT_MILLIS, REQUEST_INDEX, SESSION_ID}
import org.apache.spark.sql.util.MockEnvironment
import org.apache.spark.util.ThreadUtils

/**
* A StatementExecutionManagerImpl that throws UnrecoverableException during statement execution.
* Used for testing error handling in FlintREPL.
*/
class FailingStatementExecutionManager(
private var spark: SparkSession,
private var sessionId: String)
extends StatementExecutionManager {

def this() = {
this(null, null)
}

override def prepareStatementExecution(): Either[String, Unit] = {
throw UnrecoverableException(new RuntimeException("Simulated execution failure"))
}

override def executeStatement(statement: FlintStatement): DataFrame = {
throw UnrecoverableException(new RuntimeException("Simulated execution failure"))
}

override def getNextStatement(): Option[FlintStatement] = {
throw UnrecoverableException(new RuntimeException("Simulated execution failure"))
}

override def updateStatement(statement: FlintStatement): Unit = {
throw UnrecoverableException(new RuntimeException("Simulated execution failure"))
}

override def terminateStatementExecution(): Unit = {
throw UnrecoverableException(new RuntimeException("Simulated execution failure"))
}
}

class FlintREPLITSuite extends SparkFunSuite with OpenSearchSuite with JobTest {

var flintClient: FlintClient = _
Expand Down Expand Up @@ -584,6 +618,27 @@ class FlintREPLITSuite extends SparkFunSuite with OpenSearchSuite with JobTest {
}
}

test("REPL should handle unrecoverable exception from statement execution") {
// Note: This test sharing system property with other test cases so cannot run alone
System.setProperty(
CUSTOM_STATEMENT_MANAGER.key,
"org.apache.spark.sql.FailingStatementExecutionManager")
try {
createSession(jobRunId, "")
FlintREPL.main(Array(resultIndex))
fail("The REPL should throw an unrecoverable exception, but it succeeded instead.")
} catch {
case ex: UnrecoverableException =>
assert(
ex.getMessage.contains("Simulated execution failure"),
s"Unexpected exception message: ${ex.getMessage}")
case ex: Throwable =>
fail(s"Unexpected exception type: ${ex.getClass} with message: ${ex.getMessage}")
} finally {
System.setProperty(CUSTOM_STATEMENT_MANAGER.key, "")
}
}

/**
* JSON does not support raw newlines (\n) in string values. All newlines must be escaped or
* removed when inside a JSON string. The same goes for tab characters, which should be
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ import org.apache.spark.{SparkConf, SparkException}
import org.apache.spark.internal.Logging
import org.apache.spark.sql.SparkConfConstants.{DEFAULT_SQL_EXTENSIONS, SQL_EXTENSIONS_KEY}
import org.apache.spark.sql.catalyst.parser.ParseException
import org.apache.spark.sql.exception.UnrecoverableException
import org.apache.spark.sql.flint.config.FlintSparkConf
import org.apache.spark.sql.flint.config.FlintSparkConf.REFRESH_POLICY
import org.apache.spark.sql.types._
Expand All @@ -44,12 +45,13 @@ trait FlintJobExecutor {
this: Logging =>

val mapper = new ObjectMapper()
val throwableHandler = new ThrowableHandler()

var currentTimeProvider: TimeProvider = new RealTimeProvider()
var threadPoolFactory: ThreadPoolFactory = new DefaultThreadPoolFactory()
var environmentProvider: EnvironmentProvider = new RealEnvironment()
var enableHiveSupport: Boolean = true
// termiante JVM in the presence non-deamon thread before exiting
// terminate JVM in the presence non-daemon thread before exiting
var terminateJVM = true

// The enabled setting, which can be applied only to the top-level mapping definition and to object fields,
Expand Down Expand Up @@ -435,11 +437,13 @@ trait FlintJobExecutor {
}

private def handleQueryException(
e: Exception,
t: Throwable,
messagePrefix: String,
errorSource: Option[String] = None,
statusCode: Option[Int] = None): String = {
val errorMessage = s"$messagePrefix: ${e.getMessage}"
throwableHandler.setThrowable(t)

val errorMessage = s"$messagePrefix: ${t.getMessage}"
val errorDetails = new java.util.LinkedHashMap[String, String]()
errorDetails.put("Message", errorMessage)
errorSource.foreach(es => errorDetails.put("ErrorSource", es))
Expand All @@ -450,25 +454,25 @@ trait FlintJobExecutor {
// CustomLogging will call log4j logger.error() underneath
statusCode match {
case Some(code) =>
CustomLogging.logError(new OperationMessage(errorMessage, code), e)
CustomLogging.logError(new OperationMessage(errorMessage, code), t)
case None =>
CustomLogging.logError(errorMessage, e)
CustomLogging.logError(errorMessage, t)
}

errorJson
}

def getRootCause(e: Throwable): Throwable = {
if (e.getCause == null) e
else getRootCause(e.getCause)
def getRootCause(t: Throwable): Throwable = {
if (t.getCause == null) t
else getRootCause(t.getCause)
}

/**
* This method converts query exception into error string, which then persist to query result
* metadata
*/
def processQueryException(ex: Exception): String = {
getRootCause(ex) match {
def processQueryException(throwable: Throwable): String = {
getRootCause(throwable) match {
case r: ParseException =>
handleQueryException(r, ExceptionMessages.SyntaxErrorPrefix)
case r: AmazonS3Exception =>
Expand All @@ -495,15 +499,15 @@ trait FlintJobExecutor {
handleQueryException(r, ExceptionMessages.QueryAnalysisErrorPrefix)
case r: SparkException =>
handleQueryException(r, ExceptionMessages.SparkExceptionErrorPrefix)
case r: Exception =>
val rootCauseClassName = r.getClass.getName
val errMsg = r.getMessage
case t: Throwable =>
val rootCauseClassName = t.getClass.getName
val errMsg = t.getMessage
if (rootCauseClassName == "org.apache.hadoop.hive.metastore.api.MetaException" &&
errMsg.contains("com.amazonaws.services.glue.model.AccessDeniedException")) {
val e = new SecurityException(ExceptionMessages.GlueAccessDeniedMessage)
handleQueryException(e, ExceptionMessages.QueryRunErrorPrefix)
} else {
handleQueryException(r, ExceptionMessages.QueryRunErrorPrefix)
handleQueryException(t, ExceptionMessages.QueryRunErrorPrefix)
}
}
}
Expand Down Expand Up @@ -532,6 +536,14 @@ trait FlintJobExecutor {
throw t
}

def checkAndThrowUnrecoverableExceptions(): Unit = {
throwableHandler.exceptionThrown.foreach {
case e: UnrecoverableException =>
throw e
case _ => // Do nothing for other types of exceptions
}
}

def instantiate[T](defaultConstructor: => T, className: String, args: Any*): T = {
if (Strings.isNullOrEmpty(className)) {
defaultConstructor
Expand All @@ -551,5 +563,4 @@ trait FlintJobExecutor {
}
}
}

}
Loading

0 comments on commit 4f58bc8

Please sign in to comment.