diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/TransformWithStateInPandasPythonRunner.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/TransformWithStateInPandasPythonRunner.scala index 79e199f598b0c..f415ae2543d34 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/TransformWithStateInPandasPythonRunner.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/TransformWithStateInPandasPythonRunner.scala @@ -275,7 +275,7 @@ class TransformWithStateInPandasPythonPreInitRunner( override def stop(): Unit = { super.stop() closeServerSocketChannelSilently(stateServerSocket) - daemonThread.stop() + daemonThread.interrupt() } private def startStateServer(): Unit = { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/TransformWithStateInPandasStateServer.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/TransformWithStateInPandasStateServer.scala index a48d0bfd15034..fe1bbdd66ac17 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/TransformWithStateInPandasStateServer.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/TransformWithStateInPandasStateServer.scala @@ -146,6 +146,9 @@ class TransformWithStateInPandasStateServer( while (listeningSocket.isConnected && statefulProcessorHandle.getHandleState != StatefulProcessorHandleState.CLOSED) { + if (Thread.currentThread().isInterrupted) { + throw new InterruptedException("Thread was interrupted") + } try { val version = inputStream.readInt() if (version != -1) { @@ -159,6 +162,11 @@ class TransformWithStateInPandasStateServer( logWarning(log"No more data to read from the socket") statefulProcessorHandle.setHandleState(StatefulProcessorHandleState.CLOSED) return + case _: InterruptedException => + logInfo(log"Thread interrupted, shutting down state server") + Thread.currentThread().interrupt() + statefulProcessorHandle.setHandleState(StatefulProcessorHandleState.CLOSED) + return case e: Exception => logError(log"Error reading message: ${MDC(LogKeys.ERROR, e.getMessage)}", e) sendResponse(1, e.getMessage)