Skip to content

Commit

Permalink
feat: add handling for cube
Browse files Browse the repository at this point in the history
  • Loading branch information
richtia committed Sep 18, 2024
1 parent 89f2c50 commit 8c5056a
Show file tree
Hide file tree
Showing 2 changed files with 56 additions and 25 deletions.
60 changes: 40 additions & 20 deletions src/gateway/converter/spark_to_substrait.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,17 +7,13 @@
import pathlib
import re
from enum import Enum
from itertools import combinations

import pyarrow as pa
import pyspark.sql.connect.proto.base_pb2 as spark_pb2
import pyspark.sql.connect.proto.expressions_pb2 as spark_exprs_pb2
import pyspark.sql.connect.proto.relations_pb2 as spark_relations_pb2
import pyspark.sql.connect.proto.types_pb2 as spark_types_pb2
from google.protobuf.internal.wire_format import INT64_MAX
from pyspark.sql.connect.proto import types_pb2
from substrait.gen.proto import algebra_pb2, plan_pb2, type_pb2
from substrait.gen.proto.extensions import extensions_pb2

from gateway.converter.conversion_options import ConversionOptions
from gateway.converter.spark_functions import ExtensionFunction, FunctionType, lookup_spark_function
from gateway.converter.substrait_builder import (
Expand Down Expand Up @@ -53,6 +49,10 @@
strlen,
)
from gateway.converter.symbol_table import SymbolTable
from google.protobuf.internal.wire_format import INT64_MAX
from pyspark.sql.connect.proto import types_pb2
from substrait.gen.proto import algebra_pb2, plan_pb2, type_pb2
from substrait.gen.proto.extensions import extensions_pb2


class ExpressionProcessingMode(Enum):
Expand Down Expand Up @@ -1261,12 +1261,7 @@ def convert_aggregate_relation(self, rel: spark_relations_pb2.Aggregate) -> alge
self._under_aggregation_projects = []

# Handle different group by types
if rel.group_type == spark_relations_pb2.Aggregate.GroupType.GROUP_TYPE_GROUPBY:
self.handle_group_by_aggregation(rel, aggregate, symbol)
elif rel.group_type == spark_relations_pb2.Aggregate.GroupType.GROUP_TYPE_CUBE:
self.handle_cube_aggregation(rel, aggregate, symbol)
else:
raise NotImplementedError("Only GROUPBY and CUBE group types are currently supported.")
self.handle_group_by_aggregations(rel, aggregate, symbol)

self._expression_processing_mode = ExpressionProcessingMode.NORMAL

Expand Down Expand Up @@ -1296,18 +1291,27 @@ def convert_aggregate_relation(self, rel: spark_relations_pb2.Aggregate) -> alge

return algebra_pb2.Rel(aggregate=aggregate)

def handle_group_by_aggregation(self, rel: spark_relations_pb2.Aggregate, aggregate: algebra_pb2.AggregateRel,
symbol):
"""Handle regular group by aggregation."""
def handle_group_by_aggregations(self, rel: spark_relations_pb2.Aggregate,
aggregate: algebra_pb2.AggregateRel,
symbol):
"""Handle group by aggregations."""
grouping_expression_list = []
rel_grouping_expressions = rel.grouping_expressions
for idx, grouping in enumerate(rel_grouping_expressions):
grouping_expression_list.append(self.convert_expression(grouping))
symbol.generated_fields.append(self.determine_name_for_grouping(grouping))
self._top_level_projects.append(field_reference(idx))
aggregate.groupings.append(
algebra_pb2.AggregateRel.Grouping(grouping_expressions=grouping_expression_list)
)

if rel.group_type == spark_relations_pb2.Aggregate.GroupType.GROUP_TYPE_GROUPBY:
aggregate.groupings.append(
algebra_pb2.AggregateRel.Grouping(grouping_expressions=grouping_expression_list)
)
elif rel.group_type == spark_relations_pb2.Aggregate.GroupType.GROUP_TYPE_CUBE:
# Generate and add all groupings required for CUBE
cube_groupings = self.create_cube_groupings(rel_grouping_expressions)
aggregate.groupings.extend(cube_groupings)
else:
raise NotImplementedError("Only GROUPBY and CUBE group types are currently supported.")

