diff --git a/environment.yml b/environment.yml index 98de933..79eeac9 100644 --- a/environment.yml +++ b/environment.yml @@ -17,7 +17,7 @@ dependencies: - pip: - adbc_driver_manager - cargo - - pyarrow >= 16.1.0 + - pyarrow >= 17.0.0 - duckdb >= 1.0.0 - datafusion == 40.1.* - grpcio-channelz diff --git a/pyproject.toml b/pyproject.toml index 7affda7..b24d1ba 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -5,7 +5,7 @@ authors = [{name = "Substrait contributors", email = "substrait@googlegroups.com license = {text = "Apache-2.0"} readme = "README.md" requires-python = ">=3.8.1" -dependencies = ["protobuf >= 3.20", "datafusion >= 40.1.0", "pyarrow >= 15.0.2", "substrait == 0.21.0"] +dependencies = ["protobuf >= 3.20", "datafusion >= 40.1.0", "pyarrow >= 17.0.0", "substrait == 0.21.0"] dynamic = ["version"] [tool.setuptools_scm] diff --git a/src/backends/datafusion_backend.py b/src/backends/datafusion_backend.py index 5aece59..61004a4 100644 --- a/src/backends/datafusion_backend.py +++ b/src/backends/datafusion_backend.py @@ -5,6 +5,7 @@ from contextlib import contextmanager from pathlib import Path +import datafusion.substrait import pyarrow as pa from substrait.gen.proto import plan_pb2 @@ -50,8 +51,6 @@ def adjust_plan(self, plan: plan_pb2.Plan) -> Iterator[plan_pb2.Plan]: def _execute_plan(self, plan: plan_pb2.Plan) -> pa.lib.Table: """Execute the given Substrait plan against Datafusion.""" - import datafusion.substrait - plan_data = plan.SerializeToString() substrait_plan = datafusion.substrait.Serde.deserialize_bytes(plan_data) logical_plan = datafusion.substrait.Consumer.from_substrait_plan( diff --git a/src/gateway/converter/spark_functions.py b/src/gateway/converter/spark_functions.py index fa7b60d..4999e12 100644 --- a/src/gateway/converter/spark_functions.py +++ b/src/gateway/converter/spark_functions.py @@ -808,6 +808,13 @@ def __lt__(self, obj) -> bool: i64=type_pb2.Type.I64(nullability=type_pb2.Type.Nullability.NULLABILITY_REQUIRED) ), ), + "row_number": ExtensionFunction( + "/functions_arithmetic.yaml", + "row_number:any", + type_pb2.Type( + i64=type_pb2.Type.I64(nullability=type_pb2.Type.Nullability.NULLABILITY_REQUIRED) + ), + function_type=FunctionType.WINDOW), } diff --git a/src/gateway/converter/spark_to_substrait.py b/src/gateway/converter/spark_to_substrait.py index adb6358..1387dd8 100644 --- a/src/gateway/converter/spark_to_substrait.py +++ b/src/gateway/converter/spark_to_substrait.py @@ -618,6 +618,121 @@ def convert_cast_expression( raise NotImplementedError(f'unknown cast_to_type {cast.WhichOneof("cast_to_type")}') return algebra_pb2.Expression(cast=cast_rel) + def convert_frame_boundary( + self, + boundary: spark_exprs_pb2.Expression.Window.WindowFrame.FrameBoundary) \ + -> algebra_pb2.Expression.WindowFunction.Bound: + """Convert a Spark frame boundary into a Substrait window Bound.""" + bound = algebra_pb2.Expression.WindowFunction.Bound() + match boundary.WhichOneof("boundary"): + case "current_row": + bound.current_row.CopyFrom(algebra_pb2.Expression.WindowFunction.Bound.CurrentRow()) + case "unbounded": + bound.unbounded.CopyFrom(algebra_pb2.Expression.WindowFunction.Bound.Unbounded()) + case "value": + if boundary.value.WhichOneof("expr_type") != "literal": + raise ValueError("Boundary value expression must be a literal.") + offset_expr = self.convert_literal_expression(boundary.value.literal) + match offset_expr.literal.WhichOneof("literal_type"): + case "i8": + offset_value = offset_expr.i8.value + case "i16": + offset_value = offset_expr.i16.value + case "i32": + offset_value = offset_expr.i32.value + case "i64": + offset_value = offset_expr.i64.value + case _: + raise ValueError( + "Unexpected literal type: " + f"{offset_expr.literal.WhichOneof('literal_type')}") + if offset_value > 0: + boundary.following.offset = offset_value + elif offset_value < 0: + boundary.preceding.offset = -offset_value + else: + boundary.current_row.CopyFrom( + algebra_pb2.Expression.WindowFunction.Bound.CurrentRow()) + case _: + raise ValueError(f"Unknown boundary type: {boundary.WhichOneof('boundary')}") + return bound + + def convert_order_spec( + self, + order: spark_exprs_pb2.Expression.SortOrder) -> algebra_pb2.SortField: + """Convert a Spark order specification into a Substrait sort field.""" + if order.direction == spark_exprs_pb2.Expression.SortOrder.SORT_DIRECTION_ASCENDING: + if (order.null_ordering == + spark_exprs_pb2.Expression.SortOrder.NullOrdering.SORT_NULLS_FIRST): + sort_dir = algebra_pb2.SortField.SortDirection.SORT_DIRECTION_ASC_NULLS_FIRST + else: + sort_dir = algebra_pb2.SortField.SortDirection.SORT_DIRECTION_ASC_NULLS_LAST + else: + if (order.null_ordering == + spark_exprs_pb2.Expression.SortOrder.NullOrdering.SORT_NULLS_FIRST): + sort_dir = algebra_pb2.SortField.SortDirection.SORT_DIRECTION_DESC_NULLS_FIRST + else: + sort_dir = algebra_pb2.SortField.SortDirection.SORT_DIRECTION_DESC_NULLS_LAST + return algebra_pb2.SortField( + expr=self.convert_expression(order.child), direction=sort_dir) + + def convert_window_expression( + self, window: spark_exprs_pb2.Expression.Window) -> algebra_pb2.Expression: + """Convert a Spark window expression into a Substrait window expression.""" + func = algebra_pb2.Expression.WindowFunction() + if window.window_function.WhichOneof("expr_type") != "unresolved_function": + raise NotImplementedError( + "Window functions which are not unresolved functions are not yet supported.") + function_def = self.lookup_function_by_name( + window.window_function.unresolved_function.function_name) + func.function_reference = function_def.anchor + for idx, arg in enumerate(window.window_function.unresolved_function.arguments): + if function_def.max_args is not None and idx >= function_def.max_args: + break + if (window.window_function.unresolved_function.function_name == "count" and + arg.WhichOneof("expr_type") == "unresolved_star"): + # Ignore all the rest of the arguments. + func.arguments.append( + algebra_pb2.FunctionArgument(value=bigint_literal(1))) + break + func.arguments.append( + algebra_pb2.FunctionArgument(value=self.convert_expression(arg))) + if function_def.options: + func.options.extend(function_def.options) + func.output_type.CopyFrom(function_def.output_type) + for sort in window.order_spec: + func.sorts.append(self.convert_order_spec(sort)) + if (function_def.function_type == FunctionType.WINDOW or + function_def.function_type == FunctionType.AGGREGATE): + func.invocation = algebra_pb2.AGGREGATION_PHASE_INITIAL_TO_RESULT + else: + raise ValueError( + f"Unexpected function type: {function_def.function_type}") + if window.partition_spec: + for partition in window.partition_spec: + func.partitions.append(self.convert_expression(partition)) + if window.HasField("frame_spec"): + match window.frame_spec.frame_type: + case spark_exprs_pb2.Expression.Window.WindowFrame.FRAME_TYPE_RANGE: + func.bounds_type = ( + algebra_pb2.Expression.WindowFunction.BoundsType.BOUNDS_TYPE_RANGE) + case spark_exprs_pb2.Expression.Window.WindowFrame.FRAME_TYPE_ROW: + func.bounds_type = ( + algebra_pb2.Expression.WindowFunction.BoundsType.BOUNDS_TYPE_ROW) + case _: + raise ValueError( + f"Unknown frame type: {window.frame_spec.frame_type}") + func.lower_bound.CopyFrom(self.convert_frame_boundary(window.frame_spec.lower)) + func.upper_bound.CopyFrom(self.convert_frame_boundary(window.frame_spec.upper)) + else: + func.bounds_type = ( + algebra_pb2.Expression.WindowFunction.BoundsType.BOUNDS_TYPE_UNSPECIFIED) + func.lower_bound.unbounded.CopyFrom( + algebra_pb2.Expression.WindowFunction.Bound.Unbounded()) + func.upper_bound.unbounded.CopyFrom( + algebra_pb2.Expression.WindowFunction.Bound.Unbounded()) + return algebra_pb2.Expression(window_function=func) + def convert_extract_value( self, extract: spark_exprs_pb2.Expression.UnresolvedExtractValue ) -> algebra_pb2.Expression: @@ -684,7 +799,7 @@ def convert_expression(self, expr: spark_exprs_pb2.Expression) -> algebra_pb2.Ex case "lambda_function": raise NotImplementedError("lambda_function expression type not supported") case "window": - raise NotImplementedError("window expression type not supported") + result = self.convert_window_expression(expr.window) case "unresolved_extract_value": result = self.convert_extract_value(expr.unresolved_extract_value) case "update_fields": diff --git a/src/gateway/server.py b/src/gateway/server.py index 6be7dc8..d683b07 100644 --- a/src/gateway/server.py +++ b/src/gateway/server.py @@ -66,6 +66,9 @@ def convert_pyarrow_datatype_to_spark(arrow_type: pa.DataType) -> types_pb2.Data data_type = types_pb2.DataType(integer=types_pb2.DataType.Integer()) elif arrow_type == pa.int64(): data_type = types_pb2.DataType(long=types_pb2.DataType.Long()) + elif arrow_type == pa.uint64(): + # TODO: Spark doesn't have unsigned types so come up with a way to handle overflow. + data_type = types_pb2.DataType(long=types_pb2.DataType.Long()) elif arrow_type == pa.float32(): data_type = types_pb2.DataType(float=types_pb2.DataType.Float()) elif arrow_type == pa.float64(): @@ -336,22 +339,36 @@ def AnalyzePlan(self, request, context): self._statistics.add_request(request) _LOGGER.info("AnalyzePlan: %s", request) self._InitializeExecution() - if request.schema: - substrait = self._converter.convert_plan(request.schema.plan) - self._statistics.add_plan(substrait) - try: - results = self._backend.execute(substrait) - except Exception as err: - self._ReinitializeExecution() - raise err - _LOGGER.debug(" results are: %s", results) - return pb2.AnalyzePlanResponse( - session_id=request.session_id, - schema=pb2.AnalyzePlanResponse.Schema( - schema=convert_pyarrow_schema_to_spark(results.schema) - ), - ) - raise NotImplementedError("AnalyzePlan not yet implemented for non-Schema requests.") + match request.WhichOneof("analyze"): + case 'schema': + substrait = self._converter.convert_plan(request.schema.plan) + self._statistics.add_plan(substrait) + if len(substrait.relations) != 1: + raise ValueError(f"Expected exactly _ONE_ relation in the plan: {request}") + try: + results = self._backend.execute(substrait) + except Exception as err: + self._ReinitializeExecution() + raise err + _LOGGER.debug(" results are: %s", results) + return pb2.AnalyzePlanResponse( + session_id=request.session_id, + schema=pb2.AnalyzePlanResponse.Schema( + schema=convert_pyarrow_schema_to_spark(results.schema) + ), + ) + case 'is_streaming': + # TODO -- Actually look at the plan (this path is used by pyspark.testing.utils). + return pb2.AnalyzePlanResponse( + session_id=request.session_id, + is_streaming=pb2.AnalyzePlanResponse.IsStreaming( + is_streaming=False + ), + ) + case _: + raise NotImplementedError( + "AnalyzePlan not yet implemented for non-Schema requests: " + f"{request.WhichOneof('analyze')}") def Config(self, request, context): """Get or set the configuration of the server.""" diff --git a/src/gateway/tests/test_dataframe_api.py b/src/gateway/tests/test_dataframe_api.py index 3df6a05..0387d6f 100644 --- a/src/gateway/tests/test_dataframe_api.py +++ b/src/gateway/tests/test_dataframe_api.py @@ -52,6 +52,7 @@ replace, right, rlike, + row_number, rpad, rtrim, sqrt, @@ -64,6 +65,7 @@ upper, ) from pyspark.sql.types import DoubleType, StructField, StructType +from pyspark.sql.window import Window from pyspark.testing import assertDataFrameEqual from gateway.tests.conftest import find_tpch @@ -161,6 +163,9 @@ def mark_dataframe_tests_as_xfail(request): if source == "gateway-over-datafusion" and originalname == "test_try_divide": pytest.skip(reason="returns infinity instead of null") + if source == "gateway-over-duckdb" and originalname == "test_row_number": + pytest.skip(reason="window functions not yet implemented in DuckDB") + # ruff: noqa: E712 class TestDataFrameAPI: @@ -2748,3 +2753,24 @@ def test_multiple_measures_and_calculations(self, register_tpch_dataset, spark_s ) assertDataFrameEqual(outcome, expected) + + +class TestDataFrameWindowFunctions: + """Tests window functions of the dataframe side of SparkConnect.""" + + def test_row_number(self, users_dataframe): + expected = [ + Row(user_id="user705452451", name="Adrian Reyes", paid_for_service=False, row_number=1), + Row(user_id="user406489700", name="Alan Aguirre DVM", paid_for_service=False, + row_number=2), + Row(user_id="user965620978", name="Alan Whitaker", paid_for_service=False, + row_number=3), + ] + + with utilizes_valid_plans(users_dataframe): + window_spec = Window.partitionBy().orderBy("name") + outcome = users_dataframe.withColumn("row_number", + row_number().over(window_spec)).orderBy( + "row_number").limit(3) + + assertDataFrameEqual(outcome, expected) diff --git a/src/gateway/tests/test_tpcds_sql.py b/src/gateway/tests/test_tpcds_sql.py index 51ae3fa..1e2976c 100644 --- a/src/gateway/tests/test_tpcds_sql.py +++ b/src/gateway/tests/test_tpcds_sql.py @@ -18,8 +18,8 @@ def mark_tests_as_xfail(request): """Marks a subset of tests as expected to be fail.""" source = request.getfixturevalue("source") + path = request.getfixturevalue("path") if source == "gateway-over-duckdb": - path = request.getfixturevalue("path") if path.stem in ["01", "06", "10", "16", "30", "32", "35", "69", "81", "86", "92", "94"]: pytest.skip(reason="DuckDB needs Delim join") elif path.stem in [ @@ -119,7 +119,12 @@ def mark_tests_as_xfail(request): elif path.stem in ["95"]: pytest.skip(reason="Unsupported join comparison: !=") if source == "gateway-over-datafusion": - pytest.skip(reason="not yet ready to run SQL tests regularly") + if path.stem in ["02"]: + pytest.skip(reason="Null type without kind is not supported") + elif path.stem in ["09"]: + pytest.skip(reason="Aggregate function any_value is not supported: function anchor = 6") + else: + pytest.skip(reason="not yet ready to run SQL tests regularly") if source == "gateway-over-arrow": pytest.skip(reason="not yet ready to run SQL tests regularly") diff --git a/src/gateway/tests/test_tpch_with_dataframe_api.py b/src/gateway/tests/test_tpch_with_dataframe_api.py index dcdc96a..6357467 100644 --- a/src/gateway/tests/test_tpch_with_dataframe_api.py +++ b/src/gateway/tests/test_tpch_with_dataframe_api.py @@ -18,11 +18,8 @@ def mark_tests_as_xfail(request): """Marks a subset of tests as expected to be fail.""" source = request.getfixturevalue("source") originalname = request.keywords.node.originalname - if source == "gateway-over-duckdb": - if originalname == "test_query_15": - request.node.add_marker(pytest.mark.xfail(reason="No results (float vs decimal)")) - if originalname == "test_query_16": - request.node.add_marker(pytest.mark.xfail(reason="distinct not supported")) + if source == "gateway-over-duckdb" and originalname == "test_query_16": + request.node.add_marker(pytest.mark.xfail(reason="distinct not supported")) class TestTpchWithDataFrameAPI: