Skip to content

Commit

Permalink
feat: add cube support (#86)
Browse files Browse the repository at this point in the history
Co-authored-by: David Sisson <[email protected]>
  • Loading branch information
richtia and EpsilonPrime authored Sep 18, 2024
1 parent 5dce4b8 commit 5213de9
Show file tree
Hide file tree
Showing 2 changed files with 156 additions and 32 deletions.
92 changes: 68 additions & 24 deletions src/gateway/converter/spark_to_substrait.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import pathlib
import re
from enum import Enum
from itertools import combinations

import pyarrow as pa
import pyspark.sql.connect.proto.base_pb2 as spark_pb2
Expand All @@ -19,7 +20,11 @@
from substrait.gen.proto.extensions import extensions_pb2

from gateway.converter.conversion_options import ConversionOptions
from gateway.converter.spark_functions import ExtensionFunction, FunctionType, lookup_spark_function
from gateway.converter.spark_functions import (
ExtensionFunction,
FunctionType,
lookup_spark_function,
)
from gateway.converter.substrait_builder import (
add_function,
aggregate_relation,
Expand Down Expand Up @@ -1260,29 +1265,8 @@ def convert_aggregate_relation(self, rel: spark_relations_pb2.Aggregate) -> alge
self._next_under_aggregation_reference_id = 0
self._under_aggregation_projects = []

# TODO -- Deal with mixed groupings and measures.
grouping_expression_list = []
for idx, grouping in enumerate(rel.grouping_expressions):
grouping_expression_list.append(self.convert_expression(grouping))
symbol.generated_fields.append(self.determine_name_for_grouping(grouping))
self._top_level_projects.append(field_reference(idx))
aggregate.groupings.append(
algebra_pb2.AggregateRel.Grouping(grouping_expressions=grouping_expression_list)
)

self._expression_processing_mode = ExpressionProcessingMode.AGGR_TOP_LEVEL

for expr in rel.aggregate_expressions:
result = self.convert_expression(expr)
if result:
self._top_level_projects.append(result)
symbol.generated_fields.append(self.determine_expression_name(expr))
symbol.output_fields.clear()
symbol.output_fields.extend(symbol.generated_fields)
if len(rel.grouping_expressions) > 1:
# Hide the grouping source from the downstream relations.
for i in range(len(rel.grouping_expressions) + len(rel.aggregate_expressions)):
aggregate.common.emit.output_mapping.append(i)
# Handle different group by types
self.handle_grouping_and_measures(rel, aggregate, symbol)

self._expression_processing_mode = ExpressionProcessingMode.NORMAL

Expand Down Expand Up @@ -1312,6 +1296,66 @@ def convert_aggregate_relation(self, rel: spark_relations_pb2.Aggregate) -> alge

return algebra_pb2.Rel(aggregate=aggregate)

def handle_grouping_and_measures(self, rel: spark_relations_pb2.Aggregate,
aggregate: algebra_pb2.AggregateRel,
symbol):
"""Handle group by aggregations."""
grouping_expression_list = []
rel_grouping_expressions = rel.grouping_expressions
for idx, grouping in enumerate(rel_grouping_expressions):
grouping_expression_list.append(self.convert_expression(grouping))
symbol.generated_fields.append(self.determine_name_for_grouping(grouping))
self._top_level_projects.append(field_reference(idx))

match rel.group_type:
case spark_relations_pb2.Aggregate.GroupType.GROUP_TYPE_GROUPBY:
aggregate.groupings.append(
algebra_pb2.AggregateRel.Grouping(
grouping_expressions=grouping_expression_list)
)
case spark_relations_pb2.Aggregate.GroupType.GROUP_TYPE_CUBE:
# Generate and add all groupings required for CUBE
cube_groupings = self.create_cube_groupings(rel_grouping_expressions)
aggregate.groupings.extend(cube_groupings)
case _:
raise NotImplementedError(
"Only GROUPBY and CUBE group types are currently supported."
)

self._expression_processing_mode = ExpressionProcessingMode.AGGR_TOP_LEVEL

for expr in rel.aggregate_expressions:
result = self.convert_expression(expr)
if result:
self._top_level_projects.append(result)
symbol.generated_fields.append(self.determine_expression_name(expr))
symbol.output_fields.clear()
symbol.output_fields.extend(symbol.generated_fields)
if len(rel.grouping_expressions) > 1:
# Hide the grouping source from the downstream relations.
for i in range(len(rel.grouping_expressions) + len(rel.aggregate_expressions)):
aggregate.common.emit.output_mapping.append(i)

def create_cube_groupings(self, grouping_expressions):
"""Create all combinations of grouping expressions."""
num_expressions = len(grouping_expressions)
cube_groupings = []

# Generate all possible combinations of grouping expressions
for i in range(num_expressions + 1):
for combination in combinations(range(num_expressions), i):
# Create a list of the current combination of grouping expressions
converted_expressions = []
for j in combination:
converted_expression = self.convert_expression(grouping_expressions[j])
converted_expressions.append(converted_expression)
# Add the grouping for this combination
cube_groupings.append(
algebra_pb2.AggregateRel.Grouping(grouping_expressions=converted_expressions)
)

return cube_groupings

# pylint: disable=too-many-locals,pointless-string-statement
def convert_show_string_relation(self, rel: spark_relations_pb2.ShowString) -> algebra_pb2.Rel:
"""Convert a show string relation into a Substrait subplan."""
Expand Down
96 changes: 88 additions & 8 deletions src/gateway/tests/test_dataframe_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@
ucase,
upper,
)
from pyspark.sql.types import DoubleType, StructField, StructType
from pyspark.sql.types import DoubleType, IntegerType, StringType, StructField, StructType
from pyspark.sql.window import Window
from pyspark.testing import assertDataFrameEqual

Expand Down Expand Up @@ -154,6 +154,13 @@ def mark_dataframe_tests_as_xfail(request):
pytest.skip(reason="inf vs -inf difference")
if source == "gateway-over-duckdb" and originalname in ["test_union", "test_unionall"]:
pytest.skip(reason="distinct not handled properly")
if source == "gateway-over-datafusion" and originalname == "test_rollup":
pytest.skip(reason="rollup aggregation not yet implemented in gateway")
if source == "gateway-over-duckdb" and originalname == "test_rollup":
pytest.skip(reason="rollup aggregation not yet implemented in gateway")
if source == "gateway-over-duckdb" and originalname == "test_cube":
pytest.skip(reason="cube aggregation not yet implemented in DuckDB")


# ruff: noqa: E712
class TestDataFrameAPI:
Expand Down Expand Up @@ -689,14 +696,14 @@ def test_exceptall(self, register_tpch_dataset, spark_session, caplog):
n_name="RUSSIA",
n_regionkey=3,
n_comment="uctions. furiously unusual instructions sleep furiously ironic "
"packages. slyly ",
"packages. slyly ",
),
Row(
n_nationkey=22,
n_name="RUSSIA",
n_regionkey=3,
n_comment="uctions. furiously unusual instructions sleep furiously ironic "
"packages. slyly ",
"packages. slyly ",
),
Row(
n_nationkey=23,
Expand All @@ -709,14 +716,14 @@ def test_exceptall(self, register_tpch_dataset, spark_session, caplog):
n_name="UNITED STATES",
n_regionkey=1,
n_comment="ly ironic requests along the slyly bold ideas hang after the "
"blithely special notornis; blithely even accounts",
"blithely special notornis; blithely even accounts",
),
Row(
n_nationkey=24,
n_name="UNITED STATES",
n_regionkey=1,
n_comment="ly ironic requests along the slyly bold ideas hang after the "
"blithely special notornis; blithely even accounts",
"blithely special notornis; blithely even accounts",
),
]

