diff --git a/src/gateway/converter/spark_to_substrait.py b/src/gateway/converter/spark_to_substrait.py index d67e90f..4a991f6 100644 --- a/src/gateway/converter/spark_to_substrait.py +++ b/src/gateway/converter/spark_to_substrait.py @@ -7,17 +7,13 @@ import pathlib import re from enum import Enum +from itertools import combinations import pyarrow as pa import pyspark.sql.connect.proto.base_pb2 as spark_pb2 import pyspark.sql.connect.proto.expressions_pb2 as spark_exprs_pb2 import pyspark.sql.connect.proto.relations_pb2 as spark_relations_pb2 import pyspark.sql.connect.proto.types_pb2 as spark_types_pb2 -from google.protobuf.internal.wire_format import INT64_MAX -from pyspark.sql.connect.proto import types_pb2 -from substrait.gen.proto import algebra_pb2, plan_pb2, type_pb2 -from substrait.gen.proto.extensions import extensions_pb2 - from gateway.converter.conversion_options import ConversionOptions from gateway.converter.spark_functions import ExtensionFunction, FunctionType, lookup_spark_function from gateway.converter.substrait_builder import ( @@ -53,6 +49,10 @@ strlen, ) from gateway.converter.symbol_table import SymbolTable +from google.protobuf.internal.wire_format import INT64_MAX +from pyspark.sql.connect.proto import types_pb2 +from substrait.gen.proto import algebra_pb2, plan_pb2, type_pb2 +from substrait.gen.proto.extensions import extensions_pb2 class ExpressionProcessingMode(Enum): @@ -1261,12 +1261,7 @@ def convert_aggregate_relation(self, rel: spark_relations_pb2.Aggregate) -> alge self._under_aggregation_projects = [] # Handle different group by types - if rel.group_type == spark_relations_pb2.Aggregate.GroupType.GROUP_TYPE_GROUPBY: - self.handle_group_by_aggregation(rel, aggregate, symbol) - elif rel.group_type == spark_relations_pb2.Aggregate.GroupType.GROUP_TYPE_CUBE: - self.handle_cube_aggregation(rel, aggregate, symbol) - else: - raise NotImplementedError("Only GROUPBY and CUBE group types are currently supported.") + self.handle_group_by_aggregations(rel, aggregate, symbol) self._expression_processing_mode = ExpressionProcessingMode.NORMAL @@ -1296,18 +1291,27 @@ def convert_aggregate_relation(self, rel: spark_relations_pb2.Aggregate) -> alge return algebra_pb2.Rel(aggregate=aggregate) - def handle_group_by_aggregation(self, rel: spark_relations_pb2.Aggregate, aggregate: algebra_pb2.AggregateRel, - symbol): - """Handle regular group by aggregation.""" + def handle_group_by_aggregations(self, rel: spark_relations_pb2.Aggregate, + aggregate: algebra_pb2.AggregateRel, + symbol): + """Handle group by aggregations.""" grouping_expression_list = [] rel_grouping_expressions = rel.grouping_expressions for idx, grouping in enumerate(rel_grouping_expressions): grouping_expression_list.append(self.convert_expression(grouping)) symbol.generated_fields.append(self.determine_name_for_grouping(grouping)) self._top_level_projects.append(field_reference(idx)) - aggregate.groupings.append( - algebra_pb2.AggregateRel.Grouping(grouping_expressions=grouping_expression_list) - ) + + if rel.group_type == spark_relations_pb2.Aggregate.GroupType.GROUP_TYPE_GROUPBY: + aggregate.groupings.append( + algebra_pb2.AggregateRel.Grouping(grouping_expressions=grouping_expression_list) + ) + elif rel.group_type == spark_relations_pb2.Aggregate.GroupType.GROUP_TYPE_CUBE: + # Generate and add all groupings required for CUBE + cube_groupings = self.create_cube_groupings(rel_grouping_expressions) + aggregate.groupings.extend(cube_groupings) + else: + raise NotImplementedError("Only GROUPBY and CUBE group types are currently supported.") self._expression_processing_mode = ExpressionProcessingMode.AGGR_TOP_LEVEL @@ -1323,9 +1327,25 @@ def handle_group_by_aggregation(self, rel: spark_relations_pb2.Aggregate, aggreg for i in range(len(rel.grouping_expressions) + len(rel.aggregate_expressions)): aggregate.common.emit.output_mapping.append(i) - def handle_cube_aggregation(self, rel: spark_relations_pb2.Aggregate, aggregate: algebra_pb2.AggregateRel, symbol): - """Handle cube aggregation.""" - raise NotImplementedError("Only GROUPBY group type is currently supported.") + def create_cube_groupings(self, grouping_expressions): + """Create all combinations of grouping expressions.""" + num_expressions = len(grouping_expressions) + cube_groupings = [] + + # Generate all possible combinations of grouping expressions + for i in range(num_expressions + 1): + for combination in combinations(range(num_expressions), i): + # Create a list of the current combination of grouping expressions + converted_expressions = [] + for j in combination: + converted_expression = self.convert_expression(grouping_expressions[j]) + converted_expressions.append(converted_expression) + # Add the grouping for this combination + cube_groupings.append( + algebra_pb2.AggregateRel.Grouping(grouping_expressions=converted_expressions) + ) + + return cube_groupings # pylint: disable=too-many-locals,pointless-string-statement def convert_show_string_relation(self, rel: spark_relations_pb2.ShowString) -> algebra_pb2.Rel: diff --git a/src/gateway/tests/test_dataframe_api.py b/src/gateway/tests/test_dataframe_api.py index 5996c70..477361a 100644 --- a/src/gateway/tests/test_dataframe_api.py +++ b/src/gateway/tests/test_dataframe_api.py @@ -165,6 +165,12 @@ def mark_dataframe_tests_as_xfail(request): if source == "gateway-over-duckdb" and originalname == "test_row_number": pytest.skip(reason="window functions not yet implemented in DuckDB") + if source == "gateway-over-datafusion" and originalname == "test_rollup": + pytest.skip(reason="rollup aggregation not yet implemented in gateway") + if source == "gateway-over-duckdb" and originalname == "test_rollup": + pytest.skip(reason="rollup aggregation not yet implemented in gateway") + if source == "gateway-over-duckdb" and originalname == "test_cube": + pytest.skip(reason="cube aggregation not yet implemented in DuckDB") # ruff: noqa: E712 @@ -2781,6 +2787,7 @@ def userage_dataframe(spark_session_for_setup): data = [ [1, "Alice"], [2, "Bob"], + [3, "Alice"] ] schema = StructType( @@ -2802,7 +2809,8 @@ class TestDataFrameDataScienceFunctions: def test_groupby(self, userage_dataframe): expected = [ Row(name='Alice', age=1, count=1), - Row(name='Bob', age=2, count=1), + Row(name='Alice', age=3, count=1), + Row(name='Bob', age=2, count=1) ] with utilizes_valid_plans(userage_dataframe): @@ -2814,10 +2822,11 @@ def test_groupby(self, userage_dataframe): def test_rollup(self, userage_dataframe): expected = [ Row(name='Alice', age=1, count=1), - Row(name='Alice', age=None, count=1), + Row(name='Alice', age=3, count=1), + Row(name='Alice', age=None, count=2), Row(name='Bob', age=2, count=1), Row(name='Bob', age=None, count=1), - Row(name=None, age=None, count=2), + Row(name=None, age=None, count=3) ] with utilizes_valid_plans(userage_dataframe): @@ -2829,12 +2838,14 @@ def test_rollup(self, userage_dataframe): def test_cube(self, userage_dataframe): expected = [ Row(name='Alice', age=1, count=1), - Row(name='Alice', age=None, count=1), + Row(name='Alice', age=3, count=1), + Row(name='Alice', age=None, count=2), Row(name='Bob', age=2, count=1), Row(name='Bob', age=None, count=1), Row(name=None, age=1, count=1), Row(name=None, age=2, count=1), - Row(name=None, age=None, count=2) + Row(name=None, age=3, count=1), + Row(name=None, age=None, count=3) ] with utilizes_valid_plans(userage_dataframe):