Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: add support for rollup #87

Merged
merged 1 commit into from
Sep 18, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 26 additions & 0 deletions src/gateway/converter/spark_to_substrait.py
Original file line number Diff line number Diff line change
Expand Up @@ -1317,6 +1317,10 @@ def handle_grouping_and_measures(self, rel: spark_relations_pb2.Aggregate,
# Generate and add all groupings required for CUBE
cube_groupings = self.create_cube_groupings(rel_grouping_expressions)
aggregate.groupings.extend(cube_groupings)
case spark_relations_pb2.Aggregate.GroupType.GROUP_TYPE_ROLLUP:
# Generate and add all groupings required for ROLLUP
rollup_groupings = self.create_rollup_groupings(rel_grouping_expressions)
aggregate.groupings.extend(rollup_groupings)
case _:
raise NotImplementedError(
"Only GROUPBY and CUBE group types are currently supported."
Expand Down Expand Up @@ -1356,6 +1360,28 @@ def create_cube_groupings(self, grouping_expressions):

return cube_groupings

def create_rollup_groupings(self, grouping_expressions):
"""Create all combinations of grouping expressions for rollup."""
num_expressions = len(grouping_expressions)
rollup_groupings = []

for i in range(num_expressions):
current_grouping = []
for j in range(i + 1):
converted_expression = self.convert_expression(grouping_expressions[j])
current_grouping.append(converted_expression)
rollup_groupings.append(
algebra_pb2.AggregateRel.Grouping(grouping_expressions=current_grouping)
)

# Add a final grouping with no expressions for the grand total.
# The grand total aggregates over all rows.
rollup_groupings.append(
algebra_pb2.AggregateRel.Grouping(grouping_expressions=[])
)

return rollup_groupings

# pylint: disable=too-many-locals,pointless-string-statement
def convert_show_string_relation(self, rel: spark_relations_pb2.ShowString) -> algebra_pb2.Rel:
"""Convert a show string relation into a Substrait subplan."""
Expand Down
2 changes: 0 additions & 2 deletions src/gateway/tests/test_dataframe_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,8 +154,6 @@ def mark_dataframe_tests_as_xfail(request):
pytest.skip(reason="inf vs -inf difference")
if source == "gateway-over-duckdb" and originalname in ["test_union", "test_unionall"]:
pytest.skip(reason="distinct not handled properly")
if source == "gateway-over-datafusion" and originalname == "test_rollup":
pytest.skip(reason="rollup aggregation not yet implemented in gateway")
if source == "gateway-over-duckdb" and originalname == "test_rollup":
pytest.skip(reason="rollup aggregation not yet implemented in gateway")
if source == "gateway-over-duckdb" and originalname == "test_cube":
Expand Down
Loading