Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: implement row_number window function #82

Merged
merged 15 commits into from
Aug 28, 2024
2 changes: 1 addition & 1 deletion environment.yml
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ dependencies:
- pip:
- adbc_driver_manager
- cargo
- pyarrow >= 16.1.0
- pyarrow >= 17.0.0
- duckdb >= 1.0.0
- datafusion == 40.1.*
- grpcio-channelz
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ authors = [{name = "Substrait contributors", email = "[email protected]
license = {text = "Apache-2.0"}
readme = "README.md"
requires-python = ">=3.8.1"
dependencies = ["protobuf >= 3.20", "datafusion >= 40.1.0", "pyarrow >= 15.0.2", "substrait == 0.21.0"]
dependencies = ["protobuf >= 3.20", "datafusion >= 40.1.0", "pyarrow >= 17.0.0", "substrait == 0.21.0"]
dynamic = ["version"]

[tool.setuptools_scm]
Expand Down
3 changes: 1 addition & 2 deletions src/backends/datafusion_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from contextlib import contextmanager
from pathlib import Path

import datafusion.substrait
import pyarrow as pa
from substrait.gen.proto import plan_pb2

Expand Down Expand Up @@ -50,8 +51,6 @@ def adjust_plan(self, plan: plan_pb2.Plan) -> Iterator[plan_pb2.Plan]:

def _execute_plan(self, plan: plan_pb2.Plan) -> pa.lib.Table:
"""Execute the given Substrait plan against Datafusion."""
import datafusion.substrait

plan_data = plan.SerializeToString()
substrait_plan = datafusion.substrait.Serde.deserialize_bytes(plan_data)
logical_plan = datafusion.substrait.Consumer.from_substrait_plan(
Expand Down
7 changes: 7 additions & 0 deletions src/gateway/converter/spark_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -808,6 +808,13 @@ def __lt__(self, obj) -> bool:
i64=type_pb2.Type.I64(nullability=type_pb2.Type.Nullability.NULLABILITY_REQUIRED)
),
),
"row_number": ExtensionFunction(
"/functions_arithmetic.yaml",
"row_number:any",
type_pb2.Type(
i64=type_pb2.Type.I64(nullability=type_pb2.Type.Nullability.NULLABILITY_REQUIRED)
),
function_type=FunctionType.WINDOW),
}


