Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SPARK-50909][PYTHON][4.0] Setup faulthandler in PythonPlannerRunners #49635

Open
wants to merge 2 commits into
base: branch-4.0
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ import org.apache.spark._
import org.apache.spark.api.python.PythonFunction.PythonAccumulator
import org.apache.spark.internal.{Logging, LogKeys, MDC}
import org.apache.spark.internal.LogKeys.TASK_NAME
import org.apache.spark.internal.config.{BUFFER_SIZE, EXECUTOR_CORES, Python}
import org.apache.spark.internal.config.{BUFFER_SIZE, EXECUTOR_CORES}
import org.apache.spark.internal.config.Python._
import org.apache.spark.rdd.InputFileBlockHolder
import org.apache.spark.resource.ResourceProfile.{EXECUTOR_CORES_LOCAL_PROPERTY, PYSPARK_MEMORY_LOCAL_PROPERTY}
Expand Down Expand Up @@ -90,11 +90,11 @@ private[spark] object PythonEvalType {
}
}

private object BasePythonRunner {
private[spark] object BasePythonRunner {

private lazy val faultHandlerLogDir = Utils.createTempDir(namePrefix = "faulthandler")
private[spark] lazy val faultHandlerLogDir = Utils.createTempDir(namePrefix = "faulthandler")

private def faultHandlerLogPath(pid: Int): Path = {
private[spark] def faultHandlerLogPath(pid: Int): Path = {
new File(faultHandlerLogDir, pid.toString).toPath
}
}
Expand Down Expand Up @@ -574,15 +574,15 @@ private[spark] abstract class BasePythonRunner[IN, OUT](
JavaFiles.deleteIfExists(path)
throw new SparkException(s"Python worker exited unexpectedly (crashed): $error", e)

case eof: EOFException if !faultHandlerEnabled =>
case e: IOException if !faultHandlerEnabled =>
throw new SparkException(
s"Python worker exited unexpectedly (crashed). " +
"Consider setting 'spark.sql.execution.pyspark.udf.faulthandler.enabled' or" +
s"'${Python.PYTHON_WORKER_FAULTHANLDER_ENABLED.key}' configuration to 'true' for" +
"the better Python traceback.", eof)
s"'${PYTHON_WORKER_FAULTHANLDER_ENABLED.key}' configuration to 'true' for " +
"the better Python traceback.", e)

case eof: EOFException =>
throw new SparkException("Python worker exited unexpectedly (crashed)", eof)
case e: IOException =>
throw new SparkException("Python worker exited unexpectedly (crashed)", e)
}
}

Expand Down
120 changes: 120 additions & 0 deletions python/pyspark/sql/tests/test_python_datasource.py
Original file line number Diff line number Diff line change
Expand Up @@ -508,6 +508,126 @@ def write(self, iterator):
):
df.write.format("test").mode("append").saveAsTable("test_table")

def test_data_source_segfault(self):
import ctypes

for enabled, expected in [
(True, "Segmentation fault"),
(False, "Consider setting .* for the better Python traceback."),
]:
with self.subTest(enabled=enabled), self.sql_conf(
{"spark.sql.execution.pyspark.udf.faulthandler.enabled": enabled}
):
with self.subTest(worker="pyspark.sql.worker.create_data_source"):

class TestDataSource(DataSource):
@classmethod
def name(cls):
return "test"

def schema(self):
return ctypes.string_at(0)

self.spark.dataSource.register(TestDataSource)

with self.assertRaisesRegex(Exception, expected):
self.spark.read.format("test").load().show()

with self.subTest(worker="pyspark.sql.worker.plan_data_source_read"):

class TestDataSource(DataSource):
@classmethod
def name(cls):
return "test"

def schema(self):
return "x string"

def reader(self, schema):
return TestReader()

class TestReader(DataSourceReader):
def partitions(self):
ctypes.string_at(0)
return []

def read(self, partition):
return []

self.spark.dataSource.register(TestDataSource)

with self.assertRaisesRegex(Exception, expected):
self.spark.read.format("test").load().show()

with self.subTest(worker="pyspark.worker"):

class TestDataSource(DataSource):
@classmethod
def name(cls):
return "test"

def schema(self):
return "x string"

def reader(self, schema):
return TestReader()

class TestReader(DataSourceReader):
def read(self, partition):
ctypes.string_at(0)
yield "x",

self.spark.dataSource.register(TestDataSource)

with self.assertRaisesRegex(Exception, expected):
self.spark.read.format("test").load().show()

with self.subTest(worker="pyspark.sql.worker.write_into_data_source"):