self._expression_processing_mode = ExpressionProcessingMode.AGGR_TOP_LEVEL

Expand All @@ -1323,9 +1327,25 @@ def handle_group_by_aggregation(self, rel: spark_relations_pb2.Aggregate, aggreg
for i in range(len(rel.grouping_expressions) + len(rel.aggregate_expressions)):
aggregate.common.emit.output_mapping.append(i)

def handle_cube_aggregation(self, rel: spark_relations_pb2.Aggregate, aggregate: algebra_pb2.AggregateRel, symbol):
"""Handle cube aggregation."""
raise NotImplementedError("Only GROUPBY group type is currently supported.")
def create_cube_groupings(self, grouping_expressions):
"""Create all combinations of grouping expressions."""
num_expressions = len(grouping_expressions)
cube_groupings = []

# Generate all possible combinations of grouping expressions
for i in range(num_expressions + 1):
for combination in combinations(range(num_expressions), i):
# Create a list of the current combination of grouping expressions
converted_expressions = []
for j in combination:
converted_expression = self.convert_expression(grouping_expressions[j])
converted_expressions.append(converted_expression)
# Add the grouping for this combination
cube_groupings.append(
algebra_pb2.AggregateRel.Grouping(grouping_expressions=converted_expressions)
)

return cube_groupings

# pylint: disable=too-many-locals,pointless-string-statement
def convert_show_string_relation(self, rel: spark_relations_pb2.ShowString) -> algebra_pb2.Rel:
Expand Down
21 changes: 16 additions & 5 deletions src/gateway/tests/test_dataframe_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -165,6 +165,12 @@ def mark_dataframe_tests_as_xfail(request):

if source == "gateway-over-duckdb" and originalname == "test_row_number":
pytest.skip(reason="window functions not yet implemented in DuckDB")
if source == "gateway-over-datafusion" and originalname == "test_rollup":
pytest.skip(reason="rollup aggregation not yet implemented in gateway")
if source == "gateway-over-duckdb" and originalname == "test_rollup":
pytest.skip(reason="rollup aggregation not yet implemented in gateway")
if source == "gateway-over-duckdb" and originalname == "test_cube":
pytest.skip(reason="cube aggregation not yet implemented in DuckDB")


# ruff: noqa: E712
Expand Down Expand Up @@ -2781,6 +2787,7 @@ def userage_dataframe(spark_session_for_setup):
data = [
[1, "Alice"],
[2, "Bob"],
[3, "Alice"]
]

schema = StructType(
Expand All @@ -2802,7 +2809,8 @@ class TestDataFrameDataScienceFunctions:
def test_groupby(self, userage_dataframe):
expected = [
Row(name='Alice', age=1, count=1),
Row(name='Bob', age=2, count=1),
Row(name='Alice', age=3, count=1),
Row(name='Bob', age=2, count=1)
]

with utilizes_valid_plans(userage_dataframe):
Expand All @@ -2814,10 +2822,11 @@ def test_groupby(self, userage_dataframe):
def test_rollup(self, userage_dataframe):
expected = [
Row(name='Alice', age=1, count=1),
Row(name='Alice', age=None, count=1),
Row(name='Alice', age=3, count=1),
Row(name='Alice', age=None, count=2),
Row(name='Bob', age=2, count=1),
Row(name='Bob', age=None, count=1),
Row(name=None, age=None, count=2),
Row(name=None, age=None, count=3)
]

with utilizes_valid_plans(userage_dataframe):
Expand All @@ -2829,12 +2838,14 @@ def test_rollup(self, userage_dataframe):
def test_cube(self, userage_dataframe):
expected = [
Row(name='Alice', age=1, count=1),
Row(name='Alice', age=None, count=1),
Row(name='Alice', age=3, count=1),
Row(name='Alice', age=None, count=2),
Row(name='Bob', age=2, count=1),
Row(name='Bob', age=None, count=1),
Row(name=None, age=1, count=1),
Row(name=None, age=2, count=1),
Row(name=None, age=None, count=2)
Row(name=None, age=3, count=1),
Row(name=None, age=None, count=3)
]

with utilizes_valid_plans(userage_dataframe):
Expand Down

0 comments on commit 8c5056a

Please sign in to comment.