Expand Down
117 changes: 116 additions & 1 deletion src/gateway/converter/spark_to_substrait.py
Original file line number Diff line number Diff line change
Expand Up @@ -618,6 +618,121 @@ def convert_cast_expression(
raise NotImplementedError(f'unknown cast_to_type {cast.WhichOneof("cast_to_type")}')
return algebra_pb2.Expression(cast=cast_rel)

def convert_frame_boundary(
self,
boundary: spark_exprs_pb2.Expression.Window.WindowFrame.FrameBoundary) \
-> algebra_pb2.Expression.WindowFunction.Bound:
"""Convert a Spark frame boundary into a Substrait window Bound."""
bound = algebra_pb2.Expression.WindowFunction.Bound()
match boundary.WhichOneof("boundary"):
case "current_row":
bound.current_row.CopyFrom(algebra_pb2.Expression.WindowFunction.Bound.CurrentRow())
case "unbounded":
bound.unbounded.CopyFrom(algebra_pb2.Expression.WindowFunction.Bound.Unbounded())
case "value":
if boundary.value.WhichOneof("expr_type") != "literal":
raise ValueError("Boundary value expression must be a literal.")
offset_expr = self.convert_literal_expression(boundary.value.literal)
match offset_expr.literal.WhichOneof("literal_type"):
case "i8":
offset_value = offset_expr.i8.value
case "i16":
offset_value = offset_expr.i16.value
case "i32":
offset_value = offset_expr.i32.value
case "i64":
offset_value = offset_expr.i64.value
case _:
raise ValueError(
"Unexpected literal type: "
f"{offset_expr.literal.WhichOneof('literal_type')}")
if offset_value > 0:
boundary.following.offset = offset_value
elif offset_value < 0:
boundary.preceding.offset = -offset_value
else:
boundary.current_row.CopyFrom(
algebra_pb2.Expression.WindowFunction.Bound.CurrentRow())
case _:
raise ValueError(f"Unknown boundary type: {boundary.WhichOneof('boundary')}")
return bound

def convert_order_spec(
self,
order: spark_exprs_pb2.Expression.SortOrder) -> algebra_pb2.SortField:
"""Convert a Spark order specification into a Substrait sort field."""
if order.direction == spark_exprs_pb2.Expression.SortOrder.SORT_DIRECTION_ASCENDING:
if (order.null_ordering ==
spark_exprs_pb2.Expression.SortOrder.NullOrdering.SORT_NULLS_FIRST):
sort_dir = algebra_pb2.SortField.SortDirection.SORT_DIRECTION_ASC_NULLS_FIRST
else:
sort_dir = algebra_pb2.SortField.SortDirection.SORT_DIRECTION_ASC_NULLS_LAST
else:
if (order.null_ordering ==
spark_exprs_pb2.Expression.SortOrder.NullOrdering.SORT_NULLS_FIRST):
sort_dir = algebra_pb2.SortField.SortDirection.SORT_DIRECTION_DESC_NULLS_FIRST
else:
sort_dir = algebra_pb2.SortField.SortDirection.SORT_DIRECTION_DESC_NULLS_LAST
return algebra_pb2.SortField(
expr=self.convert_expression(order.child), direction=sort_dir)

def convert_window_expression(
self, window: spark_exprs_pb2.Expression.Window) -> algebra_pb2.Expression:
"""Convert a Spark window expression into a Substrait window expression."""
func = algebra_pb2.Expression.WindowFunction()
if window.window_function.WhichOneof("expr_type") != "unresolved_function":
raise NotImplementedError(
"Window functions which are not unresolved functions are not yet supported.")
function_def = self.lookup_function_by_name(
window.window_function.unresolved_function.function_name)
func.function_reference = function_def.anchor
for idx, arg in enumerate(window.window_function.unresolved_function.arguments):
if function_def.max_args is not None and idx >= function_def.max_args:
break
if (window.window_function.unresolved_function.function_name == "count" and
arg.WhichOneof("expr_type") == "unresolved_star"):
# Ignore all the rest of the arguments.
func.arguments.append(
algebra_pb2.FunctionArgument(value=bigint_literal(1)))
break
func.arguments.append(
algebra_pb2.FunctionArgument(value=self.convert_expression(arg)))
if function_def.options:
func.options.extend(function_def.options)
func.output_type.CopyFrom(function_def.output_type)
for sort in window.order_spec:
func.sorts.append(self.convert_order_spec(sort))
if (function_def.function_type == FunctionType.WINDOW or
function_def.function_type == FunctionType.AGGREGATE):
func.invocation = algebra_pb2.AGGREGATION_PHASE_INITIAL_TO_RESULT
else:
raise ValueError(
f"Unexpected function type: {function_def.function_type}")
if window.partition_spec:
for partition in window.partition_spec:
func.partitions.append(self.convert_expression(partition))
if window.HasField("frame_spec"):
match window.frame_spec.frame_type:
case spark_exprs_pb2.Expression.Window.WindowFrame.FRAME_TYPE_RANGE:
func.bounds_type = (
algebra_pb2.Expression.WindowFunction.BoundsType.BOUNDS_TYPE_RANGE)
case spark_exprs_pb2.Expression.Window.WindowFrame.FRAME_TYPE_ROW:
func.bounds_type = (
algebra_pb2.Expression.WindowFunction.BoundsType.BOUNDS_TYPE_ROW)
case _:
raise ValueError(
f"Unknown frame type: {window.frame_spec.frame_type}")
func.lower_bound.CopyFrom(self.convert_frame_boundary(window.frame_spec.lower))
func.upper_bound.CopyFrom(self.convert_frame_boundary(window.frame_spec.upper))
else:
func.bounds_type = (
algebra_pb2.Expression.WindowFunction.BoundsType.BOUNDS_TYPE_UNSPECIFIED)
func.lower_bound.unbounded.CopyFrom(
algebra_pb2.Expression.WindowFunction.Bound.Unbounded())
func.upper_bound.unbounded.CopyFrom(
algebra_pb2.Expression.WindowFunction.Bound.Unbounded())
return algebra_pb2.Expression(window_function=func)

def convert_extract_value(
self, extract: spark_exprs_pb2.Expression.UnresolvedExtractValue
) -> algebra_pb2.Expression:
Expand Down Expand Up @@ -684,7 +799,7 @@ def convert_expression(self, expr: spark_exprs_pb2.Expression) -> algebra_pb2.Ex
case "lambda_function":
raise NotImplementedError("lambda_function expression type not supported")
case "window":
raise NotImplementedError("window expression type not supported")
result = self.convert_window_expression(expr.window)
case "unresolved_extract_value":
result = self.convert_extract_value(expr.unresolved_extract_value)
case "update_fields":
Expand Down
49 changes: 33 additions & 16 deletions src/gateway/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,9 @@ def convert_pyarrow_datatype_to_spark(arrow_type: pa.DataType) -> types_pb2.Data
data_type = types_pb2.DataType(integer=types_pb2.DataType.Integer())
elif arrow_type == pa.int64():
data_type = types_pb2.DataType(long=types_pb2.DataType.Long())
elif arrow_type == pa.uint64():
# TODO: Spark doesn't have unsigned types so come up with a way to handle overflow.
data_type = types_pb2.DataType(long=types_pb2.DataType.Long())
elif arrow_type == pa.float32():
data_type = types_pb2.DataType(float=types_pb2.DataType.Float())
elif arrow_type == pa.float64():
Expand Down Expand Up @@ -336,22 +339,36 @@ def AnalyzePlan(self, request, context):
self._statistics.add_request(request)
_LOGGER.info("AnalyzePlan: %s", request)
self._InitializeExecution()
if request.schema:
substrait = self._converter.convert_plan(request.schema.plan)
self._statistics.add_plan(substrait)
try:
results = self._backend.execute(substrait)
except Exception as err:
self._ReinitializeExecution()
raise err
_LOGGER.debug(" results are: %s", results)
return pb2.AnalyzePlanResponse(
session_id=request.session_id,
schema=pb2.AnalyzePlanResponse.Schema(
schema=convert_pyarrow_schema_to_spark(results.schema)
),
)
raise NotImplementedError("AnalyzePlan not yet implemented for non-Schema requests.")
match request.WhichOneof("analyze"):
case 'schema':
substrait = self._converter.convert_plan(request.schema.plan)
self._statistics.add_plan(substrait)
if len(substrait.relations) != 1:
raise ValueError(f"Expected exactly _ONE_ relation in the plan: {request}")
try:
results = self._backend.execute(substrait)
except Exception as err:
self._ReinitializeExecution()
raise err
_LOGGER.debug(" results are: %s", results)
return pb2.AnalyzePlanResponse(
session_id=request.session_id,
schema=pb2.AnalyzePlanResponse.Schema(
schema=convert_pyarrow_schema_to_spark(results.schema)
),
)
case 'is_streaming':
# TODO -- Actually look at the plan (this path is used by pyspark.testing.utils).
return pb2.AnalyzePlanResponse(
session_id=request.session_id,
is_streaming=pb2.AnalyzePlanResponse.IsStreaming(
is_streaming=False
),
)
case _:
raise NotImplementedError(
"AnalyzePlan not yet implemented for non-Schema requests: "
f"{request.WhichOneof('analyze')}")

def Config(self, request, context):
"""Get or set the configuration of the server."""
Expand Down
26 changes: 26 additions & 0 deletions src/gateway/tests/test_dataframe_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@
replace,
right,
rlike,
row_number,
rpad,
rtrim,
sqrt,
Expand All @@ -64,6 +65,7 @@
upper,
)
from pyspark.sql.types import DoubleType, StructField, StructType
from pyspark.sql.window import Window
from pyspark.testing import assertDataFrameEqual