class TestDataSource(DataSource):
@classmethod
def name(cls):
return "test"

def writer(self, schema, overwrite):
return TestWriter()

class TestWriter(DataSourceWriter):
def write(self, iterator):
ctypes.string_at(0)
return WriterCommitMessage()

self.spark.dataSource.register(TestDataSource)

with self.assertRaisesRegex(Exception, expected):
self.spark.range(10).write.format("test").mode("append").saveAsTable(
"test_table"
)

with self.subTest(worker="pyspark.sql.worker.commit_data_source_write"):

class TestDataSource(DataSource):
@classmethod
def name(cls):
return "test"

def writer(self, schema, overwrite):
return TestWriter()

class TestWriter(DataSourceWriter):
def write(self, iterator):
return WriterCommitMessage()

def commit(self, messages):
ctypes.string_at(0)

self.spark.dataSource.register(TestDataSource)

with self.assertRaisesRegex(Exception, expected):
self.spark.range(10).write.format("test").mode("append").saveAsTable(
"test_table"
)


class PythonDataSourceTests(BasePythonDataSourceTestsMixin, ReusedSQLTestCase):
...
Expand Down
37 changes: 37 additions & 0 deletions python/pyspark/sql/tests/test_udtf.py
Original file line number Diff line number Diff line change
Expand Up @@ -2761,6 +2761,43 @@ def eval(self, n):
res = self.spark.sql("select i, to_json(v['v1']) from test_udtf_struct(8)")
assertDataFrameEqual(res, [Row(i=n, s=f'{{"a":"{chr(99 + n)}"}}') for n in range(8)])

def test_udtf_segfault(self):
for enabled, expected in [
(True, "Segmentation fault"),
(False, "Consider setting .* for the better Python traceback."),
]:
with self.subTest(enabled=enabled), self.sql_conf(
{"spark.sql.execution.pyspark.udf.faulthandler.enabled": enabled}
):
with self.subTest(method="eval"):

class TestUDTF:
def eval(self):
import ctypes

yield ctypes.string_at(0),

self._check_result_or_exception(
TestUDTF, "x: string", expected, err_type=Exception
)

with self.subTest(method="analyze"):

class TestUDTFWithAnalyze:
@staticmethod
def analyze():
import ctypes

ctypes.string_at(0)
return AnalyzeResult(StructType().add("x", StringType()))

def eval(self):
yield "x",

self._check_result_or_exception(
TestUDTFWithAnalyze, None, expected, err_type=Exception
)


class UDTFTests(BaseUDTFTestsMixin, ReusedSQLTestCase):
@classmethod
Expand Down
12 changes: 12 additions & 0 deletions python/pyspark/sql/worker/analyze_udtf.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
# limitations under the License.
#

import faulthandler
import inspect
import os
import sys
Expand Down Expand Up @@ -106,7 +107,13 @@ def main(infile: IO, outfile: IO) -> None:
in JVM and receive the Python UDTF and its arguments for the `analyze` static method,
and call the `analyze` static method, and send back a AnalyzeResult as a result of the method.
"""
faulthandler_log_path = os.environ.get("PYTHON_FAULTHANDLER_DIR", None)
try:
if faulthandler_log_path:
faulthandler_log_path = os.path.join(faulthandler_log_path, str(os.getpid()))
faulthandler_log_file = open(faulthandler_log_path, "w")
faulthandler.enable(file=faulthandler_log_file)

check_python_version(infile)

memory_limit_mb = int(os.environ.get("PYSPARK_PLANNER_MEMORY_MB", "-1"))
Expand Down Expand Up @@ -247,6 +254,11 @@ def invalid_analyze_result_field(field_name: str, expected_field: str) -> PySpar
except BaseException as e:
handle_worker_exception(e, outfile)
sys.exit(-1)
finally:
if faulthandler_log_path:
faulthandler.disable()
faulthandler_log_file.close()
os.remove(faulthandler_log_path)

send_accumulator_updates(outfile)

Expand Down
12 changes: 12 additions & 0 deletions python/pyspark/sql/worker/commit_data_source_write.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
#
import faulthandler
import os
import sys
from typing import IO
Expand Down Expand Up @@ -47,7 +48,13 @@ def main(infile: IO, outfile: IO) -> None:
responsible for invoking either the `commit` or the `abort` method on a data source
writer instance, given a list of commit messages.
"""
faulthandler_log_path = os.environ.get("PYTHON_FAULTHANDLER_DIR", None)
try:
if faulthandler_log_path:
faulthandler_log_path = os.path.join(faulthandler_log_path, str(os.getpid()))
faulthandler_log_file = open(faulthandler_log_path, "w")
faulthandler.enable(file=faulthandler_log_file)