Expand Down Expand Up @@ -802,14 +809,14 @@ def test_subtract(self, register_tpch_dataset, spark_session):
n_name="RUSSIA",
n_regionkey=3,
n_comment="uctions. furiously unusual instructions sleep furiously "
"ironic packages. slyly ",
"ironic packages. slyly ",
),
Row(
n_nationkey=24,
n_name="UNITED STATES",
n_regionkey=1,
n_comment="ly ironic requests along the slyly bold ideas hang after "
"the blithely special notornis; blithely even accounts",
"the blithely special notornis; blithely even accounts",
),
]

Expand Down Expand Up @@ -2657,7 +2664,7 @@ def test_computation_with_two_aggregations(self, register_tpch_dataset, spark_se
assertDataFrameEqual(outcome, expected)

def test_computation_with_two_aggregations_and_internal_calculation(
self, register_tpch_dataset, spark_session
self, register_tpch_dataset, spark_session
):
expected = [
Row(l_suppkey=1, a=Decimal("3903113211864.3000")),
Expand Down Expand Up @@ -2762,3 +2769,76 @@ def test_row_number(self, users_dataframe):
"row_number").limit(3)

assertDataFrameEqual(outcome, expected)


@pytest.fixture(scope="class")
def userage_dataframe(spark_session_for_setup):
data = [
[1, "Alice"],
[2, "Bob"],
[3, "Alice"]
]

schema = StructType(
[
StructField("age", IntegerType(), True),
StructField("name", StringType(), True),
]
)

test_df = spark_session_for_setup.createDataFrame(data, schema)

test_df.createOrReplaceTempView("userage")
return spark_session_for_setup.table("userage")


class TestDataFrameDecisionSupport:
"""Tests data science methods of the dataframe side of SparkConnect."""

def test_groupby(self, userage_dataframe):
expected = [
Row(name='Alice', age=1, count=1),
Row(name='Alice', age=3, count=1),
Row(name='Bob', age=2, count=1)
]

with utilizes_valid_plans(userage_dataframe):
outcome = userage_dataframe.groupby("name", "age").count().orderBy("name",
"age").collect()

assertDataFrameEqual(outcome, expected)

def test_rollup(self, userage_dataframe):
expected = [
Row(name='Alice', age=1, count=1),
Row(name='Alice', age=3, count=1),
Row(name='Alice', age=None, count=2),
Row(name='Bob', age=2, count=1),
Row(name='Bob', age=None, count=1),
Row(name=None, age=None, count=3)
]

with utilizes_valid_plans(userage_dataframe):
outcome = userage_dataframe.rollup("name", "age").count().orderBy("name",
"age").collect()

assertDataFrameEqual(outcome, expected)

def test_cube(self, userage_dataframe):
expected = [
Row(name='Alice', age=1, count=1),
Row(name='Alice', age=3, count=1),
Row(name='Alice', age=None, count=2),
Row(name='Bob', age=2, count=1),
Row(name='Bob', age=None, count=1),
Row(name=None, age=1, count=1),
Row(name=None, age=2, count=1),
Row(name=None, age=3, count=1),
Row(name=None, age=None, count=3)
]

with utilizes_valid_plans(userage_dataframe):
outcome = userage_dataframe.cube("name", "age").count().orderBy("name",
"age").collect()

assertDataFrameEqual(outcome, expected)

0 comments on commit 5213de9

Please sign in to comment.