Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: add cube support #86

Merged
merged 8 commits into from
Sep 18, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
92 changes: 68 additions & 24 deletions src/gateway/converter/spark_to_substrait.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import pathlib
import re
from enum import Enum
from itertools import combinations

import pyarrow as pa
import pyspark.sql.connect.proto.base_pb2 as spark_pb2
Expand All @@ -19,7 +20,11 @@
from substrait.gen.proto.extensions import extensions_pb2

from gateway.converter.conversion_options import ConversionOptions
from gateway.converter.spark_functions import ExtensionFunction, FunctionType, lookup_spark_function
from gateway.converter.spark_functions import (
ExtensionFunction,
FunctionType,
lookup_spark_function,
)
from gateway.converter.substrait_builder import (
add_function,
aggregate_relation,
Expand Down Expand Up @@ -1260,29 +1265,8 @@ def convert_aggregate_relation(self, rel: spark_relations_pb2.Aggregate) -> alge
self._next_under_aggregation_reference_id = 0
self._under_aggregation_projects = []

# TODO -- Deal with mixed groupings and measures.
grouping_expression_list = []
for idx, grouping in enumerate(rel.grouping_expressions):
grouping_expression_list.append(self.convert_expression(grouping))
symbol.generated_fields.append(self.determine_name_for_grouping(grouping))
self._top_level_projects.append(field_reference(idx))
aggregate.groupings.append(
algebra_pb2.AggregateRel.Grouping(grouping_expressions=grouping_expression_list)
)

self._expression_processing_mode = ExpressionProcessingMode.AGGR_TOP_LEVEL

for expr in rel.aggregate_expressions:
result = self.convert_expression(expr)
if result:
self._top_level_projects.append(result)
symbol.generated_fields.append(self.determine_expression_name(expr))
symbol.output_fields.clear()
symbol.output_fields.extend(symbol.generated_fields)
if len(rel.grouping_expressions) > 1:
# Hide the grouping source from the downstream relations.
for i in range(len(rel.grouping_expressions) + len(rel.aggregate_expressions)):
aggregate.common.emit.output_mapping.append(i)
# Handle different group by types
self.handle_grouping_and_measures(rel, aggregate, symbol)

self._expression_processing_mode = ExpressionProcessingMode.NORMAL

Expand Down Expand Up @@ -1312,6 +1296,66 @@ def convert_aggregate_relation(self, rel: spark_relations_pb2.Aggregate) -> alge

return algebra_pb2.Rel(aggregate=aggregate)

def handle_grouping_and_measures(self, rel: spark_relations_pb2.Aggregate,
aggregate: algebra_pb2.AggregateRel,
symbol):
"""Handle group by aggregations."""
grouping_expression_list = []
rel_grouping_expressions = rel.grouping_expressions
for idx, grouping in enumerate(rel_grouping_expressions):
grouping_expression_list.append(self.convert_expression(grouping))
symbol.generated_fields.append(self.determine_name_for_grouping(grouping))
self._top_level_projects.append(field_reference(idx))

match rel.group_type:
case spark_relations_pb2.Aggregate.GroupType.GROUP_TYPE_GROUPBY:
aggregate.groupings.append(
algebra_pb2.AggregateRel.Grouping(
grouping_expressions=grouping_expression_list)
)
case spark_relations_pb2.Aggregate.GroupType.GROUP_TYPE_CUBE:
# Generate and add all groupings required for CUBE
cube_groupings = self.create_cube_groupings(rel_grouping_expressions)
aggregate.groupings.extend(cube_groupings)
case _:
raise NotImplementedError(
"Only GROUPBY and CUBE group types are currently supported."
)

self._expression_processing_mode = ExpressionProcessingMode.AGGR_TOP_LEVEL

for expr in rel.aggregate_expressions:
result = self.convert_expression(expr)
if result:
self._top_level_projects.append(result)
symbol.generated_fields.append(self.determine_expression_name(expr))
symbol.output_fields.clear()
symbol.output_fields.extend(symbol.generated_fields)
if len(rel.grouping_expressions) > 1:
# Hide the grouping source from the downstream relations.
for i in range(len(rel.grouping_expressions) + len(rel.aggregate_expressions)):
aggregate.common.emit.output_mapping.append(i)

def create_cube_groupings(self, grouping_expressions):
"""Create all combinations of grouping expressions."""
num_expressions = len(grouping_expressions)
cube_groupings = []

# Generate all possible combinations of grouping expressions
for i in range(num_expressions + 1):
for combination in combinations(range(num_expressions), i):
# Create a list of the current combination of grouping expressions
converted_expressions = []
for j in combination:
converted_expression = self.convert_expression(grouping_expressions[j])
converted_expressions.append(converted_expression)
# Add the grouping for this combination
cube_groupings.append(
algebra_pb2.AggregateRel.Grouping(grouping_expressions=converted_expressions)
)

return cube_groupings

# pylint: disable=too-many-locals,pointless-string-statement
def convert_show_string_relation(self, rel: spark_relations_pb2.ShowString) -> algebra_pb2.Rel:
"""Convert a show string relation into a Substrait subplan."""
Expand Down
96 changes: 88 additions & 8 deletions src/gateway/tests/test_dataframe_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@
ucase,
upper,
)
from pyspark.sql.types import DoubleType, StructField, StructType
from pyspark.sql.types import DoubleType, IntegerType, StringType, StructField, StructType
from pyspark.sql.window import Window
from pyspark.testing import assertDataFrameEqual