check_python_version(infile)

memory_limit_mb = int(os.environ.get("PYSPARK_PLANNER_MEMORY_MB", "-1"))
Expand Down Expand Up @@ -93,6 +100,11 @@ def main(infile: IO, outfile: IO) -> None:
except BaseException as e:
handle_worker_exception(e, outfile)
sys.exit(-1)
finally:
if faulthandler_log_path:
faulthandler.disable()
faulthandler_log_file.close()
os.remove(faulthandler_log_path)

send_accumulator_updates(outfile)

Expand Down
12 changes: 12 additions & 0 deletions python/pyspark/sql/worker/create_data_source.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
#
import faulthandler
import inspect
import os
import sys
Expand Down Expand Up @@ -60,7 +61,13 @@ def main(infile: IO, outfile: IO) -> None:
This process then creates a `DataSource` instance using the above information and
sends the pickled instance as well as the schema back to the JVM.
"""
faulthandler_log_path = os.environ.get("PYTHON_FAULTHANDLER_DIR", None)
try:
if faulthandler_log_path:
faulthandler_log_path = os.path.join(faulthandler_log_path, str(os.getpid()))
faulthandler_log_file = open(faulthandler_log_path, "w")
faulthandler.enable(file=faulthandler_log_file)

check_python_version(infile)

memory_limit_mb = int(os.environ.get("PYSPARK_PLANNER_MEMORY_MB", "-1"))
Expand Down Expand Up @@ -158,6 +165,11 @@ def main(infile: IO, outfile: IO) -> None:
except BaseException as e:
handle_worker_exception(e, outfile)
sys.exit(-1)
finally:
if faulthandler_log_path:
faulthandler.disable()
faulthandler_log_file.close()
os.remove(faulthandler_log_path)

send_accumulator_updates(outfile)

Expand Down
12 changes: 12 additions & 0 deletions python/pyspark/sql/worker/lookup_data_sources.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
#
import faulthandler
from importlib import import_module
from pkgutil import iter_modules
import os
Expand Down Expand Up @@ -50,7 +51,13 @@ def main(infile: IO, outfile: IO) -> None:
This is responsible for searching the available Python Data Sources so they can be
statically registered automatically.
"""
faulthandler_log_path = os.environ.get("PYTHON_FAULTHANDLER_DIR", None)
try:
if faulthandler_log_path:
faulthandler_log_path = os.path.join(faulthandler_log_path, str(os.getpid()))
faulthandler_log_file = open(faulthandler_log_path, "w")
faulthandler.enable(file=faulthandler_log_file)

check_python_version(infile)

memory_limit_mb = int(os.environ.get("PYSPARK_PLANNER_MEMORY_MB", "-1"))
Expand Down Expand Up @@ -78,6 +85,11 @@ def main(infile: IO, outfile: IO) -> None:
except BaseException as e:
handle_worker_exception(e, outfile)
sys.exit(-1)
finally:
if faulthandler_log_path:
faulthandler.disable()
faulthandler_log_file.close()
os.remove(faulthandler_log_path)

send_accumulator_updates(outfile)

Expand Down
12 changes: 12 additions & 0 deletions python/pyspark/sql/worker/plan_data_source_read.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
# limitations under the License.
#

import faulthandler
import os
import sys
import functools
Expand Down Expand Up @@ -187,7 +188,13 @@ def main(infile: IO, outfile: IO) -> None:
The partition values and the Arrow Batch are then serialized and sent back to the JVM
via the socket.
"""
faulthandler_log_path = os.environ.get("PYTHON_FAULTHANDLER_DIR", None)
try:
if faulthandler_log_path:
faulthandler_log_path = os.path.join(faulthandler_log_path, str(os.getpid()))
faulthandler_log_file = open(faulthandler_log_path, "w")
faulthandler.enable(file=faulthandler_log_file)

check_python_version(infile)

memory_limit_mb = int(os.environ.get("PYSPARK_PLANNER_MEMORY_MB", "-1"))
Expand Down Expand Up @@ -351,6 +358,11 @@ def data_source_read_func(iterator: Iterable[pa.RecordBatch]) -> Iterable[pa.Rec
except BaseException as e:
handle_worker_exception(e, outfile)
sys.exit(-1)
finally:
if faulthandler_log_path:
faulthandler.disable()
faulthandler_log_file.close()
os.remove(faulthandler_log_path)

send_accumulator_updates(outfile)

Expand Down
Loading
Loading