from gateway.tests.conftest import find_tpch
Expand Down Expand Up @@ -161,6 +163,9 @@ def mark_dataframe_tests_as_xfail(request):
if source == "gateway-over-datafusion" and originalname == "test_try_divide":
pytest.skip(reason="returns infinity instead of null")

if source == "gateway-over-duckdb" and originalname == "test_row_number":
pytest.skip(reason="window functions not yet implemented in DuckDB")


# ruff: noqa: E712
class TestDataFrameAPI:
Expand Down Expand Up @@ -2748,3 +2753,24 @@ def test_multiple_measures_and_calculations(self, register_tpch_dataset, spark_s
)

assertDataFrameEqual(outcome, expected)


class TestDataFrameWindowFunctions:
"""Tests window functions of the dataframe side of SparkConnect."""

def test_row_number(self, users_dataframe):
expected = [
Row(user_id="user705452451", name="Adrian Reyes", paid_for_service=False, row_number=1),
Row(user_id="user406489700", name="Alan Aguirre DVM", paid_for_service=False,
row_number=2),
Row(user_id="user965620978", name="Alan Whitaker", paid_for_service=False,
row_number=3),
]

with utilizes_valid_plans(users_dataframe):
window_spec = Window.partitionBy().orderBy("name")
outcome = users_dataframe.withColumn("row_number",
row_number().over(window_spec)).orderBy(
"row_number").limit(3)

