diff --git a/src/gateway/converter/spark_to_substrait.py b/src/gateway/converter/spark_to_substrait.py index 1387dd8..dec7f5e 100644 --- a/src/gateway/converter/spark_to_substrait.py +++ b/src/gateway/converter/spark_to_substrait.py @@ -7,6 +7,7 @@ import pathlib import re from enum import Enum +from itertools import combinations import pyarrow as pa import pyspark.sql.connect.proto.base_pb2 as spark_pb2 @@ -19,7 +20,11 @@ from substrait.gen.proto.extensions import extensions_pb2 from gateway.converter.conversion_options import ConversionOptions -from gateway.converter.spark_functions import ExtensionFunction, FunctionType, lookup_spark_function +from gateway.converter.spark_functions import ( + ExtensionFunction, + FunctionType, + lookup_spark_function, +) from gateway.converter.substrait_builder import ( add_function, aggregate_relation, @@ -1260,29 +1265,8 @@ def convert_aggregate_relation(self, rel: spark_relations_pb2.Aggregate) -> alge self._next_under_aggregation_reference_id = 0 self._under_aggregation_projects = [] - # TODO -- Deal with mixed groupings and measures. - grouping_expression_list = [] - for idx, grouping in enumerate(rel.grouping_expressions): - grouping_expression_list.append(self.convert_expression(grouping)) - symbol.generated_fields.append(self.determine_name_for_grouping(grouping)) - self._top_level_projects.append(field_reference(idx)) - aggregate.groupings.append( - algebra_pb2.AggregateRel.Grouping(grouping_expressions=grouping_expression_list) - ) - - self._expression_processing_mode = ExpressionProcessingMode.AGGR_TOP_LEVEL - - for expr in rel.aggregate_expressions: - result = self.convert_expression(expr) - if result: - self._top_level_projects.append(result) - symbol.generated_fields.append(self.determine_expression_name(expr)) - symbol.output_fields.clear() - symbol.output_fields.extend(symbol.generated_fields) - if len(rel.grouping_expressions) > 1: - # Hide the grouping source from the downstream relations. - for i in range(len(rel.grouping_expressions) + len(rel.aggregate_expressions)): - aggregate.common.emit.output_mapping.append(i) + # Handle different group by types + self.handle_grouping_and_measures(rel, aggregate, symbol) self._expression_processing_mode = ExpressionProcessingMode.NORMAL @@ -1312,6 +1296,66 @@ def convert_aggregate_relation(self, rel: spark_relations_pb2.Aggregate) -> alge return algebra_pb2.Rel(aggregate=aggregate) + def handle_grouping_and_measures(self, rel: spark_relations_pb2.Aggregate, + aggregate: algebra_pb2.AggregateRel, + symbol): + """Handle group by aggregations.""" + grouping_expression_list = [] + rel_grouping_expressions = rel.grouping_expressions + for idx, grouping in enumerate(rel_grouping_expressions): + grouping_expression_list.append(self.convert_expression(grouping)) + symbol.generated_fields.append(self.determine_name_for_grouping(grouping)) + self._top_level_projects.append(field_reference(idx)) + + match rel.group_type: + case spark_relations_pb2.Aggregate.GroupType.GROUP_TYPE_GROUPBY: + aggregate.groupings.append( + algebra_pb2.AggregateRel.Grouping( + grouping_expressions=grouping_expression_list) + ) + case spark_relations_pb2.Aggregate.GroupType.GROUP_TYPE_CUBE: + # Generate and add all groupings required for CUBE + cube_groupings = self.create_cube_groupings(rel_grouping_expressions) + aggregate.groupings.extend(cube_groupings) + case _: + raise NotImplementedError( + "Only GROUPBY and CUBE group types are currently supported." + ) + + self._expression_processing_mode = ExpressionProcessingMode.AGGR_TOP_LEVEL + + for expr in rel.aggregate_expressions: + result = self.convert_expression(expr) + if result: + self._top_level_projects.append(result) + symbol.generated_fields.append(self.determine_expression_name(expr)) + symbol.output_fields.clear() + symbol.output_fields.extend(symbol.generated_fields) + if len(rel.grouping_expressions) > 1: + # Hide the grouping source from the downstream relations. + for i in range(len(rel.grouping_expressions) + len(rel.aggregate_expressions)): + aggregate.common.emit.output_mapping.append(i) + + def create_cube_groupings(self, grouping_expressions): + """Create all combinations of grouping expressions.""" + num_expressions = len(grouping_expressions) + cube_groupings = [] + + # Generate all possible combinations of grouping expressions + for i in range(num_expressions + 1): + for combination in combinations(range(num_expressions), i): + # Create a list of the current combination of grouping expressions + converted_expressions = [] + for j in combination: + converted_expression = self.convert_expression(grouping_expressions[j]) + converted_expressions.append(converted_expression) + # Add the grouping for this combination + cube_groupings.append( + algebra_pb2.AggregateRel.Grouping(grouping_expressions=converted_expressions) + ) + + return cube_groupings + # pylint: disable=too-many-locals,pointless-string-statement def convert_show_string_relation(self, rel: spark_relations_pb2.ShowString) -> algebra_pb2.Rel: """Convert a show string relation into a Substrait subplan.""" diff --git a/src/gateway/tests/test_dataframe_api.py b/src/gateway/tests/test_dataframe_api.py index 4e97d2f..6e2a063 100644 --- a/src/gateway/tests/test_dataframe_api.py +++ b/src/gateway/tests/test_dataframe_api.py @@ -64,7 +64,7 @@ ucase, upper, ) -from pyspark.sql.types import DoubleType, StructField, StructType +from pyspark.sql.types import DoubleType, IntegerType, StringType, StructField, StructType from pyspark.sql.window import Window from pyspark.testing import assertDataFrameEqual @@ -154,6 +154,13 @@ def mark_dataframe_tests_as_xfail(request): pytest.skip(reason="inf vs -inf difference") if source == "gateway-over-duckdb" and originalname in ["test_union", "test_unionall"]: pytest.skip(reason="distinct not handled properly") + if source == "gateway-over-datafusion" and originalname == "test_rollup": + pytest.skip(reason="rollup aggregation not yet implemented in gateway") + if source == "gateway-over-duckdb" and originalname == "test_rollup": + pytest.skip(reason="rollup aggregation not yet implemented in gateway") + if source == "gateway-over-duckdb" and originalname == "test_cube": + pytest.skip(reason="cube aggregation not yet implemented in DuckDB") + # ruff: noqa: E712 class TestDataFrameAPI: @@ -689,14 +696,14 @@ def test_exceptall(self, register_tpch_dataset, spark_session, caplog): n_name="RUSSIA", n_regionkey=3, n_comment="uctions. furiously unusual instructions sleep furiously ironic " - "packages. slyly ", + "packages. slyly ", ), Row( n_nationkey=22, n_name="RUSSIA", n_regionkey=3, n_comment="uctions. furiously unusual instructions sleep furiously ironic " - "packages. slyly ", + "packages. slyly ", ), Row( n_nationkey=23, @@ -709,14 +716,14 @@ def test_exceptall(self, register_tpch_dataset, spark_session, caplog): n_name="UNITED STATES", n_regionkey=1, n_comment="ly ironic requests along the slyly bold ideas hang after the " - "blithely special notornis; blithely even accounts", + "blithely special notornis; blithely even accounts", ), Row( n_nationkey=24, n_name="UNITED STATES", n_regionkey=1, n_comment="ly ironic requests along the slyly bold ideas hang after the " - "blithely special notornis; blithely even accounts", + "blithely special notornis; blithely even accounts", ), ] @@ -802,14 +809,14 @@ def test_subtract(self, register_tpch_dataset, spark_session): n_name="RUSSIA", n_regionkey=3, n_comment="uctions. furiously unusual instructions sleep furiously " - "ironic packages. slyly ", + "ironic packages. slyly ", ), Row( n_nationkey=24, n_name="UNITED STATES", n_regionkey=1, n_comment="ly ironic requests along the slyly bold ideas hang after " - "the blithely special notornis; blithely even accounts", + "the blithely special notornis; blithely even accounts", ), ] @@ -2657,7 +2664,7 @@ def test_computation_with_two_aggregations(self, register_tpch_dataset, spark_se assertDataFrameEqual(outcome, expected) def test_computation_with_two_aggregations_and_internal_calculation( - self, register_tpch_dataset, spark_session + self, register_tpch_dataset, spark_session ): expected = [ Row(l_suppkey=1, a=Decimal("3903113211864.3000")), @@ -2762,3 +2769,76 @@ def test_row_number(self, users_dataframe): "row_number").limit(3) assertDataFrameEqual(outcome, expected) + + +@pytest.fixture(scope="class") +def userage_dataframe(spark_session_for_setup): + data = [ + [1, "Alice"], + [2, "Bob"], + [3, "Alice"] + ] + + schema = StructType( + [ + StructField("age", IntegerType(), True), + StructField("name", StringType(), True), + ] + ) + + test_df = spark_session_for_setup.createDataFrame(data, schema) + + test_df.createOrReplaceTempView("userage") + return spark_session_for_setup.table("userage") + + +class TestDataFrameDecisionSupport: + """Tests data science methods of the dataframe side of SparkConnect.""" + + def test_groupby(self, userage_dataframe): + expected = [ + Row(name='Alice', age=1, count=1), + Row(name='Alice', age=3, count=1), + Row(name='Bob', age=2, count=1) + ] + + with utilizes_valid_plans(userage_dataframe): + outcome = userage_dataframe.groupby("name", "age").count().orderBy("name", + "age").collect() + + assertDataFrameEqual(outcome, expected) + + def test_rollup(self, userage_dataframe): + expected = [ + Row(name='Alice', age=1, count=1), + Row(name='Alice', age=3, count=1), + Row(name='Alice', age=None, count=2), + Row(name='Bob', age=2, count=1), + Row(name='Bob', age=None, count=1), + Row(name=None, age=None, count=3) + ] + + with utilizes_valid_plans(userage_dataframe): + outcome = userage_dataframe.rollup("name", "age").count().orderBy("name", + "age").collect() + + assertDataFrameEqual(outcome, expected) + + def test_cube(self, userage_dataframe): + expected = [ + Row(name='Alice', age=1, count=1), + Row(name='Alice', age=3, count=1), + Row(name='Alice', age=None, count=2), + Row(name='Bob', age=2, count=1), + Row(name='Bob', age=None, count=1), + Row(name=None, age=1, count=1), + Row(name=None, age=2, count=1), + Row(name=None, age=3, count=1), + Row(name=None, age=None, count=3) + ] + + with utilizes_valid_plans(userage_dataframe): + outcome = userage_dataframe.cube("name", "age").count().orderBy("name", + "age").collect() + + assertDataFrameEqual(outcome, expected)