From e120bb9e2749395f427007c4f048eea237bc5351 Mon Sep 17 00:00:00 2001 From: David Sisson Date: Mon, 26 Aug 2024 21:48:48 -0700 Subject: [PATCH 01/15] feat: implement row_number window function --- src/gateway/converter/spark_functions.py | 7 ++ src/gateway/converter/spark_to_substrait.py | 117 +++++++++++++++++++- src/gateway/server.py | 3 + src/gateway/tests/test_dataframe_api.py | 26 +++++ 4 files changed, 152 insertions(+), 1 deletion(-) diff --git a/src/gateway/converter/spark_functions.py b/src/gateway/converter/spark_functions.py index fa7b60d..4999e12 100644 --- a/src/gateway/converter/spark_functions.py +++ b/src/gateway/converter/spark_functions.py @@ -808,6 +808,13 @@ def __lt__(self, obj) -> bool: i64=type_pb2.Type.I64(nullability=type_pb2.Type.Nullability.NULLABILITY_REQUIRED) ), ), + "row_number": ExtensionFunction( + "/functions_arithmetic.yaml", + "row_number:any", + type_pb2.Type( + i64=type_pb2.Type.I64(nullability=type_pb2.Type.Nullability.NULLABILITY_REQUIRED) + ), + function_type=FunctionType.WINDOW), } diff --git a/src/gateway/converter/spark_to_substrait.py b/src/gateway/converter/spark_to_substrait.py index adb6358..1387dd8 100644 --- a/src/gateway/converter/spark_to_substrait.py +++ b/src/gateway/converter/spark_to_substrait.py @@ -618,6 +618,121 @@ def convert_cast_expression( raise NotImplementedError(f'unknown cast_to_type {cast.WhichOneof("cast_to_type")}') return algebra_pb2.Expression(cast=cast_rel) + def convert_frame_boundary( + self, + boundary: spark_exprs_pb2.Expression.Window.WindowFrame.FrameBoundary) \ + -> algebra_pb2.Expression.WindowFunction.Bound: + """Convert a Spark frame boundary into a Substrait window Bound.""" + bound = algebra_pb2.Expression.WindowFunction.Bound() + match boundary.WhichOneof("boundary"): + case "current_row": + bound.current_row.CopyFrom(algebra_pb2.Expression.WindowFunction.Bound.CurrentRow()) + case "unbounded": + bound.unbounded.CopyFrom(algebra_pb2.Expression.WindowFunction.Bound.Unbounded()) + case "value": + if boundary.value.WhichOneof("expr_type") != "literal": + raise ValueError("Boundary value expression must be a literal.") + offset_expr = self.convert_literal_expression(boundary.value.literal) + match offset_expr.literal.WhichOneof("literal_type"): + case "i8": + offset_value = offset_expr.i8.value + case "i16": + offset_value = offset_expr.i16.value + case "i32": + offset_value = offset_expr.i32.value + case "i64": + offset_value = offset_expr.i64.value + case _: + raise ValueError( + "Unexpected literal type: " + f"{offset_expr.literal.WhichOneof('literal_type')}") + if offset_value > 0: + boundary.following.offset = offset_value + elif offset_value < 0: + boundary.preceding.offset = -offset_value + else: + boundary.current_row.CopyFrom( + algebra_pb2.Expression.WindowFunction.Bound.CurrentRow()) + case _: + raise ValueError(f"Unknown boundary type: {boundary.WhichOneof('boundary')}") + return bound + + def convert_order_spec( + self, + order: spark_exprs_pb2.Expression.SortOrder) -> algebra_pb2.SortField: + """Convert a Spark order specification into a Substrait sort field.""" + if order.direction == spark_exprs_pb2.Expression.SortOrder.SORT_DIRECTION_ASCENDING: + if (order.null_ordering == + spark_exprs_pb2.Expression.SortOrder.NullOrdering.SORT_NULLS_FIRST): + sort_dir = algebra_pb2.SortField.SortDirection.SORT_DIRECTION_ASC_NULLS_FIRST + else: + sort_dir = algebra_pb2.SortField.SortDirection.SORT_DIRECTION_ASC_NULLS_LAST + else: + if (order.null_ordering == + spark_exprs_pb2.Expression.SortOrder.NullOrdering.SORT_NULLS_FIRST): + sort_dir = algebra_pb2.SortField.SortDirection.SORT_DIRECTION_DESC_NULLS_FIRST + else: + sort_dir = algebra_pb2.SortField.SortDirection.SORT_DIRECTION_DESC_NULLS_LAST + return algebra_pb2.SortField( + expr=self.convert_expression(order.child), direction=sort_dir) + + def convert_window_expression( + self, window: spark_exprs_pb2.Expression.Window) -> algebra_pb2.Expression: + """Convert a Spark window expression into a Substrait window expression.""" + func = algebra_pb2.Expression.WindowFunction() + if window.window_function.WhichOneof("expr_type") != "unresolved_function": + raise NotImplementedError( + "Window functions which are not unresolved functions are not yet supported.") + function_def = self.lookup_function_by_name( + window.window_function.unresolved_function.function_name) + func.function_reference = function_def.anchor + for idx, arg in enumerate(window.window_function.unresolved_function.arguments): + if function_def.max_args is not None and idx >= function_def.max_args: + break + if (window.window_function.unresolved_function.function_name == "count" and + arg.WhichOneof("expr_type") == "unresolved_star"): + # Ignore all the rest of the arguments. + func.arguments.append( + algebra_pb2.FunctionArgument(value=bigint_literal(1))) + break + func.arguments.append( + algebra_pb2.FunctionArgument(value=self.convert_expression(arg))) + if function_def.options: + func.options.extend(function_def.options) + func.output_type.CopyFrom(function_def.output_type) + for sort in window.order_spec: + func.sorts.append(self.convert_order_spec(sort)) + if (function_def.function_type == FunctionType.WINDOW or + function_def.function_type == FunctionType.AGGREGATE): + func.invocation = algebra_pb2.AGGREGATION_PHASE_INITIAL_TO_RESULT + else: + raise ValueError( + f"Unexpected function type: {function_def.function_type}") + if window.partition_spec: + for partition in window.partition_spec: + func.partitions.append(self.convert_expression(partition)) + if window.HasField("frame_spec"): + match window.frame_spec.frame_type: + case spark_exprs_pb2.Expression.Window.WindowFrame.FRAME_TYPE_RANGE: + func.bounds_type = ( + algebra_pb2.Expression.WindowFunction.BoundsType.BOUNDS_TYPE_RANGE) + case spark_exprs_pb2.Expression.Window.WindowFrame.FRAME_TYPE_ROW: + func.bounds_type = ( + algebra_pb2.Expression.WindowFunction.BoundsType.BOUNDS_TYPE_ROW) + case _: + raise ValueError( + f"Unknown frame type: {window.frame_spec.frame_type}") + func.lower_bound.CopyFrom(self.convert_frame_boundary(window.frame_spec.lower)) + func.upper_bound.CopyFrom(self.convert_frame_boundary(window.frame_spec.upper)) + else: + func.bounds_type = ( + algebra_pb2.Expression.WindowFunction.BoundsType.BOUNDS_TYPE_UNSPECIFIED) + func.lower_bound.unbounded.CopyFrom( + algebra_pb2.Expression.WindowFunction.Bound.Unbounded()) + func.upper_bound.unbounded.CopyFrom( + algebra_pb2.Expression.WindowFunction.Bound.Unbounded()) + return algebra_pb2.Expression(window_function=func) + def convert_extract_value( self, extract: spark_exprs_pb2.Expression.UnresolvedExtractValue ) -> algebra_pb2.Expression: @@ -684,7 +799,7 @@ def convert_expression(self, expr: spark_exprs_pb2.Expression) -> algebra_pb2.Ex case "lambda_function": raise NotImplementedError("lambda_function expression type not supported") case "window": - raise NotImplementedError("window expression type not supported") + result = self.convert_window_expression(expr.window) case "unresolved_extract_value": result = self.convert_extract_value(expr.unresolved_extract_value) case "update_fields": diff --git a/src/gateway/server.py b/src/gateway/server.py index 6be7dc8..7519cf2 100644 --- a/src/gateway/server.py +++ b/src/gateway/server.py @@ -66,6 +66,9 @@ def convert_pyarrow_datatype_to_spark(arrow_type: pa.DataType) -> types_pb2.Data data_type = types_pb2.DataType(integer=types_pb2.DataType.Integer()) elif arrow_type == pa.int64(): data_type = types_pb2.DataType(long=types_pb2.DataType.Long()) + elif arrow_type == pa.uint64(): + # TODO: Spark doesn't have unsigned types so come up with a way to handle overflow. + data_type = types_pb2.DataType(long=types_pb2.DataType.Long()) elif arrow_type == pa.float32(): data_type = types_pb2.DataType(float=types_pb2.DataType.Float()) elif arrow_type == pa.float64(): diff --git a/src/gateway/tests/test_dataframe_api.py b/src/gateway/tests/test_dataframe_api.py index 3df6a05..0387d6f 100644 --- a/src/gateway/tests/test_dataframe_api.py +++ b/src/gateway/tests/test_dataframe_api.py @@ -52,6 +52,7 @@ replace, right, rlike, + row_number, rpad, rtrim, sqrt, @@ -64,6 +65,7 @@ upper, ) from pyspark.sql.types import DoubleType, StructField, StructType +from pyspark.sql.window import Window from pyspark.testing import assertDataFrameEqual from gateway.tests.conftest import find_tpch @@ -161,6 +163,9 @@ def mark_dataframe_tests_as_xfail(request): if source == "gateway-over-datafusion" and originalname == "test_try_divide": pytest.skip(reason="returns infinity instead of null") + if source == "gateway-over-duckdb" and originalname == "test_row_number": + pytest.skip(reason="window functions not yet implemented in DuckDB") + # ruff: noqa: E712 class TestDataFrameAPI: @@ -2748,3 +2753,24 @@ def test_multiple_measures_and_calculations(self, register_tpch_dataset, spark_s ) assertDataFrameEqual(outcome, expected) + + +class TestDataFrameWindowFunctions: + """Tests window functions of the dataframe side of SparkConnect.""" + + def test_row_number(self, users_dataframe): + expected = [ + Row(user_id="user705452451", name="Adrian Reyes", paid_for_service=False, row_number=1), + Row(user_id="user406489700", name="Alan Aguirre DVM", paid_for_service=False, + row_number=2), + Row(user_id="user965620978", name="Alan Whitaker", paid_for_service=False, + row_number=3), + ] + + with utilizes_valid_plans(users_dataframe): + window_spec = Window.partitionBy().orderBy("name") + outcome = users_dataframe.withColumn("row_number", + row_number().over(window_spec)).orderBy( + "row_number").limit(3) + + assertDataFrameEqual(outcome, expected) From 365aeb4ef115880494fa5c1d394d2ae86db110b5 Mon Sep 17 00:00:00 2001 From: David Sisson Date: Tue, 27 Aug 2024 00:10:57 -0700 Subject: [PATCH 02/15] add updates for Datafusion upgrade --- src/gateway/tests/test_tpch_with_dataframe_api.py | 7 ++----- 1 file changed, 2 insertions(+), 5 deletions(-) diff --git a/src/gateway/tests/test_tpch_with_dataframe_api.py b/src/gateway/tests/test_tpch_with_dataframe_api.py index dcdc96a..6357467 100644 --- a/src/gateway/tests/test_tpch_with_dataframe_api.py +++ b/src/gateway/tests/test_tpch_with_dataframe_api.py @@ -18,11 +18,8 @@ def mark_tests_as_xfail(request): """Marks a subset of tests as expected to be fail.""" source = request.getfixturevalue("source") originalname = request.keywords.node.originalname - if source == "gateway-over-duckdb": - if originalname == "test_query_15": - request.node.add_marker(pytest.mark.xfail(reason="No results (float vs decimal)")) - if originalname == "test_query_16": - request.node.add_marker(pytest.mark.xfail(reason="distinct not supported")) + if source == "gateway-over-duckdb" and originalname == "test_query_16": + request.node.add_marker(pytest.mark.xfail(reason="distinct not supported")) class TestTpchWithDataFrameAPI: From 18a5803efadfdefc73a3a34d44b4b407642fcf0f Mon Sep 17 00:00:00 2001 From: David Sisson Date: Tue, 27 Aug 2024 12:44:23 -0700 Subject: [PATCH 03/15] advance pyarrow version --- environment.yml | 2 +- pyproject.toml | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/environment.yml b/environment.yml index 98de933..79eeac9 100644 --- a/environment.yml +++ b/environment.yml @@ -17,7 +17,7 @@ dependencies: - pip: - adbc_driver_manager - cargo - - pyarrow >= 16.1.0 + - pyarrow >= 17.0.0 - duckdb >= 1.0.0 - datafusion == 40.1.* - grpcio-channelz diff --git a/pyproject.toml b/pyproject.toml index 7affda7..b24d1ba 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -5,7 +5,7 @@ authors = [{name = "Substrait contributors", email = "substrait@googlegroups.com license = {text = "Apache-2.0"} readme = "README.md" requires-python = ">=3.8.1" -dependencies = ["protobuf >= 3.20", "datafusion >= 40.1.0", "pyarrow >= 15.0.2", "substrait == 0.21.0"] +dependencies = ["protobuf >= 3.20", "datafusion >= 40.1.0", "pyarrow >= 17.0.0", "substrait == 0.21.0"] dynamic = ["version"] [tool.setuptools_scm] From 3f11153e942acf23e2d4cac4846123b1008cf1ab Mon Sep 17 00:00:00 2001 From: David Sisson Date: Tue, 27 Aug 2024 18:04:19 -0700 Subject: [PATCH 04/15] debugging --- src/backends/datafusion_backend.py | 2 ++ src/gateway/tests/test_tpcds_sql.py | 9 +++++++-- 2 files changed, 9 insertions(+), 2 deletions(-) diff --git a/src/backends/datafusion_backend.py b/src/backends/datafusion_backend.py index 5aece59..bd3597a 100644 --- a/src/backends/datafusion_backend.py +++ b/src/backends/datafusion_backend.py @@ -52,6 +52,8 @@ def _execute_plan(self, plan: plan_pb2.Plan) -> pa.lib.Table: """Execute the given Substrait plan against Datafusion.""" import datafusion.substrait + if len(plan.relations) != 1: + raise ValueError(f"Expected exactly one relation in the plan: {plan}") plan_data = plan.SerializeToString() substrait_plan = datafusion.substrait.Serde.deserialize_bytes(plan_data) logical_plan = datafusion.substrait.Consumer.from_substrait_plan( diff --git a/src/gateway/tests/test_tpcds_sql.py b/src/gateway/tests/test_tpcds_sql.py index 51ae3fa..1e2976c 100644 --- a/src/gateway/tests/test_tpcds_sql.py +++ b/src/gateway/tests/test_tpcds_sql.py @@ -18,8 +18,8 @@ def mark_tests_as_xfail(request): """Marks a subset of tests as expected to be fail.""" source = request.getfixturevalue("source") + path = request.getfixturevalue("path") if source == "gateway-over-duckdb": - path = request.getfixturevalue("path") if path.stem in ["01", "06", "10", "16", "30", "32", "35", "69", "81", "86", "92", "94"]: pytest.skip(reason="DuckDB needs Delim join") elif path.stem in [ @@ -119,7 +119,12 @@ def mark_tests_as_xfail(request): elif path.stem in ["95"]: pytest.skip(reason="Unsupported join comparison: !=") if source == "gateway-over-datafusion": - pytest.skip(reason="not yet ready to run SQL tests regularly") + if path.stem in ["02"]: + pytest.skip(reason="Null type without kind is not supported") + elif path.stem in ["09"]: + pytest.skip(reason="Aggregate function any_value is not supported: function anchor = 6") + else: + pytest.skip(reason="not yet ready to run SQL tests regularly") if source == "gateway-over-arrow": pytest.skip(reason="not yet ready to run SQL tests regularly") From cb576a0f96f6190e4b6ab549395dd46ac0a23b32 Mon Sep 17 00:00:00 2001 From: David Sisson Date: Tue, 27 Aug 2024 18:40:15 -0700 Subject: [PATCH 05/15] more debugging --- src/backends/datafusion_backend.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/src/backends/datafusion_backend.py b/src/backends/datafusion_backend.py index bd3597a..bfdd675 100644 --- a/src/backends/datafusion_backend.py +++ b/src/backends/datafusion_backend.py @@ -32,6 +32,9 @@ def create_connection(self) -> None: @contextmanager def adjust_plan(self, plan: plan_pb2.Plan) -> Iterator[plan_pb2.Plan]: """Modify the given Substrait plan for use with Datafusion.""" + if len(plan.relations) != 1: + raise ValueError(f"Expected exactly 1 relation in the plan: {plan}") + file_groups = ReplaceLocalFilesWithNamedTable().visit_plan(plan) registered_tables = set() for table_name, files in file_groups: From 0dc0555bc7bd5e1558cb407a9b073b54affc58f6 Mon Sep 17 00:00:00 2001 From: David Sisson Date: Tue, 27 Aug 2024 19:01:17 -0700 Subject: [PATCH 06/15] earlier debugging --- src/gateway/server.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/src/gateway/server.py b/src/gateway/server.py index 7519cf2..c8c23e9 100644 --- a/src/gateway/server.py +++ b/src/gateway/server.py @@ -286,6 +286,8 @@ def ExecutePlan( raise ValueError(f"Unknown plan type: {request.plan}") _LOGGER.debug(" as Substrait: %s", substrait) self._statistics.add_plan(substrait) + if len(substrait.relations) != 1: + raise ValueError(f"Expected exactly ONE relation in the plan: {request}") try: results = self._backend.execute(substrait) except Exception as err: From dae7310c61d573871cdbf874173004dd2a0acd70 Mon Sep 17 00:00:00 2001 From: David Sisson Date: Tue, 27 Aug 2024 19:42:35 -0700 Subject: [PATCH 07/15] spaced --- src/gateway/server.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/gateway/server.py b/src/gateway/server.py index c8c23e9..4c656ee 100644 --- a/src/gateway/server.py +++ b/src/gateway/server.py @@ -288,6 +288,7 @@ def ExecutePlan( self._statistics.add_plan(substrait) if len(substrait.relations) != 1: raise ValueError(f"Expected exactly ONE relation in the plan: {request}") + try: results = self._backend.execute(substrait) except Exception as err: From d570620fbc2022298977d8d1eadd43138bcbc634 Mon Sep 17 00:00:00 2001 From: David Sisson Date: Tue, 27 Aug 2024 21:49:56 -0700 Subject: [PATCH 08/15] more debugging --- src/gateway/server.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/src/gateway/server.py b/src/gateway/server.py index 4c656ee..8e2135d 100644 --- a/src/gateway/server.py +++ b/src/gateway/server.py @@ -345,6 +345,8 @@ def AnalyzePlan(self, request, context): if request.schema: substrait = self._converter.convert_plan(request.schema.plan) self._statistics.add_plan(substrait) + if len(substrait.relations) != 1: + raise ValueError(f"Expected exactly _ONE_ relation in the plan: {request}") try: results = self._backend.execute(substrait) except Exception as err: From edf3566e19fe77aa7f8d7dd946b213fb7725389b Mon Sep 17 00:00:00 2001 From: David Sisson Date: Tue, 27 Aug 2024 23:35:05 -0700 Subject: [PATCH 09/15] updated analyze_plan --- src/gateway/server.py | 89 +++++++++++++++++++++++-------------------- 1 file changed, 48 insertions(+), 41 deletions(-) diff --git a/src/gateway/server.py b/src/gateway/server.py index 8e2135d..4a31aec 100644 --- a/src/gateway/server.py +++ b/src/gateway/server.py @@ -128,7 +128,7 @@ def convert_pyarrow_schema_to_spark(schema: pa.Schema) -> types_pb2.DataType: def create_dataframe_view( - session_id: str, view: commands_pb2.CreateDataFrameViewCommand, backend + session_id: str, view: commands_pb2.CreateDataFrameViewCommand, backend ) -> None: """Register the temporary dataframe.""" read_data_source_relation = view.input.read.data_source @@ -235,7 +235,7 @@ def _ReinitializeExecution(self) -> None: return None def ExecutePlan( - self, request: pb2.ExecutePlanRequest, context: grpc.RpcContext + self, request: pb2.ExecutePlanRequest, context: grpc.RpcContext ) -> Generator[pb2.ExecutePlanResponse, None, None]: """Execute the given plan and return the results.""" self._statistics.execute_requests += 1 @@ -297,9 +297,9 @@ def ExecutePlan( _LOGGER.debug(" results are: %s", results) if ( - not self._options.implement_show_string - and request.plan.WhichOneof("op_type") == "root" - and request.plan.root.WhichOneof("rel_type") == "show_string" + not self._options.implement_show_string + and request.plan.WhichOneof("op_type") == "root" + and request.plan.root.WhichOneof("rel_type") == "show_string" ): yield pb2.ExecutePlanResponse( session_id=request.session_id, @@ -342,24 +342,31 @@ def AnalyzePlan(self, request, context): self._statistics.add_request(request) _LOGGER.info("AnalyzePlan: %s", request) self._InitializeExecution() - if request.schema: - substrait = self._converter.convert_plan(request.schema.plan) - self._statistics.add_plan(substrait) - if len(substrait.relations) != 1: - raise ValueError(f"Expected exactly _ONE_ relation in the plan: {request}") - try: - results = self._backend.execute(substrait) - except Exception as err: - self._ReinitializeExecution() - raise err - _LOGGER.debug(" results are: %s", results) - return pb2.AnalyzePlanResponse( - session_id=request.session_id, - schema=pb2.AnalyzePlanResponse.Schema( - schema=convert_pyarrow_schema_to_spark(results.schema) - ), - ) - raise NotImplementedError("AnalyzePlan not yet implemented for non-Schema requests.") + match request.WhichOneof("analyze"): + case 'schema': + request_plan = request.schema.plan + case 'is_streaming': + request_plan = request.is_streaming.plan + case _: + raise NotImplementedError( + "AnalyzePlan not yet implemented for non-Schema requests: " + f"{request.WhichOneof('analyze')}") + substrait = self._converter.convert_plan(request_plan) + self._statistics.add_plan(substrait) + if len(substrait.relations) != 1: + raise ValueError(f"Expected exactly _ONE_ relation in the plan: {request}") + try: + results = self._backend.execute(substrait) + except Exception as err: + self._ReinitializeExecution() + raise err + _LOGGER.debug(" results are: %s", results) + return pb2.AnalyzePlanResponse( + session_id=request.session_id, + schema=pb2.AnalyzePlanResponse.Schema( + schema=convert_pyarrow_schema_to_spark(results.schema) + ), + ) def Config(self, request, context): """Get or set the configuration of the server.""" @@ -390,7 +397,7 @@ def Config(self, request, context): elif key == "spark-substrait-gateway.plan_count": response.pairs.add(key=key, value=str(len(self._statistics.plans))) elif key.startswith("spark-substrait-gateway.plan."): - index = int(key[len("spark-substrait-gateway.plan.") :]) + index = int(key[len("spark-substrait-gateway.plan."):]) if 0 <= index - 1 < len(self._statistics.plans): response.pairs.add(key=key, value=self._statistics.plans[index - 1]) elif key == "spark.sql.session.timeZone": @@ -468,7 +475,7 @@ def Interrupt(self, request, context): return pb2.InterruptResponse() def ReattachExecute( - self, request: pb2.ReattachExecuteRequest, context: grpc.RpcContext + self, request: pb2.ReattachExecuteRequest, context: grpc.RpcContext ) -> Generator[pb2.ExecutePlanResponse, None, None]: """Reattach the execution of the given plan.""" self._statistics.reattach_requests += 1 @@ -485,13 +492,13 @@ def ReleaseExecute(self, request, context): def serve( - port: int, - wait: bool, - tls: list[str] | None = None, - enable_auth: bool = False, - jwt_audience: str | None = None, - secret_key: str | None = None, - log_level: str = "INFO", + port: int, + wait: bool, + tls: list[str] | None = None, + enable_auth: bool = False, + jwt_audience: str | None = None, + secret_key: str | None = None, + log_level: str = "INFO", ) -> grpc.Server: """Start the Spark Substrait Gateway server.""" logging.basicConfig(level=getattr(logging, log_level), encoding="utf-8") @@ -569,8 +576,8 @@ def serve( required=False, metavar=("CERTFILE", "KEYFILE"), help="Enable transport-level security (TLS/SSL). Provide a " - "Certificate file path, and a Key file path - separated by a space. " - "Example: tls/server.crt tls/server.key", + "Certificate file path, and a Key file path - separated by a space. " + "Example: tls/server.crt tls/server.key", ) @click.option( "--enable-auth/--no-enable-auth", @@ -601,13 +608,13 @@ def serve( help="The logging level to use for the server.", ) def click_serve( - port: int, - wait: bool, - tls: list[str], - enable_auth: bool, - jwt_audience: str, - secret_key: str, - log_level: str, + port: int, + wait: bool, + tls: list[str], + enable_auth: bool, + jwt_audience: str, + secret_key: str, + log_level: str, ) -> grpc.Server: """Provide a click interface for starting the Spark Substrait Gateway server.""" return serve(**locals()) From bcee3444121a11eef72a00ca20e91550b4aeac8e Mon Sep 17 00:00:00 2001 From: David Sisson Date: Tue, 27 Aug 2024 23:50:23 -0700 Subject: [PATCH 10/15] replace is_streaming logic with false always --- src/gateway/server.py | 41 +++++++++++++++++++++++------------------ 1 file changed, 23 insertions(+), 18 deletions(-) diff --git a/src/gateway/server.py b/src/gateway/server.py index 4a31aec..8519f84 100644 --- a/src/gateway/server.py +++ b/src/gateway/server.py @@ -344,29 +344,34 @@ def AnalyzePlan(self, request, context): self._InitializeExecution() match request.WhichOneof("analyze"): case 'schema': - request_plan = request.schema.plan + substrait = self._converter.convert_plan(request.schema.plan) + self._statistics.add_plan(substrait) + if len(substrait.relations) != 1: + raise ValueError(f"Expected exactly _ONE_ relation in the plan: {request}") + try: + results = self._backend.execute(substrait) + except Exception as err: + self._ReinitializeExecution() + raise err + _LOGGER.debug(" results are: %s", results) + return pb2.AnalyzePlanResponse( + session_id=request.session_id, + schema=pb2.AnalyzePlanResponse.Schema( + schema=convert_pyarrow_schema_to_spark(results.schema) + ), + ) case 'is_streaming': - request_plan = request.is_streaming.plan + # TODO -- Actually look at request.is_streaming.plan + return pb2.AnalyzePlanResponse( + session_id=request.session_id, + schema=pb2.AnalyzePlanResponse.IsStreaming( + is_streaming=False + ), + ) case _: raise NotImplementedError( "AnalyzePlan not yet implemented for non-Schema requests: " f"{request.WhichOneof('analyze')}") - substrait = self._converter.convert_plan(request_plan) - self._statistics.add_plan(substrait) - if len(substrait.relations) != 1: - raise ValueError(f"Expected exactly _ONE_ relation in the plan: {request}") - try: - results = self._backend.execute(substrait) - except Exception as err: - self._ReinitializeExecution() - raise err - _LOGGER.debug(" results are: %s", results) - return pb2.AnalyzePlanResponse( - session_id=request.session_id, - schema=pb2.AnalyzePlanResponse.Schema( - schema=convert_pyarrow_schema_to_spark(results.schema) - ), - ) def Config(self, request, context): """Get or set the configuration of the server.""" From c442a91bc6b3cc2f0e6ea8083daf30e28a44aaef Mon Sep 17 00:00:00 2001 From: David Sisson Date: Wed, 28 Aug 2024 00:05:53 -0700 Subject: [PATCH 11/15] fixed type --- src/gateway/server.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/gateway/server.py b/src/gateway/server.py index 8519f84..0a4c878 100644 --- a/src/gateway/server.py +++ b/src/gateway/server.py @@ -361,10 +361,10 @@ def AnalyzePlan(self, request, context): ), ) case 'is_streaming': - # TODO -- Actually look at request.is_streaming.plan + # TODO -- Actually look at the plan (this path is used by pyspark.testing.utils). return pb2.AnalyzePlanResponse( session_id=request.session_id, - schema=pb2.AnalyzePlanResponse.IsStreaming( + is_streaming=pb2.AnalyzePlanResponse.IsStreaming( is_streaming=False ), ) From 2ad80d62c3ab2921c33bb9bdeaa51a81b9386869 Mon Sep 17 00:00:00 2001 From: David Sisson Date: Wed, 28 Aug 2024 00:22:48 -0700 Subject: [PATCH 12/15] removed debugging code --- src/backends/datafusion_backend.py | 8 +----- src/gateway/server.py | 44 ++++++++++++++---------------- 2 files changed, 22 insertions(+), 30 deletions(-) diff --git a/src/backends/datafusion_backend.py b/src/backends/datafusion_backend.py index bfdd675..133a091 100644 --- a/src/backends/datafusion_backend.py +++ b/src/backends/datafusion_backend.py @@ -3,6 +3,7 @@ from collections.abc import Iterator from contextlib import contextmanager +import datafusion.substrait from pathlib import Path import pyarrow as pa @@ -32,9 +33,6 @@ def create_connection(self) -> None: @contextmanager def adjust_plan(self, plan: plan_pb2.Plan) -> Iterator[plan_pb2.Plan]: """Modify the given Substrait plan for use with Datafusion.""" - if len(plan.relations) != 1: - raise ValueError(f"Expected exactly 1 relation in the plan: {plan}") - file_groups = ReplaceLocalFilesWithNamedTable().visit_plan(plan) registered_tables = set() for table_name, files in file_groups: @@ -53,10 +51,6 @@ def adjust_plan(self, plan: plan_pb2.Plan) -> Iterator[plan_pb2.Plan]: def _execute_plan(self, plan: plan_pb2.Plan) -> pa.lib.Table: """Execute the given Substrait plan against Datafusion.""" - import datafusion.substrait - - if len(plan.relations) != 1: - raise ValueError(f"Expected exactly one relation in the plan: {plan}") plan_data = plan.SerializeToString() substrait_plan = datafusion.substrait.Serde.deserialize_bytes(plan_data) logical_plan = datafusion.substrait.Consumer.from_substrait_plan( diff --git a/src/gateway/server.py b/src/gateway/server.py index 0a4c878..92e782c 100644 --- a/src/gateway/server.py +++ b/src/gateway/server.py @@ -128,7 +128,7 @@ def convert_pyarrow_schema_to_spark(schema: pa.Schema) -> types_pb2.DataType: def create_dataframe_view( - session_id: str, view: commands_pb2.CreateDataFrameViewCommand, backend + session_id: str, view: commands_pb2.CreateDataFrameViewCommand, backend ) -> None: """Register the temporary dataframe.""" read_data_source_relation = view.input.read.data_source @@ -235,7 +235,7 @@ def _ReinitializeExecution(self) -> None: return None def ExecutePlan( - self, request: pb2.ExecutePlanRequest, context: grpc.RpcContext + self, request: pb2.ExecutePlanRequest, context: grpc.RpcContext ) -> Generator[pb2.ExecutePlanResponse, None, None]: """Execute the given plan and return the results.""" self._statistics.execute_requests += 1 @@ -286,8 +286,6 @@ def ExecutePlan( raise ValueError(f"Unknown plan type: {request.plan}") _LOGGER.debug(" as Substrait: %s", substrait) self._statistics.add_plan(substrait) - if len(substrait.relations) != 1: - raise ValueError(f"Expected exactly ONE relation in the plan: {request}") try: results = self._backend.execute(substrait) @@ -297,9 +295,9 @@ def ExecutePlan( _LOGGER.debug(" results are: %s", results) if ( - not self._options.implement_show_string - and request.plan.WhichOneof("op_type") == "root" - and request.plan.root.WhichOneof("rel_type") == "show_string" + not self._options.implement_show_string + and request.plan.WhichOneof("op_type") == "root" + and request.plan.root.WhichOneof("rel_type") == "show_string" ): yield pb2.ExecutePlanResponse( session_id=request.session_id, @@ -402,7 +400,7 @@ def Config(self, request, context): elif key == "spark-substrait-gateway.plan_count": response.pairs.add(key=key, value=str(len(self._statistics.plans))) elif key.startswith("spark-substrait-gateway.plan."): - index = int(key[len("spark-substrait-gateway.plan."):]) + index = int(key[len("spark-substrait-gateway.plan.") :]) if 0 <= index - 1 < len(self._statistics.plans): response.pairs.add(key=key, value=self._statistics.plans[index - 1]) elif key == "spark.sql.session.timeZone": @@ -480,7 +478,7 @@ def Interrupt(self, request, context): return pb2.InterruptResponse() def ReattachExecute( - self, request: pb2.ReattachExecuteRequest, context: grpc.RpcContext + self, request: pb2.ReattachExecuteRequest, context: grpc.RpcContext ) -> Generator[pb2.ExecutePlanResponse, None, None]: """Reattach the execution of the given plan.""" self._statistics.reattach_requests += 1 @@ -497,13 +495,13 @@ def ReleaseExecute(self, request, context): def serve( - port: int, - wait: bool, - tls: list[str] | None = None, - enable_auth: bool = False, - jwt_audience: str | None = None, - secret_key: str | None = None, - log_level: str = "INFO", + port: int, + wait: bool, + tls: list[str] | None = None, + enable_auth: bool = False, + jwt_audience: str | None = None, + secret_key: str | None = None, + log_level: str = "INFO", ) -> grpc.Server: """Start the Spark Substrait Gateway server.""" logging.basicConfig(level=getattr(logging, log_level), encoding="utf-8") @@ -613,13 +611,13 @@ def serve( help="The logging level to use for the server.", ) def click_serve( - port: int, - wait: bool, - tls: list[str], - enable_auth: bool, - jwt_audience: str, - secret_key: str, - log_level: str, + port: int, + wait: bool, + tls: list[str], + enable_auth: bool, + jwt_audience: str, + secret_key: str, + log_level: str, ) -> grpc.Server: """Provide a click interface for starting the Spark Substrait Gateway server.""" return serve(**locals()) From e3a779c9c3685c371086057870656b591e700860 Mon Sep 17 00:00:00 2001 From: David Sisson Date: Wed, 28 Aug 2024 00:24:15 -0700 Subject: [PATCH 13/15] one more indent fix --- src/gateway/server.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/gateway/server.py b/src/gateway/server.py index 92e782c..d0ac4c5 100644 --- a/src/gateway/server.py +++ b/src/gateway/server.py @@ -579,8 +579,8 @@ def serve( required=False, metavar=("CERTFILE", "KEYFILE"), help="Enable transport-level security (TLS/SSL). Provide a " - "Certificate file path, and a Key file path - separated by a space. " - "Example: tls/server.crt tls/server.key", + "Certificate file path, and a Key file path - separated by a space. " + "Example: tls/server.crt tls/server.key", ) @click.option( "--enable-auth/--no-enable-auth", From c44537c7c6b9d193640ddb984d814cbf17d3b694 Mon Sep 17 00:00:00 2001 From: David Sisson Date: Wed, 28 Aug 2024 00:24:55 -0700 Subject: [PATCH 14/15] removed extraneous linebreak --- src/gateway/server.py | 1 - 1 file changed, 1 deletion(-) diff --git a/src/gateway/server.py b/src/gateway/server.py index d0ac4c5..d683b07 100644 --- a/src/gateway/server.py +++ b/src/gateway/server.py @@ -286,7 +286,6 @@ def ExecutePlan( raise ValueError(f"Unknown plan type: {request.plan}") _LOGGER.debug(" as Substrait: %s", substrait) self._statistics.add_plan(substrait) - try: results = self._backend.execute(substrait) except Exception as err: From a0b9017b1e90a2cf72ae294c2af2f6cab40468eb Mon Sep 17 00:00:00 2001 From: David Sisson Date: Wed, 28 Aug 2024 00:33:17 -0700 Subject: [PATCH 15/15] ruff --- src/backends/datafusion_backend.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/backends/datafusion_backend.py b/src/backends/datafusion_backend.py index 133a091..61004a4 100644 --- a/src/backends/datafusion_backend.py +++ b/src/backends/datafusion_backend.py @@ -3,9 +3,9 @@ from collections.abc import Iterator from contextlib import contextmanager -import datafusion.substrait from pathlib import Path +import datafusion.substrait import pyarrow as pa from substrait.gen.proto import plan_pb2