assertDataFrameEqual(outcome, expected)
9 changes: 7 additions & 2 deletions src/gateway/tests/test_tpcds_sql.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,8 @@
def mark_tests_as_xfail(request):
"""Marks a subset of tests as expected to be fail."""
source = request.getfixturevalue("source")
path = request.getfixturevalue("path")
if source == "gateway-over-duckdb":
path = request.getfixturevalue("path")
if path.stem in ["01", "06", "10", "16", "30", "32", "35", "69", "81", "86", "92", "94"]:
pytest.skip(reason="DuckDB needs Delim join")
elif path.stem in [
Expand Down Expand Up @@ -119,7 +119,12 @@ def mark_tests_as_xfail(request):
elif path.stem in ["95"]:
pytest.skip(reason="Unsupported join comparison: !=")
if source == "gateway-over-datafusion":
pytest.skip(reason="not yet ready to run SQL tests regularly")
if path.stem in ["02"]:
pytest.skip(reason="Null type without kind is not supported")
elif path.stem in ["09"]:
pytest.skip(reason="Aggregate function any_value is not supported: function anchor = 6")
else:
pytest.skip(reason="not yet ready to run SQL tests regularly")
if source == "gateway-over-arrow":
pytest.skip(reason="not yet ready to run SQL tests regularly")

Expand Down
7 changes: 2 additions & 5 deletions src/gateway/tests/test_tpch_with_dataframe_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,11 +18,8 @@ def mark_tests_as_xfail(request):
"""Marks a subset of tests as expected to be fail."""
source = request.getfixturevalue("source")
originalname = request.keywords.node.originalname
if source == "gateway-over-duckdb":
if originalname == "test_query_15":
request.node.add_marker(pytest.mark.xfail(reason="No results (float vs decimal)"))
if originalname == "test_query_16":
request.node.add_marker(pytest.mark.xfail(reason="distinct not supported"))
if source == "gateway-over-duckdb" and originalname == "test_query_16":
request.node.add_marker(pytest.mark.xfail(reason="distinct not supported"))


class TestTpchWithDataFrameAPI:
Expand Down
Loading