Expand Down Expand Up @@ -154,6 +154,13 @@ def mark_dataframe_tests_as_xfail(request):
pytest.skip(reason="inf vs -inf difference")
if source == "gateway-over-duckdb" and originalname in ["test_union", "test_unionall"]:
pytest.skip(reason="distinct not handled properly")
if source == "gateway-over-datafusion" and originalname == "test_rollup":
pytest.skip(reason="rollup aggregation not yet implemented in gateway")
if source == "gateway-over-duckdb" and originalname == "test_rollup":
pytest.skip(reason="rollup aggregation not yet implemented in gateway")
if source == "gateway-over-duckdb" and originalname == "test_cube":
pytest.skip(reason="cube aggregation not yet implemented in DuckDB")


# ruff: noqa: E712
class TestDataFrameAPI:
Expand Down Expand Up @@ -689,14 +696,14 @@ def test_exceptall(self, register_tpch_dataset, spark_session, caplog):
n_name="RUSSIA",
n_regionkey=3,
n_comment="uctions. furiously unusual instructions sleep furiously ironic "
"packages. slyly ",
"packages. slyly ",
),
Row(
n_nationkey=22,
n_name="RUSSIA",
n_regionkey=3,
n_comment="uctions. furiously unusual instructions sleep furiously ironic "
"packages. slyly ",
"packages. slyly ",
),
Row(
n_nationkey=23,
Expand All @@ -709,14 +716,14 @@ def test_exceptall(self, register_tpch_dataset, spark_session, caplog):
n_name="UNITED STATES",
n_regionkey=1,
n_comment="ly ironic requests along the slyly bold ideas hang after the "
"blithely special notornis; blithely even accounts",
"blithely special notornis; blithely even accounts",
),
Row(
n_nationkey=24,
n_name="UNITED STATES",
n_regionkey=1,
n_comment="ly ironic requests along the slyly bold ideas hang after the "
"blithely special notornis; blithely even accounts",
"blithely special notornis; blithely even accounts",
),
]

Expand Down Expand Up @@ -802,14 +809,14 @@ def test_subtract(self, register_tpch_dataset, spark_session):
n_name="RUSSIA",
n_regionkey=3,
n_comment="uctions. furiously unusual instructions sleep furiously "
"ironic packages. slyly ",
"ironic packages. slyly ",
),
Row(
n_nationkey=24,
n_name="UNITED STATES",
n_regionkey=1,
n_comment="ly ironic requests along the slyly bold ideas hang after "
"the blithely special notornis; blithely even accounts",
"the blithely special notornis; blithely even accounts",
),
]

Expand Down Expand Up @@ -2657,7 +2664,7 @@ def test_computation_with_two_aggregations(self, register_tpch_dataset, spark_se
assertDataFrameEqual(outcome, expected)

def test_computation_with_two_aggregations_and_internal_calculation(
self, register_tpch_dataset, spark_session
self, register_tpch_dataset, spark_session
):
expected = [
Row(l_suppkey=1, a=Decimal("3903113211864.3000")),
Expand Down Expand Up @@ -2762,3 +2769,76 @@ def test_row_number(self, users_dataframe):
"row_number").limit(3)

assertDataFrameEqual(outcome, expected)


@pytest.fixture(scope="class")
def userage_dataframe(spark_session_for_setup):
data = [
[1, "Alice"],
[2, "Bob"],
[3, "Alice"]
]

schema = StructType(
[
StructField("age", IntegerType(), True),
StructField("name", StringType(), True),
]
)

test_df = spark_session_for_setup.createDataFrame(data, schema)

test_df.createOrReplaceTempView("userage")
return spark_session_for_setup.table("userage")


class TestDataFrameDecisionSupport:
"""Tests data science methods of the dataframe side of SparkConnect."""

def test_groupby(self, userage_dataframe):
expected = [
Row(name='Alice', age=1, count=1),
Row(name='Alice', age=3, count=1),
Row(name='Bob', age=2, count=1)
]

with utilizes_valid_plans(userage_dataframe):
outcome = userage_dataframe.groupby("name", "age").count().orderBy("name",
"age").collect()

assertDataFrameEqual(outcome, expected)

def test_rollup(self, userage_dataframe):
expected = [
Row(name='Alice', age=1, count=1),
Row(name='Alice', age=3, count=1),
Row(name='Alice', age=None, count=2),
Row(name='Bob', age=2, count=1),
Row(name='Bob', age=None, count=1),
Row(name=None, age=None, count=3)
]

with utilizes_valid_plans(userage_dataframe):
outcome = userage_dataframe.rollup("name", "age").count().orderBy("name",
"age").collect()

assertDataFrameEqual(outcome, expected)

def test_cube(self, userage_dataframe):
expected = [
Row(name='Alice', age=1, count=1),
Row(name='Alice', age=3, count=1),
Row(name='Alice', age=None, count=2),
Row(name='Bob', age=2, count=1),
Row(name='Bob', age=None, count=1),
Row(name=None, age=1, count=1),
Row(name=None, age=2, count=1),
Row(name=None, age=3, count=1),
Row(name=None, age=None, count=3)
]

with utilizes_valid_plans(userage_dataframe):
outcome = userage_dataframe.cube("name", "age").count().orderBy("name",
"age").collect()

assertDataFrameEqual(outcome, expected)
Loading