diff --git a/src/gateway/converter/spark_to_substrait.py b/src/gateway/converter/spark_to_substrait.py index dec7f5e..720336c 100644 --- a/src/gateway/converter/spark_to_substrait.py +++ b/src/gateway/converter/spark_to_substrait.py @@ -1317,6 +1317,10 @@ def handle_grouping_and_measures(self, rel: spark_relations_pb2.Aggregate, # Generate and add all groupings required for CUBE cube_groupings = self.create_cube_groupings(rel_grouping_expressions) aggregate.groupings.extend(cube_groupings) + case spark_relations_pb2.Aggregate.GroupType.GROUP_TYPE_ROLLUP: + # Generate and add all groupings required for ROLLUP + rollup_groupings = self.create_rollup_groupings(rel_grouping_expressions) + aggregate.groupings.extend(rollup_groupings) case _: raise NotImplementedError( "Only GROUPBY and CUBE group types are currently supported." @@ -1356,6 +1360,28 @@ def create_cube_groupings(self, grouping_expressions): return cube_groupings + def create_rollup_groupings(self, grouping_expressions): + """Create all combinations of grouping expressions for rollup.""" + num_expressions = len(grouping_expressions) + rollup_groupings = [] + + for i in range(num_expressions): + current_grouping = [] + for j in range(i + 1): + converted_expression = self.convert_expression(grouping_expressions[j]) + current_grouping.append(converted_expression) + rollup_groupings.append( + algebra_pb2.AggregateRel.Grouping(grouping_expressions=current_grouping) + ) + + # Add a final grouping with no expressions for the grand total. + # The grand total aggregates over all rows. + rollup_groupings.append( + algebra_pb2.AggregateRel.Grouping(grouping_expressions=[]) + ) + + return rollup_groupings + # pylint: disable=too-many-locals,pointless-string-statement def convert_show_string_relation(self, rel: spark_relations_pb2.ShowString) -> algebra_pb2.Rel: """Convert a show string relation into a Substrait subplan.""" diff --git a/src/gateway/tests/test_dataframe_api.py b/src/gateway/tests/test_dataframe_api.py index 6e2a063..7fe9fa8 100644 --- a/src/gateway/tests/test_dataframe_api.py +++ b/src/gateway/tests/test_dataframe_api.py @@ -154,8 +154,6 @@ def mark_dataframe_tests_as_xfail(request): pytest.skip(reason="inf vs -inf difference") if source == "gateway-over-duckdb" and originalname in ["test_union", "test_unionall"]: pytest.skip(reason="distinct not handled properly") - if source == "gateway-over-datafusion" and originalname == "test_rollup": - pytest.skip(reason="rollup aggregation not yet implemented in gateway") if source == "gateway-over-duckdb" and originalname == "test_rollup": pytest.skip(reason="rollup aggregation not yet implemented in gateway") if source == "gateway-over-duckdb" and originalname == "test_cube":