From 8595e7b48c7417fe10d4c793321aa1c535bcffd7 Mon Sep 17 00:00:00 2001 From: David Sisson Date: Tue, 26 Mar 2024 12:53:31 -0700 Subject: [PATCH 01/58] feat: implement show_string using relations (#24) --- .github/workflows/test.yml | 4 +- src/gateway/converter/conversion_options.py | 2 +- src/gateway/converter/label_relations.py | 2 +- src/gateway/converter/simplify_casts.py | 4 +- src/gateway/converter/spark_functions.py | 44 +++ src/gateway/converter/spark_to_substrait.py | 230 ++++++++++++--- .../converter/spark_to_substrait_test.py | 4 +- src/gateway/converter/substrait_builder.py | 263 ++++++++++++++++++ src/gateway/server.py | 44 ++- src/gateway/tests/test_server.py | 30 +- 10 files changed, 582 insertions(+), 45 deletions(-) create mode 100644 src/gateway/converter/substrait_builder.py diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index 739741f..d977976 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -31,8 +31,8 @@ jobs: export PIP_INTEROP_ENABLED=compatible export PIP_SCRIPT=scripts/pip_verbose.sh $CONDA/bin/conda env update --file environment.yml --name base - python -m pip install --upgrade pip - python -m pip install ".[test]" + $PIP_SCRIPT install --upgrade pip + $PIP_SCRIPT install ".[test]" - name: Run tests run: | $CONDA/bin/pytest diff --git a/src/gateway/converter/conversion_options.py b/src/gateway/converter/conversion_options.py index cc02010..63e2d56 100644 --- a/src/gateway/converter/conversion_options.py +++ b/src/gateway/converter/conversion_options.py @@ -27,7 +27,7 @@ def __init__(self, backend: BackendOptions = None): self.return_names_with_types = False - self.implement_show_string = False + self.implement_show_string = True self.backend = backend diff --git a/src/gateway/converter/label_relations.py b/src/gateway/converter/label_relations.py index 9b230d1..aee273a 100644 --- a/src/gateway/converter/label_relations.py +++ b/src/gateway/converter/label_relations.py @@ -55,7 +55,7 @@ def get_common_section(rel: algebra_pb2.Rel) -> algebra_pb2.RelCommon: result = rel.expand.common case _: raise NotImplementedError('Finding the common section for type ' - f'{rel.WhichOneof('rel_type')} is not implemented') + f'{rel.WhichOneof("rel_type")} is not implemented') return result diff --git a/src/gateway/converter/simplify_casts.py b/src/gateway/converter/simplify_casts.py index cb5ef4c..7b546ae 100644 --- a/src/gateway/converter/simplify_casts.py +++ b/src/gateway/converter/simplify_casts.py @@ -60,7 +60,7 @@ def find_single_input(rel: algebra_pb2.Rel) -> algebra_pb2.Rel: return rel.extension_single.input case _: raise NotImplementedError('Finding single inputs of relations with type ' - f'{rel.WhichOneof('rel_type')} are not implemented') + f'{rel.WhichOneof("rel_type")} are not implemented') @staticmethod def replace_single_input(rel: algebra_pb2.Rel, new_input: algebra_pb2.Rel): @@ -80,7 +80,7 @@ def replace_single_input(rel: algebra_pb2.Rel, new_input: algebra_pb2.Rel): rel.extension_single.input.CopyFrom(new_input) case _: raise NotImplementedError('Modifying inputs of relations with type ' - f'{rel.WhichOneof('rel_type')} are not implemented') + f'{rel.WhichOneof("rel_type")} are not implemented') def update_field_references(self, plan_id: int) -> None: """Uses the field references using the specified portion of the plan.""" diff --git a/src/gateway/converter/spark_functions.py b/src/gateway/converter/spark_functions.py index 2a423d5..4d834c9 100644 --- a/src/gateway/converter/spark_functions.py +++ b/src/gateway/converter/spark_functions.py @@ -40,6 +40,18 @@ def __lt__(self, obj) -> bool: '/functions_comparison.yaml', 'equal:str_str', type_pb2.Type( bool=type_pb2.Type.Boolean( nullability=type_pb2.Type.Nullability.NULLABILITY_REQUIRED))), + '>=': ExtensionFunction( + '/functions_comparison.yaml', 'gte:str_str', type_pb2.Type( + bool=type_pb2.Type.Boolean( + nullability=type_pb2.Type.Nullability.NULLABILITY_REQUIRED))), + '>': ExtensionFunction( + '/functions_comparison.yaml', 'gt:str_str', type_pb2.Type( + bool=type_pb2.Type.Boolean( + nullability=type_pb2.Type.Nullability.NULLABILITY_REQUIRED))), + '-': ExtensionFunction( + '/functions_arithmetic.yaml', 'subtract:i64_i64', type_pb2.Type( + i64=type_pb2.Type.I64( + nullability=type_pb2.Type.Nullability.NULLABILITY_REQUIRED))), 'array_contains': ExtensionFunction( '/functions_set.yaml', 'index_in:str_list', type_pb2.Type( bool=type_pb2.Type.Boolean( @@ -60,6 +72,38 @@ def __lt__(self, obj) -> bool: '/functions_string.yaml', 'char_length:str', type_pb2.Type( i64=type_pb2.Type.I64( nullability=type_pb2.Type.Nullability.NULLABILITY_REQUIRED))), + 'max': ExtensionFunction( + '/functions_aggregate.yaml', 'max:i64', type_pb2.Type( + i64=type_pb2.Type.I64( + nullability=type_pb2.Type.Nullability.NULLABILITY_REQUIRED))), + 'string_agg': ExtensionFunction( + '/functions_string.yaml', 'string_agg:str', type_pb2.Type( + string=type_pb2.Type.String( + nullability=type_pb2.Type.Nullability.NULLABILITY_REQUIRED))), + 'least': ExtensionFunction( + '/functions_comparison.yaml', 'least:i64', type_pb2.Type( + i64=type_pb2.Type.I64( + nullability=type_pb2.Type.Nullability.NULLABILITY_REQUIRED))), + 'greatest': ExtensionFunction( + '/functions_comparison.yaml', 'greatest:i64', type_pb2.Type( + i64=type_pb2.Type.I64( + nullability=type_pb2.Type.Nullability.NULLABILITY_REQUIRED))), + 'concat': ExtensionFunction( + '/functions_string.yaml', 'concat:str_str', type_pb2.Type( + string=type_pb2.Type.String( + nullability=type_pb2.Type.Nullability.NULLABILITY_REQUIRED))), + 'repeat': ExtensionFunction( + '/functions_string.yaml', 'repeat:str_i64', type_pb2.Type( + string=type_pb2.Type.String( + nullability=type_pb2.Type.Nullability.NULLABILITY_REQUIRED))), + 'lpad': ExtensionFunction( + '/functions_string.yaml', 'lpad:str_i64_str', type_pb2.Type( + string=type_pb2.Type.String( + nullability=type_pb2.Type.Nullability.NULLABILITY_REQUIRED))), + 'rpad': ExtensionFunction( + '/functions_string.yaml', 'rpad:str_i64_str', type_pb2.Type( + string=type_pb2.Type.String( + nullability=type_pb2.Type.Nullability.NULLABILITY_REQUIRED))), 'count': ExtensionFunction( '/functions_aggregate_generic.yaml', 'count:any', type_pb2.Type( i64=type_pb2.Type.I64( diff --git a/src/gateway/converter/spark_to_substrait.py b/src/gateway/converter/spark_to_substrait.py index cb636d2..9913ac6 100644 --- a/src/gateway/converter/spark_to_substrait.py +++ b/src/gateway/converter/spark_to_substrait.py @@ -2,21 +2,25 @@ """Routines to convert SparkConnect plans to Substrait plans.""" import json import operator -from typing import Dict, Optional +from typing import Dict, Optional, List import pyarrow -from substrait.gen.proto import plan_pb2 -from substrait.gen.proto import algebra_pb2 -from substrait.gen.proto import type_pb2 -from substrait.gen.proto.extensions import extensions_pb2 - import pyspark.sql.connect.proto.base_pb2 as spark_pb2 import pyspark.sql.connect.proto.expressions_pb2 as spark_exprs_pb2 import pyspark.sql.connect.proto.relations_pb2 as spark_relations_pb2 import pyspark.sql.connect.proto.types_pb2 as spark_types_pb2 +from substrait.gen.proto import algebra_pb2 +from substrait.gen.proto import plan_pb2 +from substrait.gen.proto import type_pb2 +from substrait.gen.proto.extensions import extensions_pb2 from gateway.converter.conversion_options import ConversionOptions from gateway.converter.spark_functions import ExtensionFunction, lookup_spark_function +from gateway.converter.substrait_builder import field_reference, cast_operation, string_type, \ + project_relation, strlen, concat, fetch_relation, join_relation, aggregate_relation, \ + max_agg_function, string_literal, flatten, repeat_function, \ + least_function, greatest_function, bigint_literal, lpad_function, string_concat_agg_function, \ + if_then_else_operation, greater_function, minus_function from gateway.converter.symbol_table import SymbolTable @@ -133,15 +137,15 @@ def convert_literal_expression( result = algebra_pb2.Expression.Literal() case _: raise NotImplementedError( - f'Unexpected literal type: {literal.WhichOneof('literal_type')}') + f'Unexpected literal type: {literal.WhichOneof("literal_type")}') return algebra_pb2.Expression(literal=result) def convert_unresolved_attribute( self, attr: spark_exprs_pb2.Expression.UnresolvedAttribute) -> algebra_pb2.Expression: """Converts a Spark unresolved attribute into a Substrait field reference.""" - field_reference = self.find_field_by_name(attr.unparsed_identifier) - if field_reference is None: + field_ref = self.find_field_by_name(attr.unparsed_identifier) + if field_ref is None: raise ValueError( f'could not locate field named {attr.unparsed_identifier} in plan id ' f'{self._current_plan_id}') @@ -149,7 +153,7 @@ def convert_unresolved_attribute( return algebra_pb2.Expression(selection=algebra_pb2.Expression.FieldReference( direct_reference=algebra_pb2.Expression.ReferenceSegment( struct_field=algebra_pb2.Expression.ReferenceSegment.StructField( - field=field_reference)), + field=field_ref)), root_reference=algebra_pb2.Expression.FieldReference.RootReference())) def convert_unresolved_function( @@ -211,7 +215,7 @@ def convert_cast_expression( cast_rel.type.CopyFrom(self.convert_type_str(cast.type_str)) case _: raise NotImplementedError( - f'unknown cast_to_type {cast.WhichOneof('cast_to_type')}' + f'unknown cast_to_type {cast.WhichOneof("cast_to_type")}' ) return algebra_pb2.Expression(cast=cast_rel) @@ -364,16 +368,20 @@ def convert_read_data_source_relation(self, rel: spark_relations_pb2.Read) -> al local.items.append(file_or_files) return algebra_pb2.Rel(read=algebra_pb2.ReadRel(base_schema=schema, local_files=local)) - def create_common_relation(self) -> algebra_pb2.RelCommon: + def create_common_relation(self, emit_overrides=None) -> algebra_pb2.RelCommon: """Creates the common metadata relation used by all relations.""" if not self._conversion_options.use_emits_instead_of_direct: return algebra_pb2.RelCommon(direct=algebra_pb2.RelCommon.Direct()) symbol = self._symbol_table.get_symbol(self._current_plan_id) emit = algebra_pb2.RelCommon.Emit() - field_number = 0 - for _ in symbol.output_fields: - emit.output_mapping.append(field_number) - field_number += 1 + if emit_overrides: + for field_number in emit_overrides: + emit.output_mapping.append(field_number) + else: + field_number = 0 + for _ in symbol.output_fields: + emit.output_mapping.append(field_number) + field_number += 1 return algebra_pb2.RelCommon(emit=emit) def convert_read_relation(self, rel: spark_relations_pb2.Read) -> algebra_pb2.Rel: @@ -455,31 +463,188 @@ def convert_aggregate_relation(self, rel: spark_relations_pb2.Aggregate) -> alge symbol.output_fields.extend(symbol.generated_fields) return algebra_pb2.Rel(aggregate=aggregate) + # pylint: disable=too-many-locals,pointless-string-statement def convert_show_string_relation(self, rel: spark_relations_pb2.ShowString) -> algebra_pb2.Rel: - """Converts a show string relation into a Substrait project relation.""" + """Converts a show string relation into a Substrait subplan.""" if not self._conversion_options.implement_show_string: result = self.convert_relation(rel.input) self.update_field_references(rel.input.common.plan_id) return result - # TODO -- Implement using num_rows by wrapping the input in a fetch relation. + if rel.vertical: + raise NotImplementedError('vertical show strings are not yet implemented') - # TODO -- Implement what happens if truncate is not set or less than two. - # TODO -- Implement what happens when rel.vertical is true. + if rel.truncate < 3: + raise NotImplementedError( + 'show_string values of truncate of less than 3 not yet implemented') + + """ + The subplan implementing the show_string relation has this flow: + + Input -> Fetch1 -> Project1 -> Aggregate1 -> Project2 -> Project3 + Fetch1 + Aggregate1 -> Join1 + Join1 -> Project4 -> Aggregate2 + Project3 + Aggregate2 -> Join2 + Join2 -> Project5 + + Input - The plan to run the show_string on. + Fetch1 - Restricts the input to the number of rows (if needed). + Project1 - Finds the length of each column of the remaining rows. + Aggregate1 - Finds the maximum length of each column. + Project2 - Uses the best of truncate, column name length, and max length. + Project3 - Constructs the header and the footer based on the lines. + Join1 - Combines the original rows with the maximum lengths. + Project4 - Combines all of the columns for each row into a single string. + Aggregate2 - Combines all the strings into the body of the result. + Join2 - Combines the header and footer along with the body of the result. + Project5 - Organizes the header, footer, and body in the right order. + """ + + # Find the functions we'll need. + strlen_func = self.lookup_function_by_name('length') + max_func = self.lookup_function_by_name('max') + string_concat_func = self.lookup_function_by_name('string_agg') + concat_func = self.lookup_function_by_name('concat') + repeat_func = self.lookup_function_by_name('repeat') + lpad_func = self.lookup_function_by_name('lpad') + least_func = self.lookup_function_by_name('least') + greatest_func = self.lookup_function_by_name('greatest') + greater_func = self.lookup_function_by_name('>') + minus_func = self.lookup_function_by_name('-') + + # Get the input and restrict it to the number of requested rows if necessary. input_rel = self.convert_relation(rel.input) + if rel.num_rows > 0: + input_rel = fetch_relation(input_rel, rel.num_rows) + + # Now that we've processed the input, do the bookkeeping. self.update_field_references(rel.input.common.plan_id) symbol = self._symbol_table.get_symbol(self._current_plan_id) - # TODO -- Pull the columns from symbol.input_fields. - # TODO -- Use string_agg to aggregate all of the column input into a single string. - # TODO -- Use a project to output a single field with the table info in it. + + # Find the length of each column in every row. + project1 = project_relation( + input_rel, + [strlen(strlen_func, cast_operation(field_reference(column_number), string_type())) for + column_number in range(len(symbol.input_fields))]) + + # Find the maximum width of each column (for the rows in that we will display). + aggregate1 = aggregate_relation( + project1, + measures=[ + max_agg_function(max_func, column_number) + for column_number in range(len(symbol.input_fields))]) + + # Find the maximum we will use based on the truncate, max size, and column name length. + project2 = project_relation( + aggregate1, + [greatest_function(greatest_func, + least_function(least_func, field_reference(column_number), + bigint_literal(rel.truncate)), + strlen(strlen_func, + string_literal(symbol.input_fields[column_number]))) for + column_number in range(len(symbol.input_fields))]) + + def field_header_fragment(field_number: int) -> List[algebra_pb2.Expression]: + return [string_literal('|'), + lpad_function(lpad_func, string_literal(symbol.input_fields[field_number]), + field_reference(field_number))] + + def field_line_fragment(field_number: int) -> List[algebra_pb2.Expression]: + return [string_literal('+'), + repeat_function(repeat_func, '-', field_reference(field_number))] + + def field_body_fragment(field_number: int) -> List[algebra_pb2.Expression]: + return [string_literal('|'), + if_then_else_operation( + greater_function(greater_func, + strlen(strlen_func, + cast_operation( + field_reference(field_number), + string_type())), + field_reference( + field_number + len(symbol.input_fields))), + concat(concat_func, + [lpad_function(lpad_func, field_reference(field_number), + minus_function(minus_func, field_reference( + field_number + len(symbol.input_fields)), + bigint_literal(3))), + string_literal('...')]), + lpad_function(lpad_func, field_reference(field_number), + field_reference( + field_number + len(symbol.input_fields))), + + )] + + def header_line(fields: List[str]) -> List[algebra_pb2.Expression]: + return [concat(concat_func, + flatten([ + field_header_fragment(field_number) for field_number in + range(len(fields)) + ]) + [ + string_literal('|\n'), + ])] + + def full_line(fields: List[str]) -> List[algebra_pb2.Expression]: + return [concat(concat_func, + flatten([ + field_line_fragment(field_number) for field_number in + range(len(fields)) + ]) + [ + string_literal('+\n'), + ])] + + # Construct the header and footer lines. + project3 = project_relation(project2, [ + concat(concat_func, + full_line(symbol.input_fields) + + header_line(symbol.input_fields) + + full_line(symbol.input_fields))] + + full_line(symbol.input_fields)) + + # Combine the original rows with the maximum lengths we are using. + join1 = join_relation(input_rel, project2) + + # Construct the body of the result row by row. + project4 = project_relation(join1, [ + concat(concat_func, + flatten([field_body_fragment(field_number) for field_number in + range(len(symbol.input_fields)) + ]) + [ + string_literal('|\n'), + ] + ), + ]) + + # Merge all of the rows of the result body into a single string. + aggregate2 = aggregate_relation(project4, measures=[ + string_concat_agg_function(string_concat_func, 0)]) + + # Create one row with the header, the body, and the footer in it. + join2 = join_relation(project3, aggregate2) + + symbol = self._symbol_table.get_symbol(self._current_plan_id) symbol.output_fields.clear() symbol.output_fields.append('show_string') - project = algebra_pb2.ProjectRel(input=input_rel) - project.expressions.append( - algebra_pb2.Expression(literal=self.convert_string_literal('hiya'))) - project.common.emit.output_mapping.append(len(symbol.input_fields)) - return algebra_pb2.Rel(project=project) + def compute_row_count_footer(num_rows: int) -> str: + if num_rows == 1: + return 'only showing top 1 row\n' + if num_rows < 20: + return f'only showing top {num_rows} rows\n' + return '' + + # Combine the header, body, and footer into the final result. + project5 = project_relation(join2, [ + concat(concat_func, [ + field_reference(0), + field_reference(2), + field_reference(1), + ] + [string_literal( + compute_row_count_footer(rel.num_rows)) if rel.num_rows else None + ]), + ]) + project5.project.common.emit.output_mapping.append(len(symbol.input_fields)) + return project5 def convert_with_columns_relation( self, rel: spark_relations_pb2.WithColumns) -> algebra_pb2.Rel: @@ -544,7 +709,8 @@ def convert_arrow_to_literal(self, val: pyarrow.Scalar) -> algebra_pb2.Expressio f'Conversion from arrow type {val.type} not yet implemented.') return literal - def convert_arrow_data_to_virtual_table(self, data: bytes) -> algebra_pb2.ReadRel.VirtualTable: + def convert_arrow_data_to_virtual_table(self, + data: bytes) -> algebra_pb2.ReadRel.VirtualTable: """Converts a Spark local relation into a virtual table.""" table = algebra_pb2.ReadRel.VirtualTable() # use Pyarrow to convert the bytes into an arrow structure @@ -559,7 +725,8 @@ def convert_arrow_data_to_virtual_table(self, data: bytes) -> algebra_pb2.ReadRe def convert_local_relation(self, rel: spark_relations_pb2.LocalRelation) -> algebra_pb2.Rel: """Converts a Spark local relation into a virtual table.""" - read = algebra_pb2.ReadRel(virtual_table=self.convert_arrow_data_to_virtual_table(rel.data)) + read = algebra_pb2.ReadRel( + virtual_table=self.convert_arrow_data_to_virtual_table(rel.data)) schema = self.convert_schema(rel.schema) symbol = self._symbol_table.get_symbol(self._current_plan_id) for field_name in schema.names: @@ -594,7 +761,8 @@ def convert_relation(self, rel: spark_relations_pb2.Relation) -> algebra_pb2.Rel case 'local_relation': result = self.convert_local_relation(rel.local_relation) case _: - raise ValueError(f'Unexpected Spark plan rel_type: {rel.WhichOneof("rel_type")}') + raise ValueError( + f'Unexpected Spark plan rel_type: {rel.WhichOneof("rel_type")}') self._current_plan_id = old_plan_id return result diff --git a/src/gateway/converter/spark_to_substrait_test.py b/src/gateway/converter/spark_to_substrait_test.py index 1230c49..9d35f5b 100644 --- a/src/gateway/converter/spark_to_substrait_test.py +++ b/src/gateway/converter/spark_to_substrait_test.py @@ -41,7 +41,9 @@ def test_plan_conversion(request, path): splan_prototext = file.read() substrait_plan = text_format.Parse(splan_prototext, plan_pb2.Plan()) - convert = SparkSubstraitConverter(duck_db()) + options = duck_db() + options.implement_show_string = False + convert = SparkSubstraitConverter(options) substrait = convert.convert_plan(spark_plan) if request.config.getoption('rebuild_goldens'): diff --git a/src/gateway/converter/substrait_builder.py b/src/gateway/converter/substrait_builder.py new file mode 100644 index 0000000..8dc81e1 --- /dev/null +++ b/src/gateway/converter/substrait_builder.py @@ -0,0 +1,263 @@ +# SPDX-License-Identifier: Apache-2.0 +"""Convenience builder for constructing Substrait plans.""" +import itertools +from typing import List, Any + +from substrait.gen.proto import algebra_pb2, type_pb2 + +from gateway.converter.spark_functions import ExtensionFunction + + +def flatten(list_of_lists: List[List[Any]]) -> List[Any]: + """Flattens a list of lists into a list.""" + return list(itertools.chain.from_iterable(list_of_lists)) + + +# pylint: disable=E1101 + +def fetch_relation(input_relation: algebra_pb2.Rel, num_rows: int) -> algebra_pb2.Rel: + """Constructs a Substrait fetch plan node.""" + fetch = algebra_pb2.Rel(fetch=algebra_pb2.FetchRel(input=input_relation, count=num_rows)) + + return fetch + + +def project_relation(input_relation: algebra_pb2.Rel, + expressions: List[algebra_pb2.Expression]) -> algebra_pb2.Rel: + """Constructs a Substrait project plan node.""" + return algebra_pb2.Rel( + project=algebra_pb2.ProjectRel(input=input_relation, expressions=expressions)) + + +# pylint: disable=fixme +def aggregate_relation(input_relation: algebra_pb2.Rel, + measures: List[algebra_pb2.AggregateFunction]) -> algebra_pb2.Rel: + """Constructs a Substrait aggregate plan node.""" + aggregate = algebra_pb2.Rel( + aggregate=algebra_pb2.AggregateRel( + common=algebra_pb2.RelCommon(emit=algebra_pb2.RelCommon.Emit( + output_mapping=range(len(measures)))), + input=input_relation)) + # TODO -- Add support for groupings. + for measure in measures: + aggregate.aggregate.measures.append( + algebra_pb2.AggregateRel.Measure(measure=measure)) + return aggregate + + +def join_relation(left: algebra_pb2.Rel, right: algebra_pb2.Rel) -> algebra_pb2.Rel: + """Constructs a Substrait join plan node.""" + return algebra_pb2.Rel( + join=algebra_pb2.JoinRel(common=algebra_pb2.RelCommon(), left=left, right=right, + expression=algebra_pb2.Expression( + literal=algebra_pb2.Expression.Literal(boolean=True)), + type=algebra_pb2.JoinRel.JoinType.JOIN_TYPE_INNER)) + + +def concat(function_info: ExtensionFunction, + expressions: List[algebra_pb2.Expression]) -> algebra_pb2.Expression: + """Constructs a Substrait concat expression.""" + return algebra_pb2.Expression( + scalar_function=algebra_pb2.Expression.ScalarFunction( + function_reference=function_info.anchor, + output_type=function_info.output_type, + arguments=[algebra_pb2.FunctionArgument(value=expression) for expression in expressions] + )) + + +def strlen(function_info: ExtensionFunction, + expression: algebra_pb2.Expression) -> algebra_pb2.Expression: + """Constructs a Substrait concat expression.""" + return algebra_pb2.Expression( + scalar_function=algebra_pb2.Expression.ScalarFunction( + function_reference=function_info.anchor, + output_type=function_info.output_type, + arguments=[algebra_pb2.FunctionArgument(value=expression)])) + + +def cast_operation(expression: algebra_pb2.Expression, + output_type: type_pb2.Type) -> algebra_pb2.Expression: + """Constructs a Substrait cast expression.""" + return algebra_pb2.Expression( + cast=algebra_pb2.Expression.Cast(input=expression, type=output_type) + ) + + +def if_then_else_operation(if_expr: algebra_pb2.Expression, then_expr: algebra_pb2.Expression, + else_expr: algebra_pb2.Expression) -> algebra_pb2.Expression: + """Constructs a simplistic Substrait if-then-else expression.""" + return algebra_pb2.Expression( + if_then=algebra_pb2.Expression.IfThen( + **{'ifs': [ + algebra_pb2.Expression.IfThen.IfClause(**{'if': if_expr, 'then': then_expr})], + 'else': else_expr}) + ) + + +def field_reference(field_number: int) -> algebra_pb2.Expression: + """Constructs a Substrait field reference expression.""" + return algebra_pb2.Expression( + selection=algebra_pb2.Expression.FieldReference( + direct_reference=algebra_pb2.Expression.ReferenceSegment( + struct_field=algebra_pb2.Expression.ReferenceSegment.StructField( + field=field_number)))) + + +def max_agg_function(function_info: ExtensionFunction, + field_number: int) -> algebra_pb2.AggregateFunction: + """Constructs a Substrait max aggregate function.""" + # TODO -- Reorganize all functions to belong to a class which determines the info. + return algebra_pb2.AggregateFunction( + function_reference=function_info.anchor, + output_type=function_info.output_type, + arguments=[algebra_pb2.FunctionArgument(value=field_reference(field_number))]) + + +def string_concat_agg_function(function_info: ExtensionFunction, + field_number: int, + separator: str = '') -> algebra_pb2.AggregateFunction: + """Constructs a Substrait string concat aggregate function.""" + return algebra_pb2.AggregateFunction( + function_reference=function_info.anchor, + output_type=function_info.output_type, + arguments=[algebra_pb2.FunctionArgument(value=field_reference(field_number)), + algebra_pb2.FunctionArgument(value=string_literal(separator))]) + + +def least_function(function_info: ExtensionFunction, + expr1: algebra_pb2.Expression, + expr2: algebra_pb2.Expression) -> algebra_pb2.Expression: + """Constructs a Substrait min expression.""" + return algebra_pb2.Expression(scalar_function= + algebra_pb2.Expression.ScalarFunction( + function_reference=function_info.anchor, + output_type=function_info.output_type, + arguments=[algebra_pb2.FunctionArgument(value=expr1), + algebra_pb2.FunctionArgument(value=expr2)])) + + +def greatest_function(function_info: ExtensionFunction, + expr1: algebra_pb2.Expression, + expr2: algebra_pb2.Expression) -> algebra_pb2.Expression: + """Constructs a Substrait min expression.""" + return algebra_pb2.Expression(scalar_function= + algebra_pb2.Expression.ScalarFunction( + function_reference=function_info.anchor, + output_type=function_info.output_type, + arguments=[algebra_pb2.FunctionArgument(value=expr1), + algebra_pb2.FunctionArgument(value=expr2)])) + + +def greater_or_equal_function(function_info: ExtensionFunction, + expr1: algebra_pb2.Expression, + expr2: algebra_pb2.Expression) -> algebra_pb2.Expression: + """Constructs a Substrait min expression.""" + return algebra_pb2.Expression(scalar_function= + algebra_pb2.Expression.ScalarFunction( + function_reference=function_info.anchor, + output_type=function_info.output_type, + arguments=[algebra_pb2.FunctionArgument(value=expr1), + algebra_pb2.FunctionArgument(value=expr2)])) + + +def greater_function(function_info: ExtensionFunction, + expr1: algebra_pb2.Expression, + expr2: algebra_pb2.Expression) -> algebra_pb2.Expression: + """Constructs a Substrait min expression.""" + return algebra_pb2.Expression(scalar_function= + algebra_pb2.Expression.ScalarFunction( + function_reference=function_info.anchor, + output_type=function_info.output_type, + arguments=[algebra_pb2.FunctionArgument(value=expr1), + algebra_pb2.FunctionArgument(value=expr2)])) + + +def minus_function(function_info: ExtensionFunction, + expr1: algebra_pb2.Expression, + expr2: algebra_pb2.Expression) -> algebra_pb2.Expression: + """Constructs a Substrait min expression.""" + return algebra_pb2.Expression(scalar_function= + algebra_pb2.Expression.ScalarFunction( + function_reference=function_info.anchor, + output_type=function_info.output_type, + arguments=[algebra_pb2.FunctionArgument(value=expr1), + algebra_pb2.FunctionArgument(value=expr2)])) + + +def repeat_function(function_info: ExtensionFunction, + string: str, + count: algebra_pb2.Expression) -> algebra_pb2.AggregateFunction: + """Constructs a Substrait concat expression.""" + return algebra_pb2.Expression(scalar_function= + algebra_pb2.Expression.ScalarFunction( + function_reference=function_info.anchor, + output_type=function_info.output_type, + arguments=[algebra_pb2.FunctionArgument(value=string_literal(string)), + algebra_pb2.FunctionArgument(value=count)])) + + +def lpad_function(function_info: ExtensionFunction, + expression: algebra_pb2.Expression, count: algebra_pb2.Expression, + pad_string: str = ' ') -> algebra_pb2.AggregateFunction: + """Constructs a Substrait concat expression.""" + return algebra_pb2.Expression(scalar_function= + algebra_pb2.Expression.ScalarFunction( + function_reference=function_info.anchor, + output_type=function_info.output_type, + arguments=[ + algebra_pb2.FunctionArgument(value=cast_operation(expression, varchar_type())), + algebra_pb2.FunctionArgument(value=cast_operation(count, integer_type())), + algebra_pb2.FunctionArgument( + value=cast_operation(string_literal(pad_string), varchar_type()))])) + + +def rpad_function(function_info: ExtensionFunction, + expression: algebra_pb2.Expression, count: algebra_pb2.Expression, + pad_string: str = ' ') -> algebra_pb2.AggregateFunction: + """Constructs a Substrait concat expression.""" + return algebra_pb2.Expression(scalar_function= + algebra_pb2.Expression.ScalarFunction( + function_reference=function_info.anchor, + output_type=function_info.output_type, + arguments=[ + algebra_pb2.FunctionArgument(value=cast_operation(expression, varchar_type())), + algebra_pb2.FunctionArgument(value=cast_operation(count, integer_type())), + algebra_pb2.FunctionArgument( + value=cast_operation(string_literal(pad_string), varchar_type()))])) + + +def string_literal(val: str) -> algebra_pb2.Expression: + """Constructs a Substrait string literal expression.""" + return algebra_pb2.Expression(literal=algebra_pb2.Expression.Literal(string=val)) + + +def bigint_literal(val: int) -> algebra_pb2.Expression: + """Constructs a Substrait string literal expression.""" + return algebra_pb2.Expression(literal=algebra_pb2.Expression.Literal(i64=val)) + + +def string_type(required: bool = True) -> type_pb2.Type: + """Constructs a Substrait string type.""" + if required: + nullability = type_pb2.Type.Nullability.NULLABILITY_REQUIRED + else: + nullability = type_pb2.Type.Nullability.NULLABILITY_NULLABLE + return type_pb2.Type(string=type_pb2.Type.String(nullability=nullability)) + + +def varchar_type(required: bool = True) -> type_pb2.Type: + """Constructs a Substrait varchar type.""" + if required: + nullability = type_pb2.Type.Nullability.NULLABILITY_REQUIRED + else: + nullability = type_pb2.Type.Nullability.NULLABILITY_NULLABLE + return type_pb2.Type(varchar=type_pb2.Type.VarChar(nullability=nullability)) + + +def integer_type(required: bool = True) -> type_pb2.Type: + """Constructs a Substrait i32 type.""" + if required: + nullability = type_pb2.Type.Nullability.NULLABILITY_REQUIRED + else: + nullability = type_pb2.Type.Nullability.NULLABILITY_NULLABLE + return type_pb2.Type(i32=type_pb2.Type.I32(nullability=nullability)) diff --git a/src/gateway/server.py b/src/gateway/server.py index d9de057..75ecea7 100644 --- a/src/gateway/server.py +++ b/src/gateway/server.py @@ -9,6 +9,7 @@ import pyarrow import pyspark.sql.connect.proto.base_pb2_grpc as pb2_grpc import pyspark.sql.connect.proto.base_pb2 as pb2 +from pyspark.sql.connect.proto import types_pb2 from gateway.converter.conversion_options import duck_db, datafusion from gateway.converter.spark_to_substrait import SparkSubstraitConverter @@ -42,6 +43,27 @@ def batch_to_bytes(batch: pyarrow.RecordBatch, schema: pyarrow.Schema) -> bytes: return buffer.getvalue() +# pylint: disable=E1101 +def convert_pyarrow_schema_to_spark(schema: pyarrow.Schema) -> types_pb2.DataType: + """Converts a PyArrow schema to a SparkConnect DataType.Struct schema.""" + fields = [] + for field in schema: + if field.type == pyarrow.bool_(): + data_type = types_pb2.DataType(boolean=types_pb2.DataType.Boolean()) + elif field.type == pyarrow.int64(): + data_type = types_pb2.DataType(long=types_pb2.DataType.Long()) + elif field.type == pyarrow.string(): + data_type = types_pb2.DataType(string=types_pb2.DataType.String()) + else: + raise ValueError( + f'Unsupported arrow schema to Spark schema conversion type: {field.type}') + + struct_field = types_pb2.DataType.StructField(name=field.name, data_type=data_type) + fields.append(struct_field) + + return types_pb2.DataType(struct=types_pb2.DataType.Struct(fields=fields)) + + # pylint: disable=E1101,fixme class SparkConnectService(pb2_grpc.SparkConnectServiceServicer): """Provides the SparkConnect service.""" @@ -72,15 +94,25 @@ def ExecutePlan( 'rel_type') == 'show_string': yield pb2.ExecutePlanResponse( session_id=request.session_id, - arrow_batch=pb2.ExecutePlanResponse.ArrowBatch(row_count=results.num_rows, - data=show_string(results))) + arrow_batch=pb2.ExecutePlanResponse.ArrowBatch( + row_count=results.num_rows, + data=show_string(results)), + schema=types_pb2.DataType(struct=types_pb2.DataType.Struct( + fields=[types_pb2.DataType.StructField( + name='show_string', + data_type=types_pb2.DataType(string=types_pb2.DataType.String()))] + )), + ) return for batch in results.to_batches(): - yield pb2.ExecutePlanResponse(session_id=request.session_id, - arrow_batch=pb2.ExecutePlanResponse.ArrowBatch( - row_count=batch.num_rows, - data=batch_to_bytes(batch, results.schema))) + yield pb2.ExecutePlanResponse( + session_id=request.session_id, + arrow_batch=pb2.ExecutePlanResponse.ArrowBatch( + row_count=batch.num_rows, + data=batch_to_bytes(batch, results.schema)), + schema=convert_pyarrow_schema_to_spark(results.schema), + ) # TODO -- When spark 3.4.0 support is not required, yield a ResultComplete message here. def AnalyzePlan(self, request, context): diff --git a/src/gateway/tests/test_server.py b/src/gateway/tests/test_server.py index 731d1a4..299244d 100644 --- a/src/gateway/tests/test_server.py +++ b/src/gateway/tests/test_server.py @@ -1,5 +1,6 @@ # SPDX-License-Identifier: Apache-2.0 """Tests for the Spark to Substrait Gateway server.""" +from hamcrest import assert_that, equal_to from pyspark.sql.functions import col, substring from pyspark.testing import assertDataFrameEqual @@ -15,6 +16,20 @@ def test_filter(self, users_dataframe): # pylint: disable=singleton-comparison def test_filter_with_show(self, users_dataframe, capsys): + expected = '''+-------------+---------------+----------------+ +| user_id| name|paid_for_service| ++-------------+---------------+----------------+ +|user669344115| Joshua Brown| true| +|user282427709|Michele Carroll| true| ++-------------+---------------+----------------+ + +''' + users_dataframe.filter(col('paid_for_service') == True).limit(2).show() + outcome = capsys.readouterr().out + assert_that(outcome, equal_to(expected)) + + # pylint: disable=singleton-comparison + def test_filter_with_show_with_limit(self, users_dataframe, capsys): expected = '''+-------------+------------+----------------+ | user_id| name|paid_for_service| +-------------+------------+----------------+ @@ -25,7 +40,20 @@ def test_filter_with_show(self, users_dataframe, capsys): ''' users_dataframe.filter(col('paid_for_service') == True).show(1) outcome = capsys.readouterr().out - assert outcome == expected + assert_that(outcome, equal_to(expected)) + + # pylint: disable=singleton-comparison + def test_filter_with_show_and_truncate(self, users_dataframe, capsys): + expected = '''+----------+----------+----------------+ +| user_id| name|paid_for_service| ++----------+----------+----------------+ +|user669...|Joshua ...| true| ++----------+----------+----------------+ + +''' + users_dataframe.filter(col('paid_for_service') == True).limit(1).show(truncate=10) + outcome = capsys.readouterr().out + assert_that(outcome, equal_to(expected)) def test_count(self, users_dataframe): outcome = users_dataframe.count() From 636e208fd9e40ff9c9955ca7c2a6f6359f67dd8c Mon Sep 17 00:00:00 2001 From: David Sisson Date: Thu, 28 Mar 2024 20:40:44 -0700 Subject: [PATCH 02/58] feat: fix the behavior of with_column (#25) With this PR the with_column tests in DataFusion now pass (ordering properly happens and the tests don't make extra calls to the server unnecessarily). The DuckDB version has an apparent internal error that needs to be investigated however. --- src/gateway/adbc/backend.py | 43 +++++++++++++------ src/gateway/converter/conversion_options.py | 4 -- src/gateway/converter/data/00001.splan | 36 ++++++++++++++-- src/gateway/converter/spark_to_substrait.py | 42 ++++++++---------- .../converter/substrait_plan_visitor.py | 13 +++--- src/gateway/converter/validation_test.py | 7 ++- src/gateway/tests/test_server.py | 27 ++++++------ 7 files changed, 105 insertions(+), 67 deletions(-) diff --git a/src/gateway/adbc/backend.py b/src/gateway/adbc/backend.py index 7c094fc..91599e8 100644 --- a/src/gateway/adbc/backend.py +++ b/src/gateway/adbc/backend.py @@ -13,6 +13,11 @@ from gateway.converter.replace_local_files import ReplaceLocalFilesWithNamedTable +# pylint: disable=protected-access +def _import(handle): + return pyarrow.RecordBatchReader._import_from_c(handle.address) + + # pylint: disable=fixme class AdbcBackend: """Provides methods for contacting an ADBC backend via Substrait.""" @@ -23,10 +28,12 @@ def __init__(self): def execute_with_duckdb_over_adbc(self, plan: 'plan_pb2.Plan') -> pyarrow.lib.Table: """Executes the given Substrait plan against DuckDB using ADBC.""" with adbc_driver_duckdb.dbapi.connect() as conn, conn.cursor() as cur: + cur.execute("LOAD substrait;") plan_data = plan.SerializeToString() cur.adbc_statement.set_substrait_plan(plan_data) - tbl = cur.fetch_arrow_table() - return tbl + res = cur.adbc_statement.execute_query() + table = _import(res[0]).read_all() + return table # pylint: disable=import-outside-toplevel def execute_with_datafusion(self, plan: 'plan_pb2.Plan') -> pyarrow.lib.Table: @@ -37,28 +44,40 @@ def execute_with_datafusion(self, plan: 'plan_pb2.Plan') -> pyarrow.lib.Table: ctx = datafusion.SessionContext() file_groups = ReplaceLocalFilesWithNamedTable().visit_plan(plan) + registered_tables = set() for files in file_groups: + table_name = files[0] for file in files[1]: - ctx.register_parquet(files[0], file) + if table_name not in registered_tables: + ctx.register_parquet(table_name, file) + registered_tables.add(files[0]) - plan_data = plan.SerializeToString() - substrait_plan = datafusion.substrait.substrait.serde.deserialize_bytes(plan_data) - logical_plan = datafusion.substrait.substrait.consumer.from_substrait_plan( - ctx, substrait_plan - ) + try: + plan_data = plan.SerializeToString() + substrait_plan = datafusion.substrait.substrait.serde.deserialize_bytes(plan_data) + logical_plan = datafusion.substrait.substrait.consumer.from_substrait_plan( + ctx, substrait_plan + ) - # Create a DataFrame from a deserialized logical plan - df_result = ctx.create_dataframe_from_logical_plan(logical_plan) - return df_result.to_arrow_table() + # Create a DataFrame from a deserialized logical plan + df_result = ctx.create_dataframe_from_logical_plan(logical_plan) + return df_result.to_arrow_table() + finally: + for table_name in registered_tables: + ctx.deregister_table(table_name) def execute_with_duckdb(self, plan: 'plan_pb2.Plan') -> pyarrow.lib.Table: """Executes the given Substrait plan against DuckDB.""" con = duckdb.connect(config={'max_memory': '100GB', + "allow_unsigned_extensions": "true", 'temp_directory': str(Path('.').absolute())}) con.install_extension('substrait') con.load_extension('substrait') plan_data = plan.SerializeToString() - query_result = con.from_substrait(proto=plan_data) + try: + query_result = con.from_substrait(proto=plan_data) + except Exception as err: + raise ValueError(f'DuckDB Execution Error: {err}') from err df = query_result.df() return pyarrow.Table.from_pandas(df=df) diff --git a/src/gateway/converter/conversion_options.py b/src/gateway/converter/conversion_options.py index 63e2d56..871534f 100644 --- a/src/gateway/converter/conversion_options.py +++ b/src/gateway/converter/conversion_options.py @@ -20,9 +20,6 @@ class ConversionOptions: def __init__(self, backend: BackendOptions = None): self.use_named_table_workaround = False self.needs_scheme_in_path_uris = False - self.use_project_emit_workaround = False - self.use_project_emit_workaround2 = False - self.use_project_emit_workaround3 = False self.use_emits_instead_of_direct = False self.return_names_with_types = False @@ -42,5 +39,4 @@ def duck_db(): """Standard options to connect to a DuckDB backend.""" options = ConversionOptions(backend=BackendOptions(Backend.DUCKDB)) options.return_names_with_types = True - options.use_project_emit_workaround3 = False return options diff --git a/src/gateway/converter/data/00001.splan b/src/gateway/converter/data/00001.splan index d2564f4..c736777 100644 --- a/src/gateway/converter/data/00001.splan +++ b/src/gateway/converter/data/00001.splan @@ -77,19 +77,49 @@ relations { input { project { common { - direct { + emit { + output_mapping: 0 + output_mapping: 1 + output_mapping: 2 + output_mapping: 3 + output_mapping: 4 + output_mapping: 5 + output_mapping: 6 + output_mapping: 7 + output_mapping: 8 + output_mapping: 10 } } input { project { common { - direct { + emit { + output_mapping: 0 + output_mapping: 1 + output_mapping: 2 + output_mapping: 3 + output_mapping: 4 + output_mapping: 5 + output_mapping: 6 + output_mapping: 10 + output_mapping: 8 + output_mapping: 9 } } input { project { common { - direct { + emit { + output_mapping: 0 + output_mapping: 1 + output_mapping: 2 + output_mapping: 3 + output_mapping: 4 + output_mapping: 5 + output_mapping: 10 + output_mapping: 7 + output_mapping: 8 + output_mapping: 9 } } input { diff --git a/src/gateway/converter/spark_to_substrait.py b/src/gateway/converter/spark_to_substrait.py index 9913ac6..96e82ba 100644 --- a/src/gateway/converter/spark_to_substrait.py +++ b/src/gateway/converter/spark_to_substrait.py @@ -653,33 +653,27 @@ def convert_with_columns_relation( project = algebra_pb2.ProjectRel(input=input_rel) self.update_field_references(rel.input.common.plan_id) symbol = self._symbol_table.get_symbol(self._current_plan_id) - field_number = 0 - if self._conversion_options.use_project_emit_workaround: - for _ in symbol.output_fields: - project.expressions.append(algebra_pb2.Expression( - selection=algebra_pb2.Expression.FieldReference( - direct_reference=algebra_pb2.Expression.ReferenceSegment( - struct_field=algebra_pb2.Expression.ReferenceSegment.StructField( - field=field_number))))) - field_number += 1 + remapped = False + mapping = list(range(len(symbol.input_fields))) + field_number = len(symbol.input_fields) for alias in rel.aliases: - # TODO -- Handle the common.emit.output_mapping columns correctly. + if len(alias.name) != 1: + raise ValueError('every column alias must have exactly one name') + name = alias.name[0] project.expressions.append(self.convert_expression(alias.expr)) - # TODO -- Add unique intermediate names. - symbol.generated_fields.append('intermediate') - symbol.output_fields.append('intermediate') - project.common.CopyFrom(self.create_common_relation()) - if (self._conversion_options.use_project_emit_workaround or - self._conversion_options.use_project_emit_workaround2): - field_number = 0 - for _ in symbol.output_fields: - project.common.emit.output_mapping.append(field_number) - field_number += 1 - if (self._conversion_options.use_project_emit_workaround or - self._conversion_options.use_project_emit_workaround3): - for _ in rel.aliases: - project.common.emit.output_mapping.append(field_number) + if name in symbol.input_fields: + remapped = True + mapping[symbol.input_fields.index(name)] = len(symbol.input_fields) + ( + len(project.expressions)) - 1 + else: + mapping.append(field_number) field_number += 1 + symbol.generated_fields.append(name) + symbol.output_fields.append(name) + project.common.CopyFrom(self.create_common_relation()) + if remapped: + for item in mapping: + project.common.emit.output_mapping.append(item) return algebra_pb2.Rel(project=project) def convert_to_df_relation(self, rel: spark_relations_pb2.ToDF) -> algebra_pb2.Rel: diff --git a/src/gateway/converter/substrait_plan_visitor.py b/src/gateway/converter/substrait_plan_visitor.py index 1f7e394..313de94 100644 --- a/src/gateway/converter/substrait_plan_visitor.py +++ b/src/gateway/converter/substrait_plan_visitor.py @@ -180,7 +180,7 @@ def visit_record(self, record: algebra_pb2.Expression.MultiOrList.Record) -> Any def visit_if_value(self, if_clause: algebra_pb2.Expression.SwitchExpression.IfValue) -> Any: """Visits an if value.""" if if_clause.HasField('if'): - self.visit_literal(if_clause.if_) + self.visit_expression(getattr(if_clause, 'if')) if if_clause.HasField('then'): self.visit_expression(if_clause.then) @@ -245,8 +245,8 @@ def visit_if_then(self, if_then: algebra_pb2.Expression.IfThen) -> Any: """Visits an if then.""" for if_then_if in if_then.ifs: self.visit_if_value(if_then_if) - if if_then.HasField('else_'): - self.visit_expression(if_then.else_) + if if_then.HasField('else'): + self.visit_expression(getattr(if_then, 'else')) def visit_switch_expression(self, expression: algebra_pb2.Expression.SwitchExpression) -> Any: """Visits a switch expression.""" @@ -255,7 +255,7 @@ def visit_switch_expression(self, expression: algebra_pb2.Expression.SwitchExpre for if_then_if in expression.ifs: self.visit_if_value(if_then_if) if expression.HasField('else'): - self.visit_expression(expression.else_) + self.visit_expression(getattr(expression, 'else')) def visit_singular_or_list(self, singular_or_list: algebra_pb2.Expression.SingularOrList) -> Any: @@ -418,7 +418,7 @@ def visit_expression(self, expression: algebra_pb2.Expression) -> Any: def visit_mask_expression(self, expression: algebra_pb2.Expression.MaskExpression) -> Any: """Visits a mask expression.""" - if expression.HasField('has_select'): + if expression.HasField('select'): self.visit_struct_select(expression.select) def visit_virtual_table(self, table: algebra_pb2.ReadRel.VirtualTable) -> Any: @@ -722,8 +722,7 @@ def visit_expand_relation(self, rel: algebra_pb2.ExpandRel) -> Any: return self.visit_relation(rel.input) for field in rel.fields: return self.visit_expand_field(field) - if rel.HasField('advanced_extension'): - return self.visit_advanced_extension(rel.advanced_extension) + # ExpandRel does not have an advanced_extension like other relations do. def visit_relation(self, rel: algebra_pb2.Rel) -> Any: """Visits a Substrait relation.""" diff --git a/src/gateway/converter/validation_test.py b/src/gateway/converter/validation_test.py index 9560ac0..4d5edd2 100644 --- a/src/gateway/converter/validation_test.py +++ b/src/gateway/converter/validation_test.py @@ -1,6 +1,5 @@ # SPDX-License-Identifier: Apache-2.0 """Validation for the Spark to Substrait plan conversion routines.""" -import os from pathlib import Path from google.protobuf import text_format @@ -9,11 +8,11 @@ import substrait_validator -test_case_directory = Path(os.path.dirname(os.path.realpath(__file__))) / 'data' +test_case_directory = Path(__file__).resolve().parent / 'data' -test_case_paths = [f for f in test_case_directory.iterdir() if f.name.endswith('.splan')] +test_case_paths = [f for f in test_case_directory.iterdir() if f.suffix == '.splan'] -test_case_names = [os.path.basename(p).removesuffix('.splan') for p in test_case_paths] +test_case_names = [p.stem for p in test_case_paths] # pylint: disable=E1101,fixme diff --git a/src/gateway/tests/test_server.py b/src/gateway/tests/test_server.py index 299244d..225a201 100644 --- a/src/gateway/tests/test_server.py +++ b/src/gateway/tests/test_server.py @@ -1,6 +1,7 @@ # SPDX-License-Identifier: Apache-2.0 """Tests for the Spark to Substrait Gateway server.""" from hamcrest import assert_that, equal_to +from pyspark import Row from pyspark.sql.functions import col, substring from pyspark.testing import assertDataFrameEqual @@ -59,26 +60,26 @@ def test_count(self, users_dataframe): outcome = users_dataframe.count() assert outcome == 100 - def test_limit(self, users_dataframe, spark_session): - expected = spark_session.createDataFrame( - data=[('user849118289', 'Brooke Jones', False), - ('user954079192', 'Collin Frank', False)], - schema=['user_id', 'name', 'paid_for_service']) + def test_limit(self, users_dataframe): + expected = [ + Row(user_id='user849118289', name='Brooke Jones', paid_for_service=False), + Row(user_id='user954079192', name='Collin Frank', paid_for_service=False), + ] outcome = users_dataframe.limit(2).collect() assertDataFrameEqual(outcome, expected) - def test_with_column(self, users_dataframe, spark_session): - expected = spark_session.createDataFrame( - data=[('user849118289', 'Brooke Jones', False)], - schema=['user_id', 'name', 'paid_for_service']) + def test_with_column(self, users_dataframe): + expected = [ + Row(user_id='user849118289', name='Brooke Jones', paid_for_service=False), + ] outcome = users_dataframe.withColumn( 'user_id', col('user_id')).limit(1).collect() assertDataFrameEqual(outcome, expected) - def test_cast(self, users_dataframe, spark_session): - expected = spark_session.createDataFrame( - data=[(849, 'Brooke Jones', False)], - schema=['user_id', 'name', 'paid_for_service']) + def test_cast(self, users_dataframe): + expected = [ + Row(user_id=849, name='Brooke Jones', paid_for_service=False), + ] outcome = users_dataframe.withColumn( 'user_id', substring(col('user_id'), 5, 3).cast('integer')).limit(1).collect() From 634b00f482c87ae3f46e033723ff6c5473ff2d81 Mon Sep 17 00:00:00 2001 From: David Sisson Date: Fri, 5 Apr 2024 00:31:42 -0700 Subject: [PATCH 03/58] feat: add ruff formatting check (#27) --- .github/workflows/ruff.yml | 26 ++++++++++++++++++++++++++ environment.yml | 1 + pyproject.toml | 2 ++ src/gateway/demo/client_demo.py | 1 + src/gateway/tests/test_server.py | 1 + 5 files changed, 31 insertions(+) create mode 100644 .github/workflows/ruff.yml diff --git a/.github/workflows/ruff.yml b/.github/workflows/ruff.yml new file mode 100644 index 0000000..d5de025 --- /dev/null +++ b/.github/workflows/ruff.yml @@ -0,0 +1,26 @@ +name: Ruff + +on: [push] + +jobs: + build: + runs-on: ubuntu-latest + strategy: + matrix: + python-version: ["3.12"] + steps: + - uses: actions/checkout@v4 + - name: Set up Python ${{ matrix.python-version }} + uses: actions/setup-python@v5 + with: + python-version: ${{ matrix.python-version }} + - name: Install dependencies + uses: conda-incubator/setup-miniconda@v3 + with: + miniforge-version: "latest" + activate-environment: base + environment-file: environment.yml + python-version: ${{ matrix.python }} + auto-activate-base: true + - name: Analyzing the code using ruff + uses: chartboost/ruff-action@v1 diff --git a/environment.yml b/environment.yml index 9b7b60a..dc3ff03 100644 --- a/environment.yml +++ b/environment.yml @@ -18,6 +18,7 @@ dependencies: - Faker - pip: - adbc_driver_manager + - cargo - pyarrow - duckdb == 0.10.1 - datafusion diff --git a/pyproject.toml b/pyproject.toml index 894f691..462d8d1 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -30,6 +30,8 @@ target-version = "py310" # never autoformat upstream or generated code exclude = ["third_party/", "src/spark/connect"] # do not autofix the following (will still get flagged in lint) + +[lint] unfixable = [ "F401", # unused imports "T201", # print statements diff --git a/src/gateway/demo/client_demo.py b/src/gateway/demo/client_demo.py index f802f71..f25012e 100644 --- a/src/gateway/demo/client_demo.py +++ b/src/gateway/demo/client_demo.py @@ -12,6 +12,7 @@ # pylint: disable=fixme +# ruff: noqa: E712 def execute_query(spark_session: SparkSession) -> None: """Runs a single sample query against the gateway.""" users_location = str(Path('users.parquet').absolute()) diff --git a/src/gateway/tests/test_server.py b/src/gateway/tests/test_server.py index 225a201..83ee15b 100644 --- a/src/gateway/tests/test_server.py +++ b/src/gateway/tests/test_server.py @@ -7,6 +7,7 @@ # pylint: disable=missing-function-docstring +# ruff: noqa: E712 class TestDataFrameAPI: """Tests of the dataframe side of SparkConnect.""" From 8b2e2b884096dbd692bf17d2d4848fb9945a1856 Mon Sep 17 00:00:00 2001 From: David Sisson Date: Fri, 5 Apr 2024 00:52:43 -0700 Subject: [PATCH 04/58] feat: fix the remaining datafusion test failures (#26) While this fixes the datafusion tests locally, there is a packaging issue that causes CI to fail. The datafusion tests have been left in an xfail state so the change can be available for development while the packaging issue is researched. --- .github/workflows/test.yml | 21 +++++----- environment.yml | 9 ++-- pyproject.toml | 4 +- scripts/pip_verbose.sh | 3 -- src/gateway/adbc/backend.py | 11 ++++- src/gateway/converter/rename_functions.py | 22 ++++++++++ src/gateway/converter/spark_to_substrait.py | 6 +-- src/gateway/converter/substrait_builder.py | 46 ++++++++++----------- src/gateway/server.py | 7 +++- src/gateway/tests/conftest.py | 4 +- 10 files changed, 81 insertions(+), 52 deletions(-) delete mode 100755 scripts/pip_verbose.sh create mode 100644 src/gateway/converter/rename_functions.py diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index d977976..9d34e23 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -15,24 +15,25 @@ jobs: strategy: matrix: os: [macos-latest, ubuntu-latest] - python: ["3.12"] + python: ["3.10"] runs-on: ${{ matrix.os }} steps: - name: Checkout code uses: actions/checkout@v4 with: submodules: recursive - - name: Set up Python - uses: actions/setup-python@v5 + - name: Install packages and test dependencies + uses: conda-incubator/setup-miniconda@v3 with: + activate-environment: spark-substrait-gateway-env + environment-file: environment.yml python-version: ${{ matrix.python }} - - name: Install package and test dependencies + auto-activate-base: false + - name: Build + shell: bash -el {0} run: | - export PIP_INTEROP_ENABLED=compatible - export PIP_SCRIPT=scripts/pip_verbose.sh - $CONDA/bin/conda env update --file environment.yml --name base - $PIP_SCRIPT install --upgrade pip - $PIP_SCRIPT install ".[test]" + pip install -e . - name: Run tests + shell: bash -el {0} run: | - $CONDA/bin/pytest + pytest diff --git a/environment.yml b/environment.yml index dc3ff03..19b2fa7 100644 --- a/environment.yml +++ b/environment.yml @@ -4,12 +4,11 @@ channels: dependencies: - pip - pre-commit - - protobuf >= 4.21.6 + - protobuf >= 4.21.6, < 5.0.0 - grpcio >= 1.48.1 - grpcio-status - grpcio-tools - - pytest >= 7.0.0 - - python >= 3.8.1 + - pytest >= 8.0.0 - setuptools >= 61.0.0 - setuptools_scm >= 6.2.0 - python-substrait >= 0.14.1 @@ -19,9 +18,9 @@ dependencies: - pip: - adbc_driver_manager - cargo - - pyarrow + - pyarrow >= 13.0.0 - duckdb == 0.10.1 - - datafusion + - datafusion >= 36.0.0 - pyspark - pandas >= 1.0.5 - pyhamcrest diff --git a/pyproject.toml b/pyproject.toml index 462d8d1..003541e 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -5,7 +5,7 @@ authors = [{name = "Substrait contributors", email = "substrait@googlegroups.com license = {text = "Apache-2.0"} readme = "README.md" requires-python = ">=3.8.1" -dependencies = ["protobuf >= 3.20", "datafusion"] +dependencies = ["protobuf >= 3.20", "datafusion >= 36.0.0", "pyarrow >= 15.0.2"] dynamic = ["version"] [tool.setuptools_scm] @@ -29,12 +29,12 @@ respect-gitignore = true target-version = "py310" # never autoformat upstream or generated code exclude = ["third_party/", "src/spark/connect"] -# do not autofix the following (will still get flagged in lint) [lint] unfixable = [ "F401", # unused imports "T201", # print statements + "E712", # truth comparison checks ] [tool.pylint.MASTER] diff --git a/scripts/pip_verbose.sh b/scripts/pip_verbose.sh deleted file mode 100755 index 5fd640a..0000000 --- a/scripts/pip_verbose.sh +++ /dev/null @@ -1,3 +0,0 @@ -#!/bin/bash -python -m pip "$@" -v - diff --git a/src/gateway/adbc/backend.py b/src/gateway/adbc/backend.py index 91599e8..ee537c1 100644 --- a/src/gateway/adbc/backend.py +++ b/src/gateway/adbc/backend.py @@ -10,6 +10,7 @@ from substrait.gen.proto import plan_pb2 from gateway.adbc.backend_options import BackendOptions, Backend +from gateway.converter.rename_functions import RenameFunctions from gateway.converter.replace_local_files import ReplaceLocalFilesWithNamedTable @@ -38,7 +39,6 @@ def execute_with_duckdb_over_adbc(self, plan: 'plan_pb2.Plan') -> pyarrow.lib.Ta # pylint: disable=import-outside-toplevel def execute_with_datafusion(self, plan: 'plan_pb2.Plan') -> pyarrow.lib.Table: """Executes the given Substrait plan against Datafusion.""" - import datafusion import datafusion.substrait ctx = datafusion.SessionContext() @@ -52,6 +52,8 @@ def execute_with_datafusion(self, plan: 'plan_pb2.Plan') -> pyarrow.lib.Table: ctx.register_parquet(table_name, file) registered_tables.add(files[0]) + RenameFunctions().visit_plan(plan) + try: plan_data = plan.SerializeToString() substrait_plan = datafusion.substrait.substrait.serde.deserialize_bytes(plan_data) @@ -59,8 +61,13 @@ def execute_with_datafusion(self, plan: 'plan_pb2.Plan') -> pyarrow.lib.Table: ctx, substrait_plan ) - # Create a DataFrame from a deserialized logical plan + # Create a DataFrame from a deserialized logical plan. df_result = ctx.create_dataframe_from_logical_plan(logical_plan) + for column_number, column_name in enumerate(df_result.schema().names): + df_result = df_result.with_column_renamed( + column_name, + plan.relations[0].root.names[column_number] + ) return df_result.to_arrow_table() finally: for table_name in registered_tables: diff --git a/src/gateway/converter/rename_functions.py b/src/gateway/converter/rename_functions.py new file mode 100644 index 0000000..958c6d3 --- /dev/null +++ b/src/gateway/converter/rename_functions.py @@ -0,0 +1,22 @@ +# SPDX-License-Identifier: Apache-2.0 +"""A library to search Substrait plan for local files.""" +from substrait.gen.proto import plan_pb2 + +from gateway.converter.substrait_plan_visitor import SubstraitPlanVisitor + + +# pylint: disable=no-member,fixme +class RenameFunctions(SubstraitPlanVisitor): + """Renames Substrait functions to match what Datafusion expects.""" + + def visit_plan(self, plan: plan_pb2.Plan) -> None: + """Modifies the provided plan so that functions are Datafusion compatible.""" + super().visit_plan(plan) + + for extension in plan.extensions: + if extension.WhichOneof('mapping_type') != 'extension_function': + continue + + # TODO -- Take the URI references into account. + if extension.extension_function.name == 'substring': + extension.extension_function.name = 'substr' diff --git a/src/gateway/converter/spark_to_substrait.py b/src/gateway/converter/spark_to_substrait.py index 96e82ba..a54afce 100644 --- a/src/gateway/converter/spark_to_substrait.py +++ b/src/gateway/converter/spark_to_substrait.py @@ -507,8 +507,6 @@ def convert_show_string_relation(self, rel: spark_relations_pb2.ShowString) -> a concat_func = self.lookup_function_by_name('concat') repeat_func = self.lookup_function_by_name('repeat') lpad_func = self.lookup_function_by_name('lpad') - least_func = self.lookup_function_by_name('least') - greatest_func = self.lookup_function_by_name('greatest') greater_func = self.lookup_function_by_name('>') minus_func = self.lookup_function_by_name('-') @@ -537,8 +535,8 @@ def convert_show_string_relation(self, rel: spark_relations_pb2.ShowString) -> a # Find the maximum we will use based on the truncate, max size, and column name length. project2 = project_relation( aggregate1, - [greatest_function(greatest_func, - least_function(least_func, field_reference(column_number), + [greatest_function(greater_func, + least_function(greater_func, field_reference(column_number), bigint_literal(rel.truncate)), strlen(strlen_func, string_literal(symbol.input_fields[column_number]))) for diff --git a/src/gateway/converter/substrait_builder.py b/src/gateway/converter/substrait_builder.py index 8dc81e1..eb7b718 100644 --- a/src/gateway/converter/substrait_builder.py +++ b/src/gateway/converter/substrait_builder.py @@ -124,28 +124,24 @@ def string_concat_agg_function(function_info: ExtensionFunction, algebra_pb2.FunctionArgument(value=string_literal(separator))]) -def least_function(function_info: ExtensionFunction, - expr1: algebra_pb2.Expression, +def least_function(greater_function_info: ExtensionFunction, expr1: algebra_pb2.Expression, expr2: algebra_pb2.Expression) -> algebra_pb2.Expression: """Constructs a Substrait min expression.""" - return algebra_pb2.Expression(scalar_function= - algebra_pb2.Expression.ScalarFunction( - function_reference=function_info.anchor, - output_type=function_info.output_type, - arguments=[algebra_pb2.FunctionArgument(value=expr1), - algebra_pb2.FunctionArgument(value=expr2)])) + return if_then_else_operation( + greater_function(greater_function_info, expr1, expr2), + expr2, + expr1 + ) -def greatest_function(function_info: ExtensionFunction, - expr1: algebra_pb2.Expression, +def greatest_function(greater_function_info: ExtensionFunction, expr1: algebra_pb2.Expression, expr2: algebra_pb2.Expression) -> algebra_pb2.Expression: - """Constructs a Substrait min expression.""" - return algebra_pb2.Expression(scalar_function= - algebra_pb2.Expression.ScalarFunction( - function_reference=function_info.anchor, - output_type=function_info.output_type, - arguments=[algebra_pb2.FunctionArgument(value=expr1), - algebra_pb2.FunctionArgument(value=expr2)])) + """Constructs a Substrait max expression.""" + return if_then_else_operation( + greater_function(greater_function_info, expr1, expr2), + expr1, + expr2 + ) def greater_or_equal_function(function_info: ExtensionFunction, @@ -200,30 +196,34 @@ def lpad_function(function_info: ExtensionFunction, expression: algebra_pb2.Expression, count: algebra_pb2.Expression, pad_string: str = ' ') -> algebra_pb2.AggregateFunction: """Constructs a Substrait concat expression.""" + # TODO -- Avoid a cast if we don't need it. + cast_type = string_type() return algebra_pb2.Expression(scalar_function= algebra_pb2.Expression.ScalarFunction( function_reference=function_info.anchor, output_type=function_info.output_type, arguments=[ - algebra_pb2.FunctionArgument(value=cast_operation(expression, varchar_type())), + algebra_pb2.FunctionArgument(value=cast_operation(expression, cast_type)), algebra_pb2.FunctionArgument(value=cast_operation(count, integer_type())), algebra_pb2.FunctionArgument( - value=cast_operation(string_literal(pad_string), varchar_type()))])) + value=cast_operation(string_literal(pad_string), cast_type))])) def rpad_function(function_info: ExtensionFunction, expression: algebra_pb2.Expression, count: algebra_pb2.Expression, pad_string: str = ' ') -> algebra_pb2.AggregateFunction: """Constructs a Substrait concat expression.""" + # TODO -- Avoid a cast if we don't need it. + cast_type = string_type() return algebra_pb2.Expression(scalar_function= algebra_pb2.Expression.ScalarFunction( function_reference=function_info.anchor, output_type=function_info.output_type, arguments=[ - algebra_pb2.FunctionArgument(value=cast_operation(expression, varchar_type())), + algebra_pb2.FunctionArgument(value=cast_operation(expression, cast_type)), algebra_pb2.FunctionArgument(value=cast_operation(count, integer_type())), algebra_pb2.FunctionArgument( - value=cast_operation(string_literal(pad_string), varchar_type()))])) + value=cast_operation(string_literal(pad_string), cast_type))])) def string_literal(val: str) -> algebra_pb2.Expression: @@ -245,13 +245,13 @@ def string_type(required: bool = True) -> type_pb2.Type: return type_pb2.Type(string=type_pb2.Type.String(nullability=nullability)) -def varchar_type(required: bool = True) -> type_pb2.Type: +def varchar_type(length: int = 1000, required: bool = True) -> type_pb2.Type: """Constructs a Substrait varchar type.""" if required: nullability = type_pb2.Type.Nullability.NULLABILITY_REQUIRED else: nullability = type_pb2.Type.Nullability.NULLABILITY_NULLABLE - return type_pb2.Type(varchar=type_pb2.Type.VarChar(nullability=nullability)) + return type_pb2.Type(varchar=type_pb2.Type.VarChar(length=length, nullability=nullability)) def integer_type(required: bool = True) -> type_pb2.Type: diff --git a/src/gateway/server.py b/src/gateway/server.py index 75ecea7..4f5ff9b 100644 --- a/src/gateway/server.py +++ b/src/gateway/server.py @@ -50,13 +50,16 @@ def convert_pyarrow_schema_to_spark(schema: pyarrow.Schema) -> types_pb2.DataTyp for field in schema: if field.type == pyarrow.bool_(): data_type = types_pb2.DataType(boolean=types_pb2.DataType.Boolean()) + elif field.type == pyarrow.int32(): + data_type = types_pb2.DataType(integer=types_pb2.DataType.Integer()) elif field.type == pyarrow.int64(): data_type = types_pb2.DataType(long=types_pb2.DataType.Long()) elif field.type == pyarrow.string(): data_type = types_pb2.DataType(string=types_pb2.DataType.String()) else: - raise ValueError( - f'Unsupported arrow schema to Spark schema conversion type: {field.type}') + raise NotImplementedError( + 'Conversion from Arrow schema to Spark schema not yet implemented ' + f'for type: {field.type}') struct_field = types_pb2.DataType.StructField(name=field.name, data_type=data_type) fields.append(struct_field) diff --git a/src/gateway/tests/conftest.py b/src/gateway/tests/conftest.py index a9c65fc..65c7fc3 100644 --- a/src/gateway/tests/conftest.py +++ b/src/gateway/tests/conftest.py @@ -71,7 +71,9 @@ def schema_users(): @pytest.fixture(scope='module', params=['spark', pytest.param('gateway-over-duckdb', marks=pytest.mark.xfail), - pytest.param('gateway-over-datafusion', marks=pytest.mark.xfail)]) + pytest.param('gateway-over-datafusion', + marks=pytest.mark.xfail( + reason='Datafusion Substrait missing in CI'))]) def spark_session(request): """Provides spark sessions connecting to various backends.""" match request.param: From dcf18e683874fae266a67f15762d04cc17fcb2e5 Mon Sep 17 00:00:00 2001 From: David Sisson Date: Fri, 5 Apr 2024 16:33:05 -0700 Subject: [PATCH 05/58] feat: remove pylint check now that we have ruff working (#28) --- .github/workflows/pylint.yml | 24 ------------------------ 1 file changed, 24 deletions(-) delete mode 100644 .github/workflows/pylint.yml diff --git a/.github/workflows/pylint.yml b/.github/workflows/pylint.yml deleted file mode 100644 index 37c47f6..0000000 --- a/.github/workflows/pylint.yml +++ /dev/null @@ -1,24 +0,0 @@ -name: Pylint - -on: [push] - -jobs: - build: - runs-on: ubuntu-latest - strategy: - matrix: - python-version: ["3.12"] - steps: - - uses: actions/checkout@v4 - - name: Set up Python ${{ matrix.python-version }} - uses: actions/setup-python@v5 - with: - python-version: ${{ matrix.python-version }} - - name: Install dependencies - run: | - $CONDA/bin/conda env update --file environment.yml --name base - python -m pip install --upgrade pip - $CONDA/bin/conda install pylint">=3.1.0" - - name: Analysing the code with pylint - run: | - $CONDA/bin/pylint $(git ls-files '*.py') From ce8556409f8c83acf0ac05acfa836a5bc934dba6 Mon Sep 17 00:00:00 2001 From: David Sisson Date: Sat, 6 Apr 2024 19:50:19 -0700 Subject: [PATCH 06/58] chore: add license check (#30) --- .github/workflows/ruff.yml | 2 ++ .licenserc.yaml | 25 +++++++++++++++++++++++++ 2 files changed, 27 insertions(+) create mode 100644 .licenserc.yaml diff --git a/.github/workflows/ruff.yml b/.github/workflows/ruff.yml index d5de025..4fbe69f 100644 --- a/.github/workflows/ruff.yml +++ b/.github/workflows/ruff.yml @@ -10,6 +10,8 @@ jobs: python-version: ["3.12"] steps: - uses: actions/checkout@v4 + - name: Check License Header + uses: apache/skywalking-eyes/header@v0.4.0 - name: Set up Python ${{ matrix.python-version }} uses: actions/setup-python@v5 with: diff --git a/.licenserc.yaml b/.licenserc.yaml new file mode 100644 index 0000000..d9ae273 --- /dev/null +++ b/.licenserc.yaml @@ -0,0 +1,25 @@ +header: + license: + spdx-id: Apache-2.0 + content: | + SPDX-License-Identifier: Apache-2.0 + + paths-ignore: + - '.github' + - '.gitignore' + - '.gitmodules' + - '.licenserc.yaml' + - '.pre-commit-config.yaml' + - 'environment.yml' + - 'pyproject.toml' + - 'LICENSE' + - '**/__init__.py' + - '**/*.golden' + - '**/*.md' + - '**/*.json' + - '**/*.spark' + - '**/*.splan' + - '**/*.sql' + - '**/*.sql-splan' + + comment: never From a18c16ee56da7b42f89ef4ae2c31b682ded7eec1 Mon Sep 17 00:00:00 2001 From: David Sisson Date: Sun, 7 Apr 2024 00:12:03 -0700 Subject: [PATCH 07/58] feat: add TPCH-parquet files (#31) This uses the parquet files through the github.com:tripl-ai/tpch repository as a submodule. --- .gitmodules | 3 +++ CONTRIBUTING.md | 6 ++++++ third_party/tpch | 1 + 3 files changed, 10 insertions(+) create mode 100644 .gitmodules create mode 160000 third_party/tpch diff --git a/.gitmodules b/.gitmodules new file mode 100644 index 0000000..e3d78a5 --- /dev/null +++ b/.gitmodules @@ -0,0 +1,3 @@ +[submodule "third_party/tpch"] + path = third_party/tpch + url = git@github.com:tripl-ai/tpch.git diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index 461f074..ee36344 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -6,6 +6,12 @@ git clone --recursive https://github.com//spark-substrait-gateway.git cd spark-substrait-gateway ``` +## Update the submodules +``` +git submodule init +git submodule update --recursive +``` + ## Conda env Create a conda environment with developer dependencies. ``` diff --git a/third_party/tpch b/third_party/tpch new file mode 160000 index 0000000..74f5a64 --- /dev/null +++ b/third_party/tpch @@ -0,0 +1 @@ +Subproject commit 74f5a64df4268626cba2224a135c5664062c196b From c8fb2dfe11dc269231908611c1475adc4e7ff63c Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Sun, 7 Apr 2024 22:58:44 -0700 Subject: [PATCH 08/58] chore(deps): bump apache/skywalking-eyes from 0.4.0 to 0.6.0 (#32) Signed-off-by: dependabot[bot] Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> --- .github/workflows/ruff.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/ruff.yml b/.github/workflows/ruff.yml index 4fbe69f..db245f7 100644 --- a/.github/workflows/ruff.yml +++ b/.github/workflows/ruff.yml @@ -11,7 +11,7 @@ jobs: steps: - uses: actions/checkout@v4 - name: Check License Header - uses: apache/skywalking-eyes/header@v0.4.0 + uses: apache/skywalking-eyes/header@v0.6.0 - name: Set up Python ${{ matrix.python-version }} uses: actions/setup-python@v5 with: From 3d83b8799ea79a1b3fccc59bc068b89af757d823 Mon Sep 17 00:00:00 2001 From: David Sisson Date: Mon, 8 Apr 2024 22:10:44 -0700 Subject: [PATCH 09/58] feat: alter client demo to use TPCH dataset (#33) --- src/gateway/converter/spark_to_substrait.py | 14 ++++- src/gateway/converter/sql_to_substrait.py | 3 +- src/gateway/demo/client_demo.py | 62 ++++++++++++++------- src/gateway/server.py | 7 ++- 4 files changed, 64 insertions(+), 22 deletions(-) diff --git a/src/gateway/converter/spark_to_substrait.py b/src/gateway/converter/spark_to_substrait.py index a54afce..2c0c73e 100644 --- a/src/gateway/converter/spark_to_substrait.py +++ b/src/gateway/converter/spark_to_substrait.py @@ -309,16 +309,28 @@ def convert_schema(self, schema_str: str) -> Optional[type_pb2.NamedStruct]: match field.get('type'): case 'boolean': field_type = type_pb2.Type(bool=type_pb2.Type.Boolean(nullability=nullability)) + case 'byte': + field_type = type_pb2.Type(i8=type_pb2.Type.I8(nullability=nullability)) case 'short': field_type = type_pb2.Type(i16=type_pb2.Type.I16(nullability=nullability)) case 'integer': field_type = type_pb2.Type(i32=type_pb2.Type.I32(nullability=nullability)) case 'long': field_type = type_pb2.Type(i64=type_pb2.Type.I64(nullability=nullability)) + case 'float': + field_type = type_pb2.Type(fp32=type_pb2.Type.FP32(nullability=nullability)) + case 'double': + field_type = type_pb2.Type(fp64=type_pb2.Type.FP64(nullability=nullability)) + case 'decimal': + field_type = type_pb2.Type( + decimal=type_pb2.Type.Decimal(nullability=nullability)) case 'string': field_type = type_pb2.Type(string=type_pb2.Type.String(nullability=nullability)) + case 'binary': + field_type = type_pb2.Type(binary=type_pb2.Type.Binary(nullability=nullability)) case _: - raise NotImplementedError(f'Unexpected field type: {field.get("type")}') + raise NotImplementedError( + f'Schema field type not yet implemented: {field.get("type")}') schema.struct.types.append(field_type) return schema diff --git a/src/gateway/converter/sql_to_substrait.py b/src/gateway/converter/sql_to_substrait.py index e3cfd6f..ee0e2ba 100644 --- a/src/gateway/converter/sql_to_substrait.py +++ b/src/gateway/converter/sql_to_substrait.py @@ -6,7 +6,7 @@ from substrait.gen.proto import plan_pb2 -# pylint: disable=E1101,too-few-public-methods +# pylint: disable=E1101,too-few-public-methods,fixme class SqlConverter: """Converts SQL to a Substrait plan.""" @@ -18,6 +18,7 @@ def convert_sql(self, sql: str) -> plan_pb2.Plan: con.install_extension('substrait') con.load_extension('substrait') + # TODO -- Rely on the client to register their own named tables. con.execute("CREATE TABLE users AS SELECT * FROM 'users.parquet'") proto_bytes = con.get_substrait(query=sql).fetchone()[0] diff --git a/src/gateway/demo/client_demo.py b/src/gateway/demo/client_demo.py index f25012e..6ead088 100644 --- a/src/gateway/demo/client_demo.py +++ b/src/gateway/demo/client_demo.py @@ -1,40 +1,64 @@ # SPDX-License-Identifier: Apache-2.0 """A PySpark client that can send sample queries to the gateway.""" -import atexit from pathlib import Path -from pyspark.sql import SparkSession +import pyarrow from pyspark.sql.functions import col +from pyspark.sql import SparkSession, DataFrame from pyspark.sql.pandas.types import from_arrow_schema -from gateway.demo.mystream_database import create_mystream_database, delete_mystream_database -from gateway.demo.mystream_database import get_mystream_schema +USE_GATEWAY = True + + +# pylint: disable=fixme +def future_get_customer_database(spark_session: SparkSession) -> DataFrame: + # TODO -- Use this when server-side schema evaluation is available. + location_customer = str(Path('../../third_party/tpch/parquet/customer').absolute()) + + return spark_session.read.parquet(location_customer, + mergeSchema=False) + + +def get_customer_database(spark_session: SparkSession) -> DataFrame: + location_customer = str(Path('../../third_party/tpch/parquet/customer').absolute()) + + schema_customer = pyarrow.schema([ + pyarrow.field('c_custkey', pyarrow.int64(), False), + pyarrow.field('c_name', pyarrow.string(), False), + pyarrow.field('c_address', pyarrow.string(), False), + pyarrow.field('c_nationkey', pyarrow.int64(), False), + pyarrow.field('c_phone', pyarrow.string(), False), + pyarrow.field('c_acctbal', pyarrow.float64(), False), + pyarrow.field('c_mktsegment', pyarrow.string(), False), + pyarrow.field('c_comment', pyarrow.string(), False), + ]) + + return (spark_session.read.format('parquet') + .schema(from_arrow_schema(schema_customer)) + .load(location_customer + '/*.parquet')) # pylint: disable=fixme -# ruff: noqa: E712 def execute_query(spark_session: SparkSession) -> None: """Runs a single sample query against the gateway.""" - users_location = str(Path('users.parquet').absolute()) - schema_users = get_mystream_schema('users') + df_customer = get_customer_database(spark_session) - df_users = spark_session.read.format('parquet') \ - .schema(from_arrow_schema(schema_users)) \ - .parquet(users_location) + # TODO -- Enable after named table registration is implemented. + # df_customer.createOrReplaceTempView('customer') # pylint: disable=singleton-comparison - df_users2 = df_users \ - .filter(col('paid_for_service') == True) \ - .sort(col('user_id')) \ + df_result = df_customer \ + .filter(col('c_mktsegment') == 'FURNITURE') \ + .sort(col('c_name')) \ .limit(10) - df_users2.show() + df_result.show() if __name__ == '__main__': - atexit.register(delete_mystream_database) - path = create_mystream_database() - - # TODO -- Make this configurable. - spark = SparkSession.builder.remote('sc://localhost:50051').getOrCreate() + if USE_GATEWAY: + # TODO -- Make the port configurable. + spark = SparkSession.builder.remote('sc://localhost:50051').getOrCreate() + else: + spark = SparkSession.builder.master('local').getOrCreate() execute_query(spark) diff --git a/src/gateway/server.py b/src/gateway/server.py index 4f5ff9b..fc830b1 100644 --- a/src/gateway/server.py +++ b/src/gateway/server.py @@ -85,7 +85,12 @@ def ExecutePlan( case 'root': substrait = convert.convert_plan(request.plan) case 'command': - substrait = SqlConverter().convert_sql(request.plan.command.sql_command) + match request.plan.command.WhichOneof('command_type'): + case 'sql_command': + substrait = SqlConverter().convert_sql(request.plan.command.sql_command.sql) + case _: + raise NotImplementedError( + f'Unsupported command type: {request.plan.command.WhichOneof("command_type")}') case _: raise ValueError(f'Unknown plan type: {request.plan}') _LOGGER.debug(' as Substrait: %s', substrait) From ef869138ccfa7615aff196b46dc9daf7b8db6959 Mon Sep 17 00:00:00 2001 From: Richard Tia Date: Tue, 9 Apr 2024 18:15:55 -0700 Subject: [PATCH 10/58] feat: resolve schema via adbc (#34) This PR adds server side schema evaluation for a single parquet file. TODO: Support for multiple files Supports for any adbc backend Support for merging schemas --- src/gateway/converter/spark_to_substrait.py | 55 +++++++++++++++++++++ src/gateway/demo/client_demo.py | 29 ++--------- 2 files changed, 59 insertions(+), 25 deletions(-) diff --git a/src/gateway/converter/spark_to_substrait.py b/src/gateway/converter/spark_to_substrait.py index 2c0c73e..27c1a3e 100644 --- a/src/gateway/converter/spark_to_substrait.py +++ b/src/gateway/converter/spark_to_substrait.py @@ -4,7 +4,9 @@ import operator from typing import Dict, Optional, List +import adbc_driver_duckdb.dbapi import pyarrow +import pyarrow.parquet import pyspark.sql.connect.proto.base_pb2 as spark_pb2 import pyspark.sql.connect.proto.expressions_pb2 as spark_exprs_pb2 import pyspark.sql.connect.proto.relations_pb2 as spark_relations_pb2 @@ -24,6 +26,22 @@ from gateway.converter.symbol_table import SymbolTable +DUCKDB_TABLE = "duckdb_table" + + +def fetch_schema_with_adbc(path): + """Fetch the arrow schema via ADBC.""" + + with adbc_driver_duckdb.dbapi.connect() as conn, conn.cursor() as cur: + # TODO: Support multiple paths. + reader = pyarrow.parquet.ParquetFile(path) + cur.adbc_ingest(DUCKDB_TABLE, reader.iter_batches(), mode="create") + schema = conn.adbc_get_table_schema(DUCKDB_TABLE) + cur.execute(f"DROP TABLE {DUCKDB_TABLE}") + + return schema + + # pylint: disable=E1101,fixme,too-many-public-methods class SparkSubstraitConverter: """Converts SparkConnect plans to Substrait plans.""" @@ -335,10 +353,47 @@ def convert_schema(self, schema_str: str) -> Optional[type_pb2.NamedStruct]: schema.struct.types.append(field_type) return schema + def convert_arrow_schema(self, arrow_schema: pyarrow.Schema) -> type_pb2.NamedStruct: + schema = type_pb2.NamedStruct() + schema.struct.nullability = type_pb2.Type.NULLABILITY_REQUIRED + + for field_idx in range(len(arrow_schema)): + field = arrow_schema[field_idx] + schema.names.append(field.name) + if field.nullable: + nullability = type_pb2.Type.NULLABILITY_NULLABLE + else: + nullability = type_pb2.Type.NULLABILITY_REQUIRED + + match str(field.type): + case 'bool': + field_type = type_pb2.Type(bool=type_pb2.Type.Boolean(nullability=nullability)) + case 'int16': + field_type = type_pb2.Type(i16=type_pb2.Type.I16(nullability=nullability)) + case 'int32': + field_type = type_pb2.Type(i32=type_pb2.Type.I32(nullability=nullability)) + case 'int64': + field_type = type_pb2.Type(i64=type_pb2.Type.I64(nullability=nullability)) + case 'float': + field_type = type_pb2.Type(fp32=type_pb2.Type.FP32(nullability=nullability)) + case 'double': + field_type = type_pb2.Type(fp64=type_pb2.Type.FP64(nullability=nullability)) + case 'string': + field_type = type_pb2.Type(string=type_pb2.Type.String(nullability=nullability)) + case _: + raise NotImplementedError(f'Unexpected field type: {field.type}') + + schema.struct.types.append(field_type) + return schema + def convert_read_data_source_relation(self, rel: spark_relations_pb2.Read) -> algebra_pb2.Rel: """Converts a read data source relation into a Substrait relation.""" local = algebra_pb2.ReadRel.LocalFiles() schema = self.convert_schema(rel.schema) + if not schema: + path = rel.paths[0] + arrow_schema = fetch_schema_with_adbc(path) + schema = self.convert_arrow_schema(arrow_schema) symbol = self._symbol_table.get_symbol(self._current_plan_id) for field_name in schema.names: symbol.output_fields.append(field_name) diff --git a/src/gateway/demo/client_demo.py b/src/gateway/demo/client_demo.py index 6ead088..a9c6f95 100644 --- a/src/gateway/demo/client_demo.py +++ b/src/gateway/demo/client_demo.py @@ -2,42 +2,21 @@ """A PySpark client that can send sample queries to the gateway.""" from pathlib import Path -import pyarrow -from pyspark.sql.functions import col from pyspark.sql import SparkSession, DataFrame -from pyspark.sql.pandas.types import from_arrow_schema +from pyspark.sql.functions import col USE_GATEWAY = True # pylint: disable=fixme -def future_get_customer_database(spark_session: SparkSession) -> DataFrame: - # TODO -- Use this when server-side schema evaluation is available. - location_customer = str(Path('../../third_party/tpch/parquet/customer').absolute()) +def get_customer_database(spark_session: SparkSession) -> DataFrame: + # TODO -- Support reading schema from multiple files. + location_customer = str(Path('../../../third_party/tpch/parquet/customer/part-0.parquet').absolute()) return spark_session.read.parquet(location_customer, mergeSchema=False) -def get_customer_database(spark_session: SparkSession) -> DataFrame: - location_customer = str(Path('../../third_party/tpch/parquet/customer').absolute()) - - schema_customer = pyarrow.schema([ - pyarrow.field('c_custkey', pyarrow.int64(), False), - pyarrow.field('c_name', pyarrow.string(), False), - pyarrow.field('c_address', pyarrow.string(), False), - pyarrow.field('c_nationkey', pyarrow.int64(), False), - pyarrow.field('c_phone', pyarrow.string(), False), - pyarrow.field('c_acctbal', pyarrow.float64(), False), - pyarrow.field('c_mktsegment', pyarrow.string(), False), - pyarrow.field('c_comment', pyarrow.string(), False), - ]) - - return (spark_session.read.format('parquet') - .schema(from_arrow_schema(schema_customer)) - .load(location_customer + '/*.parquet')) - - # pylint: disable=fixme def execute_query(spark_session: SparkSession) -> None: """Runs a single sample query against the gateway.""" From e95af5860ad6ece369b415222454b2fcfc3cdae2 Mon Sep 17 00:00:00 2001 From: Richard Tia Date: Wed, 10 Apr 2024 19:47:14 -0400 Subject: [PATCH 11/58] feat: support for passing folder when resolving schema (#36) * feat: support for passing folder when resolving schema * fix: add all files as part of convert_read_data_source_relation * fix: resolve path * fix: avoid passing entire read relation as input to function * fix: add type hints --- src/gateway/converter/spark_to_substrait.py | 22 +++++++++++++++------ src/gateway/demo/client_demo.py | 3 +-- 2 files changed, 17 insertions(+), 8 deletions(-) diff --git a/src/gateway/converter/spark_to_substrait.py b/src/gateway/converter/spark_to_substrait.py index 27c1a3e..dcb82bb 100644 --- a/src/gateway/converter/spark_to_substrait.py +++ b/src/gateway/converter/spark_to_substrait.py @@ -1,7 +1,9 @@ # SPDX-License-Identifier: Apache-2.0 """Routines to convert SparkConnect plans to Substrait plans.""" +import glob import json import operator +import pathlib from typing import Dict, Optional, List import adbc_driver_duckdb.dbapi @@ -25,16 +27,21 @@ if_then_else_operation, greater_function, minus_function from gateway.converter.symbol_table import SymbolTable - DUCKDB_TABLE = "duckdb_table" -def fetch_schema_with_adbc(path): +def fetch_schema_with_adbc(file_path: str, ext: str) -> pyarrow.Schema: """Fetch the arrow schema via ADBC.""" + file_paths = list(pathlib.Path(file_path).glob(f'*.{ext}')) + if len(file_paths) > 0: + # We sort the files because the later partitions don't have enough data to construct a schema. + file_paths = sorted([str(fp) for fp in file_paths]) + file_path = file_paths[0] + with adbc_driver_duckdb.dbapi.connect() as conn, conn.cursor() as cur: # TODO: Support multiple paths. - reader = pyarrow.parquet.ParquetFile(path) + reader = pyarrow.parquet.ParquetFile(file_path) cur.adbc_ingest(DUCKDB_TABLE, reader.iter_batches(), mode="create") schema = conn.adbc_get_table_schema(DUCKDB_TABLE) cur.execute(f"DROP TABLE {DUCKDB_TABLE}") @@ -391,8 +398,7 @@ def convert_read_data_source_relation(self, rel: spark_relations_pb2.Read) -> al local = algebra_pb2.ReadRel.LocalFiles() schema = self.convert_schema(rel.schema) if not schema: - path = rel.paths[0] - arrow_schema = fetch_schema_with_adbc(path) + arrow_schema = fetch_schema_with_adbc(rel.paths[0], rel.format) schema = self.convert_arrow_schema(arrow_schema) symbol = self._symbol_table.get_symbol(self._current_plan_id) for field_name in schema.names: @@ -402,7 +408,11 @@ def convert_read_data_source_relation(self, rel: spark_relations_pb2.Read) -> al read=algebra_pb2.ReadRel(base_schema=schema, named_table=algebra_pb2.ReadRel.NamedTable( names=['demotable']))) - for path in rel.paths: + if pathlib.Path(rel.paths[0]).is_dir(): + file_paths = glob.glob(f'{rel.paths[0]}/*{rel.format}') + else: + file_paths = rel.paths + for path in file_paths: uri_path = path if self._conversion_options.needs_scheme_in_path_uris: if uri_path.startswith('/'): diff --git a/src/gateway/demo/client_demo.py b/src/gateway/demo/client_demo.py index a9c6f95..c889d36 100644 --- a/src/gateway/demo/client_demo.py +++ b/src/gateway/demo/client_demo.py @@ -10,8 +10,7 @@ # pylint: disable=fixme def get_customer_database(spark_session: SparkSession) -> DataFrame: - # TODO -- Support reading schema from multiple files. - location_customer = str(Path('../../../third_party/tpch/parquet/customer/part-0.parquet').absolute()) + location_customer = str(Path('../../../third_party/tpch/parquet/customer').resolve()) return spark_session.read.parquet(location_customer, mergeSchema=False) From 2352a7f46ec7f7a0d945795ef480ca81e29b1db1 Mon Sep 17 00:00:00 2001 From: David Sisson Date: Thu, 11 Apr 2024 13:22:16 -0700 Subject: [PATCH 12/58] feat: add simple SQL side by side tests using TPC-H dataset (#35) - Fixes handling of sql as part of a plan (as compared to a sql command). - Adds automatic discovery of the TPC-H dataset directory to make it easier to run tests from anywhere in the repository. - Added a workaround to always register the TPC-H datasets on any substrait execution (or SQL conversion). --- src/gateway/adbc/backend.py | 13 ++ src/gateway/converter/data/count.sql | 2 +- src/gateway/converter/data/count.sql-splan | 40 +++++- src/gateway/converter/data/select.sql | 1 + src/gateway/converter/data/select.sql-splan | 120 ++++++++++++++++++ src/gateway/converter/spark_to_substrait.py | 21 +++ .../converter/spark_to_substrait_test.py | 4 +- src/gateway/converter/sql_to_substrait.py | 68 +++++++--- src/gateway/demo/client_demo.py | 14 +- src/gateway/server.py | 17 ++- src/gateway/tests/conftest.py | 30 +++-- .../{test_server.py => test_dataframe_api.py} | 0 src/gateway/tests/test_sql_api.py | 43 +++++++ 13 files changed, 326 insertions(+), 47 deletions(-) create mode 100644 src/gateway/converter/data/select.sql create mode 100644 src/gateway/converter/data/select.sql-splan rename src/gateway/tests/{test_server.py => test_dataframe_api.py} (100%) create mode 100644 src/gateway/tests/test_sql_api.py diff --git a/src/gateway/adbc/backend.py b/src/gateway/adbc/backend.py index ee537c1..d46cb26 100644 --- a/src/gateway/adbc/backend.py +++ b/src/gateway/adbc/backend.py @@ -12,6 +12,7 @@ from gateway.adbc.backend_options import BackendOptions, Backend from gateway.converter.rename_functions import RenameFunctions from gateway.converter.replace_local_files import ReplaceLocalFilesWithNamedTable +from gateway.converter.sql_to_substrait import register_table, find_tpch # pylint: disable=protected-access @@ -81,6 +82,18 @@ def execute_with_duckdb(self, plan: 'plan_pb2.Plan') -> pyarrow.lib.Table: con.install_extension('substrait') con.load_extension('substrait') plan_data = plan.SerializeToString() + + # TODO -- Rely on the client to register their own named tables. + tpch_location = find_tpch() + register_table(con, 'customer', tpch_location / 'customer') + register_table(con, 'lineitem', tpch_location / 'lineitem') + register_table(con, 'nation', tpch_location / 'nation') + register_table(con, 'orders', tpch_location / 'orders') + register_table(con, 'part', tpch_location / 'part') + register_table(con, 'partsupp', tpch_location / 'partsupp') + register_table(con, 'region', tpch_location / 'region') + register_table(con, 'supplier', tpch_location / 'supplier') + try: query_result = con.from_substrait(proto=plan_data) except Exception as err: diff --git a/src/gateway/converter/data/count.sql b/src/gateway/converter/data/count.sql index cf514b4..dd250a5 100644 --- a/src/gateway/converter/data/count.sql +++ b/src/gateway/converter/data/count.sql @@ -1 +1 @@ -SELECT COUNT(*) FROM users +SELECT COUNT(*) FROM customer diff --git a/src/gateway/converter/data/count.sql-splan b/src/gateway/converter/data/count.sql-splan index 7b20fe4..49eb7f5 100644 --- a/src/gateway/converter/data/count.sql-splan +++ b/src/gateway/converter/data/count.sql-splan @@ -13,10 +13,20 @@ relations { input { read { base_schema { - names: "user_id" - names: "name" - names: "paid_for_service" + names: "c_custkey" + names: "c_name" + names: "c_address" + names: "c_nationkey" + names: "c_phone" + names: "c_acctbal" + names: "c_mktsegment" + names: "c_comment" struct { + types { + i64 { + nullability: NULLABILITY_NULLABLE + } + } types { string { nullability: NULLABILITY_NULLABLE @@ -28,7 +38,27 @@ relations { } } types { - bool { + i64 { + nullability: NULLABILITY_NULLABLE + } + } + types { + string { + nullability: NULLABILITY_NULLABLE + } + } + types { + fp64 { + nullability: NULLABILITY_NULLABLE + } + } + types { + string { + nullability: NULLABILITY_NULLABLE + } + } + types { + string { nullability: NULLABILITY_NULLABLE } } @@ -36,7 +66,7 @@ relations { } } named_table { - names: "users" + names: "customer" } } } diff --git a/src/gateway/converter/data/select.sql b/src/gateway/converter/data/select.sql new file mode 100644 index 0000000..9a298ba --- /dev/null +++ b/src/gateway/converter/data/select.sql @@ -0,0 +1 @@ +SELECT c_custkey, c_phone, c_mktsegment FROM customer diff --git a/src/gateway/converter/data/select.sql-splan b/src/gateway/converter/data/select.sql-splan new file mode 100644 index 0000000..2014c28 --- /dev/null +++ b/src/gateway/converter/data/select.sql-splan @@ -0,0 +1,120 @@ +relations { + root { + input { + project { + input { + read { + base_schema { + names: "c_custkey" + names: "c_name" + names: "c_address" + names: "c_nationkey" + names: "c_phone" + names: "c_acctbal" + names: "c_mktsegment" + names: "c_comment" + struct { + types { + i64 { + nullability: NULLABILITY_NULLABLE + } + } + types { + string { + nullability: NULLABILITY_NULLABLE + } + } + types { + string { + nullability: NULLABILITY_NULLABLE + } + } + types { + i64 { + nullability: NULLABILITY_NULLABLE + } + } + types { + string { + nullability: NULLABILITY_NULLABLE + } + } + types { + fp64 { + nullability: NULLABILITY_NULLABLE + } + } + types { + string { + nullability: NULLABILITY_NULLABLE + } + } + types { + string { + nullability: NULLABILITY_NULLABLE + } + } + nullability: NULLABILITY_REQUIRED + } + } + projection { + select { + struct_items { + } + struct_items { + field: 4 + } + struct_items { + field: 6 + } + } + maintain_singular_struct: true + } + named_table { + names: "customer" + } + } + } + expressions { + selection { + direct_reference { + struct_field { + } + } + root_reference { + } + } + } + expressions { + selection { + direct_reference { + struct_field { + field: 1 + } + } + root_reference { + } + } + } + expressions { + selection { + direct_reference { + struct_field { + field: 2 + } + } + root_reference { + } + } + } + } + } + names: "c_custkey" + names: "c_phone" + names: "c_mktsegment" + } +} +version { + minor_number: 39 + producer: "DuckDB" +} diff --git a/src/gateway/converter/spark_to_substrait.py b/src/gateway/converter/spark_to_substrait.py index dcb82bb..b9d1008 100644 --- a/src/gateway/converter/spark_to_substrait.py +++ b/src/gateway/converter/spark_to_substrait.py @@ -20,6 +20,7 @@ from gateway.converter.conversion_options import ConversionOptions from gateway.converter.spark_functions import ExtensionFunction, lookup_spark_function +from gateway.converter.sql_to_substrait import convert_sql from gateway.converter.substrait_builder import field_reference, cast_operation, string_type, \ project_relation, strlen, concat, fetch_relation, join_relation, aggregate_relation, \ max_agg_function, string_literal, flatten, repeat_function, \ @@ -60,6 +61,8 @@ def __init__(self, options: ConversionOptions): self._symbol_table = SymbolTable() self._conversion_options = options self._seen_generated_names = {} + self._saved_extension_uris = {} + self._saved_extensions = {} def lookup_function_by_name(self, name: str) -> ExtensionFunction: """Finds the function reference for a given Spark function name.""" @@ -804,6 +807,19 @@ def convert_local_relation(self, rel: spark_relations_pb2.LocalRelation) -> alge read.common.CopyFrom(self.create_common_relation()) return algebra_pb2.Rel(read=read) + def convert_sql_relation(self, rel: spark_relations_pb2.SQL) -> algebra_pb2.Rel: + """Converts a Spark SQL relation into a Substrait relation.""" + plan = convert_sql(rel.query) + symbol = self._symbol_table.get_symbol(self._current_plan_id) + for field_name in plan.relations[0].root.names: + symbol.output_fields.append(field_name) + # TODO -- Correctly capture all the used functions and extensions. + self._saved_extension_uris = plan.extension_uris + self._saved_extensions = plan.extensions + # TODO -- Merge those references into the current context. + # TODO -- Renumber all of the functions/extensions in the captured subplan. + return plan.relations[0].root.input + def convert_relation(self, rel: spark_relations_pb2.Relation) -> algebra_pb2.Rel: """Converts a Spark relation into a Substrait one.""" self._symbol_table.add_symbol(rel.common.plan_id, parent=self._current_plan_id, @@ -829,6 +845,8 @@ def convert_relation(self, rel: spark_relations_pb2.Relation) -> algebra_pb2.Rel result = self.convert_to_df_relation(rel.to_df) case 'local_relation': result = self.convert_local_relation(rel.local_relation) + case 'sql': + result = self.convert_sql_relation(rel.sql) case _: raise ValueError( f'Unexpected Spark plan rel_type: {rel.WhichOneof("rel_type")}') @@ -855,4 +873,7 @@ def convert_plan(self, plan: spark_pb2.Plan) -> plan_pb2.Plan: extension_function=extensions_pb2.SimpleExtensionDeclaration.ExtensionFunction( extension_uri_reference=self._function_uris.get(f.uri), function_anchor=f.anchor, name=f.name))) + # As a workaround use the saved extensions and URIs without fixing them. + result.extension_uris.extend(self._saved_extension_uris) + result.extensions.extend(self._saved_extensions) return result diff --git a/src/gateway/converter/spark_to_substrait_test.py b/src/gateway/converter/spark_to_substrait_test.py index 9d35f5b..c0d2605 100644 --- a/src/gateway/converter/spark_to_substrait_test.py +++ b/src/gateway/converter/spark_to_substrait_test.py @@ -9,7 +9,7 @@ from gateway.converter.conversion_options import duck_db from gateway.converter.spark_to_substrait import SparkSubstraitConverter -from gateway.converter.sql_to_substrait import SqlConverter +from gateway.converter.sql_to_substrait import convert_sql from gateway.demo.mystream_database import create_mystream_database, delete_mystream_database test_case_directory = Path(__file__).resolve().parent / 'data' @@ -81,7 +81,7 @@ def test_sql_conversion(request, path): splan_prototext = file.read() substrait_plan = text_format.Parse(splan_prototext, plan_pb2.Plan()) - substrait = SqlConverter().convert_sql(str(sql)) + substrait = convert_sql(str(sql)) if request.config.getoption('rebuild_goldens'): if substrait != substrait_plan: diff --git a/src/gateway/converter/sql_to_substrait.py b/src/gateway/converter/sql_to_substrait.py index ee0e2ba..3ef55fe 100644 --- a/src/gateway/converter/sql_to_substrait.py +++ b/src/gateway/converter/sql_to_substrait.py @@ -1,26 +1,60 @@ # SPDX-License-Identifier: Apache-2.0 """Routines to convert SparkConnect plans to Substrait plans.""" from pathlib import Path +from typing import List import duckdb from substrait.gen.proto import plan_pb2 +def _expand_location(location: Path | str) -> List[str]: + """Expands the location of a file or directory into a list of files.""" + # TODO -- Handle more than just Parquet files (here and below). + files = Path(location).resolve().glob('*.parquet') + return sorted(str(f) for f in files) + + +def find_tpch() -> Path: + """Finds the location of the TPCH dataset.""" + current_location = Path('.').resolve() + while current_location != Path('/'): + location = current_location / 'third_party' / 'tpch' / 'parquet' + if location.exists(): + return location.resolve() + current_location = current_location.parent + raise ValueError('TPCH dataset not found') + + +def register_table(con: duckdb.DuckDBPyConnection, table_name, location: Path | str) -> None: + files = _expand_location(location) + if not files: + raise ValueError(f"No parquet files found at {location}") + files_str = ', '.join([f"'{f}'" for f in files]) + files_sql = f"CREATE OR REPLACE TABLE {table_name} AS FROM read_parquet([{files_str}])" + + con.execute(files_sql) + + # pylint: disable=E1101,too-few-public-methods,fixme -class SqlConverter: - """Converts SQL to a Substrait plan.""" - - def convert_sql(self, sql: str) -> plan_pb2.Plan: - """Converts SQL into a Substrait plan.""" - result = plan_pb2.Plan() - con = duckdb.connect(config={'max_memory': '100GB', - 'temp_directory': str(Path('.').absolute())}) - con.install_extension('substrait') - con.load_extension('substrait') - - # TODO -- Rely on the client to register their own named tables. - con.execute("CREATE TABLE users AS SELECT * FROM 'users.parquet'") - - proto_bytes = con.get_substrait(query=sql).fetchone()[0] - result.ParseFromString(proto_bytes) - return result +def convert_sql(sql: str) -> plan_pb2.Plan: + """Converts SQL into a Substrait plan.""" + result = plan_pb2.Plan() + con = duckdb.connect(config={'max_memory': '100GB', + 'temp_directory': str(Path('.').resolve())}) + con.install_extension('substrait') + con.load_extension('substrait') + + # TODO -- Rely on the client to register their own named tables. + tpch_location = find_tpch() + register_table(con, 'customer', tpch_location / 'customer') + register_table(con, 'lineitem', tpch_location / 'lineitem') + register_table(con, 'nation', tpch_location / 'nation') + register_table(con, 'orders', tpch_location / 'orders') + register_table(con, 'part', tpch_location / 'part') + register_table(con, 'partsupp', tpch_location / 'partsupp') + register_table(con, 'region', tpch_location / 'region') + register_table(con, 'supplier', tpch_location / 'supplier') + + proto_bytes = con.get_substrait(query=sql).fetchone()[0] + result.ParseFromString(proto_bytes) + return result diff --git a/src/gateway/demo/client_demo.py b/src/gateway/demo/client_demo.py index c889d36..1afdeb8 100644 --- a/src/gateway/demo/client_demo.py +++ b/src/gateway/demo/client_demo.py @@ -1,19 +1,19 @@ # SPDX-License-Identifier: Apache-2.0 """A PySpark client that can send sample queries to the gateway.""" -from pathlib import Path -from pyspark.sql import SparkSession, DataFrame from pyspark.sql.functions import col +from pyspark.sql import SparkSession, DataFrame + +from gateway.converter.sql_to_substrait import find_tpch USE_GATEWAY = True # pylint: disable=fixme def get_customer_database(spark_session: SparkSession) -> DataFrame: - location_customer = str(Path('../../../third_party/tpch/parquet/customer').resolve()) + location_customer = str(find_tpch() / 'customer') - return spark_session.read.parquet(location_customer, - mergeSchema=False) + return spark_session.read.parquet(location_customer, mergeSchema=False) # pylint: disable=fixme @@ -32,6 +32,10 @@ def execute_query(spark_session: SparkSession) -> None: df_result.show() + sql_results = spark_session.sql( + 'SELECT c_custkey, c_phone, c_mktsegment FROM customer LIMIT 5').collect() + print(sql_results) + if __name__ == '__main__': if USE_GATEWAY: diff --git a/src/gateway/server.py b/src/gateway/server.py index fc830b1..37cd668 100644 --- a/src/gateway/server.py +++ b/src/gateway/server.py @@ -14,7 +14,7 @@ from gateway.converter.conversion_options import duck_db, datafusion from gateway.converter.spark_to_substrait import SparkSubstraitConverter from gateway.adbc.backend import AdbcBackend -from gateway.converter.sql_to_substrait import SqlConverter +from gateway.converter.sql_to_substrait import convert_sql _LOGGER = logging.getLogger(__name__) @@ -80,14 +80,14 @@ def ExecutePlan( self, request: pb2.ExecutePlanRequest, context: grpc.RpcContext) -> Generator[ pb2.ExecutePlanResponse, None, None]: _LOGGER.info('ExecutePlan: %s', request) - convert = SparkSubstraitConverter(self._options) match request.plan.WhichOneof('op_type'): case 'root': + convert = SparkSubstraitConverter(self._options) substrait = convert.convert_plan(request.plan) case 'command': match request.plan.command.WhichOneof('command_type'): case 'sql_command': - substrait = SqlConverter().convert_sql(request.plan.command.sql_command.sql) + substrait = convert_sql(request.plan.command.sql_command.sql) case _: raise NotImplementedError( f'Unsupported command type: {request.plan.command.WhichOneof("command_type")}') @@ -98,7 +98,8 @@ def ExecutePlan( results = backend.execute(substrait, self._options.backend) _LOGGER.debug(' results are: %s', results) - if not self._options.implement_show_string and request.plan.root.WhichOneof( + if not self._options.implement_show_string and request.plan.WhichOneof( + 'op_type') == 'root' and request.plan.root.WhichOneof( 'rel_type') == 'show_string': yield pb2.ExecutePlanResponse( session_id=request.session_id, @@ -121,7 +122,13 @@ def ExecutePlan( data=batch_to_bytes(batch, results.schema)), schema=convert_pyarrow_schema_to_spark(results.schema), ) - # TODO -- When spark 3.4.0 support is not required, yield a ResultComplete message here. + + for option in request.request_options: + if option.reattach_options.reattachable: + yield pb2.ExecutePlanResponse( + session_id=request.session_id, + result_complete=pb2.ExecutePlanResponse.ResultComplete()) + return def AnalyzePlan(self, request, context): _LOGGER.info('AnalyzePlan: %s', request) diff --git a/src/gateway/tests/conftest.py b/src/gateway/tests/conftest.py index 65c7fc3..ec463f8 100644 --- a/src/gateway/tests/conftest.py +++ b/src/gateway/tests/conftest.py @@ -2,8 +2,8 @@ """Test fixtures for pytest of the gateway server.""" from pathlib import Path -from pyspark.sql import SparkSession from pyspark.sql.pandas.types import from_arrow_schema +from pyspark.sql.session import SparkSession import pytest from gateway.demo.mystream_database import create_mystream_database, delete_mystream_database @@ -11,12 +11,12 @@ from gateway.server import serve -def _create_local_spark_session(): +def _create_local_spark_session() -> SparkSession: """Creates a local spark session for testing.""" spark = ( SparkSession .builder - .master('local') + .master('local[*]') .config("spark.driver.bindAddress", "127.0.0.1") .appName('gateway') .getOrCreate() @@ -25,7 +25,7 @@ def _create_local_spark_session(): spark.stop() -def _create_gateway_session(backend: str): +def _create_gateway_session(backend: str) -> SparkSession: """Creates a local gateway session for testing.""" spark_gateway = ( SparkSession @@ -34,7 +34,7 @@ def _create_gateway_session(backend: str): .config("spark.driver.bindAddress", "127.0.0.1") .config("spark-substrait-gateway.backend", backend) .appName('gateway') - .getOrCreate() + .create() ) yield spark_gateway spark_gateway.stop() @@ -48,7 +48,7 @@ def manage_database() -> None: delete_mystream_database() -@pytest.fixture(scope='module', autouse=True) +@pytest.fixture(scope='session', autouse=True) def gateway_server(): """Starts up a spark to substrait gateway service.""" server = serve(50052, wait=False) @@ -57,9 +57,9 @@ def gateway_server(): @pytest.fixture(scope='session') -def users_location(): +def users_location() -> str: """Provides the location of the users database.""" - return str(Path('users.parquet').absolute()) + return str(Path('users.parquet').resolve()) @pytest.fixture(scope='session') @@ -68,15 +68,21 @@ def schema_users(): return get_mystream_schema('users') -@pytest.fixture(scope='module', +@pytest.fixture(scope='session', params=['spark', pytest.param('gateway-over-duckdb', marks=pytest.mark.xfail), pytest.param('gateway-over-datafusion', marks=pytest.mark.xfail( reason='Datafusion Substrait missing in CI'))]) -def spark_session(request): +def source(request) -> str: + """Provides the source (backend) to be used.""" + return request.param + + +@pytest.fixture(scope='session') +def spark_session(source): """Provides spark sessions connecting to various backends.""" - match request.param: + match source: case 'spark': session_generator = _create_local_spark_session() case 'gateway-over-datafusion': @@ -84,7 +90,7 @@ def spark_session(request): case 'gateway-over-duckdb': session_generator = _create_gateway_session('duckdb') case _: - raise NotImplementedError(f'No such session implemented: {request.param}') + raise NotImplementedError(f'No such session implemented: {source}') yield from session_generator diff --git a/src/gateway/tests/test_server.py b/src/gateway/tests/test_dataframe_api.py similarity index 100% rename from src/gateway/tests/test_server.py rename to src/gateway/tests/test_dataframe_api.py diff --git a/src/gateway/tests/test_sql_api.py b/src/gateway/tests/test_sql_api.py new file mode 100644 index 0000000..c199558 --- /dev/null +++ b/src/gateway/tests/test_sql_api.py @@ -0,0 +1,43 @@ +# SPDX-License-Identifier: Apache-2.0 +"""Tests for the Spark to Substrait Gateway server.""" +import pytest +from hamcrest import assert_that, equal_to +from pyspark import Row +from pyspark.sql.session import SparkSession +from pyspark.testing import assertDataFrameEqual + +from gateway.converter.sql_to_substrait import find_tpch + + +@pytest.fixture(scope='function') +def spark_session_with_customer_database(spark_session: SparkSession, source: str) -> SparkSession: + """Creates a temporary view of the customer database.""" + if source == 'spark': + customer_location = find_tpch() / 'customer' + spark_session.sql( + 'CREATE OR REPLACE TEMPORARY VIEW customer USING org.apache.spark.sql.parquet ' + f'OPTIONS ( path "{customer_location}" )') + return spark_session + + +# pylint: disable=missing-function-docstring +# ruff: noqa: E712 +class TestSqlAPI: + """Tests of the SQL side of SparkConnect.""" + + def test_count(self, spark_session_with_customer_database): + outcome = spark_session_with_customer_database.sql( + 'SELECT COUNT(*) FROM customer').collect() + assert_that(outcome[0][0], equal_to(149999)) + + def test_limit(self, spark_session_with_customer_database): + expected = [ + Row(c_custkey=2, c_phone='23-768-687-3665', c_mktsegment='AUTOMOBILE'), + Row(c_custkey=3, c_phone='11-719-748-3364', c_mktsegment='AUTOMOBILE'), + Row(c_custkey=4, c_phone='14-128-190-5944', c_mktsegment='MACHINERY'), + Row(c_custkey=5, c_phone='13-750-942-6364', c_mktsegment='HOUSEHOLD'), + Row(c_custkey=6, c_phone='30-114-968-4951', c_mktsegment='AUTOMOBILE'), + ] + outcome = spark_session_with_customer_database.sql( + 'SELECT c_custkey, c_phone, c_mktsegment FROM customer LIMIT 5').collect() + assertDataFrameEqual(outcome, expected) From 3ef6338bc35ff36046456c62ba5de340fa77c7ef Mon Sep 17 00:00:00 2001 From: David Sisson Date: Thu, 11 Apr 2024 18:59:05 -0700 Subject: [PATCH 13/58] feat: add side by side tests with the TPC-H (#37) The SQL used for the TPC-H queries is in the DuckDB dialect. --- .../converter/spark_to_substrait_test.py | 2 +- src/gateway/server.py | 10 +++++ src/gateway/tests/conftest.py | 10 ++++- src/gateway/tests/data/01.sql | 21 +++++++++ src/gateway/tests/data/02.sql | 43 +++++++++++++++++++ src/gateway/tests/data/03.sql | 23 ++++++++++ src/gateway/tests/data/04.sql | 20 +++++++++ src/gateway/tests/data/05.sql | 24 +++++++++++ src/gateway/tests/data/06.sql | 10 +++++ src/gateway/tests/data/07.sql | 38 ++++++++++++++++ src/gateway/tests/data/08.sql | 38 ++++++++++++++++ src/gateway/tests/data/09.sql | 30 +++++++++++++ src/gateway/tests/data/10.sql | 32 ++++++++++++++ src/gateway/tests/data/11.sql | 27 ++++++++++++ src/gateway/tests/data/12.sql | 30 +++++++++++++ src/gateway/tests/data/13.sql | 19 ++++++++ src/gateway/tests/data/14.sql | 14 ++++++ src/gateway/tests/data/15.sql | 37 ++++++++++++++++ src/gateway/tests/data/16.sql | 29 +++++++++++++ src/gateway/tests/data/17.sql | 16 +++++++ src/gateway/tests/data/18.sql | 33 ++++++++++++++ src/gateway/tests/data/19.sql | 29 +++++++++++++ src/gateway/tests/data/20.sql | 34 +++++++++++++++ src/gateway/tests/data/21.sql | 38 ++++++++++++++++ src/gateway/tests/data/22.sql | 31 +++++++++++++ src/gateway/tests/test_sql_api.py | 40 +++++++++++++++-- 26 files changed, 672 insertions(+), 6 deletions(-) create mode 100644 src/gateway/tests/data/01.sql create mode 100644 src/gateway/tests/data/02.sql create mode 100644 src/gateway/tests/data/03.sql create mode 100644 src/gateway/tests/data/04.sql create mode 100644 src/gateway/tests/data/05.sql create mode 100644 src/gateway/tests/data/06.sql create mode 100644 src/gateway/tests/data/07.sql create mode 100644 src/gateway/tests/data/08.sql create mode 100644 src/gateway/tests/data/09.sql create mode 100644 src/gateway/tests/data/10.sql create mode 100644 src/gateway/tests/data/11.sql create mode 100644 src/gateway/tests/data/12.sql create mode 100644 src/gateway/tests/data/13.sql create mode 100644 src/gateway/tests/data/14.sql create mode 100644 src/gateway/tests/data/15.sql create mode 100644 src/gateway/tests/data/16.sql create mode 100644 src/gateway/tests/data/17.sql create mode 100644 src/gateway/tests/data/18.sql create mode 100644 src/gateway/tests/data/19.sql create mode 100644 src/gateway/tests/data/20.sql create mode 100644 src/gateway/tests/data/21.sql create mode 100644 src/gateway/tests/data/22.sql diff --git a/src/gateway/converter/spark_to_substrait_test.py b/src/gateway/converter/spark_to_substrait_test.py index c0d2605..df1c03f 100644 --- a/src/gateway/converter/spark_to_substrait_test.py +++ b/src/gateway/converter/spark_to_substrait_test.py @@ -71,7 +71,7 @@ def manage_database() -> None: ) def test_sql_conversion(request, path): """Test the conversion of SQL to a Substrait plan.""" - # Read the Spark plan to convert. + # Read the SQL to run. with open(path, "rb") as file: sql_bytes = file.read() sql = sql_bytes.decode('utf-8') diff --git a/src/gateway/server.py b/src/gateway/server.py index 37cd668..2ca07d3 100644 --- a/src/gateway/server.py +++ b/src/gateway/server.py @@ -50,12 +50,22 @@ def convert_pyarrow_schema_to_spark(schema: pyarrow.Schema) -> types_pb2.DataTyp for field in schema: if field.type == pyarrow.bool_(): data_type = types_pb2.DataType(boolean=types_pb2.DataType.Boolean()) + elif field.type == pyarrow.int8(): + data_type = types_pb2.DataType(byte=types_pb2.DataType.Byte()) + elif field.type == pyarrow.int16(): + data_type = types_pb2.DataType(integer=types_pb2.DataType.Short()) elif field.type == pyarrow.int32(): data_type = types_pb2.DataType(integer=types_pb2.DataType.Integer()) elif field.type == pyarrow.int64(): data_type = types_pb2.DataType(long=types_pb2.DataType.Long()) + elif field.type == pyarrow.float32(): + data_type = types_pb2.DataType(float=types_pb2.DataType.Float()) + elif field.type == pyarrow.float64(): + data_type = types_pb2.DataType(double=types_pb2.DataType.Double()) elif field.type == pyarrow.string(): data_type = types_pb2.DataType(string=types_pb2.DataType.String()) + elif field.type == pyarrow.timestamp('us'): + data_type = types_pb2.DataType(timestamp=types_pb2.DataType.Timestamp()) else: raise NotImplementedError( 'Conversion from Arrow schema to Spark schema not yet implemented ' diff --git a/src/gateway/tests/conftest.py b/src/gateway/tests/conftest.py index ec463f8..63d1d64 100644 --- a/src/gateway/tests/conftest.py +++ b/src/gateway/tests/conftest.py @@ -17,10 +17,18 @@ def _create_local_spark_session() -> SparkSession: SparkSession .builder .master('local[*]') - .config("spark.driver.bindAddress", "127.0.0.1") + .config("spark.driver.memory", "2g") .appName('gateway') .getOrCreate() ) + + conf = spark.sparkContext.getConf() + # Dump the configuration settings for debug purposes. + print("==== BEGIN SPARK CONFIG ====") + for k, v in sorted(conf.getAll()): + print(f"{k} = {v}") + print("===== END SPARK CONFIG =====") + yield spark spark.stop() diff --git a/src/gateway/tests/data/01.sql b/src/gateway/tests/data/01.sql new file mode 100644 index 0000000..1513b8c --- /dev/null +++ b/src/gateway/tests/data/01.sql @@ -0,0 +1,21 @@ +SELECT + l_returnflag, + l_linestatus, + sum(l_quantity) AS sum_qty, + sum(l_extendedprice) AS sum_base_price, + sum(l_extendedprice * (1 - l_discount)) AS sum_disc_price, + sum(l_extendedprice * (1 - l_discount) * (1 + l_tax)) AS sum_charge, + avg(l_quantity) AS avg_qty, + avg(l_extendedprice) AS avg_price, + avg(l_discount) AS avg_disc, + count(*) AS count_order +FROM + lineitem +WHERE + l_shipdate <= CAST('1998-09-02' AS date) +GROUP BY + l_returnflag, + l_linestatus +ORDER BY + l_returnflag, + l_linestatus; diff --git a/src/gateway/tests/data/02.sql b/src/gateway/tests/data/02.sql new file mode 100644 index 0000000..206a472 --- /dev/null +++ b/src/gateway/tests/data/02.sql @@ -0,0 +1,43 @@ +SELECT + s_acctbal, + s_name, + n_name, + p_partkey, + p_mfgr, + s_address, + s_phone, + s_comment +FROM + part, + supplier, + partsupp, + nation, + region +WHERE + p_partkey = ps_partkey + AND s_suppkey = ps_suppkey + AND p_size = 15 + AND p_type LIKE '%BRASS' + AND s_nationkey = n_nationkey + AND n_regionkey = r_regionkey + AND r_name = 'EUROPE' + AND ps_supplycost = ( + SELECT + min(ps_supplycost) + FROM + partsupp, + supplier, + nation, + region + WHERE + p_partkey = ps_partkey + AND s_suppkey = ps_suppkey + AND s_nationkey = n_nationkey + AND n_regionkey = r_regionkey + AND r_name = 'EUROPE') +ORDER BY + s_acctbal DESC, + n_name, + s_name, + p_partkey +LIMIT 100; diff --git a/src/gateway/tests/data/03.sql b/src/gateway/tests/data/03.sql new file mode 100644 index 0000000..b05a73b --- /dev/null +++ b/src/gateway/tests/data/03.sql @@ -0,0 +1,23 @@ +SELECT + l_orderkey, + sum(l_extendedprice * (1 - l_discount)) AS revenue, + o_orderdate, + o_shippriority +FROM + customer, + orders, + lineitem +WHERE + c_mktsegment = 'BUILDING' + AND c_custkey = o_custkey + AND l_orderkey = o_orderkey + AND o_orderdate < CAST('1995-03-15' AS date) + AND l_shipdate > CAST('1995-03-15' AS date) +GROUP BY + l_orderkey, + o_orderdate, + o_shippriority +ORDER BY + revenue DESC, + o_orderdate +LIMIT 10; diff --git a/src/gateway/tests/data/04.sql b/src/gateway/tests/data/04.sql new file mode 100644 index 0000000..8965f4d --- /dev/null +++ b/src/gateway/tests/data/04.sql @@ -0,0 +1,20 @@ +SELECT + o_orderpriority, + count(*) AS order_count +FROM + orders +WHERE + o_orderdate >= CAST('1993-07-01' AS date) + AND o_orderdate < CAST('1993-10-01' AS date) + AND EXISTS ( + SELECT + * + FROM + lineitem + WHERE + l_orderkey = o_orderkey + AND l_commitdate < l_receiptdate) +GROUP BY + o_orderpriority +ORDER BY + o_orderpriority; \ No newline at end of file diff --git a/src/gateway/tests/data/05.sql b/src/gateway/tests/data/05.sql new file mode 100644 index 0000000..411596e --- /dev/null +++ b/src/gateway/tests/data/05.sql @@ -0,0 +1,24 @@ +SELECT + n_name, + sum(l_extendedprice * (1 - l_discount)) AS revenue +FROM + customer, + orders, + lineitem, + supplier, + nation, + region +WHERE + c_custkey = o_custkey + AND l_orderkey = o_orderkey + AND l_suppkey = s_suppkey + AND c_nationkey = s_nationkey + AND s_nationkey = n_nationkey + AND n_regionkey = r_regionkey + AND r_name = 'ASIA' + AND o_orderdate >= CAST('1994-01-01' AS date) + AND o_orderdate < CAST('1995-01-01' AS date) +GROUP BY + n_name +ORDER BY + revenue DESC; diff --git a/src/gateway/tests/data/06.sql b/src/gateway/tests/data/06.sql new file mode 100644 index 0000000..162e516 --- /dev/null +++ b/src/gateway/tests/data/06.sql @@ -0,0 +1,10 @@ +SELECT + sum(l_extendedprice * l_discount) AS revenue +FROM + lineitem +WHERE + l_shipdate >= CAST('1994-01-01' AS date) + AND l_shipdate < CAST('1995-01-01' AS date) + AND l_discount BETWEEN 0.05 + AND 0.07 + AND l_quantity < 24; diff --git a/src/gateway/tests/data/07.sql b/src/gateway/tests/data/07.sql new file mode 100644 index 0000000..695b953 --- /dev/null +++ b/src/gateway/tests/data/07.sql @@ -0,0 +1,38 @@ +SELECT + supp_nation, + cust_nation, + l_year, + sum(volume) AS revenue +FROM ( + SELECT + n1.n_name AS supp_nation, + n2.n_name AS cust_nation, + extract(year FROM l_shipdate) AS l_year, + l_extendedprice * (1 - l_discount) AS volume + FROM + supplier, + lineitem, + orders, + customer, + nation n1, + nation n2 + WHERE + s_suppkey = l_suppkey + AND o_orderkey = l_orderkey + AND c_custkey = o_custkey + AND s_nationkey = n1.n_nationkey + AND c_nationkey = n2.n_nationkey + AND ((n1.n_name = 'FRANCE' + AND n2.n_name = 'GERMANY') + OR (n1.n_name = 'GERMANY' + AND n2.n_name = 'FRANCE')) + AND l_shipdate BETWEEN CAST('1995-01-01' AS date) + AND CAST('1996-12-31' AS date)) AS shipping +GROUP BY + supp_nation, + cust_nation, + l_year +ORDER BY + supp_nation, + cust_nation, + l_year; diff --git a/src/gateway/tests/data/08.sql b/src/gateway/tests/data/08.sql new file mode 100644 index 0000000..6c80028 --- /dev/null +++ b/src/gateway/tests/data/08.sql @@ -0,0 +1,38 @@ +SELECT + o_year, + sum( + CASE WHEN nation = 'BRAZIL' THEN + volume + ELSE + 0 + END) / sum(volume) AS mkt_share +FROM ( + SELECT + extract(year FROM o_orderdate) AS o_year, + l_extendedprice * (1 - l_discount) AS volume, + n2.n_name AS nation + FROM + part, + supplier, + lineitem, + orders, + customer, + nation n1, + nation n2, + region + WHERE + p_partkey = l_partkey + AND s_suppkey = l_suppkey + AND l_orderkey = o_orderkey + AND o_custkey = c_custkey + AND c_nationkey = n1.n_nationkey + AND n1.n_regionkey = r_regionkey + AND r_name = 'AMERICA' + AND s_nationkey = n2.n_nationkey + AND o_orderdate BETWEEN CAST('1995-01-01' AS date) + AND CAST('1996-12-31' AS date) + AND p_type = 'ECONOMY ANODIZED STEEL') AS all_nations +GROUP BY + o_year +ORDER BY + o_year; diff --git a/src/gateway/tests/data/09.sql b/src/gateway/tests/data/09.sql new file mode 100644 index 0000000..4c16659 --- /dev/null +++ b/src/gateway/tests/data/09.sql @@ -0,0 +1,30 @@ +SELECT + nation, + o_year, + sum(amount) AS sum_profit +FROM ( + SELECT + n_name AS nation, + extract(year FROM o_orderdate) AS o_year, + l_extendedprice * (1 - l_discount) - ps_supplycost * l_quantity AS amount + FROM + part, + supplier, + lineitem, + partsupp, + orders, + nation + WHERE + s_suppkey = l_suppkey + AND ps_suppkey = l_suppkey + AND ps_partkey = l_partkey + AND p_partkey = l_partkey + AND o_orderkey = l_orderkey + AND s_nationkey = n_nationkey + AND p_name LIKE '%green%') AS profit +GROUP BY + nation, + o_year +ORDER BY + nation, + o_year DESC; diff --git a/src/gateway/tests/data/10.sql b/src/gateway/tests/data/10.sql new file mode 100644 index 0000000..ae37ab9 --- /dev/null +++ b/src/gateway/tests/data/10.sql @@ -0,0 +1,32 @@ +SELECT + c_custkey, + c_name, + sum(l_extendedprice * (1 - l_discount)) AS revenue, + c_acctbal, + n_name, + c_address, + c_phone, + c_comment +FROM + customer, + orders, + lineitem, + nation +WHERE + c_custkey = o_custkey + AND l_orderkey = o_orderkey + AND o_orderdate >= CAST('1993-10-01' AS date) + AND o_orderdate < CAST('1994-01-01' AS date) + AND l_returnflag = 'R' + AND c_nationkey = n_nationkey +GROUP BY + c_custkey, + c_name, + c_acctbal, + c_phone, + n_name, + c_address, + c_comment +ORDER BY + revenue DESC +LIMIT 20; diff --git a/src/gateway/tests/data/11.sql b/src/gateway/tests/data/11.sql new file mode 100644 index 0000000..527e2d1 --- /dev/null +++ b/src/gateway/tests/data/11.sql @@ -0,0 +1,27 @@ +SELECT + ps_partkey, + sum(ps_supplycost * ps_availqty) AS value +FROM + partsupp, + supplier, + nation +WHERE + ps_suppkey = s_suppkey + AND s_nationkey = n_nationkey + AND n_name = 'GERMANY' +GROUP BY + ps_partkey +HAVING + sum(ps_supplycost * ps_availqty) > ( + SELECT + sum(ps_supplycost * ps_availqty) * 0.0001000000 + FROM + partsupp, + supplier, + nation + WHERE + ps_suppkey = s_suppkey + AND s_nationkey = n_nationkey + AND n_name = 'GERMANY') +ORDER BY + value DESC; diff --git a/src/gateway/tests/data/12.sql b/src/gateway/tests/data/12.sql new file mode 100644 index 0000000..479cd3e --- /dev/null +++ b/src/gateway/tests/data/12.sql @@ -0,0 +1,30 @@ +SELECT + l_shipmode, + sum( + CASE WHEN o_orderpriority = '1-URGENT' + OR o_orderpriority = '2-HIGH' THEN + 1 + ELSE + 0 + END) AS high_line_count, + sum( + CASE WHEN o_orderpriority <> '1-URGENT' + AND o_orderpriority <> '2-HIGH' THEN + 1 + ELSE + 0 + END) AS low_line_count +FROM + orders, + lineitem +WHERE + o_orderkey = l_orderkey + AND l_shipmode IN ('MAIL', 'SHIP') + AND l_commitdate < l_receiptdate + AND l_shipdate < l_commitdate + AND l_receiptdate >= CAST('1994-01-01' AS date) + AND l_receiptdate < CAST('1995-01-01' AS date) +GROUP BY + l_shipmode +ORDER BY + l_shipmode; diff --git a/src/gateway/tests/data/13.sql b/src/gateway/tests/data/13.sql new file mode 100644 index 0000000..e78f212 --- /dev/null +++ b/src/gateway/tests/data/13.sql @@ -0,0 +1,19 @@ +SELECT + c_count, + count(*) AS custdist +FROM ( + SELECT + c_custkey, + count(o_orderkey) + FROM + customer + LEFT OUTER JOIN orders ON c_custkey = o_custkey + AND o_comment NOT LIKE '%special%requests%' +GROUP BY + c_custkey) AS c_orders (c_custkey, + c_count) +GROUP BY + c_count +ORDER BY + custdist DESC, + c_count DESC; diff --git a/src/gateway/tests/data/14.sql b/src/gateway/tests/data/14.sql new file mode 100644 index 0000000..327d5cd --- /dev/null +++ b/src/gateway/tests/data/14.sql @@ -0,0 +1,14 @@ +SELECT + 100.00 * sum( + CASE WHEN p_type LIKE 'PROMO%' THEN + l_extendedprice * (1 - l_discount) + ELSE + 0 + END) / sum(l_extendedprice * (1 - l_discount)) AS promo_revenue +FROM + lineitem, + part +WHERE + l_partkey = p_partkey + AND l_shipdate >= date '1995-09-01' + AND l_shipdate < CAST('1995-10-01' AS date); diff --git a/src/gateway/tests/data/15.sql b/src/gateway/tests/data/15.sql new file mode 100644 index 0000000..4e533dc --- /dev/null +++ b/src/gateway/tests/data/15.sql @@ -0,0 +1,37 @@ +SELECT + s_suppkey, + s_name, + s_address, + s_phone, + total_revenue +FROM + supplier, + ( + SELECT + l_suppkey AS supplier_no, + sum(l_extendedprice * (1 - l_discount)) AS total_revenue + FROM + lineitem + WHERE + l_shipdate >= CAST('1996-01-01' AS date) + AND l_shipdate < CAST('1996-04-01' AS date) + GROUP BY + supplier_no) revenue0 +WHERE + s_suppkey = supplier_no + AND total_revenue = ( + SELECT + max(total_revenue) + FROM ( + SELECT + l_suppkey AS supplier_no, + sum(l_extendedprice * (1 - l_discount)) AS total_revenue + FROM + lineitem + WHERE + l_shipdate >= CAST('1996-01-01' AS date) + AND l_shipdate < CAST('1996-04-01' AS date) + GROUP BY + supplier_no) revenue1) +ORDER BY + s_suppkey; diff --git a/src/gateway/tests/data/16.sql b/src/gateway/tests/data/16.sql new file mode 100644 index 0000000..bfb8c74 --- /dev/null +++ b/src/gateway/tests/data/16.sql @@ -0,0 +1,29 @@ +SELECT + p_brand, + p_type, + p_size, + count(DISTINCT ps_suppkey) AS supplier_cnt +FROM + partsupp, + part +WHERE + p_partkey = ps_partkey + AND p_brand <> 'Brand#45' + AND p_type NOT LIKE 'MEDIUM POLISHED%' + AND p_size IN (49, 14, 23, 45, 19, 3, 36, 9) + AND ps_suppkey NOT IN ( + SELECT + s_suppkey + FROM + supplier + WHERE + s_comment LIKE '%Customer%Complaints%') +GROUP BY + p_brand, + p_type, + p_size +ORDER BY + supplier_cnt DESC, + p_brand, + p_type, + p_size; diff --git a/src/gateway/tests/data/17.sql b/src/gateway/tests/data/17.sql new file mode 100644 index 0000000..964fea1 --- /dev/null +++ b/src/gateway/tests/data/17.sql @@ -0,0 +1,16 @@ +SELECT + sum(l_extendedprice) / 7.0 AS avg_yearly +FROM + lineitem, + part +WHERE + p_partkey = l_partkey + AND p_brand = 'Brand#23' + AND p_container = 'MED BOX' + AND l_quantity < ( + SELECT + 0.2 * avg(l_quantity) + FROM + lineitem + WHERE + l_partkey = p_partkey); diff --git a/src/gateway/tests/data/18.sql b/src/gateway/tests/data/18.sql new file mode 100644 index 0000000..3f6886a --- /dev/null +++ b/src/gateway/tests/data/18.sql @@ -0,0 +1,33 @@ +SELECT + c_name, + c_custkey, + o_orderkey, + o_orderdate, + o_totalprice, + sum(l_quantity) +FROM + customer, + orders, + lineitem +WHERE + o_orderkey IN ( + SELECT + l_orderkey + FROM + lineitem + GROUP BY + l_orderkey + HAVING + sum(l_quantity) > 300) + AND c_custkey = o_custkey + AND o_orderkey = l_orderkey +GROUP BY + c_name, + c_custkey, + o_orderkey, + o_orderdate, + o_totalprice +ORDER BY + o_totalprice DESC, + o_orderdate +LIMIT 100; diff --git a/src/gateway/tests/data/19.sql b/src/gateway/tests/data/19.sql new file mode 100644 index 0000000..1046501 --- /dev/null +++ b/src/gateway/tests/data/19.sql @@ -0,0 +1,29 @@ +SELECT + sum(l_extendedprice * (1 - l_discount)) AS revenue +FROM + lineitem, + part +WHERE (p_partkey = l_partkey + AND p_brand = 'Brand#12' + AND p_container IN ('SM CASE', 'SM BOX', 'SM PACK', 'SM PKG') + AND l_quantity >= 1 + AND l_quantity <= 1 + 10 + AND p_size BETWEEN 1 AND 5 + AND l_shipmode IN ('AIR', 'AIR REG') + AND l_shipinstruct = 'DELIVER IN PERSON') + OR (p_partkey = l_partkey + AND p_brand = 'Brand#23' + AND p_container IN ('MED BAG', 'MED BOX', 'MED PKG', 'MED PACK') + AND l_quantity >= 10 + AND l_quantity <= 10 + 10 + AND p_size BETWEEN 1 AND 10 + AND l_shipmode IN ('AIR', 'AIR REG') + AND l_shipinstruct = 'DELIVER IN PERSON') + OR (p_partkey = l_partkey + AND p_brand = 'Brand#34' + AND p_container IN ('LG CASE', 'LG BOX', 'LG PACK', 'LG PKG') + AND l_quantity >= 20 + AND l_quantity <= 20 + 10 + AND p_size BETWEEN 1 AND 15 + AND l_shipmode IN ('AIR', 'AIR REG') + AND l_shipinstruct = 'DELIVER IN PERSON'); diff --git a/src/gateway/tests/data/20.sql b/src/gateway/tests/data/20.sql new file mode 100644 index 0000000..8760453 --- /dev/null +++ b/src/gateway/tests/data/20.sql @@ -0,0 +1,34 @@ +SELECT + s_name, + s_address +FROM + supplier, + nation +WHERE + s_suppkey IN ( + SELECT + ps_suppkey + FROM + partsupp + WHERE + ps_partkey IN ( + SELECT + p_partkey + FROM + part + WHERE + p_name LIKE 'forest%') + AND ps_availqty > ( + SELECT + 0.5 * sum(l_quantity) + FROM + lineitem + WHERE + l_partkey = ps_partkey + AND l_suppkey = ps_suppkey + AND l_shipdate >= CAST('1994-01-01' AS date) + AND l_shipdate < CAST('1995-01-01' AS date))) + AND s_nationkey = n_nationkey + AND n_name = 'CANADA' + ORDER BY + s_name; diff --git a/src/gateway/tests/data/21.sql b/src/gateway/tests/data/21.sql new file mode 100644 index 0000000..19d18fd --- /dev/null +++ b/src/gateway/tests/data/21.sql @@ -0,0 +1,38 @@ +SELECT + s_name, + count(*) AS numwait +FROM + supplier, + lineitem l1, + orders, + nation +WHERE + s_suppkey = l1.l_suppkey + AND o_orderkey = l1.l_orderkey + AND o_orderstatus = 'F' + AND l1.l_receiptdate > l1.l_commitdate + AND EXISTS ( + SELECT + * + FROM + lineitem l2 + WHERE + l2.l_orderkey = l1.l_orderkey + AND l2.l_suppkey <> l1.l_suppkey) + AND NOT EXISTS ( + SELECT + * + FROM + lineitem l3 + WHERE + l3.l_orderkey = l1.l_orderkey + AND l3.l_suppkey <> l1.l_suppkey + AND l3.l_receiptdate > l3.l_commitdate) + AND s_nationkey = n_nationkey + AND n_name = 'SAUDI ARABIA' +GROUP BY + s_name +ORDER BY + numwait DESC, + s_name +LIMIT 100; diff --git a/src/gateway/tests/data/22.sql b/src/gateway/tests/data/22.sql new file mode 100644 index 0000000..7ce29a5 --- /dev/null +++ b/src/gateway/tests/data/22.sql @@ -0,0 +1,31 @@ +SELECT + cntrycode, + count(*) AS numcust, + sum(c_acctbal) AS totacctbal +FROM ( + SELECT + substring(c_phone FROM 1 FOR 2) AS cntrycode, + c_acctbal + FROM + customer + WHERE + substring(c_phone FROM 1 FOR 2) IN ('13', '31', '23', '29', '30', '18', '17') + AND c_acctbal > ( + SELECT + avg(c_acctbal) + FROM + customer + WHERE + c_acctbal > 0.00 + AND substring(c_phone FROM 1 FOR 2) IN ('13', '31', '23', '29', '30', '18', '17')) + AND NOT EXISTS ( + SELECT + * + FROM + orders + WHERE + o_custkey = c_custkey)) AS custsale +GROUP BY + cntrycode +ORDER BY + cntrycode; diff --git a/src/gateway/tests/test_sql_api.py b/src/gateway/tests/test_sql_api.py index c199558..4b90a43 100644 --- a/src/gateway/tests/test_sql_api.py +++ b/src/gateway/tests/test_sql_api.py @@ -1,5 +1,7 @@ # SPDX-License-Identifier: Apache-2.0 """Tests for the Spark to Substrait Gateway server.""" +from pathlib import Path + import pytest from hamcrest import assert_that, equal_to from pyspark import Row @@ -8,15 +10,32 @@ from gateway.converter.sql_to_substrait import find_tpch +test_case_directory = Path(__file__).resolve().parent / 'data' + +sql_test_case_paths = [f for f in sorted(test_case_directory.iterdir()) if f.suffix == '.sql'] + +sql_test_case_names = [p.stem for p in sql_test_case_paths] + + +def _register_table(spark_session: SparkSession, name: str) -> None: + location = find_tpch() / name + spark_session.sql( + f'CREATE OR REPLACE TEMPORARY VIEW {name} USING org.apache.spark.sql.parquet ' + f'OPTIONS ( path "{location}" )') + @pytest.fixture(scope='function') def spark_session_with_customer_database(spark_session: SparkSession, source: str) -> SparkSession: """Creates a temporary view of the customer database.""" if source == 'spark': - customer_location = find_tpch() / 'customer' - spark_session.sql( - 'CREATE OR REPLACE TEMPORARY VIEW customer USING org.apache.spark.sql.parquet ' - f'OPTIONS ( path "{customer_location}" )') + _register_table(spark_session, 'customer') + _register_table(spark_session, 'lineitem') + _register_table(spark_session, 'nation') + _register_table(spark_session, 'orders') + _register_table(spark_session, 'part') + _register_table(spark_session, 'partsupp') + _register_table(spark_session, 'region') + _register_table(spark_session, 'supplier') return spark_session @@ -41,3 +60,16 @@ def test_limit(self, spark_session_with_customer_database): outcome = spark_session_with_customer_database.sql( 'SELECT c_custkey, c_phone, c_mktsegment FROM customer LIMIT 5').collect() assertDataFrameEqual(outcome, expected) + + @pytest.mark.parametrize( + 'path', + sql_test_case_paths, + ids=sql_test_case_names, + ) + def test_tpch(self, spark_session_with_customer_database, path): + """Test the TPC-H queries.""" + # Read the SQL to run. + with open(path, "rb") as file: + sql_bytes = file.read() + sql = sql_bytes.decode('utf-8') + spark_session_with_customer_database.sql(sql).collect() From 7946ab14601d8fddedeb458ca92b69fa980f3604 Mon Sep 17 00:00:00 2001 From: Richard Tia Date: Mon, 15 Apr 2024 19:58:04 -0400 Subject: [PATCH 14/58] feat: use generic adbc driver to setup backend db (#38) Instead of using the duckdb_adbc_driver we now use the adbc_driver_manager so that we can easily switch between different backends. --- src/gateway/converter/spark_to_substrait.py | 32 +++++++++++++++------ 1 file changed, 24 insertions(+), 8 deletions(-) diff --git a/src/gateway/converter/spark_to_substrait.py b/src/gateway/converter/spark_to_substrait.py index b9d1008..b03ed70 100644 --- a/src/gateway/converter/spark_to_substrait.py +++ b/src/gateway/converter/spark_to_substrait.py @@ -6,18 +6,20 @@ import pathlib from typing import Dict, Optional, List -import adbc_driver_duckdb.dbapi +import duckdb import pyarrow import pyarrow.parquet import pyspark.sql.connect.proto.base_pb2 as spark_pb2 import pyspark.sql.connect.proto.expressions_pb2 as spark_exprs_pb2 import pyspark.sql.connect.proto.relations_pb2 as spark_relations_pb2 import pyspark.sql.connect.proto.types_pb2 as spark_types_pb2 +from adbc_driver_manager import dbapi from substrait.gen.proto import algebra_pb2 from substrait.gen.proto import plan_pb2 from substrait.gen.proto import type_pb2 from substrait.gen.proto.extensions import extensions_pb2 +from gateway.adbc.backend_options import BackendOptions, Backend from gateway.converter.conversion_options import ConversionOptions from gateway.converter.spark_functions import ExtensionFunction, lookup_spark_function from gateway.converter.sql_to_substrait import convert_sql @@ -28,10 +30,22 @@ if_then_else_operation, greater_function, minus_function from gateway.converter.symbol_table import SymbolTable -DUCKDB_TABLE = "duckdb_table" +TABLE_NAME = "my_table" -def fetch_schema_with_adbc(file_path: str, ext: str) -> pyarrow.Schema: +def get_backend_driver(options: BackendOptions) -> tuple[str, str]: + """Gets the driver and entry point for the specified backend.""" + match options.backend: + case Backend.DUCKDB: + driver = duckdb.duckdb.__file__ + entry_point = "duckdb_adbc_init" + case _: + raise ValueError(f'Unknown backend type: {options.backend}') + + return driver, entry_point + + +def fetch_schema_with_adbc(file_path: str, ext: str, options: BackendOptions) -> pyarrow.Schema: """Fetch the arrow schema via ADBC.""" file_paths = list(pathlib.Path(file_path).glob(f'*.{ext}')) @@ -40,12 +54,14 @@ def fetch_schema_with_adbc(file_path: str, ext: str) -> pyarrow.Schema: file_paths = sorted([str(fp) for fp in file_paths]) file_path = file_paths[0] - with adbc_driver_duckdb.dbapi.connect() as conn, conn.cursor() as cur: + driver, entry_point = get_backend_driver(options) + + with dbapi.connect(driver=driver, entrypoint=entry_point) as conn, conn.cursor() as cur: # TODO: Support multiple paths. reader = pyarrow.parquet.ParquetFile(file_path) - cur.adbc_ingest(DUCKDB_TABLE, reader.iter_batches(), mode="create") - schema = conn.adbc_get_table_schema(DUCKDB_TABLE) - cur.execute(f"DROP TABLE {DUCKDB_TABLE}") + cur.adbc_ingest(TABLE_NAME, reader.iter_batches(), mode="create") + schema = conn.adbc_get_table_schema(TABLE_NAME) + cur.execute(f"DROP TABLE {TABLE_NAME}") return schema @@ -401,7 +417,7 @@ def convert_read_data_source_relation(self, rel: spark_relations_pb2.Read) -> al local = algebra_pb2.ReadRel.LocalFiles() schema = self.convert_schema(rel.schema) if not schema: - arrow_schema = fetch_schema_with_adbc(rel.paths[0], rel.format) + arrow_schema = fetch_schema_with_adbc(rel.paths[0], rel.format, self._conversion_options.backend) schema = self.convert_arrow_schema(arrow_schema) symbol = self._symbol_table.get_symbol(self._current_plan_id) for field_name in schema.names: From 97c6d76c2da31c3aabcff84e596e10e8cc61cf5f Mon Sep 17 00:00:00 2001 From: David Sisson Date: Tue, 16 Apr 2024 15:49:46 -0700 Subject: [PATCH 15/58] feat: refactor the backend access logic into a class hierarchy (#39) --- environment.yml | 1 + src/gateway/adbc/backend.py | 123 ------------------ src/gateway/backends/__init__.py | 0 src/gateway/backends/adbc_backend.py | 71 ++++++++++ src/gateway/backends/arrow_backend.py | 24 ++++ src/gateway/backends/backend.py | 76 +++++++++++ .../{adbc => backends}/backend_options.py | 0 src/gateway/backends/backend_selector.py | 22 ++++ src/gateway/backends/datafusion_backend.py | 64 +++++++++ src/gateway/backends/duckdb_backend.py | 54 ++++++++ src/gateway/converter/conversion_options.py | 2 +- src/gateway/converter/spark_to_substrait.py | 47 ++----- src/gateway/converter/sql_to_substrait.py | 53 +------- src/gateway/server.py | 8 +- src/gateway/tests/test_sql_api.py | 5 +- 15 files changed, 335 insertions(+), 215 deletions(-) delete mode 100644 src/gateway/adbc/backend.py create mode 100644 src/gateway/backends/__init__.py create mode 100644 src/gateway/backends/adbc_backend.py create mode 100644 src/gateway/backends/arrow_backend.py create mode 100644 src/gateway/backends/backend.py rename src/gateway/{adbc => backends}/backend_options.py (100%) create mode 100644 src/gateway/backends/backend_selector.py create mode 100644 src/gateway/backends/datafusion_backend.py create mode 100644 src/gateway/backends/duckdb_backend.py diff --git a/environment.yml b/environment.yml index 19b2fa7..9d04cd4 100644 --- a/environment.yml +++ b/environment.yml @@ -25,3 +25,4 @@ dependencies: - pandas >= 1.0.5 - pyhamcrest - substrait-validator + - pytest-timeout diff --git a/src/gateway/adbc/backend.py b/src/gateway/adbc/backend.py deleted file mode 100644 index d46cb26..0000000 --- a/src/gateway/adbc/backend.py +++ /dev/null @@ -1,123 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -"""Will eventually provide client access to an ADBC backend.""" -from pathlib import Path - -import adbc_driver_duckdb.dbapi -import duckdb -import pyarrow -from pyarrow import substrait - -from substrait.gen.proto import plan_pb2 - -from gateway.adbc.backend_options import BackendOptions, Backend -from gateway.converter.rename_functions import RenameFunctions -from gateway.converter.replace_local_files import ReplaceLocalFilesWithNamedTable -from gateway.converter.sql_to_substrait import register_table, find_tpch - - -# pylint: disable=protected-access -def _import(handle): - return pyarrow.RecordBatchReader._import_from_c(handle.address) - - -# pylint: disable=fixme -class AdbcBackend: - """Provides methods for contacting an ADBC backend via Substrait.""" - - def __init__(self): - pass - - def execute_with_duckdb_over_adbc(self, plan: 'plan_pb2.Plan') -> pyarrow.lib.Table: - """Executes the given Substrait plan against DuckDB using ADBC.""" - with adbc_driver_duckdb.dbapi.connect() as conn, conn.cursor() as cur: - cur.execute("LOAD substrait;") - plan_data = plan.SerializeToString() - cur.adbc_statement.set_substrait_plan(plan_data) - res = cur.adbc_statement.execute_query() - table = _import(res[0]).read_all() - return table - - # pylint: disable=import-outside-toplevel - def execute_with_datafusion(self, plan: 'plan_pb2.Plan') -> pyarrow.lib.Table: - """Executes the given Substrait plan against Datafusion.""" - import datafusion.substrait - - ctx = datafusion.SessionContext() - - file_groups = ReplaceLocalFilesWithNamedTable().visit_plan(plan) - registered_tables = set() - for files in file_groups: - table_name = files[0] - for file in files[1]: - if table_name not in registered_tables: - ctx.register_parquet(table_name, file) - registered_tables.add(files[0]) - - RenameFunctions().visit_plan(plan) - - try: - plan_data = plan.SerializeToString() - substrait_plan = datafusion.substrait.substrait.serde.deserialize_bytes(plan_data) - logical_plan = datafusion.substrait.substrait.consumer.from_substrait_plan( - ctx, substrait_plan - ) - - # Create a DataFrame from a deserialized logical plan. - df_result = ctx.create_dataframe_from_logical_plan(logical_plan) - for column_number, column_name in enumerate(df_result.schema().names): - df_result = df_result.with_column_renamed( - column_name, - plan.relations[0].root.names[column_number] - ) - return df_result.to_arrow_table() - finally: - for table_name in registered_tables: - ctx.deregister_table(table_name) - - def execute_with_duckdb(self, plan: 'plan_pb2.Plan') -> pyarrow.lib.Table: - """Executes the given Substrait plan against DuckDB.""" - con = duckdb.connect(config={'max_memory': '100GB', - "allow_unsigned_extensions": "true", - 'temp_directory': str(Path('.').absolute())}) - con.install_extension('substrait') - con.load_extension('substrait') - plan_data = plan.SerializeToString() - - # TODO -- Rely on the client to register their own named tables. - tpch_location = find_tpch() - register_table(con, 'customer', tpch_location / 'customer') - register_table(con, 'lineitem', tpch_location / 'lineitem') - register_table(con, 'nation', tpch_location / 'nation') - register_table(con, 'orders', tpch_location / 'orders') - register_table(con, 'part', tpch_location / 'part') - register_table(con, 'partsupp', tpch_location / 'partsupp') - register_table(con, 'region', tpch_location / 'region') - register_table(con, 'supplier', tpch_location / 'supplier') - - try: - query_result = con.from_substrait(proto=plan_data) - except Exception as err: - raise ValueError(f'DuckDB Execution Error: {err}') from err - df = query_result.df() - return pyarrow.Table.from_pandas(df=df) - - def execute_with_arrow(self, plan: 'plan_pb2.Plan') -> pyarrow.lib.Table: - """Executes the given Substrait plan against Acero.""" - plan_data = plan.SerializeToString() - reader = substrait.run_query(plan_data) - query_result = reader.read_all() - return query_result - - def execute(self, plan: 'plan_pb2.Plan', options: BackendOptions) -> pyarrow.lib.Table: - """Executes the given Substrait plan.""" - match options.backend: - case Backend.ARROW: - return self.execute_with_arrow(plan) - case Backend.DATAFUSION: - return self.execute_with_datafusion(plan) - case Backend.DUCKDB: - if options.use_adbc: - return self.execute_with_duckdb_over_adbc(plan) - return self.execute_with_duckdb(plan) - case _: - raise ValueError('unknown backend requested') diff --git a/src/gateway/backends/__init__.py b/src/gateway/backends/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/src/gateway/backends/adbc_backend.py b/src/gateway/backends/adbc_backend.py new file mode 100644 index 0000000..70f8add --- /dev/null +++ b/src/gateway/backends/adbc_backend.py @@ -0,0 +1,71 @@ +# SPDX-License-Identifier: Apache-2.0 +"""Provides access to a generic ADBC backend.""" +from pathlib import Path + +import duckdb +import pyarrow +from adbc_driver_manager import dbapi +from substrait.gen.proto import plan_pb2 + +from gateway.backends.backend import Backend +from gateway.backends.backend_options import BackendOptions + + +def _import(handle): + return pyarrow.RecordBatchReader._import_from_c(handle.address) + + +def _get_backend_driver(options: BackendOptions) -> tuple[str, str]: + """Gets the driver and entry point for the specified backend.""" + match options.backend: + case Backend.DUCKDB: + driver = duckdb.duckdb.__file__ + entry_point = "duckdb_adbc_init" + case _: + raise ValueError(f'Unknown backend type: {options.backend}') + + return driver, entry_point + + +class AdbcBackend(Backend): + """Provides access to send ADBC backends Substrait plans.""" + + def __init__(self, options: BackendOptions): + self._options = options + super().__init__(options) + self.create_connection() + + def create_connection(self) -> None: + driver, entry_point = _get_backend_driver(self._options) + self._connection = dbapi.connect(driver=driver, entrypoint=entry_point) + + # pylint: disable=import-outside-toplevel + def execute(self, plan: plan_pb2.Plan) -> pyarrow.lib.Table: + """Executes the given Substrait plan against an ADBC backend.""" + with self._connection.cursor() as cur: + cur.execute("LOAD substrait;") + plan_data = plan.SerializeToString() + cur.adbc_statement.set_substrait_plan(plan_data) + res = cur.adbc_statement.execute_query() + table = _import(res[0]).read_all() + return table + + def register_table(self, name: str, path: Path, extension: str = 'parquet') -> None: + """Registers the given table with the backend.""" + file_paths = sorted(Path(path).glob(f'*.{extension}')) + if len(file_paths) > 0: + # Sort the files because the later ones don't have enough data to construct a schema. + file_paths = sorted([str(fp) for fp in file_paths]) + # TODO: Support multiple paths. + reader = pyarrow.parquet.ParquetFile(file_paths[0]) + self._connection.cursor().adbc_ingest(name, reader.iter_batches(), mode="create") + + def describe_table(self, table_name: str): + """Asks the backend to describe the given table.""" + return self._connection.adbc_get_table_schema(table_name) + + def drop_table(self, table_name: str): + """Asks the backend to drop the given table.""" + with self._connection.cursor() as cur: + # TODO -- Use an explicit ADBC call here. + cur.execute(f'DROP TABLE {table_name}') diff --git a/src/gateway/backends/arrow_backend.py b/src/gateway/backends/arrow_backend.py new file mode 100644 index 0000000..5e26a6f --- /dev/null +++ b/src/gateway/backends/arrow_backend.py @@ -0,0 +1,24 @@ +# SPDX-License-Identifier: Apache-2.0 +"""Provides access to Acero.""" +from pathlib import Path + +import pyarrow +from substrait.gen.proto import plan_pb2 + +from gateway.backends.backend import Backend + + +class ArrowBackend(Backend): + """Provides access to send Acero Substrait plans.""" + + # pylint: disable=import-outside-toplevel + def execute(self, plan: plan_pb2.Plan) -> pyarrow.lib.Table: + """Executes the given Substrait plan against Acero.""" + plan_data = plan.SerializeToString() + reader = pyarrow.substrait.run_query(plan_data) + query_result = reader.read_all() + return query_result + + def register_table(self, name: str, path: Path) -> None: + """Registers the given table with the backend.""" + raise NotImplementedError() diff --git a/src/gateway/backends/backend.py b/src/gateway/backends/backend.py new file mode 100644 index 0000000..c23f1c4 --- /dev/null +++ b/src/gateway/backends/backend.py @@ -0,0 +1,76 @@ +# SPDX-License-Identifier: Apache-2.0 +"""The base class for all Substrait backends.""" +from pathlib import Path +from typing import List, Any + +import pyarrow +from substrait.gen.proto import plan_pb2 + +from gateway.backends.backend_options import BackendOptions + + +class Backend: + """Base class providing methods for contacting a backend utilizing Substrait.""" + + def __init__(self, options: BackendOptions): + self._connection = None + + def create_connection(self) -> None: + raise NotImplementedError() + + def get_connection(self) -> Any: + """Returns the connection to the backend.""" + if self._connection is None: + self._connection = self.create_connection() + return self._connection + + # pylint: disable=import-outside-toplevel + def execute(self, plan: plan_pb2.Plan) -> pyarrow.lib.Table: + """Executes the given Substrait plan against Datafusion.""" + raise NotImplementedError() + + def register_table(self, name: str, path: Path | str, extension: str = 'parquet') -> None: + """Registers the given table with the backend.""" + raise NotImplementedError() + + def describe_table(self, name: str): + """Asks the backend to describe the given table.""" + raise NotImplementedError() + + def drop_table(self, name: str): + """Asks the backend to drop the given table.""" + raise NotImplementedError() + + @staticmethod + def expand_location(location: Path | str) -> List[str]: + """Expands the location of a file or directory into a list of files.""" + # TODO -- Handle more than just Parquet files. + path = Path(location) + if path.is_dir(): + files = Path(location).resolve().glob('*.parquet') + else: + files = [path] + return sorted(str(f) for f in files) + + @staticmethod + def find_tpch() -> Path: + """Finds the location of the TPCH dataset.""" + current_location = Path('.').resolve() + while current_location != Path('/'): + location = current_location / 'third_party' / 'tpch' / 'parquet' + if location.exists(): + return location.resolve() + current_location = current_location.parent + raise ValueError('TPCH dataset not found') + + def register_tpch(self): + """Convenience function to register the entire TPC-H dataset.""" + tpch_location = Backend.find_tpch() + self.register_table('customer', tpch_location / 'customer') + self.register_table('lineitem', tpch_location / 'lineitem') + self.register_table('nation', tpch_location / 'nation') + self.register_table('orders', tpch_location / 'orders') + self.register_table('part', tpch_location / 'part') + self.register_table('partsupp', tpch_location / 'partsupp') + self.register_table('region', tpch_location / 'region') + self.register_table('supplier', tpch_location / 'supplier') diff --git a/src/gateway/adbc/backend_options.py b/src/gateway/backends/backend_options.py similarity index 100% rename from src/gateway/adbc/backend_options.py rename to src/gateway/backends/backend_options.py diff --git a/src/gateway/backends/backend_selector.py b/src/gateway/backends/backend_selector.py new file mode 100644 index 0000000..d78d72f --- /dev/null +++ b/src/gateway/backends/backend_selector.py @@ -0,0 +1,22 @@ +# SPDX-License-Identifier: Apache-2.0 +"""Given a backend enum, returns an instance of the correct Backend descendant.""" +from gateway.backends import backend +from gateway.backends.adbc_backend import AdbcBackend +from gateway.backends.arrow_backend import ArrowBackend +from gateway.backends.backend_options import BackendOptions, Backend +from gateway.backends.datafusion_backend import DatafusionBackend +from gateway.backends.duckdb_backend import DuckDBBackend + + +def find_backend(options: BackendOptions) -> backend.Backend: + match options.backend: + case Backend.ARROW: + return ArrowBackend(options) + case Backend.DATAFUSION: + return DatafusionBackend(options) + case Backend.DUCKDB: + if options.use_adbc: + return AdbcBackend(options) + return DuckDBBackend(options) + case _: + raise ValueError(f'Unknown backend {options.backend} requested.') diff --git a/src/gateway/backends/datafusion_backend.py b/src/gateway/backends/datafusion_backend.py new file mode 100644 index 0000000..2528440 --- /dev/null +++ b/src/gateway/backends/datafusion_backend.py @@ -0,0 +1,64 @@ +# SPDX-License-Identifier: Apache-2.0 +"""Provides access to Datafusion.""" +from pathlib import Path + +import pyarrow +from substrait.gen.proto import plan_pb2 + +from gateway.backends.backend import Backend +from gateway.converter.rename_functions import RenameFunctions +from gateway.converter.replace_local_files import ReplaceLocalFilesWithNamedTable + + +# pylint: disable=import-outside-toplevel +class DatafusionBackend(Backend): + """Provides access to send Substrait plans to Datafusion.""" + + def __init__(self, options): + super().__init__(options) + self.create_connection() + + def create_connection(self) -> None: + """Creates a connection to the backend.""" + import datafusion + self._connection = datafusion.SessionContext() + + def execute(self, plan: plan_pb2.Plan) -> pyarrow.lib.Table: + """Executes the given Substrait plan against Datafusion.""" + import datafusion.substrait + + self.register_tpch() + + file_groups = ReplaceLocalFilesWithNamedTable().visit_plan(plan) + registered_tables = set() + for files in file_groups: + table_name = files[0] + for file in files[1]: + if table_name not in registered_tables: + self.register_table(table_name, file) + registered_tables.add(files[0]) + + RenameFunctions().visit_plan(plan) + + try: + plan_data = plan.SerializeToString() + substrait_plan = datafusion.substrait.substrait.serde.deserialize_bytes(plan_data) + logical_plan = datafusion.substrait.substrait.consumer.from_substrait_plan( + self._connection, substrait_plan + ) + + # Create a DataFrame from a deserialized logical plan. + df_result = self._connection.create_dataframe_from_logical_plan(logical_plan) + for column_number, column_name in enumerate(df_result.schema().names): + df_result = df_result.with_column_renamed( + column_name, + plan.relations[0].root.names[column_number] + ) + return df_result.to_arrow_table() + finally: + for table_name in registered_tables: + self._connection.deregister_table(table_name) + + def register_table(self, name: str, path: Path) -> None: + files = Backend.expand_location(path) + self._connection.register_parquet(name, files[0]) diff --git a/src/gateway/backends/duckdb_backend.py b/src/gateway/backends/duckdb_backend.py new file mode 100644 index 0000000..6cb6fef --- /dev/null +++ b/src/gateway/backends/duckdb_backend.py @@ -0,0 +1,54 @@ +# SPDX-License-Identifier: Apache-2.0 +"""Provides access to DuckDB.""" +from pathlib import Path + +import duckdb +import pyarrow +from substrait.gen.proto import plan_pb2 + +from gateway.backends.backend import Backend + + +# pylint: disable=fixme +class DuckDBBackend(Backend): + """Provides access to send Substrait plans to DuckDB.""" + + def __init__(self, options): + super().__init__(options) + self.create_connection() + + def create_connection(self): + if self._connection is not None: + return self._connection + + self._connection = duckdb.connect(config={'max_memory': '100GB', + "allow_unsigned_extensions": "true", + 'temp_directory': str(Path('.').resolve())}) + self._connection.install_extension('substrait') + self._connection.load_extension('substrait') + + return self._connection + + def execute(self, plan: plan_pb2.Plan) -> pyarrow.lib.Table: + """Executes the given Substrait plan against DuckDB.""" + plan_data = plan.SerializeToString() + + # TODO -- Rely on the client to register their own named tables. + self.register_tpch() + + try: + query_result = self._connection.from_substrait(proto=plan_data) + except Exception as err: + raise ValueError(f'DuckDB Execution Error: {err}') from err + df = query_result.df() + return pyarrow.Table.from_pandas(df=df) + + def register_table(self, table_name: str, location: Path) -> None: + """Registers the given table with the backend.""" + files = Backend.expand_location(location) + if not files: + raise ValueError(f"No parquet files found at {location}") + files_str = ', '.join([f"'{f}'" for f in files]) + files_sql = f"CREATE OR REPLACE TABLE {table_name} AS FROM read_parquet([{files_str}])" + + self._connection.execute(files_sql) diff --git a/src/gateway/converter/conversion_options.py b/src/gateway/converter/conversion_options.py index 871534f..373cead 100644 --- a/src/gateway/converter/conversion_options.py +++ b/src/gateway/converter/conversion_options.py @@ -2,7 +2,7 @@ """Tracks conversion related options.""" import dataclasses -from gateway.adbc.backend_options import BackendOptions, Backend +from gateway.backends.backend_options import BackendOptions, Backend # pylint: disable=too-many-instance-attributes diff --git a/src/gateway/converter/spark_to_substrait.py b/src/gateway/converter/spark_to_substrait.py index b03ed70..13ffdcc 100644 --- a/src/gateway/converter/spark_to_substrait.py +++ b/src/gateway/converter/spark_to_substrait.py @@ -6,20 +6,19 @@ import pathlib from typing import Dict, Optional, List -import duckdb import pyarrow import pyarrow.parquet import pyspark.sql.connect.proto.base_pb2 as spark_pb2 import pyspark.sql.connect.proto.expressions_pb2 as spark_exprs_pb2 import pyspark.sql.connect.proto.relations_pb2 as spark_relations_pb2 import pyspark.sql.connect.proto.types_pb2 as spark_types_pb2 -from adbc_driver_manager import dbapi from substrait.gen.proto import algebra_pb2 from substrait.gen.proto import plan_pb2 from substrait.gen.proto import type_pb2 from substrait.gen.proto.extensions import extensions_pb2 -from gateway.adbc.backend_options import BackendOptions, Backend +from gateway.backends.backend_options import BackendOptions +from gateway.backends.backend_selector import find_backend from gateway.converter.conversion_options import ConversionOptions from gateway.converter.spark_functions import ExtensionFunction, lookup_spark_function from gateway.converter.sql_to_substrait import convert_sql @@ -33,39 +32,6 @@ TABLE_NAME = "my_table" -def get_backend_driver(options: BackendOptions) -> tuple[str, str]: - """Gets the driver and entry point for the specified backend.""" - match options.backend: - case Backend.DUCKDB: - driver = duckdb.duckdb.__file__ - entry_point = "duckdb_adbc_init" - case _: - raise ValueError(f'Unknown backend type: {options.backend}') - - return driver, entry_point - - -def fetch_schema_with_adbc(file_path: str, ext: str, options: BackendOptions) -> pyarrow.Schema: - """Fetch the arrow schema via ADBC.""" - - file_paths = list(pathlib.Path(file_path).glob(f'*.{ext}')) - if len(file_paths) > 0: - # We sort the files because the later partitions don't have enough data to construct a schema. - file_paths = sorted([str(fp) for fp in file_paths]) - file_path = file_paths[0] - - driver, entry_point = get_backend_driver(options) - - with dbapi.connect(driver=driver, entrypoint=entry_point) as conn, conn.cursor() as cur: - # TODO: Support multiple paths. - reader = pyarrow.parquet.ParquetFile(file_path) - cur.adbc_ingest(TABLE_NAME, reader.iter_batches(), mode="create") - schema = conn.adbc_get_table_schema(TABLE_NAME) - cur.execute(f"DROP TABLE {TABLE_NAME}") - - return schema - - # pylint: disable=E1101,fixme,too-many-public-methods class SparkSubstraitConverter: """Converts SparkConnect plans to Substrait plans.""" @@ -417,8 +383,13 @@ def convert_read_data_source_relation(self, rel: spark_relations_pb2.Read) -> al local = algebra_pb2.ReadRel.LocalFiles() schema = self.convert_schema(rel.schema) if not schema: - arrow_schema = fetch_schema_with_adbc(rel.paths[0], rel.format, self._conversion_options.backend) - schema = self.convert_arrow_schema(arrow_schema) + backend = find_backend(BackendOptions(self._conversion_options.backend.backend, True)) + try: + backend.register_table(TABLE_NAME, rel.paths[0], rel.format) + arrow_schema = backend.describe_table(TABLE_NAME) + schema = self.convert_arrow_schema(arrow_schema) + finally: + backend.drop_table(TABLE_NAME) symbol = self._symbol_table.get_symbol(self._current_plan_id) for field_name in schema.names: symbol.output_fields.append(field_name) diff --git a/src/gateway/converter/sql_to_substrait.py b/src/gateway/converter/sql_to_substrait.py index 3ef55fe..1f4a6cf 100644 --- a/src/gateway/converter/sql_to_substrait.py +++ b/src/gateway/converter/sql_to_substrait.py @@ -1,60 +1,19 @@ # SPDX-License-Identifier: Apache-2.0 """Routines to convert SparkConnect plans to Substrait plans.""" -from pathlib import Path -from typing import List - -import duckdb from substrait.gen.proto import plan_pb2 - -def _expand_location(location: Path | str) -> List[str]: - """Expands the location of a file or directory into a list of files.""" - # TODO -- Handle more than just Parquet files (here and below). - files = Path(location).resolve().glob('*.parquet') - return sorted(str(f) for f in files) - - -def find_tpch() -> Path: - """Finds the location of the TPCH dataset.""" - current_location = Path('.').resolve() - while current_location != Path('/'): - location = current_location / 'third_party' / 'tpch' / 'parquet' - if location.exists(): - return location.resolve() - current_location = current_location.parent - raise ValueError('TPCH dataset not found') - - -def register_table(con: duckdb.DuckDBPyConnection, table_name, location: Path | str) -> None: - files = _expand_location(location) - if not files: - raise ValueError(f"No parquet files found at {location}") - files_str = ', '.join([f"'{f}'" for f in files]) - files_sql = f"CREATE OR REPLACE TABLE {table_name} AS FROM read_parquet([{files_str}])" - - con.execute(files_sql) +from gateway.backends import backend_selector +from gateway.backends.backend_options import BackendOptions, Backend -# pylint: disable=E1101,too-few-public-methods,fixme def convert_sql(sql: str) -> plan_pb2.Plan: """Converts SQL into a Substrait plan.""" result = plan_pb2.Plan() - con = duckdb.connect(config={'max_memory': '100GB', - 'temp_directory': str(Path('.').resolve())}) - con.install_extension('substrait') - con.load_extension('substrait') - # TODO -- Rely on the client to register their own named tables. - tpch_location = find_tpch() - register_table(con, 'customer', tpch_location / 'customer') - register_table(con, 'lineitem', tpch_location / 'lineitem') - register_table(con, 'nation', tpch_location / 'nation') - register_table(con, 'orders', tpch_location / 'orders') - register_table(con, 'part', tpch_location / 'part') - register_table(con, 'partsupp', tpch_location / 'partsupp') - register_table(con, 'region', tpch_location / 'region') - register_table(con, 'supplier', tpch_location / 'supplier') + backend = backend_selector.find_backend(BackendOptions(Backend.DUCKDB, False)) + backend.register_tpch() - proto_bytes = con.get_substrait(query=sql).fetchone()[0] + connection = backend.get_connection() + proto_bytes = connection.get_substrait(query=sql).fetchone()[0] result.ParseFromString(proto_bytes) return result diff --git a/src/gateway/server.py b/src/gateway/server.py index 2ca07d3..d611380 100644 --- a/src/gateway/server.py +++ b/src/gateway/server.py @@ -7,13 +7,13 @@ import grpc import pyarrow -import pyspark.sql.connect.proto.base_pb2_grpc as pb2_grpc import pyspark.sql.connect.proto.base_pb2 as pb2 +import pyspark.sql.connect.proto.base_pb2_grpc as pb2_grpc from pyspark.sql.connect.proto import types_pb2 +from gateway.backends.backend_selector import find_backend from gateway.converter.conversion_options import duck_db, datafusion from gateway.converter.spark_to_substrait import SparkSubstraitConverter -from gateway.adbc.backend import AdbcBackend from gateway.converter.sql_to_substrait import convert_sql _LOGGER = logging.getLogger(__name__) @@ -104,8 +104,8 @@ def ExecutePlan( case _: raise ValueError(f'Unknown plan type: {request.plan}') _LOGGER.debug(' as Substrait: %s', substrait) - backend = AdbcBackend() - results = backend.execute(substrait, self._options.backend) + backend = find_backend(self._options.backend) + results = backend.execute(substrait) _LOGGER.debug(' results are: %s', results) if not self._options.implement_show_string and request.plan.WhichOneof( diff --git a/src/gateway/tests/test_sql_api.py b/src/gateway/tests/test_sql_api.py index 4b90a43..9aee812 100644 --- a/src/gateway/tests/test_sql_api.py +++ b/src/gateway/tests/test_sql_api.py @@ -8,7 +8,7 @@ from pyspark.sql.session import SparkSession from pyspark.testing import assertDataFrameEqual -from gateway.converter.sql_to_substrait import find_tpch +from gateway.backends.backend import Backend test_case_directory = Path(__file__).resolve().parent / 'data' @@ -18,7 +18,7 @@ def _register_table(spark_session: SparkSession, name: str) -> None: - location = find_tpch() / name + location = Backend.find_tpch() / name spark_session.sql( f'CREATE OR REPLACE TEMPORARY VIEW {name} USING org.apache.spark.sql.parquet ' f'OPTIONS ( path "{location}" )') @@ -61,6 +61,7 @@ def test_limit(self, spark_session_with_customer_database): 'SELECT c_custkey, c_phone, c_mktsegment FROM customer LIMIT 5').collect() assertDataFrameEqual(outcome, expected) + @pytest.mark.timeout(60) @pytest.mark.parametrize( 'path', sql_test_case_paths, From a989e9be65df3660959462e3e939690cef49a296 Mon Sep 17 00:00:00 2001 From: Richard Tia Date: Tue, 16 Apr 2024 20:56:57 -0400 Subject: [PATCH 16/58] fix: imports needed to be updated based on latest refactor (#40) This PR fixed some imports so the client demo should work again --- src/gateway/backends/adbc_backend.py | 3 ++- src/gateway/demo/client_demo.py | 4 ++-- 2 files changed, 4 insertions(+), 3 deletions(-) diff --git a/src/gateway/backends/adbc_backend.py b/src/gateway/backends/adbc_backend.py index 70f8add..167ab0f 100644 --- a/src/gateway/backends/adbc_backend.py +++ b/src/gateway/backends/adbc_backend.py @@ -9,6 +9,7 @@ from gateway.backends.backend import Backend from gateway.backends.backend_options import BackendOptions +from gateway.backends.backend_options import Backend as backend_engine def _import(handle): @@ -18,7 +19,7 @@ def _import(handle): def _get_backend_driver(options: BackendOptions) -> tuple[str, str]: """Gets the driver and entry point for the specified backend.""" match options.backend: - case Backend.DUCKDB: + case backend_engine.DUCKDB: driver = duckdb.duckdb.__file__ entry_point = "duckdb_adbc_init" case _: diff --git a/src/gateway/demo/client_demo.py b/src/gateway/demo/client_demo.py index 1afdeb8..98275b0 100644 --- a/src/gateway/demo/client_demo.py +++ b/src/gateway/demo/client_demo.py @@ -4,14 +4,14 @@ from pyspark.sql.functions import col from pyspark.sql import SparkSession, DataFrame -from gateway.converter.sql_to_substrait import find_tpch +from gateway.backends.backend import Backend USE_GATEWAY = True # pylint: disable=fixme def get_customer_database(spark_session: SparkSession) -> DataFrame: - location_customer = str(find_tpch() / 'customer') + location_customer = str(Backend.find_tpch() / 'customer') return spark_session.read.parquet(location_customer, mergeSchema=False) From 4c84176f4245590a016378ae1c04651a6114f75f Mon Sep 17 00:00:00 2001 From: David Sisson Date: Tue, 16 Apr 2024 21:28:52 -0700 Subject: [PATCH 17/58] feat: fix a number of datafusion tests (#41) --- src/gateway/converter/rename_functions.py | 15 +++++++++++++++ src/gateway/converter/substrait_plan_visitor.py | 2 +- src/gateway/server.py | 2 ++ 3 files changed, 18 insertions(+), 1 deletion(-) diff --git a/src/gateway/converter/rename_functions.py b/src/gateway/converter/rename_functions.py index 958c6d3..448be22 100644 --- a/src/gateway/converter/rename_functions.py +++ b/src/gateway/converter/rename_functions.py @@ -17,6 +17,21 @@ def visit_plan(self, plan: plan_pb2.Plan) -> None: if extension.WhichOneof('mapping_type') != 'extension_function': continue + if ':' in extension.extension_function.name: + extension.extension_function.name = extension.extension_function.name.split(':')[0] + # TODO -- Take the URI references into account. if extension.extension_function.name == 'substring': extension.extension_function.name = 'substr' + elif extension.extension_function.name == '*': + extension.extension_function.name = 'multiply' + elif extension.extension_function.name == '-': + extension.extension_function.name = 'subtract' + elif extension.extension_function.name == '+': + extension.extension_function.name = 'add' + elif extension.extension_function.name == '/': + extension.extension_function.name = 'divide' + elif extension.extension_function.name == 'contains': + extension.extension_function.name = 'instr' + elif extension.extension_function.name == 'extract': + extension.extension_function.name = 'date_part' diff --git a/src/gateway/converter/substrait_plan_visitor.py b/src/gateway/converter/substrait_plan_visitor.py index 313de94..d441d21 100644 --- a/src/gateway/converter/substrait_plan_visitor.py +++ b/src/gateway/converter/substrait_plan_visitor.py @@ -35,7 +35,7 @@ def visit_subquery_set_comparison( if subquery.HasField('left'): self.visit_expression(subquery.left) if subquery.HasField('right'): - self.visit_expression(subquery.right) + self.visit_relation(subquery.right) def visit_nested_struct(self, structure: algebra_pb2.Expression.Nested.Struct) -> Any: """Visits a nested struct.""" diff --git a/src/gateway/server.py b/src/gateway/server.py index d611380..d3c78ad 100644 --- a/src/gateway/server.py +++ b/src/gateway/server.py @@ -66,6 +66,8 @@ def convert_pyarrow_schema_to_spark(schema: pyarrow.Schema) -> types_pb2.DataTyp data_type = types_pb2.DataType(string=types_pb2.DataType.String()) elif field.type == pyarrow.timestamp('us'): data_type = types_pb2.DataType(timestamp=types_pb2.DataType.Timestamp()) + elif field.type == pyarrow.date32(): + data_type = types_pb2.DataType(date=types_pb2.DataType.Date()) else: raise NotImplementedError( 'Conversion from Arrow schema to Spark schema not yet implemented ' From 07260992be882800cac4c853aa1aa86ae24e16f3 Mon Sep 17 00:00:00 2001 From: David Sisson Date: Wed, 17 Apr 2024 17:16:33 -0700 Subject: [PATCH 18/58] chore: add more checks to ruff check (#43) The current set of checks wasn't nearly exhaustive enough, the new configuration will ensure more consistency. --- pyproject.toml | 36 +++++++- src/gateway/backends/adbc_backend.py | 22 ++--- src/gateway/backends/arrow_backend.py | 13 ++- src/gateway/backends/backend.py | 29 +++---- src/gateway/backends/backend_options.py | 3 + src/gateway/backends/backend_selector.py | 3 +- src/gateway/backends/datafusion_backend.py | 11 ++- src/gateway/backends/duckdb_backend.py | 14 +-- src/gateway/converter/conversion_options.py | 6 +- src/gateway/converter/label_relations.py | 4 +- .../output_field_tracking_visitor.py | 7 +- src/gateway/converter/rename_functions.py | 3 +- src/gateway/converter/replace_local_files.py | 15 ++-- src/gateway/converter/simplify_casts.py | 11 ++- src/gateway/converter/spark_functions.py | 9 +- src/gateway/converter/spark_to_substrait.py | 86 +++++++++++-------- .../converter/spark_to_substrait_test.py | 11 ++- src/gateway/converter/sql_to_substrait.py | 5 +- src/gateway/converter/substrait_builder.py | 17 ++-- .../converter/substrait_plan_visitor.py | 10 +-- src/gateway/converter/symbol_table.py | 19 ++-- .../tools/duckdb_substrait_to_arrow.py | 8 +- .../tools/tests/simplify_casts_test.py | 7 +- src/gateway/converter/validation_test.py | 5 +- src/gateway/demo/client_demo.py | 9 +- src/gateway/demo/mystream_database.py | 75 ++++++++-------- src/gateway/server.py | 69 ++++++++------- src/gateway/tests/conftest.py | 13 +-- src/gateway/tests/test_sql_api.py | 3 +- 29 files changed, 287 insertions(+), 236 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 003541e..d930ee4 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -5,7 +5,7 @@ authors = [{name = "Substrait contributors", email = "substrait@googlegroups.com license = {text = "Apache-2.0"} readme = "README.md" requires-python = ">=3.8.1" -dependencies = ["protobuf >= 3.20", "datafusion >= 36.0.0", "pyarrow >= 15.0.2"] +dependencies = ["protobuf >= 3.20", "datafusion >= 35.0.0", "pyarrow >= 15.0.2"] dynamic = ["version"] [tool.setuptools_scm] @@ -24,18 +24,48 @@ requires = ["setuptools>=61.0.0", "setuptools_scm[toml]>=6.2.0"] build-backend = "setuptools.build_meta" [tool.ruff] +line-length = 100 respect-gitignore = true # should target minimum supported version target-version = "py310" # never autoformat upstream or generated code -exclude = ["third_party/", "src/spark/connect"] +exclude = ["third_party/"] -[lint] +[tool.ruff.lint] +select = [ + "B", # flake8-bugbear + "BLE", # flake8-blind-except + "C4", # comprehensions + "D", # pydocstyle + "E", # pycodestyle + "F", # pyflakes + "I", # isort + "ICN", # flake8-import-conventions + "ISC", # flake8-implicit-str-concat + "PGH", # pygrep-hooks + "PLC", # pylint + "PLE", # pylint + "PLW", # pylint + "RET", # flake8-return + "RUF", # ruff-specific rules + "SIM", # flake8-simplify + "T10", # flake8-debugger + "T20", # flake8-print + "TID", # flake8-tidy-imports + "UP", # pyupgrade + "W", # pycodestyle + "YTT", # flake8-2020 +] unfixable = [ "F401", # unused imports "T201", # print statements "E712", # truth comparison checks ] +[tool.ruff.lint.per-file-ignores] +"*test*.py" = ["D"] # ignore all docstring lints in tests +"__init__.py" = ["D"] +"src/gateway/converter/*" = ["D"] + [tool.pylint.MASTER] extension-pkg-allow-list = ["pyarrow.lib"] diff --git a/src/gateway/backends/adbc_backend.py b/src/gateway/backends/adbc_backend.py index 167ab0f..874760f 100644 --- a/src/gateway/backends/adbc_backend.py +++ b/src/gateway/backends/adbc_backend.py @@ -3,21 +3,21 @@ from pathlib import Path import duckdb -import pyarrow +import pyarrow as pa from adbc_driver_manager import dbapi from substrait.gen.proto import plan_pb2 from gateway.backends.backend import Backend -from gateway.backends.backend_options import BackendOptions from gateway.backends.backend_options import Backend as backend_engine +from gateway.backends.backend_options import BackendOptions def _import(handle): - return pyarrow.RecordBatchReader._import_from_c(handle.address) + return pa.RecordBatchReader._import_from_c(handle.address) def _get_backend_driver(options: BackendOptions) -> tuple[str, str]: - """Gets the driver and entry point for the specified backend.""" + """Get the driver and entry point for the specified backend.""" match options.backend: case backend_engine.DUCKDB: driver = duckdb.duckdb.__file__ @@ -32,33 +32,35 @@ class AdbcBackend(Backend): """Provides access to send ADBC backends Substrait plans.""" def __init__(self, options: BackendOptions): + """Initialize the ADBC backend.""" + self._connection = None self._options = options super().__init__(options) self.create_connection() def create_connection(self) -> None: + """Create a connection to the ADBC backend.""" driver, entry_point = _get_backend_driver(self._options) self._connection = dbapi.connect(driver=driver, entrypoint=entry_point) # pylint: disable=import-outside-toplevel - def execute(self, plan: plan_pb2.Plan) -> pyarrow.lib.Table: - """Executes the given Substrait plan against an ADBC backend.""" + def execute(self, plan: plan_pb2.Plan) -> pa.lib.Table: + """Execute the given Substrait plan against an ADBC backend.""" with self._connection.cursor() as cur: cur.execute("LOAD substrait;") plan_data = plan.SerializeToString() cur.adbc_statement.set_substrait_plan(plan_data) res = cur.adbc_statement.execute_query() - table = _import(res[0]).read_all() - return table + return _import(res[0]).read_all() def register_table(self, name: str, path: Path, extension: str = 'parquet') -> None: - """Registers the given table with the backend.""" + """Register the given table with the backend.""" file_paths = sorted(Path(path).glob(f'*.{extension}')) if len(file_paths) > 0: # Sort the files because the later ones don't have enough data to construct a schema. file_paths = sorted([str(fp) for fp in file_paths]) # TODO: Support multiple paths. - reader = pyarrow.parquet.ParquetFile(file_paths[0]) + reader = pa.parquet.ParquetFile(file_paths[0]) self._connection.cursor().adbc_ingest(name, reader.iter_batches(), mode="create") def describe_table(self, table_name: str): diff --git a/src/gateway/backends/arrow_backend.py b/src/gateway/backends/arrow_backend.py index 5e26a6f..2c8019e 100644 --- a/src/gateway/backends/arrow_backend.py +++ b/src/gateway/backends/arrow_backend.py @@ -2,7 +2,7 @@ """Provides access to Acero.""" from pathlib import Path -import pyarrow +import pyarrow as pa from substrait.gen.proto import plan_pb2 from gateway.backends.backend import Backend @@ -12,13 +12,12 @@ class ArrowBackend(Backend): """Provides access to send Acero Substrait plans.""" # pylint: disable=import-outside-toplevel - def execute(self, plan: plan_pb2.Plan) -> pyarrow.lib.Table: - """Executes the given Substrait plan against Acero.""" + def execute(self, plan: plan_pb2.Plan) -> pa.lib.Table: + """Execute the given Substrait plan against Acero.""" plan_data = plan.SerializeToString() - reader = pyarrow.substrait.run_query(plan_data) - query_result = reader.read_all() - return query_result + reader = pa.substrait.run_query(plan_data) + return reader.read_all() def register_table(self, name: str, path: Path) -> None: - """Registers the given table with the backend.""" + """Register the given table with the backend.""" raise NotImplementedError() diff --git a/src/gateway/backends/backend.py b/src/gateway/backends/backend.py index c23f1c4..f3fd3ea 100644 --- a/src/gateway/backends/backend.py +++ b/src/gateway/backends/backend.py @@ -1,9 +1,9 @@ # SPDX-License-Identifier: Apache-2.0 """The base class for all Substrait backends.""" from pathlib import Path -from typing import List, Any +from typing import Any -import pyarrow +import pyarrow as pa from substrait.gen.proto import plan_pb2 from gateway.backends.backend_options import BackendOptions @@ -13,24 +13,26 @@ class Backend: """Base class providing methods for contacting a backend utilizing Substrait.""" def __init__(self, options: BackendOptions): + """Initialize the backend.""" self._connection = None def create_connection(self) -> None: + """Create a connection to the backend.""" raise NotImplementedError() def get_connection(self) -> Any: - """Returns the connection to the backend.""" + """Return the connection to the backend (creating one if necessary).""" if self._connection is None: - self._connection = self.create_connection() + self.create_connection() return self._connection # pylint: disable=import-outside-toplevel - def execute(self, plan: plan_pb2.Plan) -> pyarrow.lib.Table: - """Executes the given Substrait plan against Datafusion.""" + def execute(self, plan: plan_pb2.Plan) -> pa.lib.Table: + """Execute the given Substrait plan against Datafusion.""" raise NotImplementedError() def register_table(self, name: str, path: Path | str, extension: str = 'parquet') -> None: - """Registers the given table with the backend.""" + """Register the given table with the backend.""" raise NotImplementedError() def describe_table(self, name: str): @@ -42,19 +44,16 @@ def drop_table(self, name: str): raise NotImplementedError() @staticmethod - def expand_location(location: Path | str) -> List[str]: - """Expands the location of a file or directory into a list of files.""" + def expand_location(location: Path | str) -> list[str]: + """Expand the location of a file or directory into a list of files.""" # TODO -- Handle more than just Parquet files. path = Path(location) - if path.is_dir(): - files = Path(location).resolve().glob('*.parquet') - else: - files = [path] + files = Path(location).resolve().glob('*.parquet') if path.is_dir() else [path] return sorted(str(f) for f in files) @staticmethod def find_tpch() -> Path: - """Finds the location of the TPCH dataset.""" + """Find the location of the TPCH dataset.""" current_location = Path('.').resolve() while current_location != Path('/'): location = current_location / 'third_party' / 'tpch' / 'parquet' @@ -64,7 +63,7 @@ def find_tpch() -> Path: raise ValueError('TPCH dataset not found') def register_tpch(self): - """Convenience function to register the entire TPC-H dataset.""" + """Register the entire TPC-H dataset.""" tpch_location = Backend.find_tpch() self.register_table('customer', tpch_location / 'customer') self.register_table('lineitem', tpch_location / 'lineitem') diff --git a/src/gateway/backends/backend_options.py b/src/gateway/backends/backend_options.py index 55e833c..7039acb 100644 --- a/src/gateway/backends/backend_options.py +++ b/src/gateway/backends/backend_options.py @@ -6,6 +6,7 @@ class Backend(Enum): """Represents the different backends we have support for.""" + ARROW = 1 DATAFUSION = 2 DUCKDB = 3 @@ -14,9 +15,11 @@ class Backend(Enum): @dataclasses.dataclass class BackendOptions: """Holds all the possible backend options.""" + backend: Backend use_adbc: bool def __init__(self, backend: Backend, use_adbc: bool = False): + """Create a BackendOptions structure.""" self.backend = backend self.use_adbc = use_adbc diff --git a/src/gateway/backends/backend_selector.py b/src/gateway/backends/backend_selector.py index d78d72f..fa88625 100644 --- a/src/gateway/backends/backend_selector.py +++ b/src/gateway/backends/backend_selector.py @@ -3,12 +3,13 @@ from gateway.backends import backend from gateway.backends.adbc_backend import AdbcBackend from gateway.backends.arrow_backend import ArrowBackend -from gateway.backends.backend_options import BackendOptions, Backend +from gateway.backends.backend_options import Backend, BackendOptions from gateway.backends.datafusion_backend import DatafusionBackend from gateway.backends.duckdb_backend import DuckDBBackend def find_backend(options: BackendOptions) -> backend.Backend: + """Given a backend enum, returns an instance of the correct Backend descendant.""" match options.backend: case Backend.ARROW: return ArrowBackend(options) diff --git a/src/gateway/backends/datafusion_backend.py b/src/gateway/backends/datafusion_backend.py index 2528440..c462839 100644 --- a/src/gateway/backends/datafusion_backend.py +++ b/src/gateway/backends/datafusion_backend.py @@ -2,7 +2,7 @@ """Provides access to Datafusion.""" from pathlib import Path -import pyarrow +import pyarrow as pa from substrait.gen.proto import plan_pb2 from gateway.backends.backend import Backend @@ -15,16 +15,18 @@ class DatafusionBackend(Backend): """Provides access to send Substrait plans to Datafusion.""" def __init__(self, options): + """Initialize the Datafusion backend.""" + self._connection = None super().__init__(options) self.create_connection() def create_connection(self) -> None: - """Creates a connection to the backend.""" + """Create a connection to the backend.""" import datafusion self._connection = datafusion.SessionContext() - def execute(self, plan: plan_pb2.Plan) -> pyarrow.lib.Table: - """Executes the given Substrait plan against Datafusion.""" + def execute(self, plan: plan_pb2.Plan) -> pa.lib.Table: + """Execute the given Substrait plan against Datafusion.""" import datafusion.substrait self.register_tpch() @@ -60,5 +62,6 @@ def execute(self, plan: plan_pb2.Plan) -> pyarrow.lib.Table: self._connection.deregister_table(table_name) def register_table(self, name: str, path: Path) -> None: + """Register the given table with the backend.""" files = Backend.expand_location(path) self._connection.register_parquet(name, files[0]) diff --git a/src/gateway/backends/duckdb_backend.py b/src/gateway/backends/duckdb_backend.py index 6cb6fef..fb9825c 100644 --- a/src/gateway/backends/duckdb_backend.py +++ b/src/gateway/backends/duckdb_backend.py @@ -3,7 +3,7 @@ from pathlib import Path import duckdb -import pyarrow +import pyarrow as pa from substrait.gen.proto import plan_pb2 from gateway.backends.backend import Backend @@ -14,10 +14,13 @@ class DuckDBBackend(Backend): """Provides access to send Substrait plans to DuckDB.""" def __init__(self, options): + """Initialize the DuckDB backend.""" + self._connection = None super().__init__(options) self.create_connection() def create_connection(self): + """Create a connection to the backend.""" if self._connection is not None: return self._connection @@ -29,8 +32,9 @@ def create_connection(self): return self._connection - def execute(self, plan: plan_pb2.Plan) -> pyarrow.lib.Table: - """Executes the given Substrait plan against DuckDB.""" + # ruff: noqa: BLE001 + def execute(self, plan: plan_pb2.Plan) -> pa.lib.Table: + """Execute the given Substrait plan against DuckDB.""" plan_data = plan.SerializeToString() # TODO -- Rely on the client to register their own named tables. @@ -41,10 +45,10 @@ def execute(self, plan: plan_pb2.Plan) -> pyarrow.lib.Table: except Exception as err: raise ValueError(f'DuckDB Execution Error: {err}') from err df = query_result.df() - return pyarrow.Table.from_pandas(df=df) + return pa.Table.from_pandas(df=df) def register_table(self, table_name: str, location: Path) -> None: - """Registers the given table with the backend.""" + """Register the given table with the backend.""" files = Backend.expand_location(location) if not files: raise ValueError(f"No parquet files found at {location}") diff --git a/src/gateway/converter/conversion_options.py b/src/gateway/converter/conversion_options.py index 373cead..d849633 100644 --- a/src/gateway/converter/conversion_options.py +++ b/src/gateway/converter/conversion_options.py @@ -2,13 +2,14 @@ """Tracks conversion related options.""" import dataclasses -from gateway.backends.backend_options import BackendOptions, Backend +from gateway.backends.backend_options import Backend, BackendOptions # pylint: disable=too-many-instance-attributes @dataclasses.dataclass class ConversionOptions: """Holds all the possible conversion options.""" + use_named_table_workaround: bool needs_scheme_in_path_uris: bool use_project_emit_workaround: bool @@ -31,8 +32,7 @@ def __init__(self, backend: BackendOptions = None): def datafusion(): """Standard options to connect to a Datafusion backend.""" - options = ConversionOptions(backend=BackendOptions(Backend.DATAFUSION)) - return options + return ConversionOptions(backend=BackendOptions(Backend.DATAFUSION)) def duck_db(): diff --git a/src/gateway/converter/label_relations.py b/src/gateway/converter/label_relations.py index aee273a..3e618eb 100644 --- a/src/gateway/converter/label_relations.py +++ b/src/gateway/converter/label_relations.py @@ -2,9 +2,8 @@ """A library to search Substrait plan for local files.""" from typing import Any -from substrait.gen.proto import algebra_pb2 - from gateway.converter.substrait_plan_visitor import SubstraitPlanVisitor +from substrait.gen.proto import algebra_pb2 # pylint: disable=E1101 @@ -62,6 +61,7 @@ def get_common_section(rel: algebra_pb2.Rel) -> algebra_pb2.RelCommon: # pylint: disable=E1101,no-member,fixme class LabelRelations(SubstraitPlanVisitor): """Replaces all cast expressions with projects of casts instead.""" + _seen_relations: int def __init__(self): diff --git a/src/gateway/converter/output_field_tracking_visitor.py b/src/gateway/converter/output_field_tracking_visitor.py index bdc9ed7..a424e79 100644 --- a/src/gateway/converter/output_field_tracking_visitor.py +++ b/src/gateway/converter/output_field_tracking_visitor.py @@ -1,12 +1,11 @@ # SPDX-License-Identifier: Apache-2.0 """A library to search Substrait plan for local files.""" -from typing import Any, Optional - -from substrait.gen.proto import algebra_pb2, plan_pb2 +from typing import Any from gateway.converter.label_relations import get_common_section from gateway.converter.substrait_plan_visitor import SubstraitPlanVisitor from gateway.converter.symbol_table import SymbolTable +from substrait.gen.proto import algebra_pb2, plan_pb2 # pylint: disable=E1101 @@ -30,7 +29,7 @@ class OutputFieldTrackingVisitor(SubstraitPlanVisitor): def __init__(self): super().__init__() - self._current_plan_id: Optional[int] = None # The relation currently being processed. + self._current_plan_id: int | None = None # The relation currently being processed. self._symbol_table = SymbolTable() def update_field_references(self, plan_id: int) -> None: diff --git a/src/gateway/converter/rename_functions.py b/src/gateway/converter/rename_functions.py index 448be22..a4dc126 100644 --- a/src/gateway/converter/rename_functions.py +++ b/src/gateway/converter/rename_functions.py @@ -1,8 +1,7 @@ # SPDX-License-Identifier: Apache-2.0 """A library to search Substrait plan for local files.""" -from substrait.gen.proto import plan_pb2 - from gateway.converter.substrait_plan_visitor import SubstraitPlanVisitor +from substrait.gen.proto import plan_pb2 # pylint: disable=no-member,fixme diff --git a/src/gateway/converter/replace_local_files.py b/src/gateway/converter/replace_local_files.py index 0510506..dc20c68 100644 --- a/src/gateway/converter/replace_local_files.py +++ b/src/gateway/converter/replace_local_files.py @@ -1,10 +1,9 @@ # SPDX-License-Identifier: Apache-2.0 """A library to search Substrait plan for local files.""" -from typing import Any, List, Tuple - -from substrait.gen.proto import algebra_pb2, plan_pb2 +from typing import Any from gateway.converter.substrait_plan_visitor import SubstraitPlanVisitor +from substrait.gen.proto import algebra_pb2, plan_pb2 # pylint: disable=no-member @@ -12,12 +11,12 @@ class ReplaceLocalFilesWithNamedTable(SubstraitPlanVisitor): """Replaces all of the local file instances with named tables.""" def __init__(self): - self._file_groups: List[Tuple[str, List[str]]] = [] + self._file_groups: list[tuple[str, list[str]]] = [] super().__init__() def visit_local_files(self, local_files: algebra_pb2.ReadRel.LocalFiles) -> Any: - """Visits a local files node.""" + """Visit a local files node.""" files = [] for item in local_files.items: files.append(item.uri_file) @@ -25,13 +24,13 @@ def visit_local_files(self, local_files: algebra_pb2.ReadRel.LocalFiles) -> Any: self._file_groups.append(('possible_table_name', files)) def visit_read_relation(self, rel: algebra_pb2.ReadRel) -> Any: - """Visits a read relation node.""" + """Visit a read relation node.""" super().visit_read_relation(rel) if rel.HasField('local_files'): rel.ClearField('local_files') rel.named_table.names.append(self._file_groups[-1][0]) - def visit_plan(self, plan: plan_pb2.Plan) -> List[Tuple[str, List[str]]]: - """Modifies the provided plan so that Local Files are replaced with Named Tables.""" + def visit_plan(self, plan: plan_pb2.Plan) -> list[tuple[str, list[str]]]: + """Modify the provided plan so that Local Files are replaced with Named Tables.""" super().visit_plan(plan) return self._file_groups diff --git a/src/gateway/converter/simplify_casts.py b/src/gateway/converter/simplify_casts.py index 7b546ae..abbe86c 100644 --- a/src/gateway/converter/simplify_casts.py +++ b/src/gateway/converter/simplify_casts.py @@ -1,12 +1,11 @@ # SPDX-License-Identifier: Apache-2.0 """A library to search Substrait plan for local files.""" -from typing import Any, List, Optional - -from substrait.gen.proto import algebra_pb2 +from typing import Any from gateway.converter.output_field_tracking_visitor import get_plan_id from gateway.converter.substrait_plan_visitor import SubstraitPlanVisitor from gateway.converter.symbol_table import SymbolTable +from substrait.gen.proto import algebra_pb2 # pylint: disable=no-member,fixme @@ -16,10 +15,10 @@ class SimplifyCasts(SubstraitPlanVisitor): def __init__(self, symbol_table: SymbolTable): super().__init__() self._symbol_table = symbol_table - self._current_plan_id: Optional[int] = None # The relation currently being processed. + self._current_plan_id: int | None = None # The relation currently being processed. - self._rewrite_expressions: List[algebra_pb2.Expression] = [] - self._previous_rewrite_expressions: List[List[algebra_pb2.Expression]] = [] + self._rewrite_expressions: list[algebra_pb2.Expression] = [] + self._previous_rewrite_expressions: list[list[algebra_pb2.Expression]] = [] def visit_cast(self, cast: algebra_pb2.Expression.Cast) -> Any: """Visits a cast node.""" diff --git a/src/gateway/converter/spark_functions.py b/src/gateway/converter/spark_functions.py index 4d834c9..f495414 100644 --- a/src/gateway/converter/spark_functions.py +++ b/src/gateway/converter/spark_functions.py @@ -1,25 +1,24 @@ # SPDX-License-Identifier: Apache-2.0 """Provides the mapping of Spark functions to Substrait.""" import dataclasses -from typing import Optional - -from substrait.gen.proto import type_pb2 from gateway.converter.conversion_options import ConversionOptions +from substrait.gen.proto import type_pb2 # pylint: disable=E1101 @dataclasses.dataclass class ExtensionFunction: """Represents a Substrait function.""" + uri: str name: str output_type: type_pb2.Type anchor: int - max_args: Optional[int] + max_args: int | None def __init__(self, uri: str, name: str, output_type: type_pb2.Type, - max_args: Optional[int] = None): + max_args: int | None = None): self.uri = uri self.name = name self.output_type = output_type diff --git a/src/gateway/converter/spark_to_substrait.py b/src/gateway/converter/spark_to_substrait.py index 13ffdcc..afacb6d 100644 --- a/src/gateway/converter/spark_to_substrait.py +++ b/src/gateway/converter/spark_to_substrait.py @@ -4,42 +4,57 @@ import json import operator import pathlib -from typing import Dict, Optional, List -import pyarrow +import pyarrow as pa import pyarrow.parquet import pyspark.sql.connect.proto.base_pb2 as spark_pb2 import pyspark.sql.connect.proto.expressions_pb2 as spark_exprs_pb2 import pyspark.sql.connect.proto.relations_pb2 as spark_relations_pb2 import pyspark.sql.connect.proto.types_pb2 as spark_types_pb2 -from substrait.gen.proto import algebra_pb2 -from substrait.gen.proto import plan_pb2 -from substrait.gen.proto import type_pb2 -from substrait.gen.proto.extensions import extensions_pb2 - from gateway.backends.backend_options import BackendOptions from gateway.backends.backend_selector import find_backend from gateway.converter.conversion_options import ConversionOptions from gateway.converter.spark_functions import ExtensionFunction, lookup_spark_function from gateway.converter.sql_to_substrait import convert_sql -from gateway.converter.substrait_builder import field_reference, cast_operation, string_type, \ - project_relation, strlen, concat, fetch_relation, join_relation, aggregate_relation, \ - max_agg_function, string_literal, flatten, repeat_function, \ - least_function, greatest_function, bigint_literal, lpad_function, string_concat_agg_function, \ - if_then_else_operation, greater_function, minus_function +from gateway.converter.substrait_builder import ( + aggregate_relation, + bigint_literal, + cast_operation, + concat, + fetch_relation, + field_reference, + flatten, + greater_function, + greatest_function, + if_then_else_operation, + join_relation, + least_function, + lpad_function, + max_agg_function, + minus_function, + project_relation, + repeat_function, + string_concat_agg_function, + string_literal, + string_type, + strlen, +) from gateway.converter.symbol_table import SymbolTable +from substrait.gen.proto import algebra_pb2, plan_pb2, type_pb2 +from substrait.gen.proto.extensions import extensions_pb2 TABLE_NAME = "my_table" # pylint: disable=E1101,fixme,too-many-public-methods +# ruff: noqa: RUF005 class SparkSubstraitConverter: """Converts SparkConnect plans to Substrait plans.""" def __init__(self, options: ConversionOptions): - self._function_uris: Dict[str, int] = {} - self._functions: Dict[str, ExtensionFunction] = {} - self._current_plan_id: Optional[int] = None # The relation currently being processed. + self._function_uris: dict[str, int] = {} + self._functions: dict[str, ExtensionFunction] = {} + self._current_plan_id: int | None = None # The relation currently being processed. self._symbol_table = SymbolTable() self._conversion_options = options self._seen_generated_names = {} @@ -66,7 +81,7 @@ def update_field_references(self, plan_id: int) -> None: current_symbol.input_fields.extend(source_symbol.output_fields) current_symbol.output_fields.extend(current_symbol.input_fields) - def find_field_by_name(self, field_name: str) -> Optional[int]: + def find_field_by_name(self, field_name: str) -> int | None: """Looks up the field name in the current set of field references.""" current_symbol = self._symbol_table.get_symbol(self._current_plan_id) try: @@ -174,10 +189,8 @@ def convert_unresolved_function( func = algebra_pb2.Expression.ScalarFunction() function_def = self.lookup_function_by_name(unresolved_function.function_name) func.function_reference = function_def.anchor - curr_arg_count = 0 - for arg in unresolved_function.arguments: - curr_arg_count += 1 - if function_def.max_args is not None and curr_arg_count > function_def.max_args: + for idx, arg in enumerate(unresolved_function.arguments): + if function_def.max_args is not None and idx >= function_def.max_args: break func.arguments.append( algebra_pb2.FunctionArgument(value=self.convert_expression(arg))) @@ -193,7 +206,7 @@ def convert_alias_expression( # TODO -- Utilize the alias name. return self.convert_expression(alias.expr) - def convert_type_str(self, spark_type_str: Optional[str]) -> type_pb2.Type: + def convert_type_str(self, spark_type_str: str | None) -> type_pb2.Type: """Converts a Spark type string into a Substrait type.""" # TODO -- Properly handle nullability. match spark_type_str: @@ -302,7 +315,7 @@ def convert_read_named_table_relation(self, rel: spark_relations_pb2.Read) -> al """Converts a read named table relation to a Substrait relation.""" raise NotImplementedError('named tables are not yet implemented') - def convert_schema(self, schema_str: str) -> Optional[type_pb2.NamedStruct]: + def convert_schema(self, schema_str: str) -> type_pb2.NamedStruct | None: """Converts the Spark JSON schema string into a Substrait named type structure.""" if not schema_str: return None @@ -345,7 +358,7 @@ def convert_schema(self, schema_str: str) -> Optional[type_pb2.NamedStruct]: schema.struct.types.append(field_type) return schema - def convert_arrow_schema(self, arrow_schema: pyarrow.Schema) -> type_pb2.NamedStruct: + def convert_arrow_schema(self, arrow_schema: pa.Schema) -> type_pb2.NamedStruct: schema = type_pb2.NamedStruct() schema.struct.nullability = type_pb2.Type.NULLABILITY_REQUIRED @@ -404,9 +417,8 @@ def convert_read_data_source_relation(self, rel: spark_relations_pb2.Read) -> al file_paths = rel.paths for path in file_paths: uri_path = path - if self._conversion_options.needs_scheme_in_path_uris: - if uri_path.startswith('/'): - uri_path = "file:" + uri_path + if self._conversion_options.needs_scheme_in_path_uris and uri_path.startswith('/'): + uri_path = "file:" + uri_path file_or_files = algebra_pb2.ReadRel.LocalFiles.FileOrFiles(uri_file=uri_path) match rel.format: case 'parquet': @@ -500,7 +512,7 @@ def convert_limit_relation(self, rel: spark_relations_pb2.Limit) -> algebra_pb2. count=rel.limit) return algebra_pb2.Rel(fetch=fetch) - def determine_expression_name(self, expr: spark_exprs_pb2.Expression) -> Optional[str]: + def determine_expression_name(self, expr: spark_exprs_pb2.Expression) -> str | None: """Determines the name of the expression.""" if expr.HasField('alias'): return expr.alias.name[0] @@ -609,16 +621,16 @@ def convert_show_string_relation(self, rel: spark_relations_pb2.ShowString) -> a string_literal(symbol.input_fields[column_number]))) for column_number in range(len(symbol.input_fields))]) - def field_header_fragment(field_number: int) -> List[algebra_pb2.Expression]: + def field_header_fragment(field_number: int) -> list[algebra_pb2.Expression]: return [string_literal('|'), lpad_function(lpad_func, string_literal(symbol.input_fields[field_number]), field_reference(field_number))] - def field_line_fragment(field_number: int) -> List[algebra_pb2.Expression]: + def field_line_fragment(field_number: int) -> list[algebra_pb2.Expression]: return [string_literal('+'), repeat_function(repeat_func, '-', field_reference(field_number))] - def field_body_fragment(field_number: int) -> List[algebra_pb2.Expression]: + def field_body_fragment(field_number: int) -> list[algebra_pb2.Expression]: return [string_literal('|'), if_then_else_operation( greater_function(greater_func, @@ -640,7 +652,7 @@ def field_body_fragment(field_number: int) -> List[algebra_pb2.Expression]: )] - def header_line(fields: List[str]) -> List[algebra_pb2.Expression]: + def header_line(fields: list[str]) -> list[algebra_pb2.Expression]: return [concat(concat_func, flatten([ field_header_fragment(field_number) for field_number in @@ -649,7 +661,7 @@ def header_line(fields: List[str]) -> List[algebra_pb2.Expression]: string_literal('|\n'), ])] - def full_line(fields: List[str]) -> List[algebra_pb2.Expression]: + def full_line(fields: list[str]) -> list[algebra_pb2.Expression]: return [concat(concat_func, flatten([ field_line_fragment(field_number) for field_number in @@ -756,12 +768,12 @@ def convert_to_df_relation(self, rel: spark_relations_pb2.ToDF) -> algebra_pb2.R project.common.CopyFrom(self.create_common_relation()) return algebra_pb2.Rel(project=project) - def convert_arrow_to_literal(self, val: pyarrow.Scalar) -> algebra_pb2.Expression.Literal: + def convert_arrow_to_literal(self, val: pa.Scalar) -> algebra_pb2.Expression.Literal: """Converts an Arrow scalar into a Substrait literal.""" literal = algebra_pb2.Expression.Literal() - if isinstance(val, pyarrow.BooleanScalar): + if isinstance(val, pa.BooleanScalar): literal.boolean = val.as_py() - elif isinstance(val, pyarrow.StringScalar): + elif isinstance(val, pa.StringScalar): literal.string = val.as_py() else: raise NotImplementedError( @@ -772,8 +784,8 @@ def convert_arrow_data_to_virtual_table(self, data: bytes) -> algebra_pb2.ReadRel.VirtualTable: """Converts a Spark local relation into a virtual table.""" table = algebra_pb2.ReadRel.VirtualTable() - # use Pyarrow to convert the bytes into an arrow structure - with pyarrow.ipc.open_stream(data) as arrow: + # Use pyarrow to convert the bytes into an arrow structure. + with pa.ipc.open_stream(data) as arrow: for batch in arrow.iter_batches_with_custom_metadata(): for row_number in range(batch.batch.num_rows): values = algebra_pb2.Expression.Literal.Struct() diff --git a/src/gateway/converter/spark_to_substrait_test.py b/src/gateway/converter/spark_to_substrait_test.py index df1c03f..685c71d 100644 --- a/src/gateway/converter/spark_to_substrait_test.py +++ b/src/gateway/converter/spark_to_substrait_test.py @@ -2,15 +2,14 @@ """Tests for the Spark to Substrait plan conversion routines.""" from pathlib import Path -from google.protobuf import text_format import pytest -from pyspark.sql.connect.proto import base_pb2 as spark_base_pb2 -from substrait.gen.proto import plan_pb2 - from gateway.converter.conversion_options import duck_db from gateway.converter.spark_to_substrait import SparkSubstraitConverter from gateway.converter.sql_to_substrait import convert_sql from gateway.demo.mystream_database import create_mystream_database, delete_mystream_database +from google.protobuf import text_format +from pyspark.sql.connect.proto import base_pb2 as spark_base_pb2 +from substrait.gen.proto import plan_pb2 test_case_directory = Path(__file__).resolve().parent / 'data' @@ -48,7 +47,7 @@ def test_plan_conversion(request, path): if request.config.getoption('rebuild_goldens'): if substrait != substrait_plan: - with open(path.with_suffix('.splan'), "wt", encoding='utf-8') as file: + with open(path.with_suffix('.splan'), "w", encoding='utf-8') as file: file.write(text_format.MessageToString(substrait)) return @@ -85,7 +84,7 @@ def test_sql_conversion(request, path): if request.config.getoption('rebuild_goldens'): if substrait != substrait_plan: - with open(path.with_suffix('.sql-splan'), "wt", encoding='utf-8') as file: + with open(path.with_suffix('.sql-splan'), "w", encoding='utf-8') as file: file.write(text_format.MessageToString(substrait)) return diff --git a/src/gateway/converter/sql_to_substrait.py b/src/gateway/converter/sql_to_substrait.py index 1f4a6cf..2833c1f 100644 --- a/src/gateway/converter/sql_to_substrait.py +++ b/src/gateway/converter/sql_to_substrait.py @@ -1,9 +1,8 @@ # SPDX-License-Identifier: Apache-2.0 """Routines to convert SparkConnect plans to Substrait plans.""" -from substrait.gen.proto import plan_pb2 - from gateway.backends import backend_selector -from gateway.backends.backend_options import BackendOptions, Backend +from gateway.backends.backend_options import Backend, BackendOptions +from substrait.gen.proto import plan_pb2 def convert_sql(sql: str) -> plan_pb2.Plan: diff --git a/src/gateway/converter/substrait_builder.py b/src/gateway/converter/substrait_builder.py index eb7b718..c8510b2 100644 --- a/src/gateway/converter/substrait_builder.py +++ b/src/gateway/converter/substrait_builder.py @@ -1,14 +1,13 @@ # SPDX-License-Identifier: Apache-2.0 """Convenience builder for constructing Substrait plans.""" import itertools -from typing import List, Any - -from substrait.gen.proto import algebra_pb2, type_pb2 +from typing import Any from gateway.converter.spark_functions import ExtensionFunction +from substrait.gen.proto import algebra_pb2, type_pb2 -def flatten(list_of_lists: List[List[Any]]) -> List[Any]: +def flatten(list_of_lists: list[list[Any]]) -> list[Any]: """Flattens a list of lists into a list.""" return list(itertools.chain.from_iterable(list_of_lists)) @@ -17,13 +16,11 @@ def flatten(list_of_lists: List[List[Any]]) -> List[Any]: def fetch_relation(input_relation: algebra_pb2.Rel, num_rows: int) -> algebra_pb2.Rel: """Constructs a Substrait fetch plan node.""" - fetch = algebra_pb2.Rel(fetch=algebra_pb2.FetchRel(input=input_relation, count=num_rows)) - - return fetch + return algebra_pb2.Rel(fetch=algebra_pb2.FetchRel(input=input_relation, count=num_rows)) def project_relation(input_relation: algebra_pb2.Rel, - expressions: List[algebra_pb2.Expression]) -> algebra_pb2.Rel: + expressions: list[algebra_pb2.Expression]) -> algebra_pb2.Rel: """Constructs a Substrait project plan node.""" return algebra_pb2.Rel( project=algebra_pb2.ProjectRel(input=input_relation, expressions=expressions)) @@ -31,7 +28,7 @@ def project_relation(input_relation: algebra_pb2.Rel, # pylint: disable=fixme def aggregate_relation(input_relation: algebra_pb2.Rel, - measures: List[algebra_pb2.AggregateFunction]) -> algebra_pb2.Rel: + measures: list[algebra_pb2.AggregateFunction]) -> algebra_pb2.Rel: """Constructs a Substrait aggregate plan node.""" aggregate = algebra_pb2.Rel( aggregate=algebra_pb2.AggregateRel( @@ -55,7 +52,7 @@ def join_relation(left: algebra_pb2.Rel, right: algebra_pb2.Rel) -> algebra_pb2. def concat(function_info: ExtensionFunction, - expressions: List[algebra_pb2.Expression]) -> algebra_pb2.Expression: + expressions: list[algebra_pb2.Expression]) -> algebra_pb2.Expression: """Constructs a Substrait concat expression.""" return algebra_pb2.Expression( scalar_function=algebra_pb2.Expression.ScalarFunction( diff --git a/src/gateway/converter/substrait_plan_visitor.py b/src/gateway/converter/substrait_plan_visitor.py index d441d21..1a83ede 100644 --- a/src/gateway/converter/substrait_plan_visitor.py +++ b/src/gateway/converter/substrait_plan_visitor.py @@ -2,9 +2,7 @@ """Abstract visitor class for Substrait plans.""" from typing import Any -from substrait.gen.proto import plan_pb2 -from substrait.gen.proto import algebra_pb2 -from substrait.gen.proto import type_pb2 +from substrait.gen.proto import algebra_pb2, plan_pb2, type_pb2 from substrait.gen.proto.extensions import extensions_pb2 @@ -717,11 +715,11 @@ def visit_exchange_relation(self, rel: algebra_pb2.ExchangeRel) -> Any: def visit_expand_relation(self, rel: algebra_pb2.ExpandRel) -> Any: """Visits an expand relation.""" if rel.HasField('common'): - return self.visit_relation_common(rel.common) + self.visit_relation_common(rel.common) if rel.HasField('input'): - return self.visit_relation(rel.input) + self.visit_relation(rel.input) for field in rel.fields: - return self.visit_expand_field(field) + self.visit_expand_field(field) # ExpandRel does not have an advanced_extension like other relations do. def visit_relation(self, rel: algebra_pb2.Rel) -> Any: diff --git a/src/gateway/converter/symbol_table.py b/src/gateway/converter/symbol_table.py index ebb1197..d173370 100644 --- a/src/gateway/converter/symbol_table.py +++ b/src/gateway/converter/symbol_table.py @@ -1,18 +1,18 @@ # SPDX-License-Identifier: Apache-2.0 """Routines to convert SparkConnect plans to Substrait plans.""" import dataclasses -from typing import Optional, List, Dict @dataclasses.dataclass class PlanMetadata: """Tracks various information about a specific plan id.""" + plan_id: int - type: Optional[str] - parent_plan_id: Optional[int] - input_fields: List[str] # And maybe type - generated_fields: List[str] - output_fields: List[str] + type: str | None + parent_plan_id: int | None + input_fields: list[str] # And maybe type + generated_fields: list[str] + output_fields: list[str] def __init__(self, plan_id: int): self.plan_id = plan_id @@ -25,13 +25,14 @@ def __init__(self, plan_id: int): class SymbolTable: """Manages metadata related to symbols and provides easy lookup.""" - _symbols: Dict[int, PlanMetadata] + + _symbols: dict[int, PlanMetadata] def __init__(self): self._symbols = {} # pylint: disable=E1101 - def add_symbol(self, plan_id: int, parent: Optional[int], symbol_type: Optional[str]): + def add_symbol(self, plan_id: int, parent: int | None, symbol_type: str | None): """Creates a new symbol and returns it.""" symbol = PlanMetadata(plan_id) symbol.symbol_type = symbol_type @@ -39,6 +40,6 @@ def add_symbol(self, plan_id: int, parent: Optional[int], symbol_type: Optional[ self._symbols[plan_id] = symbol return symbol - def get_symbol(self, plan_id: int) -> Optional[PlanMetadata]: + def get_symbol(self, plan_id: int) -> PlanMetadata | None: """Fetches the symbol with the requested plan id.""" return self._symbols.get(plan_id) diff --git a/src/gateway/converter/tools/duckdb_substrait_to_arrow.py b/src/gateway/converter/tools/duckdb_substrait_to_arrow.py index 7287bcb..8187aeb 100644 --- a/src/gateway/converter/tools/duckdb_substrait_to_arrow.py +++ b/src/gateway/converter/tools/duckdb_substrait_to_arrow.py @@ -2,12 +2,11 @@ """Converts the provided plans from the DuckDB Substrait dialect to Acero's.""" import sys -from google.protobuf import json_format -from substrait.gen.proto import plan_pb2 - from gateway.converter.label_relations import LabelRelations, UnlabelRelations from gateway.converter.output_field_tracking_visitor import OutputFieldTrackingVisitor from gateway.converter.simplify_casts import SimplifyCasts +from google.protobuf import json_format +from substrait.gen.proto import plan_pb2 # pylint: disable=E1101 @@ -27,6 +26,7 @@ def simplify_casts(substrait_plan: plan_pb2.Plan) -> plan_pb2.Plan: # pylint: disable=E1101 +# ruff: noqa: T201 def main(): """Converts the provided plans from the DuckDB Substrait dialect to Acero's.""" args = sys.argv[1:] @@ -40,7 +40,7 @@ def main(): arrow_plan = simplify_casts(duckdb_plan) - with open(args[1], "wt", encoding='utf-8') as file: + with open(args[1], "w", encoding='utf-8') as file: file.write(json_format.MessageToJson(arrow_plan)) diff --git a/src/gateway/converter/tools/tests/simplify_casts_test.py b/src/gateway/converter/tools/tests/simplify_casts_test.py index aabe461..d9bb9a7 100644 --- a/src/gateway/converter/tools/tests/simplify_casts_test.py +++ b/src/gateway/converter/tools/tests/simplify_casts_test.py @@ -2,13 +2,12 @@ """Tests for the Spark to Substrait plan conversion routines.""" from pathlib import Path -from google.protobuf import json_format, text_format import pytest +from gateway.converter.tools.duckdb_substrait_to_arrow import simplify_casts +from google.protobuf import json_format, text_format from hamcrest import assert_that, equal_to from substrait.gen.proto import plan_pb2 -from gateway.converter.tools.duckdb_substrait_to_arrow import simplify_casts - test_case_directory = Path(__file__).resolve().parent / 'data' test_case_paths = [f for f in test_case_directory.iterdir() if f.suffix == '.json'] @@ -38,7 +37,7 @@ def test_simplify_casts(request, path): if request.config.getoption('rebuild_goldens'): if arrow_plan != expected_plan: - with open(path.with_suffix('.golden'), "wt", encoding='utf-8') as file: + with open(path.with_suffix('.golden'), "w", encoding='utf-8') as file: file.write(json_format.MessageToJson(arrow_plan)) return diff --git a/src/gateway/converter/validation_test.py b/src/gateway/converter/validation_test.py index 4d5edd2..391eb69 100644 --- a/src/gateway/converter/validation_test.py +++ b/src/gateway/converter/validation_test.py @@ -2,11 +2,10 @@ """Validation for the Spark to Substrait plan conversion routines.""" from pathlib import Path -from google.protobuf import text_format import pytest -from substrait.gen.proto import plan_pb2 import substrait_validator - +from google.protobuf import text_format +from substrait.gen.proto import plan_pb2 test_case_directory = Path(__file__).resolve().parent / 'data' diff --git a/src/gateway/demo/client_demo.py b/src/gateway/demo/client_demo.py index 98275b0..40bb5e2 100644 --- a/src/gateway/demo/client_demo.py +++ b/src/gateway/demo/client_demo.py @@ -1,24 +1,25 @@ # SPDX-License-Identifier: Apache-2.0 """A PySpark client that can send sample queries to the gateway.""" -from pyspark.sql.functions import col -from pyspark.sql import SparkSession, DataFrame - from gateway.backends.backend import Backend +from pyspark.sql import DataFrame, SparkSession +from pyspark.sql.functions import col USE_GATEWAY = True # pylint: disable=fixme def get_customer_database(spark_session: SparkSession) -> DataFrame: + """Register the TPC-H customer database.""" location_customer = str(Backend.find_tpch() / 'customer') return spark_session.read.parquet(location_customer, mergeSchema=False) # pylint: disable=fixme +# ruff: noqa: T201 def execute_query(spark_session: SparkSession) -> None: - """Runs a single sample query against the gateway.""" + """Run a single sample query against the gateway.""" df_customer = get_customer_database(spark_session) # TODO -- Enable after named table registration is implemented. diff --git a/src/gateway/demo/mystream_database.py b/src/gateway/demo/mystream_database.py index 44f0c25..e9da821 100644 --- a/src/gateway/demo/mystream_database.py +++ b/src/gateway/demo/mystream_database.py @@ -1,70 +1,72 @@ # SPDX-License-Identifier: Apache-2.0 """Routines to create a fake mystream database for testing.""" +import contextlib from pathlib import Path -import pyarrow +import pyarrow as pa from faker import Faker from pyarrow import parquet TABLE_SCHEMAS = { - 'users': pyarrow.schema([ - pyarrow.field('user_id', pyarrow.string(), False), - pyarrow.field('name', pyarrow.string(), False), - pyarrow.field('paid_for_service', pyarrow.bool_(), False), + 'users': pa.schema([ + pa.field('user_id', pa.string(), False), + pa.field('name', pa.string(), False), + pa.field('paid_for_service', pa.bool_(), False), ], metadata={'user_id': 'A unique user id.', 'name': 'The user\'s name.', 'paid_for_service': 'Whether the user is considered up to date on payment.'}), - 'channels': pyarrow.schema([ - pyarrow.field('creator_id', pyarrow.string(), False), - pyarrow.field('channel_id', pyarrow.string(), False), - pyarrow.field('name', pyarrow.string(), False), - pyarrow.field('primary_category', pyarrow.string(), True), + 'channels': pa.schema([ + pa.field('creator_id', pa.string(), False), + pa.field('channel_id', pa.string(), False), + pa.field('name', pa.string(), False), + pa.field('primary_category', pa.string(), True), ]), - 'subscriptions': pyarrow.schema([ - pyarrow.field('user_id', pyarrow.string(), False), - pyarrow.field('channel_id', pyarrow.string(), False), + 'subscriptions': pa.schema([ + pa.field('user_id', pa.string(), False), + pa.field('channel_id', pa.string(), False), ]), - 'streams': pyarrow.schema([ - pyarrow.field('stream_id', pyarrow.string(), False), - pyarrow.field('channel_id', pyarrow.string(), False), - pyarrow.field('name', pyarrow.string(), False), + 'streams': pa.schema([ + pa.field('stream_id', pa.string(), False), + pa.field('channel_id', pa.string(), False), + pa.field('name', pa.string(), False), ]), - 'categories': pyarrow.schema([ - pyarrow.field('category_id', pyarrow.string(), False), - pyarrow.field('name', pyarrow.string(), False), - pyarrow.field('language', pyarrow.string(), False), + 'categories': pa.schema([ + pa.field('category_id', pa.string(), False), + pa.field('name', pa.string(), False), + pa.field('language', pa.string(), False), ]), - 'watches': pyarrow.schema([ - pyarrow.field('user_id', pyarrow.string(), False), - pyarrow.field('channel_id', pyarrow.string(), False), - pyarrow.field('stream_id', pyarrow.string(), False), - pyarrow.field('start_time', pyarrow.string(), False), - pyarrow.field('end_time', pyarrow.string(), True), + 'watches': pa.schema([ + pa.field('user_id', pa.string(), False), + pa.field('channel_id', pa.string(), False), + pa.field('stream_id', pa.string(), False), + pa.field('start_time', pa.string(), False), + pa.field('end_time', pa.string(), True), ]), } -def get_mystream_schema(name: str) -> pyarrow.Schema: - """Fetches the schema for the table with the requested name.""" +def get_mystream_schema(name: str) -> pa.Schema: + """Fetch the schema for the mystream table with the requested name.""" return TABLE_SCHEMAS[name] # pylint: disable=fixme,line-too-long def make_users_database(): - """Constructs the users table.""" + """Construct the users table.""" fake = Faker(['en_US']) rows = [] - # TODO -- Make the number of users, the uniqueness of userids, and the density of paid customers configurable. + # TODO -- Make the number and uniqueness of userids configurable. + # TODO -- Make the density of paid customers configurable. for _ in range(100): rows.append({'name': fake.name(), 'user_id': f'user{fake.unique.pyint(max_value=999999999):>09}', 'paid_for_service': fake.pybool(truth_probability=21)}) - table = pyarrow.Table.from_pylist(rows, schema=get_mystream_schema('users')) + table = pa.Table.from_pylist(rows, schema=get_mystream_schema('users')) parquet.write_table(table, 'users.parquet', version='2.4', flavor='spark', compression='NONE') def create_mystream_database() -> Path: - """Creates all the tables that make up the mystream database.""" + """Create all the tables that make up the mystream database.""" Faker.seed(9999) # Build all the tables in sorted order. make_users_database() @@ -72,10 +74,7 @@ def create_mystream_database() -> Path: def delete_mystream_database() -> None: - """Deletes all the tables related to the mystream database.""" + """Delete all the tables related to the mystream database.""" for table_name in TABLE_SCHEMAS: - try: + with contextlib.suppress(FileNotFoundError): Path(table_name + '.parquet').unlink() - except FileNotFoundError: - # We don't care if the file doesn't exist. - pass diff --git a/src/gateway/server.py b/src/gateway/server.py index d3c78ad..c4e1978 100644 --- a/src/gateway/server.py +++ b/src/gateway/server.py @@ -2,71 +2,71 @@ """SparkConnect server that drives a backend using Substrait.""" import io import logging +from collections.abc import Generator from concurrent import futures -from typing import Generator import grpc -import pyarrow +import pyarrow as pa import pyspark.sql.connect.proto.base_pb2 as pb2 import pyspark.sql.connect.proto.base_pb2_grpc as pb2_grpc from pyspark.sql.connect.proto import types_pb2 from gateway.backends.backend_selector import find_backend -from gateway.converter.conversion_options import duck_db, datafusion +from gateway.converter.conversion_options import datafusion, duck_db from gateway.converter.spark_to_substrait import SparkSubstraitConverter from gateway.converter.sql_to_substrait import convert_sql _LOGGER = logging.getLogger(__name__) -def show_string(table: pyarrow.lib.Table) -> bytes: - """Converts a table into a byte serialized single row string column Arrow Table.""" +def show_string(table: pa.lib.Table) -> bytes: + """Convert a table into a byte serialized single row string column Arrow Table.""" results_str = str(table) - schema = pyarrow.schema([('show_string', pyarrow.string())]) - array = pyarrow.array([results_str]) - batch = pyarrow.RecordBatch.from_arrays([array], schema=schema) - result_table = pyarrow.Table.from_batches([batch]) + schema = pa.schema([('show_string', pa.string())]) + array = pa.array([results_str]) + batch = pa.RecordBatch.from_arrays([array], schema=schema) + result_table = pa.Table.from_batches([batch]) buffer = io.BytesIO() - stream = pyarrow.RecordBatchStreamWriter(buffer, schema) + stream = pa.RecordBatchStreamWriter(buffer, schema) stream.write_table(result_table) stream.close() return buffer.getvalue() -def batch_to_bytes(batch: pyarrow.RecordBatch, schema: pyarrow.Schema) -> bytes: - """Serializes a RecordBatch into a bytes.""" - result_table = pyarrow.Table.from_batches(batches=[batch]) +def batch_to_bytes(batch: pa.RecordBatch, schema: pa.Schema) -> bytes: + """Serialize a RecordBatch into a bytes.""" + result_table = pa.Table.from_batches(batches=[batch]) buffer = io.BytesIO() - stream = pyarrow.RecordBatchStreamWriter(buffer, schema) + stream = pa.RecordBatchStreamWriter(buffer, schema) stream.write_table(result_table) stream.close() return buffer.getvalue() # pylint: disable=E1101 -def convert_pyarrow_schema_to_spark(schema: pyarrow.Schema) -> types_pb2.DataType: - """Converts a PyArrow schema to a SparkConnect DataType.Struct schema.""" +def convert_pyarrow_schema_to_spark(schema: pa.Schema) -> types_pb2.DataType: + """Convert a pyarrow schema to a SparkConnect DataType.Struct schema.""" fields = [] for field in schema: - if field.type == pyarrow.bool_(): + if field.type == pa.bool_(): data_type = types_pb2.DataType(boolean=types_pb2.DataType.Boolean()) - elif field.type == pyarrow.int8(): + elif field.type == pa.int8(): data_type = types_pb2.DataType(byte=types_pb2.DataType.Byte()) - elif field.type == pyarrow.int16(): + elif field.type == pa.int16(): data_type = types_pb2.DataType(integer=types_pb2.DataType.Short()) - elif field.type == pyarrow.int32(): + elif field.type == pa.int32(): data_type = types_pb2.DataType(integer=types_pb2.DataType.Integer()) - elif field.type == pyarrow.int64(): + elif field.type == pa.int64(): data_type = types_pb2.DataType(long=types_pb2.DataType.Long()) - elif field.type == pyarrow.float32(): + elif field.type == pa.float32(): data_type = types_pb2.DataType(float=types_pb2.DataType.Float()) - elif field.type == pyarrow.float64(): + elif field.type == pa.float64(): data_type = types_pb2.DataType(double=types_pb2.DataType.Double()) - elif field.type == pyarrow.string(): + elif field.type == pa.string(): data_type = types_pb2.DataType(string=types_pb2.DataType.String()) - elif field.type == pyarrow.timestamp('us'): + elif field.type == pa.timestamp('us'): data_type = types_pb2.DataType(timestamp=types_pb2.DataType.Timestamp()) - elif field.type == pyarrow.date32(): + elif field.type == pa.date32(): data_type = types_pb2.DataType(date=types_pb2.DataType.Date()) else: raise NotImplementedError( @@ -85,12 +85,14 @@ class SparkConnectService(pb2_grpc.SparkConnectServiceServicer): # pylint: disable=unused-argument def __init__(self, *args, **kwargs): + """Initialize the SparkConnect service.""" # This is the central point for configuring the behavior of the service. self._options = duck_db() def ExecutePlan( self, request: pb2.ExecutePlanRequest, context: grpc.RpcContext) -> Generator[ pb2.ExecutePlanResponse, None, None]: + """Execute the given plan and return the results.""" _LOGGER.info('ExecutePlan: %s', request) match request.plan.WhichOneof('op_type'): case 'root': @@ -101,8 +103,8 @@ def ExecutePlan( case 'sql_command': substrait = convert_sql(request.plan.command.sql_command.sql) case _: - raise NotImplementedError( - f'Unsupported command type: {request.plan.command.WhichOneof("command_type")}') + type = request.plan.command.WhichOneof("command_type") + raise NotImplementedError(f'Unsupported command type: {type}') case _: raise ValueError(f'Unknown plan type: {request.plan}') _LOGGER.debug(' as Substrait: %s', substrait) @@ -112,7 +114,7 @@ def ExecutePlan( if not self._options.implement_show_string and request.plan.WhichOneof( 'op_type') == 'root' and request.plan.root.WhichOneof( - 'rel_type') == 'show_string': + 'rel_type') == 'show_string': yield pb2.ExecutePlanResponse( session_id=request.session_id, arrow_batch=pb2.ExecutePlanResponse.ArrowBatch( @@ -143,10 +145,12 @@ def ExecutePlan( return def AnalyzePlan(self, request, context): + """Analyze the given plan and return the results.""" _LOGGER.info('AnalyzePlan: %s', request) return pb2.AnalyzePlanResponse(session_id=request.session_id) def Config(self, request, context): + """Get or set the configuration of the server.""" _LOGGER.info('Config: %s', request) response = pb2.ConfigResponse(session_id=request.session_id) match request.operation.WhichOneof('op_type'): @@ -167,32 +171,37 @@ def Config(self, request, context): return response def AddArtifacts(self, request_iterator, context): + """Add the given artifacts to the server.""" _LOGGER.info('AddArtifacts') return pb2.AddArtifactsResponse() def ArtifactStatus(self, request, context): + """Get the status of the given artifact.""" _LOGGER.info('ArtifactStatus') return pb2.ArtifactStatusesResponse() def Interrupt(self, request, context): + """Interrupt the execution of the given plan.""" _LOGGER.info('Interrupt') return pb2.InterruptResponse() def ReattachExecute( self, request: pb2.ReattachExecuteRequest, context: grpc.RpcContext) -> Generator[ pb2.ExecutePlanResponse, None, None]: + """Reattach the execution of the given plan.""" _LOGGER.info('ReattachExecute') yield pb2.ExecutePlanResponse( session_id=request.session_id, result_complete=pb2.ExecutePlanResponse.ResultComplete()) def ReleaseExecute(self, request, context): + """Release the execution of the given plan.""" _LOGGER.info('ReleaseExecute') return pb2.ReleaseExecuteResponse() def serve(port: int, wait: bool = True): - """Starts the SparkConnect to Substrait gateway server.""" + """Start the SparkConnect to Substrait gateway server.""" server = grpc.server(futures.ThreadPoolExecutor(max_workers=10)) pb2_grpc.add_SparkConnectServiceServicer_to_server(SparkConnectService(), server) server.add_insecure_port(f'[::]:{port}') diff --git a/src/gateway/tests/conftest.py b/src/gateway/tests/conftest.py index 63d1d64..1f3a466 100644 --- a/src/gateway/tests/conftest.py +++ b/src/gateway/tests/conftest.py @@ -2,15 +2,18 @@ """Test fixtures for pytest of the gateway server.""" from pathlib import Path -from pyspark.sql.pandas.types import from_arrow_schema -from pyspark.sql.session import SparkSession import pytest - -from gateway.demo.mystream_database import create_mystream_database, delete_mystream_database -from gateway.demo.mystream_database import get_mystream_schema +from gateway.demo.mystream_database import ( + create_mystream_database, + delete_mystream_database, + get_mystream_schema, +) from gateway.server import serve +from pyspark.sql.pandas.types import from_arrow_schema +from pyspark.sql.session import SparkSession +# ruff: noqa: T201 def _create_local_spark_session() -> SparkSession: """Creates a local spark session for testing.""" spark = ( diff --git a/src/gateway/tests/test_sql_api.py b/src/gateway/tests/test_sql_api.py index 9aee812..48deb01 100644 --- a/src/gateway/tests/test_sql_api.py +++ b/src/gateway/tests/test_sql_api.py @@ -3,13 +3,12 @@ from pathlib import Path import pytest +from gateway.backends.backend import Backend from hamcrest import assert_that, equal_to from pyspark import Row from pyspark.sql.session import SparkSession from pyspark.testing import assertDataFrameEqual -from gateway.backends.backend import Backend - test_case_directory = Path(__file__).resolve().parent / 'data' sql_test_case_paths = [f for f in sorted(test_case_directory.iterdir()) if f.suffix == '.sql'] From d24f4ada332838c731aa4cb8ae3ea46a9d4b74ae Mon Sep 17 00:00:00 2001 From: David Sisson Date: Wed, 17 Apr 2024 20:10:09 -0700 Subject: [PATCH 19/58] chore: fix the rest of the documentation related ruff issues (#44) --- pyproject.toml | 1 - src/gateway/converter/conversion_options.py | 5 +- src/gateway/converter/label_relations.py | 3 +- .../output_field_tracking_visitor.py | 21 +++-- src/gateway/converter/rename_functions.py | 2 +- src/gateway/converter/replace_local_files.py | 1 + src/gateway/converter/simplify_casts.py | 11 +-- src/gateway/converter/spark_functions.py | 4 +- src/gateway/converter/spark_to_substrait.py | 77 ++++++++++--------- src/gateway/converter/sql_to_substrait.py | 2 +- src/gateway/converter/substrait_builder.py | 50 ++++++------ src/gateway/converter/symbol_table.py | 6 +- .../tools/duckdb_substrait_to_arrow.py | 8 +- .../tools/tests/simplify_casts_test.py | 4 +- 14 files changed, 105 insertions(+), 90 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index d930ee4..78b1a05 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -65,7 +65,6 @@ unfixable = [ [tool.ruff.lint.per-file-ignores] "*test*.py" = ["D"] # ignore all docstring lints in tests "__init__.py" = ["D"] -"src/gateway/converter/*" = ["D"] [tool.pylint.MASTER] extension-pkg-allow-list = ["pyarrow.lib"] diff --git a/src/gateway/converter/conversion_options.py b/src/gateway/converter/conversion_options.py index d849633..0064bf6 100644 --- a/src/gateway/converter/conversion_options.py +++ b/src/gateway/converter/conversion_options.py @@ -19,6 +19,7 @@ class ConversionOptions: return_names_with_types: bool def __init__(self, backend: BackendOptions = None): + """Initialize the conversion options.""" self.use_named_table_workaround = False self.needs_scheme_in_path_uris = False self.use_emits_instead_of_direct = False @@ -31,12 +32,12 @@ def __init__(self, backend: BackendOptions = None): def datafusion(): - """Standard options to connect to a Datafusion backend.""" + """Return standard options to connect to a Datafusion backend.""" return ConversionOptions(backend=BackendOptions(Backend.DATAFUSION)) def duck_db(): - """Standard options to connect to a DuckDB backend.""" + """Return standard options to connect to a DuckDB backend.""" options = ConversionOptions(backend=BackendOptions(Backend.DUCKDB)) options.return_names_with_types = True return options diff --git a/src/gateway/converter/label_relations.py b/src/gateway/converter/label_relations.py index 3e618eb..b287d7c 100644 --- a/src/gateway/converter/label_relations.py +++ b/src/gateway/converter/label_relations.py @@ -8,7 +8,7 @@ # pylint: disable=E1101 def get_common_section(rel: algebra_pb2.Rel) -> algebra_pb2.RelCommon: - """Finds the single input to the relation.""" + """Find the single input to the relation.""" match rel.WhichOneof('rel_type'): case 'read': result = rel.read.common @@ -65,6 +65,7 @@ class LabelRelations(SubstraitPlanVisitor): _seen_relations: int def __init__(self): + """Initialize the LabelRelations visitor.""" super().__init__() self._seen_relations = 0 diff --git a/src/gateway/converter/output_field_tracking_visitor.py b/src/gateway/converter/output_field_tracking_visitor.py index a424e79..6ba2b90 100644 --- a/src/gateway/converter/output_field_tracking_visitor.py +++ b/src/gateway/converter/output_field_tracking_visitor.py @@ -10,7 +10,7 @@ # pylint: disable=E1101 def get_plan_id_from_common(common: algebra_pb2.RelCommon) -> int: - """Gets the plan ID from the common section.""" + """Get the plan ID from the common section.""" ref_rel = algebra_pb2.ReferenceRel() common.advanced_extension.optimization.Unpack(ref_rel) return ref_rel.subtree_ordinal @@ -18,22 +18,23 @@ def get_plan_id_from_common(common: algebra_pb2.RelCommon) -> int: # pylint: disable=E1101 def get_plan_id(rel: algebra_pb2.Rel) -> int: - """Gets the plan ID from the relation.""" + """Get the plan ID from the relation.""" common = get_common_section(rel) return get_plan_id_from_common(common) # pylint: disable=no-member,fixme class OutputFieldTrackingVisitor(SubstraitPlanVisitor): - """Replaces all cast expressions with projects of casts instead.""" + """Collect which field references are computed for each relation.""" def __init__(self): + """Initialize the visitor.""" super().__init__() self._current_plan_id: int | None = None # The relation currently being processed. self._symbol_table = SymbolTable() def update_field_references(self, plan_id: int) -> None: - """Uses the field references using the specified portion of the plan.""" + """Use the field references using the specified portion of the plan.""" source_symbol = self._symbol_table.get_symbol(plan_id) current_symbol = self._symbol_table.get_symbol(self._current_plan_id) current_symbol.input_fields.extend(source_symbol.output_fields) @@ -41,7 +42,7 @@ def update_field_references(self, plan_id: int) -> None: current_symbol.output_fields.extend(current_symbol.generated_fields) def visit_read_relation(self, rel: algebra_pb2.ReadRel) -> Any: - """Uses the field references from the read relation.""" + """Collect the field references from the read relation.""" super().visit_read_relation(rel) symbol = self._symbol_table.get_symbol(self._current_plan_id) # TODO -- Validate this logic where complicated data structures are used. @@ -49,14 +50,17 @@ def visit_read_relation(self, rel: algebra_pb2.ReadRel) -> Any: symbol.output_fields.append(field) def visit_filter_relation(self, rel: algebra_pb2.FilterRel) -> Any: + """Collect the field references from the filter relation.""" super().visit_filter_relation(rel) self.update_field_references(get_plan_id(rel.input)) def visit_fetch_relation(self, rel: algebra_pb2.FetchRel) -> Any: + """Collect the field references from the fetch relation.""" super().visit_fetch_relation(rel) self.update_field_references(get_plan_id(rel.input)) def visit_aggregate_relation(self, rel: algebra_pb2.AggregateRel) -> Any: + """Collect the field references from the aggregate relation.""" super().visit_aggregate_relation(rel) symbol = self._symbol_table.get_symbol(self._current_plan_id) for _ in rel.groupings: @@ -66,10 +70,12 @@ def visit_aggregate_relation(self, rel: algebra_pb2.AggregateRel) -> Any: self.update_field_references(get_plan_id(rel.input)) def visit_sort_relation(self, rel: algebra_pb2.SortRel) -> Any: + """Collect the field references from the sort relation.""" super().visit_sort_relation(rel) self.update_field_references(get_plan_id(rel.input)) def visit_project_relation(self, rel: algebra_pb2.ProjectRel) -> Any: + """Collect the field references from the project relation.""" super().visit_project_relation(rel) symbol = self._symbol_table.get_symbol(self._current_plan_id) for _ in rel.expressions: @@ -77,13 +83,14 @@ def visit_project_relation(self, rel: algebra_pb2.ProjectRel) -> Any: self.update_field_references(get_plan_id(rel.input)) def visit_extension_single_relation(self, rel: algebra_pb2.ExtensionSingleRel) -> Any: + """Collect the field references from the extension single relation.""" super().visit_extension_single_relation(rel) self.update_field_references(get_plan_id(rel.input)) # TODO -- Add the other relation types. def visit_relation(self, rel: algebra_pb2.Rel) -> Any: - """Visits a relation node.""" + """Visit a relation node.""" new_plan_id = get_plan_id(rel) self._symbol_table.add_symbol(new_plan_id, parent=self._current_plan_id, @@ -96,6 +103,6 @@ def visit_relation(self, rel: algebra_pb2.Rel) -> Any: self._current_plan_id = old_plan_id def visit_plan(self, plan: plan_pb2 .Plan) -> Any: - """Visits a plan node.""" + """Visit a plan node.""" super().visit_plan(plan) return self._symbol_table diff --git a/src/gateway/converter/rename_functions.py b/src/gateway/converter/rename_functions.py index a4dc126..9ed8bc9 100644 --- a/src/gateway/converter/rename_functions.py +++ b/src/gateway/converter/rename_functions.py @@ -9,7 +9,7 @@ class RenameFunctions(SubstraitPlanVisitor): """Renames Substrait functions to match what Datafusion expects.""" def visit_plan(self, plan: plan_pb2.Plan) -> None: - """Modifies the provided plan so that functions are Datafusion compatible.""" + """Modify the provided plan so that functions are Datafusion compatible.""" super().visit_plan(plan) for extension in plan.extensions: diff --git a/src/gateway/converter/replace_local_files.py b/src/gateway/converter/replace_local_files.py index dc20c68..6f33165 100644 --- a/src/gateway/converter/replace_local_files.py +++ b/src/gateway/converter/replace_local_files.py @@ -11,6 +11,7 @@ class ReplaceLocalFilesWithNamedTable(SubstraitPlanVisitor): """Replaces all of the local file instances with named tables.""" def __init__(self): + """Initialize the visitor.""" self._file_groups: list[tuple[str, list[str]]] = [] super().__init__() diff --git a/src/gateway/converter/simplify_casts.py b/src/gateway/converter/simplify_casts.py index abbe86c..5337117 100644 --- a/src/gateway/converter/simplify_casts.py +++ b/src/gateway/converter/simplify_casts.py @@ -13,6 +13,7 @@ class SimplifyCasts(SubstraitPlanVisitor): """Replaces all cast expressions with projects of casts instead.""" def __init__(self, symbol_table: SymbolTable): + """Initialize the visitor.""" super().__init__() self._symbol_table = symbol_table self._current_plan_id: int | None = None # The relation currently being processed. @@ -21,7 +22,7 @@ def __init__(self, symbol_table: SymbolTable): self._previous_rewrite_expressions: list[list[algebra_pb2.Expression]] = [] def visit_cast(self, cast: algebra_pb2.Expression.Cast) -> Any: - """Visits a cast node.""" + """Visit a cast node.""" super().visit_cast(cast) # Acero only accepts casts of selections. @@ -43,7 +44,7 @@ def visit_cast(self, cast: algebra_pb2.Expression.Cast) -> Any: @staticmethod def find_single_input(rel: algebra_pb2.Rel) -> algebra_pb2.Rel: - """Finds the single input to the relation.""" + """Find the single input to the relation.""" match rel.WhichOneof('rel_type'): case 'filter': return rel.filter.input @@ -63,7 +64,7 @@ def find_single_input(rel: algebra_pb2.Rel) -> algebra_pb2.Rel: @staticmethod def replace_single_input(rel: algebra_pb2.Rel, new_input: algebra_pb2.Rel): - """Updates the single input to the relation.""" + """Update the single input to the relation.""" match rel.WhichOneof('rel_type'): case 'filter': rel.filter.input.CopyFrom(new_input) @@ -82,14 +83,14 @@ def replace_single_input(rel: algebra_pb2.Rel, new_input: algebra_pb2.Rel): f'{rel.WhichOneof("rel_type")} are not implemented') def update_field_references(self, plan_id: int) -> None: - """Uses the field references using the specified portion of the plan.""" + """Use the field references using the specified portion of the plan.""" source_symbol = self._symbol_table.get_symbol(plan_id) current_symbol = self._symbol_table.get_symbol(self._current_plan_id) current_symbol.input_fields.extend(source_symbol.output_fields) current_symbol.output_fields.extend(current_symbol.input_fields) def visit_relation(self, rel: algebra_pb2.Rel) -> Any: - """Visits a relation node.""" + """Visit a relation node.""" previous_plan_id = self._current_plan_id self._current_plan_id = get_plan_id(rel) symbol = self._symbol_table.get_symbol(self._current_plan_id) diff --git a/src/gateway/converter/spark_functions.py b/src/gateway/converter/spark_functions.py index f495414..c4fcb0c 100644 --- a/src/gateway/converter/spark_functions.py +++ b/src/gateway/converter/spark_functions.py @@ -19,12 +19,14 @@ class ExtensionFunction: def __init__(self, uri: str, name: str, output_type: type_pb2.Type, max_args: int | None = None): + """Create the ExtensionFunction structure.""" self.uri = uri self.name = name self.output_type = output_type self.max_args = max_args def __lt__(self, obj) -> bool: + """Compare two ExtensionFunction objects.""" return self.uri < obj.uri and self.name < obj.name @@ -111,7 +113,7 @@ def __lt__(self, obj) -> bool: def lookup_spark_function(name: str, options: ConversionOptions) -> ExtensionFunction: - """Returns a Substrait function given a spark function name.""" + """Return a Substrait function given a spark function name.""" definition = SPARK_SUBSTRAIT_MAPPING.get(name) if not options.return_names_with_types: definition.name = definition.name.split(':', 1)[0] diff --git a/src/gateway/converter/spark_to_substrait.py b/src/gateway/converter/spark_to_substrait.py index afacb6d..46a02ba 100644 --- a/src/gateway/converter/spark_to_substrait.py +++ b/src/gateway/converter/spark_to_substrait.py @@ -6,7 +6,6 @@ import pathlib import pyarrow as pa -import pyarrow.parquet import pyspark.sql.connect.proto.base_pb2 as spark_pb2 import pyspark.sql.connect.proto.expressions_pb2 as spark_exprs_pb2 import pyspark.sql.connect.proto.relations_pb2 as spark_relations_pb2 @@ -52,6 +51,7 @@ class SparkSubstraitConverter: """Converts SparkConnect plans to Substrait plans.""" def __init__(self, options: ConversionOptions): + """Initialize the converter.""" self._function_uris: dict[str, int] = {} self._functions: dict[str, ExtensionFunction] = {} self._current_plan_id: int | None = None # The relation currently being processed. @@ -62,7 +62,7 @@ def __init__(self, options: ConversionOptions): self._saved_extensions = {} def lookup_function_by_name(self, name: str) -> ExtensionFunction: - """Finds the function reference for a given Spark function name.""" + """Find the function reference for a given Spark function name.""" if name in self._functions: return self._functions.get(name) func = lookup_spark_function(name, self._conversion_options) @@ -75,14 +75,14 @@ def lookup_function_by_name(self, name: str) -> ExtensionFunction: return self._functions.get(name) def update_field_references(self, plan_id: int) -> None: - """Uses the field references using the specified portion of the plan.""" + """Use the field references using the specified portion of the plan.""" source_symbol = self._symbol_table.get_symbol(plan_id) current_symbol = self._symbol_table.get_symbol(self._current_plan_id) current_symbol.input_fields.extend(source_symbol.output_fields) current_symbol.output_fields.extend(current_symbol.input_fields) def find_field_by_name(self, field_name: str) -> int | None: - """Looks up the field name in the current set of field references.""" + """Look up the field name in the current set of field references.""" current_symbol = self._symbol_table.get_symbol(self._current_plan_id) try: return current_symbol.output_fields.index(field_name) @@ -91,37 +91,37 @@ def find_field_by_name(self, field_name: str) -> int | None: def convert_boolean_literal( self, boolean: bool) -> algebra_pb2.Expression.Literal: - """Transforms a boolean into a Substrait expression literal.""" + """Transform a boolean into a Substrait expression literal.""" return algebra_pb2.Expression.Literal(boolean=boolean) def convert_short_literal( self, i: int) -> algebra_pb2.Expression.Literal: - """Transforms a short integer into a Substrait expression literal.""" + """Transform a short integer into a Substrait expression literal.""" return algebra_pb2.Expression.Literal(i16=i) def convert_integer_literal( self, i: int) -> algebra_pb2.Expression.Literal: - """Transforms an integer into a Substrait expression literal.""" + """Transform an integer into a Substrait expression literal.""" return algebra_pb2.Expression.Literal(i32=i) def convert_float_literal( self, f: float) -> algebra_pb2.Expression.Literal: - """Transforms a float into a Substrait expression literal.""" + """Transform a float into a Substrait expression literal.""" return algebra_pb2.Expression.Literal(fp32=f) def convert_double_literal( self, d: float) -> algebra_pb2.Expression.Literal: - """Transforms a double into a Substrait expression literal.""" + """Transform a double into a Substrait expression literal.""" return algebra_pb2.Expression.Literal(fp64=d) def convert_string_literal( self, s: str) -> algebra_pb2.Expression.Literal: - """Transforms a string into a Substrait expression literal.""" + """Transform a string into a Substrait expression literal.""" return algebra_pb2.Expression.Literal(string=s) def convert_literal_expression( self, literal: spark_exprs_pb2.Expression.Literal) -> algebra_pb2.Expression: - """Converts a Spark literal into a Substrait literal.""" + """Convert a Spark literal into a Substrait literal.""" match literal.WhichOneof('literal_type'): case 'null': # TODO -- Finish with the type implementation. @@ -168,7 +168,7 @@ def convert_literal_expression( def convert_unresolved_attribute( self, attr: spark_exprs_pb2.Expression.UnresolvedAttribute) -> algebra_pb2.Expression: - """Converts a Spark unresolved attribute into a Substrait field reference.""" + """Convert a Spark unresolved attribute into a Substrait field reference.""" field_ref = self.find_field_by_name(attr.unparsed_identifier) if field_ref is None: raise ValueError( @@ -185,7 +185,7 @@ def convert_unresolved_function( self, unresolved_function: spark_exprs_pb2.Expression.UnresolvedFunction) -> algebra_pb2.Expression: - """Converts a Spark unresolved function into a Substrait scalar function.""" + """Convert a Spark unresolved function into a Substrait scalar function.""" func = algebra_pb2.Expression.ScalarFunction() function_def = self.lookup_function_by_name(unresolved_function.function_name) func.function_reference = function_def.anchor @@ -202,12 +202,12 @@ def convert_unresolved_function( def convert_alias_expression( self, alias: spark_exprs_pb2.Expression.Alias) -> algebra_pb2.Expression: - """Converts a Spark alias into a Substrait expression.""" + """Convert a Spark alias into a Substrait expression.""" # TODO -- Utilize the alias name. return self.convert_expression(alias.expr) def convert_type_str(self, spark_type_str: str | None) -> type_pb2.Type: - """Converts a Spark type string into a Substrait type.""" + """Convert a Spark type string into a Substrait type.""" # TODO -- Properly handle nullability. match spark_type_str: case 'boolean': @@ -224,12 +224,12 @@ def convert_type_str(self, spark_type_str: str | None) -> type_pb2.Type: f'type {spark_type_str} not yet implemented.') def convert_type(self, spark_type: spark_types_pb2.DataType) -> type_pb2.Type: - """Converts a Spark type into a Substrait type.""" + """Convert a Spark type into a Substrait type.""" return self.convert_type_str(spark_type.WhichOneof('kind')) def convert_cast_expression( self, cast: spark_exprs_pb2.Expression.Cast) -> algebra_pb2.Expression: - """Converts a Spark cast expression into a Substrait cast expression.""" + """Convert a Spark cast expression into a Substrait cast expression.""" cast_rel = algebra_pb2.Expression.Cast(input=self.convert_expression(cast.expr)) match cast.WhichOneof('cast_to_type'): case 'type': @@ -243,7 +243,7 @@ def convert_cast_expression( return algebra_pb2.Expression(cast=cast_rel) def convert_expression(self, expr: spark_exprs_pb2.Expression) -> algebra_pb2.Expression: - """Converts a SparkConnect expression to a Substrait expression.""" + """Convert a SparkConnect expression to a Substrait expression.""" match expr.WhichOneof('expr_type'): case 'literal': result = self.convert_literal_expression(expr.literal) @@ -293,7 +293,7 @@ def convert_expression(self, expr: spark_exprs_pb2.Expression) -> algebra_pb2.Ex def convert_expression_to_aggregate_function( self, expr: spark_exprs_pb2.Expression) -> algebra_pb2.AggregateFunction: - """Converts a SparkConnect expression to a Substrait expression.""" + """Convert a SparkConnect expression to a Substrait expression.""" func = algebra_pb2.AggregateFunction() expression = self.convert_expression(expr) match expression.WhichOneof('rex_type'): @@ -312,11 +312,11 @@ def convert_expression_to_aggregate_function( return func def convert_read_named_table_relation(self, rel: spark_relations_pb2.Read) -> algebra_pb2.Rel: - """Converts a read named table relation to a Substrait relation.""" + """Convert a read named table relation to a Substrait relation.""" raise NotImplementedError('named tables are not yet implemented') def convert_schema(self, schema_str: str) -> type_pb2.NamedStruct | None: - """Converts the Spark JSON schema string into a Substrait named type structure.""" + """Convert the Spark JSON schema string into a Substrait named type structure.""" if not schema_str: return None # TODO -- Deal with potential denial of service due to malformed JSON. @@ -359,6 +359,7 @@ def convert_schema(self, schema_str: str) -> type_pb2.NamedStruct | None: return schema def convert_arrow_schema(self, arrow_schema: pa.Schema) -> type_pb2.NamedStruct: + """Convert an Arrow schema into a Substrait named type structure.""" schema = type_pb2.NamedStruct() schema.struct.nullability = type_pb2.Type.NULLABILITY_REQUIRED @@ -392,7 +393,7 @@ def convert_arrow_schema(self, arrow_schema: pa.Schema) -> type_pb2.NamedStruct: return schema def convert_read_data_source_relation(self, rel: spark_relations_pb2.Read) -> algebra_pb2.Rel: - """Converts a read data source relation into a Substrait relation.""" + """Convert a read data source relation into a Substrait relation.""" local = algebra_pb2.ReadRel.LocalFiles() schema = self.convert_schema(rel.schema) if not schema: @@ -448,7 +449,7 @@ def convert_read_data_source_relation(self, rel: spark_relations_pb2.Read) -> al return algebra_pb2.Rel(read=algebra_pb2.ReadRel(base_schema=schema, local_files=local)) def create_common_relation(self, emit_overrides=None) -> algebra_pb2.RelCommon: - """Creates the common metadata relation used by all relations.""" + """Create the common metadata relation used by all relations.""" if not self._conversion_options.use_emits_instead_of_direct: return algebra_pb2.RelCommon(direct=algebra_pb2.RelCommon.Direct()) symbol = self._symbol_table.get_symbol(self._current_plan_id) @@ -464,7 +465,7 @@ def create_common_relation(self, emit_overrides=None) -> algebra_pb2.RelCommon: return algebra_pb2.RelCommon(emit=emit) def convert_read_relation(self, rel: spark_relations_pb2.Read) -> algebra_pb2.Rel: - """Converts a read relation into a Substrait relation.""" + """Convert a read relation into a Substrait relation.""" match rel.WhichOneof('read_type'): case 'named_table': result = self.convert_read_named_table_relation(rel.named_table) @@ -476,7 +477,7 @@ def convert_read_relation(self, rel: spark_relations_pb2.Read) -> algebra_pb2.Re return result def convert_filter_relation(self, rel: spark_relations_pb2.Filter) -> algebra_pb2.Rel: - """Converts a filter relation into a Substrait relation.""" + """Convert a filter relation into a Substrait relation.""" filter_rel = algebra_pb2.FilterRel(input=self.convert_relation(rel.input)) self.update_field_references(rel.input.common.plan_id) filter_rel.common.CopyFrom(self.create_common_relation()) @@ -484,7 +485,7 @@ def convert_filter_relation(self, rel: spark_relations_pb2.Filter) -> algebra_pb return algebra_pb2.Rel(filter=filter_rel) def convert_sort_relation(self, rel: spark_relations_pb2.Sort) -> algebra_pb2.Rel: - """Converts a sort relation into a Substrait relation.""" + """Convert a sort relation into a Substrait relation.""" sort = algebra_pb2.SortRel(input=self.convert_relation(rel.input)) self.update_field_references(rel.input.common.plan_id) sort.common.CopyFrom(self.create_common_relation()) @@ -505,7 +506,7 @@ def convert_sort_relation(self, rel: spark_relations_pb2.Sort) -> algebra_pb2.Re return algebra_pb2.Rel(sort=sort) def convert_limit_relation(self, rel: spark_relations_pb2.Limit) -> algebra_pb2.Rel: - """Converts a limit relation into a Substrait FetchRel relation.""" + """Convert a limit relation into a Substrait FetchRel relation.""" input_relation = self.convert_relation(rel.input) self.update_field_references(rel.input.common.plan_id) fetch = algebra_pb2.FetchRel(common=self.create_common_relation(), input=input_relation, @@ -513,7 +514,7 @@ def convert_limit_relation(self, rel: spark_relations_pb2.Limit) -> algebra_pb2. return algebra_pb2.Rel(fetch=fetch) def determine_expression_name(self, expr: spark_exprs_pb2.Expression) -> str | None: - """Determines the name of the expression.""" + """Determine the name of the expression.""" if expr.HasField('alias'): return expr.alias.name[0] self._seen_generated_names.setdefault('aggregate_expression', 0) @@ -521,7 +522,7 @@ def determine_expression_name(self, expr: spark_exprs_pb2.Expression) -> str | N return f'aggregate_expression{self._seen_generated_names["aggregate_expression"]}' def convert_aggregate_relation(self, rel: spark_relations_pb2.Aggregate) -> algebra_pb2.Rel: - """Converts an aggregate relation into a Substrait relation.""" + """Convert an aggregate relation into a Substrait relation.""" aggregate = algebra_pb2.AggregateRel(input=self.convert_relation(rel.input)) self.update_field_references(rel.input.common.plan_id) aggregate.common.CopyFrom(self.create_common_relation()) @@ -544,7 +545,7 @@ def convert_aggregate_relation(self, rel: spark_relations_pb2.Aggregate) -> alge # pylint: disable=too-many-locals,pointless-string-statement def convert_show_string_relation(self, rel: spark_relations_pb2.ShowString) -> algebra_pb2.Rel: - """Converts a show string relation into a Substrait subplan.""" + """Convert a show string relation into a Substrait subplan.""" if not self._conversion_options.implement_show_string: result = self.convert_relation(rel.input) self.update_field_references(rel.input.common.plan_id) @@ -725,7 +726,7 @@ def compute_row_count_footer(num_rows: int) -> str: def convert_with_columns_relation( self, rel: spark_relations_pb2.WithColumns) -> algebra_pb2.Rel: - """Converts a with columns relation into a Substrait project relation.""" + """Convert a with columns relation into a Substrait project relation.""" input_rel = self.convert_relation(rel.input) project = algebra_pb2.ProjectRel(input=input_rel) self.update_field_references(rel.input.common.plan_id) @@ -754,7 +755,7 @@ def convert_with_columns_relation( return algebra_pb2.Rel(project=project) def convert_to_df_relation(self, rel: spark_relations_pb2.ToDF) -> algebra_pb2.Rel: - """Converts a to dataframe relation into a Substrait project relation.""" + """Convert a to dataframe relation into a Substrait project relation.""" input_rel = self.convert_relation(rel.input) project = algebra_pb2.ProjectRel(input=input_rel) self.update_field_references(rel.input.common.plan_id) @@ -769,7 +770,7 @@ def convert_to_df_relation(self, rel: spark_relations_pb2.ToDF) -> algebra_pb2.R return algebra_pb2.Rel(project=project) def convert_arrow_to_literal(self, val: pa.Scalar) -> algebra_pb2.Expression.Literal: - """Converts an Arrow scalar into a Substrait literal.""" + """Convert an Arrow scalar into a Substrait literal.""" literal = algebra_pb2.Expression.Literal() if isinstance(val, pa.BooleanScalar): literal.boolean = val.as_py() @@ -782,7 +783,7 @@ def convert_arrow_to_literal(self, val: pa.Scalar) -> algebra_pb2.Expression.Lit def convert_arrow_data_to_virtual_table(self, data: bytes) -> algebra_pb2.ReadRel.VirtualTable: - """Converts a Spark local relation into a virtual table.""" + """Convert a Spark local relation into a virtual table.""" table = algebra_pb2.ReadRel.VirtualTable() # Use pyarrow to convert the bytes into an arrow structure. with pa.ipc.open_stream(data) as arrow: @@ -795,7 +796,7 @@ def convert_arrow_data_to_virtual_table(self, return table def convert_local_relation(self, rel: spark_relations_pb2.LocalRelation) -> algebra_pb2.Rel: - """Converts a Spark local relation into a virtual table.""" + """Convert a Spark local relation into a virtual table.""" read = algebra_pb2.ReadRel( virtual_table=self.convert_arrow_data_to_virtual_table(rel.data)) schema = self.convert_schema(rel.schema) @@ -807,7 +808,7 @@ def convert_local_relation(self, rel: spark_relations_pb2.LocalRelation) -> alge return algebra_pb2.Rel(read=read) def convert_sql_relation(self, rel: spark_relations_pb2.SQL) -> algebra_pb2.Rel: - """Converts a Spark SQL relation into a Substrait relation.""" + """Convert a Spark SQL relation into a Substrait relation.""" plan = convert_sql(rel.query) symbol = self._symbol_table.get_symbol(self._current_plan_id) for field_name in plan.relations[0].root.names: @@ -820,7 +821,7 @@ def convert_sql_relation(self, rel: spark_relations_pb2.SQL) -> algebra_pb2.Rel: return plan.relations[0].root.input def convert_relation(self, rel: spark_relations_pb2.Relation) -> algebra_pb2.Rel: - """Converts a Spark relation into a Substrait one.""" + """Convert a Spark relation into a Substrait one.""" self._symbol_table.add_symbol(rel.common.plan_id, parent=self._current_plan_id, symbol_type=rel.WhichOneof('rel_type')) old_plan_id = self._current_plan_id @@ -853,7 +854,7 @@ def convert_relation(self, rel: spark_relations_pb2.Relation) -> algebra_pb2.Rel return result def convert_plan(self, plan: spark_pb2.Plan) -> plan_pb2.Plan: - """Converts a Spark plan into a Substrait plan.""" + """Convert a Spark plan into a Substrait plan.""" result = plan_pb2.Plan() result.version.CopyFrom( plan_pb2.Version(minor_number=42, producer='spark-substrait-gateway')) diff --git a/src/gateway/converter/sql_to_substrait.py b/src/gateway/converter/sql_to_substrait.py index 2833c1f..0b12c2e 100644 --- a/src/gateway/converter/sql_to_substrait.py +++ b/src/gateway/converter/sql_to_substrait.py @@ -6,7 +6,7 @@ def convert_sql(sql: str) -> plan_pb2.Plan: - """Converts SQL into a Substrait plan.""" + """Convert SQL into a Substrait plan.""" result = plan_pb2.Plan() backend = backend_selector.find_backend(BackendOptions(Backend.DUCKDB, False)) diff --git a/src/gateway/converter/substrait_builder.py b/src/gateway/converter/substrait_builder.py index c8510b2..6ab9ae5 100644 --- a/src/gateway/converter/substrait_builder.py +++ b/src/gateway/converter/substrait_builder.py @@ -8,20 +8,20 @@ def flatten(list_of_lists: list[list[Any]]) -> list[Any]: - """Flattens a list of lists into a list.""" + """Flatten a list of lists into a list.""" return list(itertools.chain.from_iterable(list_of_lists)) # pylint: disable=E1101 def fetch_relation(input_relation: algebra_pb2.Rel, num_rows: int) -> algebra_pb2.Rel: - """Constructs a Substrait fetch plan node.""" + """Construct a Substrait fetch plan node.""" return algebra_pb2.Rel(fetch=algebra_pb2.FetchRel(input=input_relation, count=num_rows)) def project_relation(input_relation: algebra_pb2.Rel, expressions: list[algebra_pb2.Expression]) -> algebra_pb2.Rel: - """Constructs a Substrait project plan node.""" + """Construct a Substrait project plan node.""" return algebra_pb2.Rel( project=algebra_pb2.ProjectRel(input=input_relation, expressions=expressions)) @@ -29,7 +29,7 @@ def project_relation(input_relation: algebra_pb2.Rel, # pylint: disable=fixme def aggregate_relation(input_relation: algebra_pb2.Rel, measures: list[algebra_pb2.AggregateFunction]) -> algebra_pb2.Rel: - """Constructs a Substrait aggregate plan node.""" + """Construct a Substrait aggregate plan node.""" aggregate = algebra_pb2.Rel( aggregate=algebra_pb2.AggregateRel( common=algebra_pb2.RelCommon(emit=algebra_pb2.RelCommon.Emit( @@ -43,7 +43,7 @@ def aggregate_relation(input_relation: algebra_pb2.Rel, def join_relation(left: algebra_pb2.Rel, right: algebra_pb2.Rel) -> algebra_pb2.Rel: - """Constructs a Substrait join plan node.""" + """Construct a Substrait join plan node.""" return algebra_pb2.Rel( join=algebra_pb2.JoinRel(common=algebra_pb2.RelCommon(), left=left, right=right, expression=algebra_pb2.Expression( @@ -53,7 +53,7 @@ def join_relation(left: algebra_pb2.Rel, right: algebra_pb2.Rel) -> algebra_pb2. def concat(function_info: ExtensionFunction, expressions: list[algebra_pb2.Expression]) -> algebra_pb2.Expression: - """Constructs a Substrait concat expression.""" + """Construct a Substrait concat expression.""" return algebra_pb2.Expression( scalar_function=algebra_pb2.Expression.ScalarFunction( function_reference=function_info.anchor, @@ -64,7 +64,7 @@ def concat(function_info: ExtensionFunction, def strlen(function_info: ExtensionFunction, expression: algebra_pb2.Expression) -> algebra_pb2.Expression: - """Constructs a Substrait concat expression.""" + """Construct a Substrait concat expression.""" return algebra_pb2.Expression( scalar_function=algebra_pb2.Expression.ScalarFunction( function_reference=function_info.anchor, @@ -74,7 +74,7 @@ def strlen(function_info: ExtensionFunction, def cast_operation(expression: algebra_pb2.Expression, output_type: type_pb2.Type) -> algebra_pb2.Expression: - """Constructs a Substrait cast expression.""" + """Construct a Substrait cast expression.""" return algebra_pb2.Expression( cast=algebra_pb2.Expression.Cast(input=expression, type=output_type) ) @@ -82,7 +82,7 @@ def cast_operation(expression: algebra_pb2.Expression, def if_then_else_operation(if_expr: algebra_pb2.Expression, then_expr: algebra_pb2.Expression, else_expr: algebra_pb2.Expression) -> algebra_pb2.Expression: - """Constructs a simplistic Substrait if-then-else expression.""" + """Construct a simplistic Substrait if-then-else expression.""" return algebra_pb2.Expression( if_then=algebra_pb2.Expression.IfThen( **{'ifs': [ @@ -92,7 +92,7 @@ def if_then_else_operation(if_expr: algebra_pb2.Expression, then_expr: algebra_p def field_reference(field_number: int) -> algebra_pb2.Expression: - """Constructs a Substrait field reference expression.""" + """Construct a Substrait field reference expression.""" return algebra_pb2.Expression( selection=algebra_pb2.Expression.FieldReference( direct_reference=algebra_pb2.Expression.ReferenceSegment( @@ -102,7 +102,7 @@ def field_reference(field_number: int) -> algebra_pb2.Expression: def max_agg_function(function_info: ExtensionFunction, field_number: int) -> algebra_pb2.AggregateFunction: - """Constructs a Substrait max aggregate function.""" + """Construct a Substrait max aggregate function.""" # TODO -- Reorganize all functions to belong to a class which determines the info. return algebra_pb2.AggregateFunction( function_reference=function_info.anchor, @@ -113,7 +113,7 @@ def max_agg_function(function_info: ExtensionFunction, def string_concat_agg_function(function_info: ExtensionFunction, field_number: int, separator: str = '') -> algebra_pb2.AggregateFunction: - """Constructs a Substrait string concat aggregate function.""" + """Construct a Substrait string concat aggregate function.""" return algebra_pb2.AggregateFunction( function_reference=function_info.anchor, output_type=function_info.output_type, @@ -123,7 +123,7 @@ def string_concat_agg_function(function_info: ExtensionFunction, def least_function(greater_function_info: ExtensionFunction, expr1: algebra_pb2.Expression, expr2: algebra_pb2.Expression) -> algebra_pb2.Expression: - """Constructs a Substrait min expression.""" + """Construct a Substrait min expression.""" return if_then_else_operation( greater_function(greater_function_info, expr1, expr2), expr2, @@ -133,7 +133,7 @@ def least_function(greater_function_info: ExtensionFunction, expr1: algebra_pb2. def greatest_function(greater_function_info: ExtensionFunction, expr1: algebra_pb2.Expression, expr2: algebra_pb2.Expression) -> algebra_pb2.Expression: - """Constructs a Substrait max expression.""" + """Construct a Substrait max expression.""" return if_then_else_operation( greater_function(greater_function_info, expr1, expr2), expr1, @@ -144,7 +144,7 @@ def greatest_function(greater_function_info: ExtensionFunction, expr1: algebra_p def greater_or_equal_function(function_info: ExtensionFunction, expr1: algebra_pb2.Expression, expr2: algebra_pb2.Expression) -> algebra_pb2.Expression: - """Constructs a Substrait min expression.""" + """Construct a Substrait min expression.""" return algebra_pb2.Expression(scalar_function= algebra_pb2.Expression.ScalarFunction( function_reference=function_info.anchor, @@ -156,7 +156,7 @@ def greater_or_equal_function(function_info: ExtensionFunction, def greater_function(function_info: ExtensionFunction, expr1: algebra_pb2.Expression, expr2: algebra_pb2.Expression) -> algebra_pb2.Expression: - """Constructs a Substrait min expression.""" + """Construct a Substrait min expression.""" return algebra_pb2.Expression(scalar_function= algebra_pb2.Expression.ScalarFunction( function_reference=function_info.anchor, @@ -168,7 +168,7 @@ def greater_function(function_info: ExtensionFunction, def minus_function(function_info: ExtensionFunction, expr1: algebra_pb2.Expression, expr2: algebra_pb2.Expression) -> algebra_pb2.Expression: - """Constructs a Substrait min expression.""" + """Construct a Substrait min expression.""" return algebra_pb2.Expression(scalar_function= algebra_pb2.Expression.ScalarFunction( function_reference=function_info.anchor, @@ -180,7 +180,7 @@ def minus_function(function_info: ExtensionFunction, def repeat_function(function_info: ExtensionFunction, string: str, count: algebra_pb2.Expression) -> algebra_pb2.AggregateFunction: - """Constructs a Substrait concat expression.""" + """Construct a Substrait concat expression.""" return algebra_pb2.Expression(scalar_function= algebra_pb2.Expression.ScalarFunction( function_reference=function_info.anchor, @@ -192,7 +192,7 @@ def repeat_function(function_info: ExtensionFunction, def lpad_function(function_info: ExtensionFunction, expression: algebra_pb2.Expression, count: algebra_pb2.Expression, pad_string: str = ' ') -> algebra_pb2.AggregateFunction: - """Constructs a Substrait concat expression.""" + """Construct a Substrait concat expression.""" # TODO -- Avoid a cast if we don't need it. cast_type = string_type() return algebra_pb2.Expression(scalar_function= @@ -209,7 +209,7 @@ def lpad_function(function_info: ExtensionFunction, def rpad_function(function_info: ExtensionFunction, expression: algebra_pb2.Expression, count: algebra_pb2.Expression, pad_string: str = ' ') -> algebra_pb2.AggregateFunction: - """Constructs a Substrait concat expression.""" + """Construct a Substrait concat expression.""" # TODO -- Avoid a cast if we don't need it. cast_type = string_type() return algebra_pb2.Expression(scalar_function= @@ -224,17 +224,17 @@ def rpad_function(function_info: ExtensionFunction, def string_literal(val: str) -> algebra_pb2.Expression: - """Constructs a Substrait string literal expression.""" + """Construct a Substrait string literal expression.""" return algebra_pb2.Expression(literal=algebra_pb2.Expression.Literal(string=val)) def bigint_literal(val: int) -> algebra_pb2.Expression: - """Constructs a Substrait string literal expression.""" + """Construct a Substrait string literal expression.""" return algebra_pb2.Expression(literal=algebra_pb2.Expression.Literal(i64=val)) def string_type(required: bool = True) -> type_pb2.Type: - """Constructs a Substrait string type.""" + """Construct a Substrait string type.""" if required: nullability = type_pb2.Type.Nullability.NULLABILITY_REQUIRED else: @@ -243,7 +243,7 @@ def string_type(required: bool = True) -> type_pb2.Type: def varchar_type(length: int = 1000, required: bool = True) -> type_pb2.Type: - """Constructs a Substrait varchar type.""" + """Construct a Substrait varchar type.""" if required: nullability = type_pb2.Type.Nullability.NULLABILITY_REQUIRED else: @@ -252,7 +252,7 @@ def varchar_type(length: int = 1000, required: bool = True) -> type_pb2.Type: def integer_type(required: bool = True) -> type_pb2.Type: - """Constructs a Substrait i32 type.""" + """Construct a Substrait i32 type.""" if required: nullability = type_pb2.Type.Nullability.NULLABILITY_REQUIRED else: diff --git a/src/gateway/converter/symbol_table.py b/src/gateway/converter/symbol_table.py index d173370..1b7eb3c 100644 --- a/src/gateway/converter/symbol_table.py +++ b/src/gateway/converter/symbol_table.py @@ -15,6 +15,7 @@ class PlanMetadata: output_fields: list[str] def __init__(self, plan_id: int): + """Create the PlanMetadata structure.""" self.plan_id = plan_id self.symbol_type = None self.parent_plan_id = None @@ -29,11 +30,12 @@ class SymbolTable: _symbols: dict[int, PlanMetadata] def __init__(self): + """Initialize the symbol table.""" self._symbols = {} # pylint: disable=E1101 def add_symbol(self, plan_id: int, parent: int | None, symbol_type: str | None): - """Creates a new symbol and returns it.""" + """Create a new symbol and returns it.""" symbol = PlanMetadata(plan_id) symbol.symbol_type = symbol_type symbol.parent_plan_id = parent @@ -41,5 +43,5 @@ def add_symbol(self, plan_id: int, parent: int | None, symbol_type: str | None): return symbol def get_symbol(self, plan_id: int) -> PlanMetadata | None: - """Fetches the symbol with the requested plan id.""" + """Fetch the symbol with the requested plan id.""" return self._symbols.get(plan_id) diff --git a/src/gateway/converter/tools/duckdb_substrait_to_arrow.py b/src/gateway/converter/tools/duckdb_substrait_to_arrow.py index 8187aeb..455c551 100644 --- a/src/gateway/converter/tools/duckdb_substrait_to_arrow.py +++ b/src/gateway/converter/tools/duckdb_substrait_to_arrow.py @@ -10,8 +10,8 @@ # pylint: disable=E1101 -def simplify_casts(substrait_plan: plan_pb2.Plan) -> plan_pb2.Plan: - """Simplifies the casts in the provided Substrait plan.""" +def simplify_substrait_dialect(substrait_plan: plan_pb2.Plan) -> plan_pb2.Plan: + """Translate a DuckDB dialect Substrait plan to an Arrow friendly one.""" modified_plan = plan_pb2.Plan() modified_plan.CopyFrom(substrait_plan) # Add plan ids to every relation. @@ -28,7 +28,7 @@ def simplify_casts(substrait_plan: plan_pb2.Plan) -> plan_pb2.Plan: # pylint: disable=E1101 # ruff: noqa: T201 def main(): - """Converts the provided plans from the DuckDB Substrait dialect to Acero's.""" + """Convert the provided plans from the DuckDB Substrait dialect to Acero's.""" args = sys.argv[1:] if len(args) != 2: print("Usage: python duckdb_substrait_to_arrow.py ") @@ -38,7 +38,7 @@ def main(): plan_prototext = file.read() duckdb_plan = json_format.Parse(plan_prototext, plan_pb2.Plan()) - arrow_plan = simplify_casts(duckdb_plan) + arrow_plan = simplify_substrait_dialect(duckdb_plan) with open(args[1], "w", encoding='utf-8') as file: file.write(json_format.MessageToJson(arrow_plan)) diff --git a/src/gateway/converter/tools/tests/simplify_casts_test.py b/src/gateway/converter/tools/tests/simplify_casts_test.py index d9bb9a7..ae91237 100644 --- a/src/gateway/converter/tools/tests/simplify_casts_test.py +++ b/src/gateway/converter/tools/tests/simplify_casts_test.py @@ -3,7 +3,7 @@ from pathlib import Path import pytest -from gateway.converter.tools.duckdb_substrait_to_arrow import simplify_casts +from gateway.converter.tools.duckdb_substrait_to_arrow import simplify_substrait_dialect from google.protobuf import json_format, text_format from hamcrest import assert_that, equal_to from substrait.gen.proto import plan_pb2 @@ -33,7 +33,7 @@ def test_simplify_casts(request, path): splan_prototext = file.read() expected_plan = json_format.Parse(splan_prototext, plan_pb2.Plan()) - arrow_plan = simplify_casts(source_plan) + arrow_plan = simplify_substrait_dialect(source_plan) if request.config.getoption('rebuild_goldens'): if arrow_plan != expected_plan: From 5192147f73d6e5c5dcae9d487eaa3c04bc2305e4 Mon Sep 17 00:00:00 2001 From: David Sisson Date: Fri, 19 Apr 2024 11:01:51 -0700 Subject: [PATCH 20/58] feat: add side by side TPC-H tests that are pyspark dataframe-based (#45) --- src/gateway/tests/conftest.py | 23 + src/gateway/tests/test_sql_api.py | 36 +- .../tests/test_tpch_with_dataframe_api.py | 662 ++++++++++++++++++ 3 files changed, 691 insertions(+), 30 deletions(-) create mode 100644 src/gateway/tests/test_tpch_with_dataframe_api.py diff --git a/src/gateway/tests/conftest.py b/src/gateway/tests/conftest.py index 1f3a466..5d82e92 100644 --- a/src/gateway/tests/conftest.py +++ b/src/gateway/tests/conftest.py @@ -3,6 +3,7 @@ from pathlib import Path import pytest +from gateway.backends.backend import Backend from gateway.demo.mystream_database import ( create_mystream_database, delete_mystream_database, @@ -112,3 +113,25 @@ def users_dataframe(spark_session, schema_users, users_location): return spark_session.read.format('parquet') \ .schema(from_arrow_schema(schema_users)) \ .parquet(users_location) + + +def _register_table(spark_session: SparkSession, name: str) -> None: + location = Backend.find_tpch() / name + spark_session.sql( + f'CREATE OR REPLACE TEMPORARY VIEW {name} USING org.apache.spark.sql.parquet ' + f'OPTIONS ( path "{location}" )') + + +@pytest.fixture(scope='function') +def spark_session_with_tpch_dataset(spark_session: SparkSession, source: str) -> SparkSession: + """Add the TPC-H dataset to the current spark session.""" + if source == 'spark': + _register_table(spark_session, 'customer') + _register_table(spark_session, 'lineitem') + _register_table(spark_session, 'nation') + _register_table(spark_session, 'orders') + _register_table(spark_session, 'part') + _register_table(spark_session, 'partsupp') + _register_table(spark_session, 'region') + _register_table(spark_session, 'supplier') + return spark_session diff --git a/src/gateway/tests/test_sql_api.py b/src/gateway/tests/test_sql_api.py index 48deb01..37b9a12 100644 --- a/src/gateway/tests/test_sql_api.py +++ b/src/gateway/tests/test_sql_api.py @@ -3,10 +3,8 @@ from pathlib import Path import pytest -from gateway.backends.backend import Backend from hamcrest import assert_that, equal_to from pyspark import Row -from pyspark.sql.session import SparkSession from pyspark.testing import assertDataFrameEqual test_case_directory = Path(__file__).resolve().parent / 'data' @@ -16,39 +14,17 @@ sql_test_case_names = [p.stem for p in sql_test_case_paths] -def _register_table(spark_session: SparkSession, name: str) -> None: - location = Backend.find_tpch() / name - spark_session.sql( - f'CREATE OR REPLACE TEMPORARY VIEW {name} USING org.apache.spark.sql.parquet ' - f'OPTIONS ( path "{location}" )') - - -@pytest.fixture(scope='function') -def spark_session_with_customer_database(spark_session: SparkSession, source: str) -> SparkSession: - """Creates a temporary view of the customer database.""" - if source == 'spark': - _register_table(spark_session, 'customer') - _register_table(spark_session, 'lineitem') - _register_table(spark_session, 'nation') - _register_table(spark_session, 'orders') - _register_table(spark_session, 'part') - _register_table(spark_session, 'partsupp') - _register_table(spark_session, 'region') - _register_table(spark_session, 'supplier') - return spark_session - - # pylint: disable=missing-function-docstring # ruff: noqa: E712 class TestSqlAPI: """Tests of the SQL side of SparkConnect.""" - def test_count(self, spark_session_with_customer_database): - outcome = spark_session_with_customer_database.sql( + def test_count(self, spark_session_with_tpch_dataset): + outcome = spark_session_with_tpch_dataset.sql( 'SELECT COUNT(*) FROM customer').collect() assert_that(outcome[0][0], equal_to(149999)) - def test_limit(self, spark_session_with_customer_database): + def test_limit(self, spark_session_with_tpch_dataset): expected = [ Row(c_custkey=2, c_phone='23-768-687-3665', c_mktsegment='AUTOMOBILE'), Row(c_custkey=3, c_phone='11-719-748-3364', c_mktsegment='AUTOMOBILE'), @@ -56,7 +32,7 @@ def test_limit(self, spark_session_with_customer_database): Row(c_custkey=5, c_phone='13-750-942-6364', c_mktsegment='HOUSEHOLD'), Row(c_custkey=6, c_phone='30-114-968-4951', c_mktsegment='AUTOMOBILE'), ] - outcome = spark_session_with_customer_database.sql( + outcome = spark_session_with_tpch_dataset.sql( 'SELECT c_custkey, c_phone, c_mktsegment FROM customer LIMIT 5').collect() assertDataFrameEqual(outcome, expected) @@ -66,10 +42,10 @@ def test_limit(self, spark_session_with_customer_database): sql_test_case_paths, ids=sql_test_case_names, ) - def test_tpch(self, spark_session_with_customer_database, path): + def test_tpch(self, spark_session_with_tpch_dataset, path): """Test the TPC-H queries.""" # Read the SQL to run. with open(path, "rb") as file: sql_bytes = file.read() sql = sql_bytes.decode('utf-8') - spark_session_with_customer_database.sql(sql).collect() + spark_session_with_tpch_dataset.sql(sql).collect() diff --git a/src/gateway/tests/test_tpch_with_dataframe_api.py b/src/gateway/tests/test_tpch_with_dataframe_api.py new file mode 100644 index 0000000..6e7246d --- /dev/null +++ b/src/gateway/tests/test_tpch_with_dataframe_api.py @@ -0,0 +1,662 @@ +# SPDX-License-Identifier: Apache-2.0 +"""TPC-H Dataframe tests for the Spark to Substrait Gateway server.""" +import datetime + +import pyspark +from pyspark import Row +from pyspark.sql.functions import avg, col, count, countDistinct, desc, try_sum, when +from pyspark.testing import assertDataFrameEqual + + +class TestTpchWithDataFrameAPI: + """Runs the TPC-H standard test suite against the dataframe side of SparkConnect.""" + + # pylint: disable=singleton-comparison + def test_query_01(self, spark_session_with_tpch_dataset): + expected = [ + Row(l_returnflag='A', l_linestatus='F', sum_qty=37734107.00, + sum_base_price=56586554400.73, sum_disc_price=53758257134.87, + sum_charge=55909065222.83, avg_qty=25.52, + avg_price=38273.13, avg_disc=0.05, count_order=1478493), + ] + + lineitem = spark_session_with_tpch_dataset.table('lineitem') + outcome = lineitem.filter(col('l_shipdate') <= '1998-09-02').groupBy('l_returnflag', + 'l_linestatus').agg( + try_sum('l_quantity').alias('sum_qty'), + try_sum('l_extendedprice').alias('sum_base_price'), + try_sum(col('l_extendedprice') * (1 - col('l_discount'))).alias('sum_disc_price'), + try_sum(col('l_extendedprice') * (1 - col('l_discount')) * (1 + col('l_tax'))).alias( + 'sum_charge'), + avg('l_quantity').alias('avg_qty'), + avg('l_extendedprice').alias('avg_price'), + avg('l_discount').alias('avg_disc'), + count('*').alias('count_order')) + + sorted_outcome = outcome.sort('l_returnflag', 'l_linestatus').limit(1).collect() + assertDataFrameEqual(sorted_outcome, expected, atol=1e-2) + + def test_query_02(self, spark_session_with_tpch_dataset): + expected = [ + Row(s_acctbal=9938.53, s_name='Supplier#000005359', n_name='UNITED KINGDOM', + p_partkey=185358, p_mfgr='Manufacturer#4', s_address='QKuHYh,vZGiwu2FWEJoLDx04', + s_phone='33-429-790-6131', + s_comment='uriously regular requests hag'), + Row(s_acctbal=9937.84, s_name='Supplier#000005969', n_name='ROMANIA', + p_partkey=108438, p_mfgr='Manufacturer#1', + s_address='ANDENSOSmk,miq23Xfb5RWt6dvUcvt6Qa', s_phone='29-520-692-3537', + s_comment='efully express instructions. regular requests against the slyly fin'), + ] + + part = spark_session_with_tpch_dataset.table('part') + supplier = spark_session_with_tpch_dataset.table('supplier') + partsupp = spark_session_with_tpch_dataset.table('partsupp') + nation = spark_session_with_tpch_dataset.table('nation') + region = spark_session_with_tpch_dataset.table('region') + + europe = region.filter(col('r_name') == 'EUROPE').join( + nation, col('r_regionkey') == col('n_regionkey')).join( + supplier, col('n_nationkey') == col('s_nationkey')).join( + partsupp, col('s_suppkey') == col('ps_suppkey')) + + brass = part.filter((col('p_size') == 15) & (col('p_type').endswith('BRASS'))).join( + europe, col('ps_partkey') == col('p_partkey')) + + minCost = brass.groupBy(col('ps_partkey')).agg( + pyspark.sql.functions.min('ps_supplycost').alias('min')) + + outcome = brass.join(minCost, brass.ps_partkey == minCost.ps_partkey).filter( + col('ps_supplycost') == col('min')).select('s_acctbal', 's_name', 'n_name', 'p_partkey', + 'p_mfgr', 's_address', 's_phone', + 's_comment') + + sorted_outcome = outcome.sort( + desc('s_acctbal'), 'n_name', 's_name', 'p_partkey').limit(2).collect() + assertDataFrameEqual(sorted_outcome, expected, atol=1e-2) + + def test_query_03(self, spark_session_with_tpch_dataset): + expected = [ + Row(l_orderkey=2456423, revenue=406181.01, o_orderdate=datetime.date(1995, 3, 5), + o_shippriority=0), + Row(l_orderkey=3459808, revenue=405838.70, o_orderdate=datetime.date(1995, 3, 4), + o_shippriority=0), + Row(l_orderkey=492164, revenue=390324.06, o_orderdate=datetime.date(1995, 2, 19), + o_shippriority=0), + Row(l_orderkey=1188320, revenue=384537.94, o_orderdate=datetime.date(1995, 3, 9), + o_shippriority=0), + Row(l_orderkey=2435712, revenue=378673.06, o_orderdate=datetime.date(1995, 2, 26), + o_shippriority=0), + ] + + customer = spark_session_with_tpch_dataset.table('customer') + lineitem = spark_session_with_tpch_dataset.table('lineitem') + orders = spark_session_with_tpch_dataset.table('orders') + + fcust = customer.filter(col('c_mktsegment') == 'BUILDING') + forders = orders.filter(col('o_orderdate') < '1995-03-15') + flineitems = lineitem.filter(lineitem.l_shipdate > '1995-03-15') + + outcome = fcust.join(forders, col('c_custkey') == forders.o_custkey).select( + 'o_orderkey', 'o_orderdate', 'o_shippriority').join( + flineitems, col('o_orderkey') == flineitems.l_orderkey).select( + 'l_orderkey', + (col('l_extendedprice') * (1 - col('l_discount'))).alias('volume'), + 'o_orderdate', + 'o_shippriority').groupBy('l_orderkey', 'o_orderdate', 'o_shippriority').agg( + try_sum('volume').alias('revenue')).select( + 'l_orderkey', 'revenue', 'o_orderdate', 'o_shippriority') + + sorted_outcome = outcome.sort(desc('revenue'), 'o_orderdate').limit(5).collect() + assertDataFrameEqual(sorted_outcome, expected, atol=1e-2) + + def test_query_04(self, spark_session_with_tpch_dataset): + expected = [ + Row(o_orderpriority='1-URGENT', order_count=10594), + Row(o_orderpriority='2-HIGH', order_count=10476), + Row(o_orderpriority='3-MEDIUM', order_count=10410), + Row(o_orderpriority='4-NOT SPECIFIED', order_count=10556), + Row(o_orderpriority='5-LOW', order_count=10487), + ] + + orders = spark_session_with_tpch_dataset.table('orders') + lineitem = spark_session_with_tpch_dataset.table('lineitem') + + forders = orders.filter( + (col('o_orderdate') >= '1993-07-01') & (col('o_orderdate') < '1993-10-01')) + flineitems = lineitem.filter(col('l_commitdate') < col('l_receiptdate')).select( + 'l_orderkey').distinct() + + outcome = flineitems.join( + forders, + col('l_orderkey') == col('o_orderkey')).groupBy('o_orderpriority').agg( + count('o_orderpriority').alias('order_count')) + + sorted_outcome = outcome.sort('o_orderpriority').collect() + assertDataFrameEqual(sorted_outcome, expected, atol=1e-2) + + def test_query_05(self, spark_session_with_tpch_dataset): + expected = [ + Row(n_name='INDONESIA', revenue=55502041.17), + Row(n_name='VIETNAM', revenue=55295087.00), + Row(n_name='CHINA', revenue=53724494.26), + Row(n_name='INDIA', revenue=52035512.00), + Row(n_name='JAPAN', revenue=45410175.70), + ] + + customer = spark_session_with_tpch_dataset.table('customer') + orders = spark_session_with_tpch_dataset.table('orders') + lineitem = spark_session_with_tpch_dataset.table('lineitem') + nation = spark_session_with_tpch_dataset.table('nation') + region = spark_session_with_tpch_dataset.table('region') + supplier = spark_session_with_tpch_dataset.table('supplier') + + forders = orders.filter(col('o_orderdate') >= '1994-01-01').filter( + col('o_orderdate') < '1995-01-01') + + outcome = region.filter(col('r_name') == 'ASIA').join( # r_name = 'ASIA' + nation, col('r_regionkey') == col('n_regionkey')).join( + supplier, col('n_nationkey') == col('s_nationkey')).join( + lineitem, col('s_suppkey') == col('l_suppkey')).select( + 'n_name', 'l_extendedprice', 'l_discount', 'l_quantity', 'l_orderkey', + 's_nationkey').join(forders, col('l_orderkey') == forders.o_orderkey).join( + customer, (col('o_custkey') == col('c_custkey')) & ( + col('s_nationkey') == col('c_nationkey'))).select( + 'n_name', + (col('l_extendedprice') * (1 - col('l_discount'))).alias('volume')).groupBy( + 'n_name').agg(try_sum('volume').alias('revenue')) + + sorted_outcome = outcome.sort('revenue').collect() + + assertDataFrameEqual(sorted_outcome, expected, atol=1e-2) + + def test_query_06(self, spark_session_with_tpch_dataset): + expected = [ + Row(revenue=123141078.23), + ] + + lineitem = spark_session_with_tpch_dataset.table('lineitem') + + outcome = lineitem.filter((col('l_shipdate') >= '1994-01-01') & + (col('l_shipdate') < '1995-01-01') & + (col('l_discount') >= 0.05) & + (col('l_discount') <= 0.07) & + (col('l_quantity') < 24)).agg( + try_sum(col('l_extendedprice') * col('l_discount'))).alias('revenue') + + assertDataFrameEqual(outcome, expected, atol=1e-2) + + def test_query_07(self, spark_session_with_tpch_dataset): + expected = [ + Row(supp_nation='FRANCE', cust_nation='GERMANY', l_year='1995', revenue=54639732.73), + Row(supp_nation='FRANCE', cust_nation='GERMANY', l_year='1996', revenue=54633083.31), + Row(supp_nation='GERMANY', cust_nation='FRANCE', l_year='1995', revenue=52531746.67), + Row(supp_nation='GERMANY', cust_nation='FRANCE', l_year='1996', revenue=52520549.02), + ] + + customer = spark_session_with_tpch_dataset.table('customer') + orders = spark_session_with_tpch_dataset.table('orders') + lineitem = spark_session_with_tpch_dataset.table('lineitem') + supplier = spark_session_with_tpch_dataset.table('supplier') + nation = spark_session_with_tpch_dataset.table('nation') + + fnation = nation.filter((nation.n_name == 'FRANCE') | (nation.n_name == 'GERMANY')) + fline = lineitem.filter( + (col('l_shipdate') >= '1995-01-01') & (col('l_shipdate') <= '1996-12-31')) + + suppNation = fnation.join(supplier, col('n_nationkey') == col('s_nationkey')).join( + fline, col('s_suppkey') == col('l_suppkey')).select( + col('n_name').alias('supp_nation'), 'l_orderkey', 'l_extendedprice', 'l_discount', + 'l_shipdate') + + outcome = fnation.join(customer, col('n_nationkey') == col('c_nationkey')).join( + orders, col('c_custkey') == col('o_custkey')).select( + col('n_name').alias('cust_nation'), 'o_orderkey').join( + suppNation, col('o_orderkey') == suppNation.l_orderkey).filter( + (col('supp_nation') == 'FRANCE') & (col('cust_nation') == 'GERMANY') | ( + col('supp_nation') == 'GERMANY') & (col('cust_nation') == 'FRANCE')).select( + 'supp_nation', 'cust_nation', col('l_shipdate').substr(0, 4).alias('l_year'), + (col('l_extendedprice') * (1 - col('l_discount'))).alias('volume')).groupBy( + 'supp_nation', 'cust_nation', 'l_year').agg( + try_sum('volume').alias('revenue')) + + sorted_outcome = outcome.sort('supp_nation', 'cust_nation', 'l_year').collect() + assertDataFrameEqual(sorted_outcome, expected, atol=1e-2) + + def test_query_08(self, spark_session_with_tpch_dataset): + expected = [ + Row(o_year='1995', mkt_share=0.03), + Row(o_year='1996', mkt_share=0.04), + ] + + customer = spark_session_with_tpch_dataset.table('customer') + orders = spark_session_with_tpch_dataset.table('orders') + lineitem = spark_session_with_tpch_dataset.table('lineitem') + nation = spark_session_with_tpch_dataset.table('nation') + part = spark_session_with_tpch_dataset.table('part') + region = spark_session_with_tpch_dataset.table('region') + supplier = spark_session_with_tpch_dataset.table('supplier') + + fregion = region.filter(col('r_name') == 'AMERICA') + forder = orders.filter((col('o_orderdate') >= '1995-01-01') & ( + col('o_orderdate') <= '1996-12-31')) + fpart = part.filter(col('p_type') == 'ECONOMY ANODIZED STEEL') + + nat = nation.join(supplier, col('n_nationkey') == col('s_nationkey')) + + line = lineitem.select( + 'l_partkey', 'l_suppkey', 'l_orderkey', + (col('l_extendedprice') * (1 - col('l_discount'))).alias( + 'volume')).join( + fpart, col('l_partkey') == fpart.p_partkey).join( + nat, col('l_suppkey') == nat.s_suppkey) + + outcome = nation.join(fregion, col('n_regionkey') == fregion.r_regionkey).select( + 'n_nationkey', 'n_name').join(customer, + col('n_nationkey') == col('c_nationkey')).select( + 'c_custkey').join(forder, col('c_custkey') == col('o_custkey')).select( + 'o_orderkey', 'o_orderdate').join(line, col('o_orderkey') == line.l_orderkey).select( + col('n_name'), col('o_orderdate').substr(0, 4).alias('o_year'), + col('volume')).withColumn('case_volume', + when(col('n_name') == 'BRAZIL', col('volume')).otherwise( + 0)).groupBy('o_year').agg( + (try_sum('case_volume') / try_sum('volume')).alias('mkt_share')) + + sorted_outcome = outcome.sort('o_year').collect() + assertDataFrameEqual(sorted_outcome, expected, atol=1e-2) + + def test_query_09(self, spark_session_with_tpch_dataset): + # TODO -- Verify the corretness of these results against another version of the dataset. + expected = [ + Row(n_name='ARGENTINA', o_year='1998', sum_profit=28341663.78), + Row(n_name='ARGENTINA', o_year='1997', sum_profit=47143964.12), + Row(n_name='ARGENTINA', o_year='1996', sum_profit=45255278.60), + Row(n_name='ARGENTINA', o_year='1995', sum_profit=45631769.21), + Row(n_name='ARGENTINA', o_year='1994', sum_profit=48268856.35), + ] + + orders = spark_session_with_tpch_dataset.table('orders') + lineitem = spark_session_with_tpch_dataset.table('lineitem') + nation = spark_session_with_tpch_dataset.table('nation') + part = spark_session_with_tpch_dataset.table('part') + partsupp = spark_session_with_tpch_dataset.table('partsupp') + supplier = spark_session_with_tpch_dataset.table('supplier') + + linePart = part.filter(col('p_name').contains('green')).join( + lineitem, col('p_partkey') == lineitem.l_partkey) + natSup = nation.join(supplier, col('n_nationkey') == supplier.s_nationkey) + + outcome = linePart.join(natSup, col('l_suppkey') == natSup.s_suppkey).join( + partsupp, (col('l_suppkey') == partsupp.ps_suppkey) & ( + col('l_partkey') == partsupp.ps_partkey)).join( + orders, col('l_orderkey') == orders.o_orderkey).select( + 'n_name', col('o_orderdate').substr(0, 4).alias('o_year'), + (col('l_extendedprice') * (1 - col('l_discount')) - ( + col('ps_supplycost') * col('l_quantity'))).alias('amount')).groupBy( + 'n_name', 'o_year').agg(try_sum('amount').alias('sum_profit')) + + sorted_outcome = outcome.sort('n_name', desc('o_year')).limit(5).collect() + assertDataFrameEqual(sorted_outcome, expected, atol=1e-2) + + def test_query_10(self, spark_session_with_tpch_dataset): + expected = [ + Row(c_custkey=57040, c_name='Customer#000057040', revenue=734235.25, + c_acctbal=632.87, n_name='JAPAN', c_address='Eioyzjf4pp', + c_phone='22-895-641-3466', + c_comment='sits. slyly regular requests sleep alongside of the regular inst'), + Row(c_custkey=143347, c_name='Customer#000143347', revenue=721002.69, + c_acctbal=2557.47, n_name='EGYPT', c_address='1aReFYv,Kw4', + c_phone='14-742-935-3718', + c_comment='ggle carefully enticing requests. final deposits use bold, bold ' + 'pinto beans. ironic, idle re'), + ] + + customer = spark_session_with_tpch_dataset.table('customer') + lineitem = spark_session_with_tpch_dataset.table('lineitem') + nation = spark_session_with_tpch_dataset.table('nation') + orders = spark_session_with_tpch_dataset.table('orders') + + flineitem = lineitem.filter(col('l_returnflag') == 'R') + + outcome = orders.filter( + (col('o_orderdate') >= '1993-10-01') & (col('o_orderdate') < '1994-01-01')).join( + customer, col('o_custkey') == customer.c_custkey).join( + nation, col('c_nationkey') == nation.n_nationkey).join( + flineitem, col('o_orderkey') == flineitem.l_orderkey).select( + 'c_custkey', 'c_name', + (col('l_extendedprice') * (1 - col('l_discount'))).alias('volume'), + 'c_acctbal', 'n_name', 'c_address', 'c_phone', 'c_comment').groupBy( + 'c_custkey', 'c_name', 'c_acctbal', 'c_phone', 'n_name', 'c_address', 'c_comment').agg( + try_sum('volume').alias('revenue')).select( + 'c_custkey', 'c_name', 'revenue', 'c_acctbal', 'n_name', 'c_address', 'c_phone', + 'c_comment') + + sorted_outcome = outcome.sort(desc('revenue')).limit(2).collect() + assertDataFrameEqual(sorted_outcome, expected, atol=1e-2) + + def test_query_11(self, spark_session_with_tpch_dataset): + expected = [ + Row(ps_partkey=129760, value=17538456.86), + Row(ps_partkey=166726, value=16503353.92), + Row(ps_partkey=191287, value=16474801.97), + Row(ps_partkey=161758, value=16101755.54), + Row(ps_partkey=34452, value=15983844.72), + ] + + nation = spark_session_with_tpch_dataset.table('nation') + partsupp = spark_session_with_tpch_dataset.table('partsupp') + supplier = spark_session_with_tpch_dataset.table('supplier') + + tmp = nation.filter(col('n_name') == 'GERMANY').join( + supplier, col('n_nationkey') == supplier.s_nationkey).select( + 's_suppkey').join(partsupp, col('s_suppkey') == partsupp.ps_suppkey).select( + 'ps_partkey', (col('ps_supplycost') * col('ps_availqty')).alias('value')) + + sumRes = tmp.agg(try_sum('value').alias('total_value')) + + outcome = tmp.groupBy('ps_partkey').agg( + (try_sum('value')).alias('part_value')).join( + sumRes, col('part_value') > col('total_value') * 0.0001) + + sorted_outcome = outcome.sort(desc('part_value')).limit(5).collect() + assertDataFrameEqual(sorted_outcome, expected, atol=1e-2) + + def test_query_12(self, spark_session_with_tpch_dataset): + expected = [ + Row(l_shipmode='MAIL', high_line_count=6202, low_line_count=9324), + Row(l_shipmode='SHIP', high_line_count=6200, low_line_count=9262), + ] + + lineitem = spark_session_with_tpch_dataset.table('lineitem') + orders = spark_session_with_tpch_dataset.table('orders') + + outcome = lineitem.filter( + (col('l_shipmode') == 'MAIL') | (col('l_shipmode') == 'SHIP')).filter( + (col('l_commitdate') < col('l_receiptdate')) & + (col('l_shipdate') < col('l_commitdate')) & + (col('l_receiptdate') >= '1994-01-01') & (col('l_receiptdate') < '1995-01-01')).join( + orders, + col('l_orderkey') == orders.o_orderkey).select( + 'l_shipmode', 'o_orderpriority').groupBy('l_shipmode').agg( + count( + when((col('o_orderpriority') == '1-URGENT') | (col('o_orderpriority') == '2-HIGH'), + True)).alias('high_line_count'), + count( + when((col('o_orderpriority') != '1-URGENT') & (col('o_orderpriority') != '2-HIGH'), + True)).alias('low_line_count')) + + sorted_outcome = outcome.sort('l_shipmode').collect() + assertDataFrameEqual(sorted_outcome, expected, atol=1e-2) + + def test_query_13(self, spark_session_with_tpch_dataset): + # TODO -- Verify the corretness of these results against another version of the dataset. + expected = [ + Row(c_count=9, custdist=6641), + Row(c_count=10, custdist=6532), + Row(c_count=11, custdist=6014), + ] + + customer = spark_session_with_tpch_dataset.table('customer') + orders = spark_session_with_tpch_dataset.table('orders') + + outcome = customer.join( + orders, (col('c_custkey') == orders.o_custkey) & ( + ~col('o_comment').rlike('.*special.*requests.*')), 'left_outer').groupBy( + 'o_custkey').agg(count('o_orderkey').alias('c_count')).groupBy( + 'c_count').agg(count('o_custkey').alias('custdist')) + + sorted_outcome = outcome.sort(desc('custdist'), desc('c_count')).limit(3).collect() + assertDataFrameEqual(sorted_outcome, expected, atol=1e-2) + + def test_query_14(self, spark_session_with_tpch_dataset): + expected = [ + Row(promo_revenue=16.38), + ] + + lineitem = spark_session_with_tpch_dataset.table('lineitem') + part = spark_session_with_tpch_dataset.table('part') + + outcome = part.join(lineitem, (col('l_partkey') == col('p_partkey')) & + (col('l_shipdate') >= '1995-09-01') & + (col('l_shipdate') < '1995-10-01')).select( + 'p_type', (col('l_extendedprice') * (1 - col('l_discount'))).alias('value')).agg( + try_sum(when(col('p_type').contains('PROMO'), col('value'))) * 100 / try_sum( + col('value')) + ).alias('promo_revenue') + + assertDataFrameEqual(outcome, expected, atol=1e-2) + + def test_query_15(self, spark_session_with_tpch_dataset): + expected = [ + Row(s_suppkey=8449, s_name='Supplier#000008449', s_address='Wp34zim9qYFbVctdW'), + ] + + lineitem = spark_session_with_tpch_dataset.table('lineitem') + supplier = spark_session_with_tpch_dataset.table('supplier') + + revenue = lineitem.filter((col('l_shipdate') >= '1996-01-01') & + (col('l_shipdate') < '1996-04-01')).select( + 'l_suppkey', (col('l_extendedprice') * (1 - col('l_discount'))).alias('value')).groupBy( + 'l_suppkey').agg(try_sum('value').alias('total')) + + outcome = revenue.agg(pyspark.sql.functions.max(col('total')).alias('max_total')).join( + revenue, col('max_total') == revenue.total).join( + supplier, col('l_suppkey') == supplier.s_suppkey).select( + 's_suppkey', 's_name', 's_address', 's_phone', 'total') + + sorted_outcome = outcome.sort('s_suppkey').limit(1).collect() + assertDataFrameEqual(sorted_outcome, expected, atol=1e-2) + + def test_query_16(self, spark_session_with_tpch_dataset): + expected = [ + Row(p_brand='Brand#41', p_type='MEDIUM BRUSHED TIN', p_size=3, supplier_cnt=28), + Row(p_brand='Brand#54', p_type='STANDARD BRUSHED COPPER', p_size=14, supplier_cnt=27), + Row(p_brand='Brand#11', p_type='STANDARD BRUSHED TIN', p_size=23, supplier_cnt=24), + ] + + part = spark_session_with_tpch_dataset.table('part') + partsupp = spark_session_with_tpch_dataset.table('partsupp') + supplier = spark_session_with_tpch_dataset.table('supplier') + + fparts = part.filter((col('p_brand') != 'Brand#45') & + (~col('p_type').startswith('MEDIUM POLISHED')) & + (col('p_size').isin([3, 14, 23, 45, 49, 9, 19, 36]))).select( + 'p_partkey', 'p_brand', 'p_type', 'p_size') + + outcome = supplier.filter(~col('s_comment').rlike('.*Customer.*Complaints.*')).join( + partsupp, col('s_suppkey') == partsupp.ps_suppkey).select( + 'ps_partkey', 'ps_suppkey').join( + fparts, col('ps_partkey') == fparts.p_partkey).groupBy( + 'p_brand', 'p_type', 'p_size').agg(countDistinct('ps_suppkey').alias('supplier_cnt')) + + sorted_outcome = outcome.sort( + desc('supplier_cnt'), 'p_brand', 'p_type', 'p_size').limit(3).collect() + assertDataFrameEqual(sorted_outcome, expected, atol=1e-2) + + def test_query_17(self, spark_session_with_tpch_dataset): + expected = [ + Row(avg_yearly=348406.02), + ] + + lineitem = spark_session_with_tpch_dataset.table('lineitem') + part = spark_session_with_tpch_dataset.table('part') + + fpart = part.filter( + (col('p_brand') == 'Brand#23') & (col('p_container') == 'MED BOX')).select( + 'p_partkey').join(lineitem, col('p_partkey') == lineitem.l_partkey, 'left_outer') + + outcome = fpart.groupBy('p_partkey').agg( + (avg('l_quantity') * 0.2).alias('avg_quantity')).select( + col('p_partkey').alias('key'), 'avg_quantity').join( + fpart, col('key') == fpart.p_partkey).filter( + col('l_quantity') < col('avg_quantity')).agg( + try_sum('l_extendedprice') / 7).alias('avg_yearly') + + assertDataFrameEqual(outcome, expected, atol=1e-2) + + def test_query_18(self, spark_session_with_tpch_dataset): + expected = [ + Row(c_name='Customer#000128120', c_custkey=128120, o_orderkey=4722021, + o_orderdate=datetime.date(1994, 4, 7), + o_totalprice=544089.09, sum_l_quantity=323.00), + Row(c_name='Customer#000144617', c_custkey=144617, o_orderkey=3043270, + o_orderdate=datetime.date(1997, 2, 12), + o_totalprice=530604.44, sum_l_quantity=317.00), + ] + + customer = spark_session_with_tpch_dataset.table('customer') + lineitem = spark_session_with_tpch_dataset.table('lineitem') + orders = spark_session_with_tpch_dataset.table('orders') + + outcome = lineitem.groupBy('l_orderkey').agg( + try_sum('l_quantity').alias('sum_quantity')).filter( + col('sum_quantity') > 300).select(col('l_orderkey').alias('key'), 'sum_quantity').join( + orders, orders.o_orderkey == col('key')).join( + lineitem, col('o_orderkey') == lineitem.l_orderkey).join( + customer, col('o_custkey') == customer.c_custkey).select( + 'l_quantity', 'c_name', 'c_custkey', 'o_orderkey', 'o_orderdate', + 'o_totalprice').groupBy( + 'c_name', 'c_custkey', 'o_orderkey', 'o_orderdate', 'o_totalprice').agg( + try_sum('l_quantity')) + + sorted_outcome = outcome.sort(desc('o_totalprice'), 'o_orderdate').limit(2).collect() + assertDataFrameEqual(sorted_outcome, expected, atol=1e-2) + + def test_query_19(self, spark_session_with_tpch_dataset): + expected = [ + Row(revenue=3083843.06), + ] + + lineitem = spark_session_with_tpch_dataset.table('lineitem') + part = spark_session_with_tpch_dataset.table('part') + + outcome = part.join(lineitem, col('l_partkey') == col('p_partkey')).filter( + col('l_shipmode').isin(['AIR', 'AIR REG']) & ( + col('l_shipinstruct') == 'DELIVER IN PERSON')).filter( + ((col('p_brand') == 'Brand#12') & ( + col('p_container').isin(['SM CASE', 'SM BOX', 'SM PACK', 'SM PKG'])) & + (col('l_quantity') >= 1) & (col('l_quantity') <= 11) & + (col('p_size') >= 1) & (col('p_size') <= 5)) | + ((col('p_brand') == 'Brand#23') & ( + col('p_container').isin(['MED BAG', 'MED BOX', 'MED PKG', 'MED PACK'])) & + (col('l_quantity') >= 10) & (col('l_quantity') <= 20) & + (col('p_size') >= 1) & (col('p_size') <= 10)) | + ((col('p_brand') == 'Brand#34') & ( + col('p_container').isin(['LG CASE', 'LG BOX', 'LG PACK', 'LG PKG'])) & + (col('l_quantity') >= 20) & (col('l_quantity') <= 30) & + (col('p_size') >= 1) & (col('p_size') <= 15))).select( + (col('l_extendedprice') * (1 - col('l_discount'))).alias('volume')).agg( + try_sum('volume').alias('revenue')) + + assertDataFrameEqual(outcome, expected, atol=1e-2) + + def test_query_20(self, spark_session_with_tpch_dataset): + expected = [ + Row(s_name='Supplier#000000020', s_address='iybAE,RmTymrZVYaFZva2SH,j'), + Row(s_name='Supplier#000000091', s_address='YV45D7TkfdQanOOZ7q9QxkyGUapU1oOWU6q3'), + Row(s_name='Supplier#000000205', s_address='rF uV8d0JNEk'), + ] + + lineitem = spark_session_with_tpch_dataset.table('lineitem') + nation = spark_session_with_tpch_dataset.table('nation') + part = spark_session_with_tpch_dataset.table('part') + partsupp = spark_session_with_tpch_dataset.table('partsupp') + supplier = spark_session_with_tpch_dataset.table('supplier') + + flineitem = lineitem.filter( + (col('l_shipdate') >= '1994-01-01') & (col('l_shipdate') < '1995-01-01')).groupBy( + 'l_partkey', 'l_suppkey').agg( + try_sum(col('l_quantity') * 0.5).alias('sum_quantity')) + + fnation = nation.filter(col('n_name') == 'CANADA') + nat_supp = supplier.select('s_suppkey', 's_name', 's_nationkey', 's_address').join( + fnation, col('s_nationkey') == fnation.n_nationkey) + + outcome = part.filter(col('p_name').startswith('forest')).select('p_partkey').join( + partsupp, col('p_partkey') == partsupp.ps_partkey).join( + flineitem, (col('ps_suppkey') == flineitem.l_suppkey) & ( + col('ps_partkey') == flineitem.l_partkey)).filter( + col('ps_availqty') > col('sum_quantity')).select('ps_suppkey').distinct().join( + nat_supp, col('ps_suppkey') == nat_supp.s_suppkey).select('s_name', 's_address') + + sorted_outcome = outcome.sort('s_name').limit(3).collect() + assertDataFrameEqual(sorted_outcome, expected, atol=1e-2) + + def test_query_21(self, spark_session_with_tpch_dataset): + # TODO -- Verify the corretness of these results against another version of the dataset. + expected = [ + Row(s_name='Supplier#000002095', numwait=26), + Row(s_name='Supplier#000003063', numwait=26), + Row(s_name='Supplier#000006384', numwait=26), + Row(s_name='Supplier#000006450', numwait=26), + Row(s_name='Supplier#000000486', numwait=25), + ] + + lineitem = spark_session_with_tpch_dataset.table('lineitem') + nation = spark_session_with_tpch_dataset.table('nation') + orders = spark_session_with_tpch_dataset.table('orders') + supplier = spark_session_with_tpch_dataset.table('supplier') + + fsupplier = supplier.select('s_suppkey', 's_nationkey', 's_name') + + plineitem = lineitem.select('l_suppkey', 'l_orderkey', 'l_receiptdate', 'l_commitdate') + + flineitem = plineitem.filter(col('l_receiptdate') > col('l_commitdate')) + + line1 = plineitem.groupBy('l_orderkey').agg( + countDistinct('l_suppkey').alias('suppkey_count'), + pyspark.sql.functions.max(col('l_suppkey')).alias('suppkey_max')).select( + col('l_orderkey').alias('key'), 'suppkey_count', 'suppkey_max') + + line2 = flineitem.groupBy('l_orderkey').agg( + countDistinct('l_suppkey').alias('suppkey_count'), + pyspark.sql.functions.max(col('l_suppkey')).alias('suppkey_max')).select( + col('l_orderkey').alias('key'), 'suppkey_count', 'suppkey_max') + + forder = orders.select('o_orderkey', 'o_orderstatus').filter(col('o_orderstatus') == 'F') + + outcome = nation.filter(col('n_name') == 'SAUDI ARABIA').join( + fsupplier, col('n_nationkey') == fsupplier.s_nationkey).join( + flineitem, col('s_suppkey') == flineitem.l_suppkey).join( + forder, col('l_orderkey') == forder.o_orderkey).join( + line1, col('l_orderkey') == line1.key).filter( + (col('suppkey_count') > 1) | + ((col('suppkey_count') == 1) & (col('l_suppkey') == col('suppkey_max')))).select( + 's_name', 'l_orderkey', 'l_suppkey').join( + line2, col('l_orderkey') == line2.key, 'left_outer').select( + 's_name', 'l_orderkey', 'l_suppkey', 'suppkey_count', 'suppkey_max').filter( + (col('suppkey_count') == 1) & (col('l_suppkey') == col('suppkey_max'))).groupBy( + 's_name').agg(count(col('l_suppkey')).alias('numwait')) + + sorted_outcome = outcome.sort(desc('numwait'), 's_name').limit(5).collect() + assertDataFrameEqual(sorted_outcome, expected, atol=1e-2) + + def test_query_22(self, spark_session_with_tpch_dataset): + expected = [ + Row(cntrycode='13', numcust=888, totacctbal=6737713.99), + Row(cntrycode='17', numcust=861, totacctbal=6460573.72), + Row(cntrycode='18', numcust=964, totacctbal=7236687.40), + Row(cntrycode='23', numcust=892, totacctbal=6701457.95), + Row(cntrycode='29', numcust=948, totacctbal=7158866.63), + Row(cntrycode='30', numcust=909, totacctbal=6808436.13), + Row(cntrycode='31', numcust=922, totacctbal=6806670.18), + ] + + customer = spark_session_with_tpch_dataset.table('customer') + orders = spark_session_with_tpch_dataset.table('orders') + + fcustomer = customer.select( + 'c_acctbal', 'c_custkey', (col('c_phone').substr(0, 2)).alias('cntrycode')).filter( + col('cntrycode').isin(['13', '31', '23', '29', '30', '18', '17'])) + + avg_customer = fcustomer.filter(col('c_acctbal') > 0.00).agg( + avg('c_acctbal').alias('avg_acctbal')) + + outcome = orders.groupBy('o_custkey').agg( + count('o_custkey')).select('o_custkey').join( + fcustomer, col('o_custkey') == fcustomer.c_custkey, 'right_outer').filter( + col('o_custkey').isNull()).join(avg_customer).filter( + col('c_acctbal') > col('avg_acctbal')).groupBy('cntrycode').agg( + count('c_custkey').alias('numcust'), try_sum('c_acctbal')) + + sorted_outcome = outcome.sort('cntrycode').collect() + assertDataFrameEqual(sorted_outcome, expected, atol=1e-2) From d27c2bdba8c4448ff7849bee8630001ca373e9e1 Mon Sep 17 00:00:00 2001 From: Richard Tia Date: Fri, 19 Apr 2024 18:36:53 -0400 Subject: [PATCH 21/58] feat: add spark to substrait named table relation conversion with schema resolution (#46) Add spark to substrait named table relation conversion with schema resolution --- src/gateway/backends/adbc_backend.py | 3 ++- src/gateway/backends/duckdb_backend.py | 3 --- src/gateway/converter/spark_to_substrait.py | 27 ++++++++++++++++++--- src/gateway/server.py | 2 ++ 4 files changed, 27 insertions(+), 8 deletions(-) diff --git a/src/gateway/backends/adbc_backend.py b/src/gateway/backends/adbc_backend.py index 874760f..d9772df 100644 --- a/src/gateway/backends/adbc_backend.py +++ b/src/gateway/backends/adbc_backend.py @@ -4,6 +4,7 @@ import duckdb import pyarrow as pa +import pyarrow.parquet as pq from adbc_driver_manager import dbapi from substrait.gen.proto import plan_pb2 @@ -60,7 +61,7 @@ def register_table(self, name: str, path: Path, extension: str = 'parquet') -> N # Sort the files because the later ones don't have enough data to construct a schema. file_paths = sorted([str(fp) for fp in file_paths]) # TODO: Support multiple paths. - reader = pa.parquet.ParquetFile(file_paths[0]) + reader = pq.ParquetFile(file_paths[0]) self._connection.cursor().adbc_ingest(name, reader.iter_batches(), mode="create") def describe_table(self, table_name: str): diff --git a/src/gateway/backends/duckdb_backend.py b/src/gateway/backends/duckdb_backend.py index fb9825c..77521db 100644 --- a/src/gateway/backends/duckdb_backend.py +++ b/src/gateway/backends/duckdb_backend.py @@ -37,9 +37,6 @@ def execute(self, plan: plan_pb2.Plan) -> pa.lib.Table: """Execute the given Substrait plan against DuckDB.""" plan_data = plan.SerializeToString() - # TODO -- Rely on the client to register their own named tables. - self.register_tpch() - try: query_result = self._connection.from_substrait(proto=plan_data) except Exception as err: diff --git a/src/gateway/converter/spark_to_substrait.py b/src/gateway/converter/spark_to_substrait.py index 46a02ba..e656511 100644 --- a/src/gateway/converter/spark_to_substrait.py +++ b/src/gateway/converter/spark_to_substrait.py @@ -10,6 +10,9 @@ import pyspark.sql.connect.proto.expressions_pb2 as spark_exprs_pb2 import pyspark.sql.connect.proto.relations_pb2 as spark_relations_pb2 import pyspark.sql.connect.proto.types_pb2 as spark_types_pb2 +from substrait.gen.proto import algebra_pb2, plan_pb2, type_pb2 +from substrait.gen.proto.extensions import extensions_pb2 + from gateway.backends.backend_options import BackendOptions from gateway.backends.backend_selector import find_backend from gateway.converter.conversion_options import ConversionOptions @@ -39,8 +42,6 @@ strlen, ) from gateway.converter.symbol_table import SymbolTable -from substrait.gen.proto import algebra_pb2, plan_pb2, type_pb2 -from substrait.gen.proto.extensions import extensions_pb2 TABLE_NAME = "my_table" @@ -311,9 +312,27 @@ def convert_expression_to_aggregate_function( func.output_type.CopyFrom(function.output_type) return func - def convert_read_named_table_relation(self, rel: spark_relations_pb2.Read) -> algebra_pb2.Rel: + def convert_read_named_table_relation( + self, + rel: spark_relations_pb2.Read.named_table + ) -> algebra_pb2.Rel: """Convert a read named table relation to a Substrait relation.""" - raise NotImplementedError('named tables are not yet implemented') + table_name = rel.unparsed_identifier + + backend = find_backend(BackendOptions(self._conversion_options.backend.backend, True)) + tpch_location = backend.find_tpch() + backend.register_table(table_name, tpch_location / table_name) + arrow_schema = backend.describe_table(table_name) + schema = self.convert_arrow_schema(arrow_schema) + + symbol = self._symbol_table.get_symbol(self._current_plan_id) + for field_name in schema.names: + symbol.output_fields.append(field_name) + + return algebra_pb2.Rel( + read=algebra_pb2.ReadRel( + base_schema=schema, + named_table=algebra_pb2.ReadRel.NamedTable(names=[table_name]))) def convert_schema(self, schema_str: str) -> type_pb2.NamedStruct | None: """Convert the Spark JSON schema string into a Substrait named type structure.""" diff --git a/src/gateway/server.py b/src/gateway/server.py index c4e1978..84617bb 100644 --- a/src/gateway/server.py +++ b/src/gateway/server.py @@ -109,6 +109,8 @@ def ExecutePlan( raise ValueError(f'Unknown plan type: {request.plan}') _LOGGER.debug(' as Substrait: %s', substrait) backend = find_backend(self._options.backend) + tpch_location = backend.find_tpch() + backend.register_table('customer', tpch_location / 'customer') results = backend.execute(substrait) _LOGGER.debug(' results are: %s', results) From 73a52a3ecc2780dd85c774f383b28d129612deb8 Mon Sep 17 00:00:00 2001 From: David Sisson Date: Fri, 19 Apr 2024 16:54:19 -0700 Subject: [PATCH 22/58] feat: add support for converting spark joins (#47) --- src/gateway/converter/spark_to_substrait.py | 55 +++++++++++++++++++-- 1 file changed, 52 insertions(+), 3 deletions(-) diff --git a/src/gateway/converter/spark_to_substrait.py b/src/gateway/converter/spark_to_substrait.py index e656511..9385029 100644 --- a/src/gateway/converter/spark_to_substrait.py +++ b/src/gateway/converter/spark_to_substrait.py @@ -10,9 +10,6 @@ import pyspark.sql.connect.proto.expressions_pb2 as spark_exprs_pb2 import pyspark.sql.connect.proto.relations_pb2 as spark_relations_pb2 import pyspark.sql.connect.proto.types_pb2 as spark_types_pb2 -from substrait.gen.proto import algebra_pb2, plan_pb2, type_pb2 -from substrait.gen.proto.extensions import extensions_pb2 - from gateway.backends.backend_options import BackendOptions from gateway.backends.backend_selector import find_backend from gateway.converter.conversion_options import ConversionOptions @@ -42,6 +39,8 @@ strlen, ) from gateway.converter.symbol_table import SymbolTable +from substrait.gen.proto import algebra_pb2, plan_pb2, type_pb2 +from substrait.gen.proto.extensions import extensions_pb2 TABLE_NAME = "my_table" @@ -839,6 +838,54 @@ def convert_sql_relation(self, rel: spark_relations_pb2.SQL) -> algebra_pb2.Rel: # TODO -- Renumber all of the functions/extensions in the captured subplan. return plan.relations[0].root.input + def convert_spark_join_type( + self, join_type: spark_relations_pb2.Join.JoinType) -> algebra_pb2.JoinRel.JoinType: + """Convert a Spark join type into a Substrait join type.""" + match join_type: + case spark_relations_pb2.Join.JOIN_TYPE_UNSPECIFIED: + return algebra_pb2.JoinRel.JOIN_TYPE_UNSPECIFIED + case spark_relations_pb2.Join.JOIN_TYPE_INNER: + return algebra_pb2.JoinRel.JOIN_TYPE_INNER + case spark_relations_pb2.Join.JOIN_TYPE_FULL_OUTER: + return algebra_pb2.JoinRel.JOIN_TYPE_OUTER + case spark_relations_pb2.Join.JOIN_TYPE_LEFT_OUTER: + return algebra_pb2.JoinRel.JOIN_TYPE_LEFT + case spark_relations_pb2.Join.JOIN_TYPE_RIGHT_OUTER: + return algebra_pb2.JoinRel.JOIN_TYPE_RIGHT + case spark_relations_pb2.Join.JOIN_TYPE_LEFT_ANTI: + return algebra_pb2.JoinRel.JOIN_TYPE_ANTI + case spark_relations_pb2.Join.JOIN_TYPE_LEFT_SEMI: + return algebra_pb2.JoinRel.JOIN_TYPE_SEMI + case spark_relations_pb2.Join.CROSS: + raise RuntimeError('Internal error: cross joins should be handled elsewhere') + case _: + raise ValueError(f'Unexpected join type: {join_type}') + + def convert_cross_join_relation(self, rel: spark_relations_pb2.Join) -> algebra_pb2.Rel: + """Convert a Spark join relation into a Substrait join relation.""" + join = algebra_pb2.CrossRel(left=self.convert_relation(rel.left), + right=self.convert_relation(rel.right)) + self.update_field_references(rel.left.common.plan_id) + self.update_field_references(rel.right.common.plan_id) + if rel.HasField('join_condition'): + raise ValueError('Cross joins do not support having a join condition.') + join.common.CopyFrom(self.create_common_relation()) + return algebra_pb2.Rel(join=join) + + def convert_join_relation(self, rel: spark_relations_pb2.Join) -> algebra_pb2.Rel: + """Convert a Spark join relation into a Substrait join relation.""" + if rel.join_type == spark_relations_pb2.Join.JOIN_TYPE_CROSS: + return self.convert_cross_join_relation(rel) + join = algebra_pb2.JoinRel(left=self.convert_relation(rel.left), + right=self.convert_relation(rel.right)) + self.update_field_references(rel.left.common.plan_id) + self.update_field_references(rel.right.common.plan_id) + if rel.HasField('join_condition'): + join.expression.CopyFrom(self.convert_expression(rel.join_condition)) + join.type = self.convert_spark_join_type(rel.join_type) + join.common.CopyFrom(self.create_common_relation()) + return algebra_pb2.Rel(join=join) + def convert_relation(self, rel: spark_relations_pb2.Relation) -> algebra_pb2.Rel: """Convert a Spark relation into a Substrait one.""" self._symbol_table.add_symbol(rel.common.plan_id, parent=self._current_plan_id, @@ -866,6 +913,8 @@ def convert_relation(self, rel: spark_relations_pb2.Relation) -> algebra_pb2.Rel result = self.convert_local_relation(rel.local_relation) case 'sql': result = self.convert_sql_relation(rel.sql) + case 'join': + result = self.convert_join_relation(rel.join) case _: raise ValueError( f'Unexpected Spark plan rel_type: {rel.WhichOneof("rel_type")}') From e5df0e75aa5d01e7a66c7565808e4ce46270b759 Mon Sep 17 00:00:00 2001 From: Richard Tia Date: Fri, 19 Apr 2024 20:00:36 -0400 Subject: [PATCH 23/58] fix: remove backend tpch registration and move it to server (#48) --- src/gateway/backends/datafusion_backend.py | 2 -- src/gateway/server.py | 3 +-- 2 files changed, 1 insertion(+), 4 deletions(-) diff --git a/src/gateway/backends/datafusion_backend.py b/src/gateway/backends/datafusion_backend.py index c462839..f921c29 100644 --- a/src/gateway/backends/datafusion_backend.py +++ b/src/gateway/backends/datafusion_backend.py @@ -29,8 +29,6 @@ def execute(self, plan: plan_pb2.Plan) -> pa.lib.Table: """Execute the given Substrait plan against Datafusion.""" import datafusion.substrait - self.register_tpch() - file_groups = ReplaceLocalFilesWithNamedTable().visit_plan(plan) registered_tables = set() for files in file_groups: diff --git a/src/gateway/server.py b/src/gateway/server.py index 84617bb..d57a555 100644 --- a/src/gateway/server.py +++ b/src/gateway/server.py @@ -109,8 +109,7 @@ def ExecutePlan( raise ValueError(f'Unknown plan type: {request.plan}') _LOGGER.debug(' as Substrait: %s', substrait) backend = find_backend(self._options.backend) - tpch_location = backend.find_tpch() - backend.register_table('customer', tpch_location / 'customer') + backend.register_tpch() results = backend.execute(substrait) _LOGGER.debug(' results are: %s', results) From 9a134509ebf70bc44f027af1d1af5e43e9779a70 Mon Sep 17 00:00:00 2001 From: David Sisson Date: Fri, 19 Apr 2024 22:36:33 -0700 Subject: [PATCH 24/58] feat: separate passing DuckDB tests from xfailing ones (#49) This should help prevent the passing tests from breaking during ongoing development. --- src/gateway/tests/conftest.py | 2 +- src/gateway/tests/test_dataframe_api.py | 12 ++++++++++ src/gateway/tests/test_sql_api.py | 11 ++++++++++ .../tests/test_tpch_with_dataframe_api.py | 22 +++++++++++++++++++ 4 files changed, 46 insertions(+), 1 deletion(-) diff --git a/src/gateway/tests/conftest.py b/src/gateway/tests/conftest.py index 5d82e92..8fe1555 100644 --- a/src/gateway/tests/conftest.py +++ b/src/gateway/tests/conftest.py @@ -82,7 +82,7 @@ def schema_users(): @pytest.fixture(scope='session', params=['spark', - pytest.param('gateway-over-duckdb', marks=pytest.mark.xfail), + 'gateway-over-duckdb', pytest.param('gateway-over-datafusion', marks=pytest.mark.xfail( reason='Datafusion Substrait missing in CI'))]) diff --git a/src/gateway/tests/test_dataframe_api.py b/src/gateway/tests/test_dataframe_api.py index 83ee15b..84f1f54 100644 --- a/src/gateway/tests/test_dataframe_api.py +++ b/src/gateway/tests/test_dataframe_api.py @@ -1,11 +1,23 @@ # SPDX-License-Identifier: Apache-2.0 """Tests for the Spark to Substrait Gateway server.""" +import pytest from hamcrest import assert_that, equal_to from pyspark import Row from pyspark.sql.functions import col, substring from pyspark.testing import assertDataFrameEqual +@pytest.fixture(autouse=True) +def mark_dataframe_tests_as_xfail(request): + """Marks a subset of tests as expected to be fail.""" + source = request.getfixturevalue('source') + originalname = request.keywords.node.originalname + if source == 'gateway-over-duckdb' and (originalname == 'test_with_column' or + originalname == 'test_cast'): + request.node.add_marker( + pytest.mark.xfail(reason='DuckDB column binding error')) + + # pylint: disable=missing-function-docstring # ruff: noqa: E712 class TestDataFrameAPI: diff --git a/src/gateway/tests/test_sql_api.py b/src/gateway/tests/test_sql_api.py index 37b9a12..4643bdb 100644 --- a/src/gateway/tests/test_sql_api.py +++ b/src/gateway/tests/test_sql_api.py @@ -14,6 +14,17 @@ sql_test_case_names = [p.stem for p in sql_test_case_paths] +@pytest.fixture(autouse=True) +def mark_tests_as_xfail(request): + """Marks a subset of tests as expected to be fail.""" + source = request.getfixturevalue('source') + originalname = request.keywords.node.originalname + if source == 'gateway-over-duckdb' and originalname == 'test_tpch': + path = request.getfixturevalue('path') + if path.stem in ['02', '04', '15', '16', '17', '18', '20', '21', '22']: + request.node.add_marker(pytest.mark.xfail(reason='DuckDB needs Delim join')) + + # pylint: disable=missing-function-docstring # ruff: noqa: E712 class TestSqlAPI: diff --git a/src/gateway/tests/test_tpch_with_dataframe_api.py b/src/gateway/tests/test_tpch_with_dataframe_api.py index 6e7246d..d35bf9f 100644 --- a/src/gateway/tests/test_tpch_with_dataframe_api.py +++ b/src/gateway/tests/test_tpch_with_dataframe_api.py @@ -3,11 +3,33 @@ import datetime import pyspark +import pytest from pyspark import Row from pyspark.sql.functions import avg, col, count, countDistinct, desc, try_sum, when from pyspark.testing import assertDataFrameEqual +@pytest.fixture(autouse=True) +def mark_tests_as_xfail(request): + """Marks a subset of tests as expected to be fail.""" + source = request.getfixturevalue('source') + originalname = request.keywords.node.originalname + if source == 'gateway-over-duckdb' and originalname == 'test_query_01': + request.node.add_marker(pytest.mark.xfail(reason='date32[day] not handled')) + if source == 'gateway-over-duckdb' and originalname in [ + 'test_query_02', 'test_query_03', 'test_query_05', 'test_query_07', 'test_query_08', + 'test_query_09', 'test_query_10', 'test_query_11', 'test_query_12', 'test_query_13', + 'test_query_15', 'test_query_16', 'test_query_17', 'test_query_18', 'test_query_20', + 'test_query_21', 'test_query_22']: + request.node.add_marker(pytest.mark.xfail(reason='AnalyzePlan not implemented')) + if source == 'gateway-over-duckdb' and originalname == 'test_query_04': + request.node.add_marker(pytest.mark.xfail(reason='deduplicate not implemented')) + if source == 'gateway-over-duckdb' and originalname in ['test_query_06', 'test_query_14']: + request.node.add_marker(pytest.mark.xfail(reason='subquery_alias not implemented')) + if source == 'gateway-over-duckdb' and originalname == 'test_query_19': + request.node.add_marker(pytest.mark.xfail(reason='project not implemented')) + + class TestTpchWithDataFrameAPI: """Runs the TPC-H standard test suite against the dataframe side of SparkConnect.""" From bea90e66c557a0f5851ab4f3764cbf3af3a42233 Mon Sep 17 00:00:00 2001 From: David Sisson Date: Mon, 22 Apr 2024 10:34:47 -0700 Subject: [PATCH 25/58] feat: add support for TPC-H query 1 (#50) Adds support for five new functions, implemented unresolved_star expressions for count(), added schema conversion for timestamp and datetime, and fixed the output field name used by groupings in aggregations. --- src/gateway/converter/data/00001.splan | 2 +- src/gateway/converter/spark_functions.py | 20 +++++++++++++++++++ src/gateway/converter/spark_to_substrait.py | 20 ++++++++++++++++--- .../converter/substrait_plan_visitor.py | 2 +- .../tests/test_tpch_with_dataframe_api.py | 2 -- 5 files changed, 39 insertions(+), 7 deletions(-) diff --git a/src/gateway/converter/data/00001.splan b/src/gateway/converter/data/00001.splan index c736777..d3f510b 100644 --- a/src/gateway/converter/data/00001.splan +++ b/src/gateway/converter/data/00001.splan @@ -399,7 +399,7 @@ relations { count: 10 } } - names: "grouping" + names: "artist_lastfm" names: "# of Listeners" } } diff --git a/src/gateway/converter/spark_functions.py b/src/gateway/converter/spark_functions.py index c4fcb0c..0841220 100644 --- a/src/gateway/converter/spark_functions.py +++ b/src/gateway/converter/spark_functions.py @@ -41,6 +41,10 @@ def __lt__(self, obj) -> bool: '/functions_comparison.yaml', 'equal:str_str', type_pb2.Type( bool=type_pb2.Type.Boolean( nullability=type_pb2.Type.Nullability.NULLABILITY_REQUIRED))), + '<=': ExtensionFunction( + '/functions_comparison.yaml', 'lte:str_str', type_pb2.Type( + bool=type_pb2.Type.Boolean( + nullability=type_pb2.Type.Nullability.NULLABILITY_REQUIRED))), '>=': ExtensionFunction( '/functions_comparison.yaml', 'gte:str_str', type_pb2.Type( bool=type_pb2.Type.Boolean( @@ -49,10 +53,18 @@ def __lt__(self, obj) -> bool: '/functions_comparison.yaml', 'gt:str_str', type_pb2.Type( bool=type_pb2.Type.Boolean( nullability=type_pb2.Type.Nullability.NULLABILITY_REQUIRED))), + '+': ExtensionFunction( + '/functions_arithmetic.yaml', 'add:i64_i64', type_pb2.Type( + i64=type_pb2.Type.I64( + nullability=type_pb2.Type.Nullability.NULLABILITY_REQUIRED))), '-': ExtensionFunction( '/functions_arithmetic.yaml', 'subtract:i64_i64', type_pb2.Type( i64=type_pb2.Type.I64( nullability=type_pb2.Type.Nullability.NULLABILITY_REQUIRED))), + '*': ExtensionFunction( + '/functions_arithmetic.yaml', 'multiply:i64_i64', type_pb2.Type( + i64=type_pb2.Type.I64( + nullability=type_pb2.Type.Nullability.NULLABILITY_REQUIRED))), 'array_contains': ExtensionFunction( '/functions_set.yaml', 'index_in:str_list', type_pb2.Type( bool=type_pb2.Type.Boolean( @@ -61,6 +73,14 @@ def __lt__(self, obj) -> bool: '/functions_arithmetic.yaml', 'sum:int', type_pb2.Type( i32=type_pb2.Type.I32( nullability=type_pb2.Type.Nullability.NULLABILITY_REQUIRED))), + 'try_sum': ExtensionFunction( + '/functions_arithmetic.yaml', 'sum:int', type_pb2.Type( + i32=type_pb2.Type.I32( + nullability=type_pb2.Type.Nullability.NULLABILITY_REQUIRED))), + 'avg': ExtensionFunction( + '/functions_arithmetic.yaml', 'avg:int', type_pb2.Type( + i32=type_pb2.Type.I32( + nullability=type_pb2.Type.Nullability.NULLABILITY_REQUIRED))), 'regexp_extract_all': ExtensionFunction( '/functions_string.yaml', 'regexp_match:str_binary_str', type_pb2.Type( list=type_pb2.Type.List(type=type_pb2.Type(string=type_pb2.Type.String( diff --git a/src/gateway/converter/spark_to_substrait.py b/src/gateway/converter/spark_to_substrait.py index 9385029..bba711a 100644 --- a/src/gateway/converter/spark_to_substrait.py +++ b/src/gateway/converter/spark_to_substrait.py @@ -192,6 +192,10 @@ def convert_unresolved_function( for idx, arg in enumerate(unresolved_function.arguments): if function_def.max_args is not None and idx >= function_def.max_args: break + if unresolved_function.function_name == 'count' and arg.WhichOneof( + 'expr_type') == 'unresolved_star': + # Ignore all the rest of the arguments. + break func.arguments.append( algebra_pb2.FunctionArgument(value=self.convert_expression(arg))) if unresolved_function.is_distinct: @@ -256,7 +260,7 @@ def convert_expression(self, expr: spark_exprs_pb2.Expression) -> algebra_pb2.Ex 'expression_string expression type not supported') case 'unresolved_star': raise NotImplementedError( - 'unresolved_star expression type not supported') + '* expressions are only supported within count aggregations') case 'alias': result = self.convert_alias_expression(expr.alias) case 'cast': @@ -404,6 +408,11 @@ def convert_arrow_schema(self, arrow_schema: pa.Schema) -> type_pb2.NamedStruct: field_type = type_pb2.Type(fp64=type_pb2.Type.FP64(nullability=nullability)) case 'string': field_type = type_pb2.Type(string=type_pb2.Type.String(nullability=nullability)) + case 'timestamp[us]': + field_type = type_pb2.Type( + timestamp=type_pb2.Type.Timestamp(nullability=nullability)) + case 'date32[day]': + field_type = type_pb2.Type(date=type_pb2.Type.Date(nullability=nullability)) case _: raise NotImplementedError(f'Unexpected field type: {field.type}') @@ -539,6 +548,12 @@ def determine_expression_name(self, expr: spark_exprs_pb2.Expression) -> str | N self._seen_generated_names['aggregate_expression'] += 1 return f'aggregate_expression{self._seen_generated_names["aggregate_expression"]}' + def determine_name_for_grouping(self, expr: spark_exprs_pb2.Expression) -> str: + """Determine the field name the grouping should use.""" + if expr.WhichOneof('expr_type') == 'unresolved_attribute': + return expr.unresolved_attribute.unparsed_identifier + return 'grouping' + def convert_aggregate_relation(self, rel: spark_relations_pb2.Aggregate) -> algebra_pb2.Rel: """Convert an aggregate relation into a Substrait relation.""" aggregate = algebra_pb2.AggregateRel(input=self.convert_relation(rel.input)) @@ -549,8 +564,7 @@ def convert_aggregate_relation(self, rel: spark_relations_pb2.Aggregate) -> alge aggregate.groupings.append( algebra_pb2.AggregateRel.Grouping( grouping_expressions=[self.convert_expression(grouping)])) - # TODO -- Use the same field name as what was selected in the grouping. - symbol.generated_fields.append('grouping') + symbol.generated_fields.append(self.determine_name_for_grouping(grouping)) for expr in rel.aggregate_expressions: aggregate.measures.append( algebra_pb2.AggregateRel.Measure( diff --git a/src/gateway/converter/substrait_plan_visitor.py b/src/gateway/converter/substrait_plan_visitor.py index 1a83ede..561ef47 100644 --- a/src/gateway/converter/substrait_plan_visitor.py +++ b/src/gateway/converter/substrait_plan_visitor.py @@ -476,7 +476,7 @@ def visit_expand_field(self, field: algebra_pb2.ExpandRel.ExpandField) -> Any: if field.HasField('consistent_field'): self.visit_expression(field.consistent_field) case _: - raise ValueError(f'Unexpected field type: {field.WhichOneof("field_type")}') + raise ValueError(f'Unexpected expand field type: {field.WhichOneof("field_type")}') def visit_read_relation(self, rel: algebra_pb2.ReadRel) -> Any: """Visits a read relation.""" diff --git a/src/gateway/tests/test_tpch_with_dataframe_api.py b/src/gateway/tests/test_tpch_with_dataframe_api.py index d35bf9f..66af920 100644 --- a/src/gateway/tests/test_tpch_with_dataframe_api.py +++ b/src/gateway/tests/test_tpch_with_dataframe_api.py @@ -14,8 +14,6 @@ def mark_tests_as_xfail(request): """Marks a subset of tests as expected to be fail.""" source = request.getfixturevalue('source') originalname = request.keywords.node.originalname - if source == 'gateway-over-duckdb' and originalname == 'test_query_01': - request.node.add_marker(pytest.mark.xfail(reason='date32[day] not handled')) if source == 'gateway-over-duckdb' and originalname in [ 'test_query_02', 'test_query_03', 'test_query_05', 'test_query_07', 'test_query_08', 'test_query_09', 'test_query_10', 'test_query_11', 'test_query_12', 'test_query_13', From 951755b67332bfa082cff44ef554e808603ae032 Mon Sep 17 00:00:00 2001 From: David Sisson Date: Mon, 22 Apr 2024 20:27:54 -0700 Subject: [PATCH 26/58] feat: finish arrow backend implementation (#51) Since most of the tests are currently failing they've been left disabled for now (no sense wasting resources on testing something known failing). --- src/gateway/backends/arrow_backend.py | 24 +++++++++++++++++++-- src/gateway/backends/backend.py | 4 ++-- src/gateway/converter/conversion_options.py | 7 ++++++ src/gateway/converter/data/00001.splan | 1 + src/gateway/converter/spark_to_substrait.py | 3 ++- src/gateway/server.py | 4 +++- src/gateway/tests/conftest.py | 2 ++ src/gateway/tests/test_dataframe_api.py | 4 ++++ 8 files changed, 43 insertions(+), 6 deletions(-) diff --git a/src/gateway/backends/arrow_backend.py b/src/gateway/backends/arrow_backend.py index 2c8019e..04df38f 100644 --- a/src/gateway/backends/arrow_backend.py +++ b/src/gateway/backends/arrow_backend.py @@ -1,8 +1,10 @@ # SPDX-License-Identifier: Apache-2.0 """Provides access to Acero.""" from pathlib import Path +from typing import ClassVar import pyarrow as pa +import pyarrow.substrait from substrait.gen.proto import plan_pb2 from gateway.backends.backend import Backend @@ -11,13 +13,31 @@ class ArrowBackend(Backend): """Provides access to send Acero Substrait plans.""" + _registered_tables: ClassVar[dict[str, Path]] = {} + + def __init__(self, options): + """Initialize the Datafusion backend.""" + super().__init__(options) + # pylint: disable=import-outside-toplevel def execute(self, plan: plan_pb2.Plan) -> pa.lib.Table: """Execute the given Substrait plan against Acero.""" plan_data = plan.SerializeToString() - reader = pa.substrait.run_query(plan_data) + reader = pa.substrait.run_query(plan_data, table_provider=self._provide_tables) return reader.read_all() def register_table(self, name: str, path: Path) -> None: """Register the given table with the backend.""" - raise NotImplementedError() + self._registered_tables[name] = path + + def drop_table(self, name: str) -> None: + """Asks the backend to drop the given table.""" + if self._registered_tables.get(name): + del self._registered_tables[name] + + def _provide_tables(self, names: list[str], unused_schema) -> pyarrow.Table: + """Provide the tables requested.""" + for name in names: + if name in self._registered_tables: + return pa.Table.from_pandas(pa.read_parquet(self._registered_tables[name])) + raise ValueError(f'Table {names} not found in {self._registered_tables}') diff --git a/src/gateway/backends/backend.py b/src/gateway/backends/backend.py index f3fd3ea..6a159bd 100644 --- a/src/gateway/backends/backend.py +++ b/src/gateway/backends/backend.py @@ -31,7 +31,7 @@ def execute(self, plan: plan_pb2.Plan) -> pa.lib.Table: """Execute the given Substrait plan against Datafusion.""" raise NotImplementedError() - def register_table(self, name: str, path: Path | str, extension: str = 'parquet') -> None: + def register_table(self, name: str, path: Path) -> None: """Register the given table with the backend.""" raise NotImplementedError() @@ -39,7 +39,7 @@ def describe_table(self, name: str): """Asks the backend to describe the given table.""" raise NotImplementedError() - def drop_table(self, name: str): + def drop_table(self, name: str) -> None: """Asks the backend to drop the given table.""" raise NotImplementedError() diff --git a/src/gateway/converter/conversion_options.py b/src/gateway/converter/conversion_options.py index 0064bf6..0f87207 100644 --- a/src/gateway/converter/conversion_options.py +++ b/src/gateway/converter/conversion_options.py @@ -31,6 +31,13 @@ def __init__(self, backend: BackendOptions = None): self.backend = backend +def arrow(): + """Return standard options to connect to the Acero backend.""" + options = ConversionOptions(backend=BackendOptions(Backend.ARROW)) + options.needs_scheme_in_path_uris = True + return options + + def datafusion(): """Return standard options to connect to a Datafusion backend.""" return ConversionOptions(backend=BackendOptions(Backend.DATAFUSION)) diff --git a/src/gateway/converter/data/00001.splan b/src/gateway/converter/data/00001.splan index d3f510b..eeac539 100644 --- a/src/gateway/converter/data/00001.splan +++ b/src/gateway/converter/data/00001.splan @@ -358,6 +358,7 @@ relations { measures { measure { function_reference: 4 + phase: AGGREGATION_PHASE_INITIAL_TO_RESULT output_type { i32 { nullability: NULLABILITY_REQUIRED diff --git a/src/gateway/converter/spark_to_substrait.py b/src/gateway/converter/spark_to_substrait.py index bba711a..a72c3aa 100644 --- a/src/gateway/converter/spark_to_substrait.py +++ b/src/gateway/converter/spark_to_substrait.py @@ -298,7 +298,8 @@ def convert_expression_to_aggregate_function( self, expr: spark_exprs_pb2.Expression) -> algebra_pb2.AggregateFunction: """Convert a SparkConnect expression to a Substrait expression.""" - func = algebra_pb2.AggregateFunction() + func = algebra_pb2.AggregateFunction( + phase=algebra_pb2.AggregationPhase.AGGREGATION_PHASE_INITIAL_TO_RESULT) expression = self.convert_expression(expr) match expression.WhichOneof('rex_type'): case 'scalar_function': diff --git a/src/gateway/server.py b/src/gateway/server.py index d57a555..17917b4 100644 --- a/src/gateway/server.py +++ b/src/gateway/server.py @@ -12,7 +12,7 @@ from pyspark.sql.connect.proto import types_pb2 from gateway.backends.backend_selector import find_backend -from gateway.converter.conversion_options import datafusion, duck_db +from gateway.converter.conversion_options import arrow, datafusion, duck_db from gateway.converter.spark_to_substrait import SparkSubstraitConverter from gateway.converter.sql_to_substrait import convert_sql @@ -160,6 +160,8 @@ def Config(self, request, context): if pair.key == 'spark-substrait-gateway.backend': # Set the server backend for all connections (including ongoing ones). match pair.value: + case 'arrow': + self._options = arrow() case 'duckdb': self._options = duck_db() case 'datafusion': diff --git a/src/gateway/tests/conftest.py b/src/gateway/tests/conftest.py index 8fe1555..3941c0d 100644 --- a/src/gateway/tests/conftest.py +++ b/src/gateway/tests/conftest.py @@ -97,6 +97,8 @@ def spark_session(source): match source: case 'spark': session_generator = _create_local_spark_session() + case 'gateway-over-arrow': + session_generator = _create_gateway_session('arrow') case 'gateway-over-datafusion': session_generator = _create_gateway_session('datafusion') case 'gateway-over-duckdb': diff --git a/src/gateway/tests/test_dataframe_api.py b/src/gateway/tests/test_dataframe_api.py index 84f1f54..5f08b41 100644 --- a/src/gateway/tests/test_dataframe_api.py +++ b/src/gateway/tests/test_dataframe_api.py @@ -23,6 +23,10 @@ def mark_dataframe_tests_as_xfail(request): class TestDataFrameAPI: """Tests of the dataframe side of SparkConnect.""" + def test_collect(self, users_dataframe): + outcome = users_dataframe.collect() + assert len(outcome) == 100 + # pylint: disable=singleton-comparison def test_filter(self, users_dataframe): outcome = users_dataframe.filter(col('paid_for_service') == True).collect() From 4a38309cc4d302d859836ad7ee2ce69a540391b2 Mon Sep 17 00:00:00 2001 From: David Sisson Date: Mon, 22 Apr 2024 20:29:18 -0700 Subject: [PATCH 27/58] feat: implement AnalyzePlan for schema requests of plans (#52) In lieu of computing the schema from the Spark plan this PR merely executes the plan and then returns the final schema. --- src/gateway/converter/spark_functions.py | 20 +++++++++++++++++++ src/gateway/converter/spark_to_substrait.py | 3 ++- src/gateway/server.py | 13 +++++++++++- .../tests/test_tpch_with_dataframe_api.py | 14 +++++++------ 4 files changed, 42 insertions(+), 8 deletions(-) diff --git a/src/gateway/converter/spark_functions.py b/src/gateway/converter/spark_functions.py index 0841220..84dc7f6 100644 --- a/src/gateway/converter/spark_functions.py +++ b/src/gateway/converter/spark_functions.py @@ -49,6 +49,10 @@ def __lt__(self, obj) -> bool: '/functions_comparison.yaml', 'gte:str_str', type_pb2.Type( bool=type_pb2.Type.Boolean( nullability=type_pb2.Type.Nullability.NULLABILITY_REQUIRED))), + '<': ExtensionFunction( + '/functions_comparison.yaml', 'lt:str_str', type_pb2.Type( + bool=type_pb2.Type.Boolean( + nullability=type_pb2.Type.Nullability.NULLABILITY_REQUIRED))), '>': ExtensionFunction( '/functions_comparison.yaml', 'gt:str_str', type_pb2.Type( bool=type_pb2.Type.Boolean( @@ -89,6 +93,10 @@ def __lt__(self, obj) -> bool: '/functions_string.yaml', 'substring:str_int_int', type_pb2.Type( string=type_pb2.Type.String( nullability=type_pb2.Type.Nullability.NULLABILITY_REQUIRED))), + 'endswith': ExtensionFunction( + '/functions_string.yaml', 'ends_with:str_str', type_pb2.Type( + bool=type_pb2.Type.Boolean( + nullability=type_pb2.Type.Nullability.NULLABILITY_REQUIRED))), 'length': ExtensionFunction( '/functions_string.yaml', 'char_length:str', type_pb2.Type( i64=type_pb2.Type.I64( @@ -128,6 +136,18 @@ def __lt__(self, obj) -> bool: 'count': ExtensionFunction( '/functions_aggregate_generic.yaml', 'count:any', type_pb2.Type( i64=type_pb2.Type.I64( + nullability=type_pb2.Type.Nullability.NULLABILITY_REQUIRED))), + 'and': ExtensionFunction( + '/functions_boolean.yaml', 'and:bool_bool', type_pb2.Type( + bool=type_pb2.Type.Boolean( + nullability=type_pb2.Type.Nullability.NULLABILITY_REQUIRED))), + 'or': ExtensionFunction( + '/functions_boolean.yaml', 'or:bool_bool', type_pb2.Type( + bool=type_pb2.Type.Boolean( + nullability=type_pb2.Type.Nullability.NULLABILITY_REQUIRED))), + 'not': ExtensionFunction( + '/functions_boolean.yaml', 'not:bool', type_pb2.Type( + bool=type_pb2.Type.Boolean( nullability=type_pb2.Type.Nullability.NULLABILITY_REQUIRED))) } diff --git a/src/gateway/converter/spark_to_substrait.py b/src/gateway/converter/spark_to_substrait.py index a72c3aa..63289c6 100644 --- a/src/gateway/converter/spark_to_substrait.py +++ b/src/gateway/converter/spark_to_substrait.py @@ -67,7 +67,8 @@ def lookup_function_by_name(self, name: str) -> ExtensionFunction: return self._functions.get(name) func = lookup_spark_function(name, self._conversion_options) if not func: - raise LookupError(f'function name {name} does not have a known Substrait conversion') + raise LookupError( + f'Spark function named {name} does not have a known Substrait conversion.') func.anchor = len(self._functions) + 1 self._functions[name] = func if not self._function_uris.get(func.uri): diff --git a/src/gateway/server.py b/src/gateway/server.py index 17917b4..b063ece 100644 --- a/src/gateway/server.py +++ b/src/gateway/server.py @@ -148,7 +148,18 @@ def ExecutePlan( def AnalyzePlan(self, request, context): """Analyze the given plan and return the results.""" _LOGGER.info('AnalyzePlan: %s', request) - return pb2.AnalyzePlanResponse(session_id=request.session_id) + if request.schema: + convert = SparkSubstraitConverter(self._options) + substrait = convert.convert_plan(request.schema.plan) + backend = find_backend(self._options.backend) + backend.register_tpch() + results = backend.execute(substrait) + _LOGGER.debug(' results are: %s', results) + return pb2.AnalyzePlanResponse( + session_id=request.session_id, + schema=pb2.AnalyzePlanResponse.Schema(schema=convert_pyarrow_schema_to_spark( + results.schema))) + raise NotImplementedError('AnalyzePlan not yet implemented for non-Schema requests.') def Config(self, request, context): """Get or set the configuration of the server.""" diff --git a/src/gateway/tests/test_tpch_with_dataframe_api.py b/src/gateway/tests/test_tpch_with_dataframe_api.py index 66af920..9816052 100644 --- a/src/gateway/tests/test_tpch_with_dataframe_api.py +++ b/src/gateway/tests/test_tpch_with_dataframe_api.py @@ -15,16 +15,18 @@ def mark_tests_as_xfail(request): source = request.getfixturevalue('source') originalname = request.keywords.node.originalname if source == 'gateway-over-duckdb' and originalname in [ - 'test_query_02', 'test_query_03', 'test_query_05', 'test_query_07', 'test_query_08', - 'test_query_09', 'test_query_10', 'test_query_11', 'test_query_12', 'test_query_13', - 'test_query_15', 'test_query_16', 'test_query_17', 'test_query_18', 'test_query_20', - 'test_query_21', 'test_query_22']: - request.node.add_marker(pytest.mark.xfail(reason='AnalyzePlan not implemented')) + 'test_query_02', 'test_query_08', 'test_query_09']: + request.node.add_marker(pytest.mark.xfail(reason='DuckDB binder error')) if source == 'gateway-over-duckdb' and originalname == 'test_query_04': request.node.add_marker(pytest.mark.xfail(reason='deduplicate not implemented')) if source == 'gateway-over-duckdb' and originalname in ['test_query_06', 'test_query_14']: request.node.add_marker(pytest.mark.xfail(reason='subquery_alias not implemented')) - if source == 'gateway-over-duckdb' and originalname == 'test_query_19': + if source == 'gateway-over-duckdb' and originalname == 'test_query_13': + request.node.add_marker(pytest.mark.xfail(reason='function rlike not implemented')) + if source == 'gateway-over-duckdb' and originalname in [ + 'test_query_03', 'test_query_05', 'test_query_07', 'test_query_10', 'test_query_11', + 'test_query_12', 'test_query_15', 'test_query_16', 'test_query_17', 'test_query_18', + 'test_query_19', 'test_query_20', 'test_query_21', 'test_query_22']: request.node.add_marker(pytest.mark.xfail(reason='project not implemented')) From 2bed729751b3df9a72f0dd68689661bd8232c837 Mon Sep 17 00:00:00 2001 From: David Sisson Date: Mon, 22 Apr 2024 23:06:40 -0700 Subject: [PATCH 28/58] feat: implement conversion of spark project relation to substrait (#54) --- src/gateway/converter/spark_functions.py | 4 ++++ src/gateway/converter/spark_to_substrait.py | 21 ++++++++++++++++++- .../tests/test_tpch_with_dataframe_api.py | 13 ++++++------ 3 files changed, 31 insertions(+), 7 deletions(-) diff --git a/src/gateway/converter/spark_functions.py b/src/gateway/converter/spark_functions.py index 84dc7f6..54c5bc3 100644 --- a/src/gateway/converter/spark_functions.py +++ b/src/gateway/converter/spark_functions.py @@ -93,6 +93,10 @@ def __lt__(self, obj) -> bool: '/functions_string.yaml', 'substring:str_int_int', type_pb2.Type( string=type_pb2.Type.String( nullability=type_pb2.Type.Nullability.NULLABILITY_REQUIRED))), + 'startswith': ExtensionFunction( + '/functions_string.yaml', 'starts_with:str_str', type_pb2.Type( + bool=type_pb2.Type.Boolean( + nullability=type_pb2.Type.Nullability.NULLABILITY_REQUIRED))), 'endswith': ExtensionFunction( '/functions_string.yaml', 'ends_with:str_str', type_pb2.Type( bool=type_pb2.Type.Boolean( diff --git a/src/gateway/converter/spark_to_substrait.py b/src/gateway/converter/spark_to_substrait.py index 63289c6..360480f 100644 --- a/src/gateway/converter/spark_to_substrait.py +++ b/src/gateway/converter/spark_to_substrait.py @@ -208,7 +208,6 @@ def convert_unresolved_function( def convert_alias_expression( self, alias: spark_exprs_pb2.Expression.Alias) -> algebra_pb2.Expression: """Convert a Spark alias into a Substrait expression.""" - # TODO -- Utilize the alias name. return self.convert_expression(alias.expr) def convert_type_str(self, spark_type_str: str | None) -> type_pb2.Type: @@ -902,6 +901,24 @@ def convert_join_relation(self, rel: spark_relations_pb2.Join) -> algebra_pb2.Re join.common.CopyFrom(self.create_common_relation()) return algebra_pb2.Rel(join=join) + def convert_project_relation( + self, rel: spark_relations_pb2.Project) -> algebra_pb2.Rel: + """Convert a Spark project relation into a Substrait project relation.""" + input_rel = self.convert_relation(rel.input) + project = algebra_pb2.ProjectRel(input=input_rel) + self.update_field_references(rel.input.common.plan_id) + symbol = self._symbol_table.get_symbol(self._current_plan_id) + for field_number, expr in enumerate(rel.expressions): + project.expressions.append(self.convert_expression(expr)) + if expr.HasField('alias'): + name = expr.alias.name[0] + else: + name = f'generated_field_{field_number}' + symbol.generated_fields.append(name) + symbol.output_fields.append(name) + project.common.CopyFrom(self.create_common_relation()) + return algebra_pb2.Rel(project=project) + def convert_relation(self, rel: spark_relations_pb2.Relation) -> algebra_pb2.Rel: """Convert a Spark relation into a Substrait one.""" self._symbol_table.add_symbol(rel.common.plan_id, parent=self._current_plan_id, @@ -931,6 +948,8 @@ def convert_relation(self, rel: spark_relations_pb2.Relation) -> algebra_pb2.Rel result = self.convert_sql_relation(rel.sql) case 'join': result = self.convert_join_relation(rel.join) + case 'project': + result = self.convert_project_relation(rel.project) case _: raise ValueError( f'Unexpected Spark plan rel_type: {rel.WhichOneof("rel_type")}') diff --git a/src/gateway/tests/test_tpch_with_dataframe_api.py b/src/gateway/tests/test_tpch_with_dataframe_api.py index 9816052..2c8c10c 100644 --- a/src/gateway/tests/test_tpch_with_dataframe_api.py +++ b/src/gateway/tests/test_tpch_with_dataframe_api.py @@ -15,19 +15,20 @@ def mark_tests_as_xfail(request): source = request.getfixturevalue('source') originalname = request.keywords.node.originalname if source == 'gateway-over-duckdb' and originalname in [ - 'test_query_02', 'test_query_08', 'test_query_09']: + 'test_query_02', 'test_query_03', 'test_query_05', 'test_query_07', 'test_query_08', + 'test_query_09', 'test_query_10', 'test_query_11', 'test_query_15', 'test_query_17', + 'test_query_18', 'test_query_20', 'test_query_21', 'test_query_22']: request.node.add_marker(pytest.mark.xfail(reason='DuckDB binder error')) if source == 'gateway-over-duckdb' and originalname == 'test_query_04': request.node.add_marker(pytest.mark.xfail(reason='deduplicate not implemented')) if source == 'gateway-over-duckdb' and originalname in ['test_query_06', 'test_query_14']: request.node.add_marker(pytest.mark.xfail(reason='subquery_alias not implemented')) + if source == 'gateway-over-duckdb' and originalname == 'test_query_12': + request.node.add_marker(pytest.mark.xfail(reason='function when not implemented')) if source == 'gateway-over-duckdb' and originalname == 'test_query_13': request.node.add_marker(pytest.mark.xfail(reason='function rlike not implemented')) - if source == 'gateway-over-duckdb' and originalname in [ - 'test_query_03', 'test_query_05', 'test_query_07', 'test_query_10', 'test_query_11', - 'test_query_12', 'test_query_15', 'test_query_16', 'test_query_17', 'test_query_18', - 'test_query_19', 'test_query_20', 'test_query_21', 'test_query_22']: - request.node.add_marker(pytest.mark.xfail(reason='project not implemented')) + if source == 'gateway-over-duckdb' and originalname in ['test_query_16', 'test_query_19']: + request.node.add_marker(pytest.mark.xfail(reason='function in not implemented')) class TestTpchWithDataFrameAPI: From 63f2be8bf6f135914665202da19cb077723f817c Mon Sep 17 00:00:00 2001 From: Richard Tia Date: Tue, 23 Apr 2024 13:12:55 -0400 Subject: [PATCH 29/58] feat: add tests for read/table apis and schema resolution (#53) --- src/gateway/tests/conftest.py | 8 ++++++++ src/gateway/tests/test_dataframe_api.py | 25 +++++++++++++++++++++++++ 2 files changed, 33 insertions(+) diff --git a/src/gateway/tests/conftest.py b/src/gateway/tests/conftest.py index 3941c0d..44b118e 100644 --- a/src/gateway/tests/conftest.py +++ b/src/gateway/tests/conftest.py @@ -137,3 +137,11 @@ def spark_session_with_tpch_dataset(spark_session: SparkSession, source: str) -> _register_table(spark_session, 'region') _register_table(spark_session, 'supplier') return spark_session + + +@pytest.fixture(scope='function') +def spark_session_with_customer_dataset(spark_session: SparkSession, source: str) -> SparkSession: + """Add the TPC-H dataset to the current spark session.""" + if source == 'spark': + _register_table(spark_session, 'customer') + return spark_session diff --git a/src/gateway/tests/test_dataframe_api.py b/src/gateway/tests/test_dataframe_api.py index 5f08b41..c5e91eb 100644 --- a/src/gateway/tests/test_dataframe_api.py +++ b/src/gateway/tests/test_dataframe_api.py @@ -1,6 +1,7 @@ # SPDX-License-Identifier: Apache-2.0 """Tests for the Spark to Substrait Gateway server.""" import pytest +from gateway.backends.backend import Backend from hamcrest import assert_that, equal_to from pyspark import Row from pyspark.sql.functions import col, substring @@ -101,3 +102,27 @@ def test_cast(self, users_dataframe): 'user_id', substring(col('user_id'), 5, 3).cast('integer')).limit(1).collect() assertDataFrameEqual(outcome, expected) + + def test_data_source_schema(self, spark_session): + location_customer = str(Backend.find_tpch() / 'customer') + schema = spark_session.read.parquet(location_customer).schema + assert len(schema) == 8 + + def test_data_source_filter(self, spark_session): + location_customer = str(Backend.find_tpch() / 'customer') + customer_dataframe = spark_session.read.parquet(location_customer) + outcome = customer_dataframe.filter(col('c_mktsegment') == 'FURNITURE').collect() + assert len(outcome) == 29968 + + def test_table(self, spark_session_with_customer_dataset): + outcome = spark_session_with_customer_dataset.table('customer').collect() + assert len(outcome) == 149999 + + def test_table_schema(self, spark_session_with_customer_dataset): + schema = spark_session_with_customer_dataset.table('customer').schema + assert len(schema) == 8 + + def test_table_filter(self, spark_session_with_customer_dataset): + customer_dataframe = spark_session_with_customer_dataset.table('customer') + outcome = customer_dataframe.filter(col('c_mktsegment') == 'FURNITURE').collect() + assert len(outcome) == 29968 From 9a9eb6e52c2767146c35e53a1bb3a02fb9e22e07 Mon Sep 17 00:00:00 2001 From: David Sisson Date: Tue, 23 Apr 2024 13:49:47 -0700 Subject: [PATCH 30/58] feat: implement subquery_alias (#55) This fixes TPC-H query #6 for DuckDB. --- src/gateway/converter/spark_functions.py | 4 ++++ src/gateway/converter/spark_to_substrait.py | 10 ++++++++++ src/gateway/tests/test_tpch_with_dataframe_api.py | 4 +--- 3 files changed, 15 insertions(+), 3 deletions(-) diff --git a/src/gateway/converter/spark_functions.py b/src/gateway/converter/spark_functions.py index 54c5bc3..170a0b9 100644 --- a/src/gateway/converter/spark_functions.py +++ b/src/gateway/converter/spark_functions.py @@ -69,6 +69,10 @@ def __lt__(self, obj) -> bool: '/functions_arithmetic.yaml', 'multiply:i64_i64', type_pb2.Type( i64=type_pb2.Type.I64( nullability=type_pb2.Type.Nullability.NULLABILITY_REQUIRED))), + '/': ExtensionFunction( + '/functions_arithmetic.yaml', 'divide:i64_i64', type_pb2.Type( + i64=type_pb2.Type.I64( + nullability=type_pb2.Type.Nullability.NULLABILITY_REQUIRED))), 'array_contains': ExtensionFunction( '/functions_set.yaml', 'index_in:str_list', type_pb2.Type( bool=type_pb2.Type.Boolean( diff --git a/src/gateway/converter/spark_to_substrait.py b/src/gateway/converter/spark_to_substrait.py index 360480f..884ffad 100644 --- a/src/gateway/converter/spark_to_substrait.py +++ b/src/gateway/converter/spark_to_substrait.py @@ -919,6 +919,14 @@ def convert_project_relation( project.common.CopyFrom(self.create_common_relation()) return algebra_pb2.Rel(project=project) + def convert_subquery_alias_relation(self, + rel: spark_relations_pb2.SubqueryAlias) -> algebra_pb2.Rel: + """Convert a Spark subquery alias relation into a Substrait relation.""" + # TODO -- Utilize rel.alias somehow. + result = self.convert_relation(rel.input) + self.update_field_references(rel.input.common.plan_id) + return result + def convert_relation(self, rel: spark_relations_pb2.Relation) -> algebra_pb2.Rel: """Convert a Spark relation into a Substrait one.""" self._symbol_table.add_symbol(rel.common.plan_id, parent=self._current_plan_id, @@ -950,6 +958,8 @@ def convert_relation(self, rel: spark_relations_pb2.Relation) -> algebra_pb2.Rel result = self.convert_join_relation(rel.join) case 'project': result = self.convert_project_relation(rel.project) + case 'subquery_alias': + result = self.convert_subquery_alias_relation(rel.subquery_alias) case _: raise ValueError( f'Unexpected Spark plan rel_type: {rel.WhichOneof("rel_type")}') diff --git a/src/gateway/tests/test_tpch_with_dataframe_api.py b/src/gateway/tests/test_tpch_with_dataframe_api.py index 2c8c10c..95765db 100644 --- a/src/gateway/tests/test_tpch_with_dataframe_api.py +++ b/src/gateway/tests/test_tpch_with_dataframe_api.py @@ -21,9 +21,7 @@ def mark_tests_as_xfail(request): request.node.add_marker(pytest.mark.xfail(reason='DuckDB binder error')) if source == 'gateway-over-duckdb' and originalname == 'test_query_04': request.node.add_marker(pytest.mark.xfail(reason='deduplicate not implemented')) - if source == 'gateway-over-duckdb' and originalname in ['test_query_06', 'test_query_14']: - request.node.add_marker(pytest.mark.xfail(reason='subquery_alias not implemented')) - if source == 'gateway-over-duckdb' and originalname == 'test_query_12': + if source == 'gateway-over-duckdb' and originalname in ['test_query_12', 'test_query_14']: request.node.add_marker(pytest.mark.xfail(reason='function when not implemented')) if source == 'gateway-over-duckdb' and originalname == 'test_query_13': request.node.add_marker(pytest.mark.xfail(reason='function rlike not implemented')) From 53dbc61beb62c1df576a30d414fface41b95ef64 Mon Sep 17 00:00:00 2001 From: David Sisson Date: Tue, 23 Apr 2024 17:19:04 -0700 Subject: [PATCH 31/58] feat: provide support for the when condition (#56) The corresponding implementation for the when function in Substrait is the IfThen expression. --- src/gateway/converter/spark_functions.py | 4 ++ src/gateway/converter/spark_to_substrait.py | 56 +++++++++++++++++++ .../tests/test_tpch_with_dataframe_api.py | 7 +-- 3 files changed, 63 insertions(+), 4 deletions(-) diff --git a/src/gateway/converter/spark_functions.py b/src/gateway/converter/spark_functions.py index 170a0b9..6a0666d 100644 --- a/src/gateway/converter/spark_functions.py +++ b/src/gateway/converter/spark_functions.py @@ -105,6 +105,10 @@ def __lt__(self, obj) -> bool: '/functions_string.yaml', 'ends_with:str_str', type_pb2.Type( bool=type_pb2.Type.Boolean( nullability=type_pb2.Type.Nullability.NULLABILITY_REQUIRED))), + 'contains': ExtensionFunction( + '/functions_string.yaml', 'contains:str_str', type_pb2.Type( + bool=type_pb2.Type.Boolean( + nullability=type_pb2.Type.Nullability.NULLABILITY_REQUIRED))), 'length': ExtensionFunction( '/functions_string.yaml', 'char_length:str', type_pb2.Type( i64=type_pb2.Type.I64( diff --git a/src/gateway/converter/spark_to_substrait.py b/src/gateway/converter/spark_to_substrait.py index 884ffad..dc6fc67 100644 --- a/src/gateway/converter/spark_to_substrait.py +++ b/src/gateway/converter/spark_to_substrait.py @@ -182,12 +182,68 @@ def convert_unresolved_attribute( field=field_ref)), root_reference=algebra_pb2.Expression.FieldReference.RootReference())) + def determine_type_of_expression(self, expr: algebra_pb2.Expression) -> type_pb2.Type: + """Determine the type of a Substrait expression.""" + if expr.WhichOneof('rex_type') == 'literal': + match expr.literal.WhichOneof('literal_type'): + case 'boolean': + return type_pb2.Type(bool=type_pb2.Type.Boolean()) + case 'i8': + return type_pb2.Type(i8=type_pb2.Type.I8()) + case 'i16': + return type_pb2.Type(i16=type_pb2.Type.I16()) + case 'i32': + return type_pb2.Type(i32=type_pb2.Type.I32()) + case 'i64': + return type_pb2.Type(i64=type_pb2.Type.I64()) + case 'float': + return type_pb2.Type(fp32=type_pb2.Type.FP32()) + case 'double': + return type_pb2.Type(fp64=type_pb2.Type.FP64()) + case 'string': + return type_pb2.Type(string=type_pb2.Type.String()) + case _: + raise NotImplementedError( + 'Type determination not implemented for literal of type ' + f'{expr.literal.WhichOneof("literal_type")}.') + if expr.WhichOneof('rex_type') == 'scalar_function': + return expr.scalar_function.output_type + if expr.WhichOneof('rex_type') == 'selection': + # TODO -- Figure out how to determine the type of a field reference. + return type_pb2.Type(i32=type_pb2.Type.I32()) + raise NotImplementedError( + 'Type determination not implemented for expressions of type ' + f'{expr.WhichOneof("rex_type")}.') + + def convert_when_function( + self, + when: spark_exprs_pb2.Expression.UnresolvedFunction) -> algebra_pb2.Expression: + """Convert a Spark when function into a Substrait if-then expression.""" + ifthen = algebra_pb2.Expression.IfThen() + for i in range(0, len(when.arguments) - 1, 2): + clause = algebra_pb2.Expression.IfThen.IfClause() + getattr(clause, 'if').CopyFrom(self.convert_expression(when.arguments[i])) + clause.then.CopyFrom(self.convert_expression(when.arguments[i + 1])) + ifthen.ifs.append(clause) + if len(when.arguments) % 2 == 1: + getattr(ifthen, 'else').CopyFrom( + self.convert_expression(when.arguments[len(when.arguments) - 1])) + else: + getattr(ifthen, 'else').CopyFrom( + algebra_pb2.Expression( + literal=algebra_pb2.Expression.Literal( + null=self.determine_type_of_expression(ifthen.ifs[-1].then)))) + + return algebra_pb2.Expression(if_then=ifthen) + def convert_unresolved_function( self, unresolved_function: spark_exprs_pb2.Expression.UnresolvedFunction) -> algebra_pb2.Expression: """Convert a Spark unresolved function into a Substrait scalar function.""" func = algebra_pb2.Expression.ScalarFunction() + if unresolved_function.function_name == 'when': + return self.convert_when_function(unresolved_function) function_def = self.lookup_function_by_name(unresolved_function.function_name) func.function_reference = function_def.anchor for idx, arg in enumerate(unresolved_function.arguments): diff --git a/src/gateway/tests/test_tpch_with_dataframe_api.py b/src/gateway/tests/test_tpch_with_dataframe_api.py index 95765db..0478c80 100644 --- a/src/gateway/tests/test_tpch_with_dataframe_api.py +++ b/src/gateway/tests/test_tpch_with_dataframe_api.py @@ -16,13 +16,12 @@ def mark_tests_as_xfail(request): originalname = request.keywords.node.originalname if source == 'gateway-over-duckdb' and originalname in [ 'test_query_02', 'test_query_03', 'test_query_05', 'test_query_07', 'test_query_08', - 'test_query_09', 'test_query_10', 'test_query_11', 'test_query_15', 'test_query_17', - 'test_query_18', 'test_query_20', 'test_query_21', 'test_query_22']: + 'test_query_09', 'test_query_10', 'test_query_11', 'test_query_12', 'test_query_14', + 'test_query_15', 'test_query_17', 'test_query_18', 'test_query_20', 'test_query_21', + 'test_query_22']: request.node.add_marker(pytest.mark.xfail(reason='DuckDB binder error')) if source == 'gateway-over-duckdb' and originalname == 'test_query_04': request.node.add_marker(pytest.mark.xfail(reason='deduplicate not implemented')) - if source == 'gateway-over-duckdb' and originalname in ['test_query_12', 'test_query_14']: - request.node.add_marker(pytest.mark.xfail(reason='function when not implemented')) if source == 'gateway-over-duckdb' and originalname == 'test_query_13': request.node.add_marker(pytest.mark.xfail(reason='function rlike not implemented')) if source == 'gateway-over-duckdb' and originalname in ['test_query_16', 'test_query_19']: From b67932d74a5cd3ee17c72c179e3a56e72309e881 Mon Sep 17 00:00:00 2001 From: David Sisson Date: Tue, 23 Apr 2024 18:41:57 -0700 Subject: [PATCH 32/58] feat: add format back to the register table method signature (#57) There is one register_table call that uses format (in spark_to_substrait). This fixes that and marks the tests that now pass again as passing. --- src/gateway/backends/adbc_backend.py | 4 +-- src/gateway/backends/arrow_backend.py | 2 +- src/gateway/backends/backend.py | 2 +- src/gateway/backends/datafusion_backend.py | 2 +- src/gateway/backends/duckdb_backend.py | 2 +- src/gateway/tests/conftest.py | 5 ++-- src/gateway/tests/test_dataframe_api.py | 7 +++++ src/gateway/tests/test_sql_api.py | 27 +++++++++++++++++ .../tests/test_tpch_with_dataframe_api.py | 30 +++++++++++-------- 9 files changed, 60 insertions(+), 21 deletions(-) diff --git a/src/gateway/backends/adbc_backend.py b/src/gateway/backends/adbc_backend.py index d9772df..e89e7fb 100644 --- a/src/gateway/backends/adbc_backend.py +++ b/src/gateway/backends/adbc_backend.py @@ -54,9 +54,9 @@ def execute(self, plan: plan_pb2.Plan) -> pa.lib.Table: res = cur.adbc_statement.execute_query() return _import(res[0]).read_all() - def register_table(self, name: str, path: Path, extension: str = 'parquet') -> None: + def register_table(self, name: str, path: Path, file_format: str = 'parquet') -> None: """Register the given table with the backend.""" - file_paths = sorted(Path(path).glob(f'*.{extension}')) + file_paths = sorted(Path(path).glob(f'*.{file_format}')) if len(file_paths) > 0: # Sort the files because the later ones don't have enough data to construct a schema. file_paths = sorted([str(fp) for fp in file_paths]) diff --git a/src/gateway/backends/arrow_backend.py b/src/gateway/backends/arrow_backend.py index 04df38f..eca41fb 100644 --- a/src/gateway/backends/arrow_backend.py +++ b/src/gateway/backends/arrow_backend.py @@ -26,7 +26,7 @@ def execute(self, plan: plan_pb2.Plan) -> pa.lib.Table: reader = pa.substrait.run_query(plan_data, table_provider=self._provide_tables) return reader.read_all() - def register_table(self, name: str, path: Path) -> None: + def register_table(self, name: str, path: Path, file_format: str = 'parquet') -> None: """Register the given table with the backend.""" self._registered_tables[name] = path diff --git a/src/gateway/backends/backend.py b/src/gateway/backends/backend.py index 6a159bd..eb3c4ab 100644 --- a/src/gateway/backends/backend.py +++ b/src/gateway/backends/backend.py @@ -31,7 +31,7 @@ def execute(self, plan: plan_pb2.Plan) -> pa.lib.Table: """Execute the given Substrait plan against Datafusion.""" raise NotImplementedError() - def register_table(self, name: str, path: Path) -> None: + def register_table(self, name: str, path: Path, file_format: str = 'parquet') -> None: """Register the given table with the backend.""" raise NotImplementedError() diff --git a/src/gateway/backends/datafusion_backend.py b/src/gateway/backends/datafusion_backend.py index f921c29..3311cb6 100644 --- a/src/gateway/backends/datafusion_backend.py +++ b/src/gateway/backends/datafusion_backend.py @@ -59,7 +59,7 @@ def execute(self, plan: plan_pb2.Plan) -> pa.lib.Table: for table_name in registered_tables: self._connection.deregister_table(table_name) - def register_table(self, name: str, path: Path) -> None: + def register_table(self, name: str, path: Path, file_format: str = 'parquet') -> None: """Register the given table with the backend.""" files = Backend.expand_location(path) self._connection.register_parquet(name, files[0]) diff --git a/src/gateway/backends/duckdb_backend.py b/src/gateway/backends/duckdb_backend.py index 77521db..c35078c 100644 --- a/src/gateway/backends/duckdb_backend.py +++ b/src/gateway/backends/duckdb_backend.py @@ -44,7 +44,7 @@ def execute(self, plan: plan_pb2.Plan) -> pa.lib.Table: df = query_result.df() return pa.Table.from_pandas(df=df) - def register_table(self, table_name: str, location: Path) -> None: + def register_table(self, table_name: str, location: Path, file_format: str = 'parquet') -> None: """Register the given table with the backend.""" files = Backend.expand_location(location) if not files: diff --git a/src/gateway/tests/conftest.py b/src/gateway/tests/conftest.py index 44b118e..a99d24c 100644 --- a/src/gateway/tests/conftest.py +++ b/src/gateway/tests/conftest.py @@ -83,9 +83,8 @@ def schema_users(): @pytest.fixture(scope='session', params=['spark', 'gateway-over-duckdb', - pytest.param('gateway-over-datafusion', - marks=pytest.mark.xfail( - reason='Datafusion Substrait missing in CI'))]) + 'gateway-over-datafusion', + ]) def source(request) -> str: """Provides the source (backend) to be used.""" return request.param diff --git a/src/gateway/tests/test_dataframe_api.py b/src/gateway/tests/test_dataframe_api.py index c5e91eb..d414ebc 100644 --- a/src/gateway/tests/test_dataframe_api.py +++ b/src/gateway/tests/test_dataframe_api.py @@ -17,6 +17,13 @@ def mark_dataframe_tests_as_xfail(request): originalname == 'test_cast'): request.node.add_marker( pytest.mark.xfail(reason='DuckDB column binding error')) + elif source == 'gateway-over-datafusion': + if originalname in [ + 'test_data_source_schema', 'test_data_source_filter', 'test_table', 'test_table_schema', + 'test_table_filter']: + request.node.add_marker(pytest.mark.xfail(reason='Gateway internal iterating error')) + else: + pytest.importorskip("datafusion.substrait") # pylint: disable=missing-function-docstring diff --git a/src/gateway/tests/test_sql_api.py b/src/gateway/tests/test_sql_api.py index 4643bdb..b6d4553 100644 --- a/src/gateway/tests/test_sql_api.py +++ b/src/gateway/tests/test_sql_api.py @@ -23,6 +23,33 @@ def mark_tests_as_xfail(request): path = request.getfixturevalue('path') if path.stem in ['02', '04', '15', '16', '17', '18', '20', '21', '22']: request.node.add_marker(pytest.mark.xfail(reason='DuckDB needs Delim join')) + if source == 'gateway-over-datafusion': + pytest.importorskip("datafusion.substrait") + if originalname == 'test_count': + request.node.add_marker(pytest.mark.xfail(reason='COUNT() not implemented')) + if originalname in ['test_tpch']: + path = request.getfixturevalue('path') + if path.stem in ['01']: + request.node.add_marker(pytest.mark.xfail(reason='COUNT() not implemented')) + elif path.stem in ['07']: + request.node.add_marker(pytest.mark.xfail(reason='Projection uniqueness error')) + elif path.stem in ['08']: + request.node.add_marker(pytest.mark.xfail(reason='aggregation error')) + elif path.stem in ['09']: + request.node.add_marker(pytest.mark.xfail(reason='instr not implemented')) + elif path.stem in ['11', '15']: + request.node.add_marker(pytest.mark.xfail(reason='first not implemented')) + elif path.stem in ['13']: + request.node.add_marker(pytest.mark.xfail(reason='not rlike not implemented')) + elif path.stem in ['16']: + request.node.add_marker(pytest.mark.xfail(reason='mark join not implemented')) + elif path.stem in ['18']: + request.node.add_marker(pytest.mark.xfail(reason='out of bounds error')) + elif path.stem in ['19']: + request.node.add_marker(pytest.mark.xfail(reason='multiargument OR not supported')) + elif path.stem in ['02', '04', '17', '20', '21', '22']: + request.node.add_marker(pytest.mark.xfail(reason='DataFusion needs Delim join')) + request.node.add_marker(pytest.mark.xfail(reason='Gateway internal iterating error')) # pylint: disable=missing-function-docstring diff --git a/src/gateway/tests/test_tpch_with_dataframe_api.py b/src/gateway/tests/test_tpch_with_dataframe_api.py index 0478c80..2f2793f 100644 --- a/src/gateway/tests/test_tpch_with_dataframe_api.py +++ b/src/gateway/tests/test_tpch_with_dataframe_api.py @@ -14,18 +14,24 @@ def mark_tests_as_xfail(request): """Marks a subset of tests as expected to be fail.""" source = request.getfixturevalue('source') originalname = request.keywords.node.originalname - if source == 'gateway-over-duckdb' and originalname in [ - 'test_query_02', 'test_query_03', 'test_query_05', 'test_query_07', 'test_query_08', - 'test_query_09', 'test_query_10', 'test_query_11', 'test_query_12', 'test_query_14', - 'test_query_15', 'test_query_17', 'test_query_18', 'test_query_20', 'test_query_21', - 'test_query_22']: - request.node.add_marker(pytest.mark.xfail(reason='DuckDB binder error')) - if source == 'gateway-over-duckdb' and originalname == 'test_query_04': - request.node.add_marker(pytest.mark.xfail(reason='deduplicate not implemented')) - if source == 'gateway-over-duckdb' and originalname == 'test_query_13': - request.node.add_marker(pytest.mark.xfail(reason='function rlike not implemented')) - if source == 'gateway-over-duckdb' and originalname in ['test_query_16', 'test_query_19']: - request.node.add_marker(pytest.mark.xfail(reason='function in not implemented')) + if source == 'gateway-over-duckdb': + if originalname in [ + 'test_query_02', 'test_query_03', 'test_query_05', 'test_query_07', 'test_query_08', + 'test_query_09', 'test_query_10', 'test_query_11', 'test_query_12', 'test_query_14', + 'test_query_15', 'test_query_17', 'test_query_18', 'test_query_20', 'test_query_21', + 'test_query_22']: + request.node.add_marker(pytest.mark.xfail(reason='DuckDB binder error')) + if originalname == 'test_query_04': + request.node.add_marker(pytest.mark.xfail(reason='deduplicate not implemented')) + if originalname in ['test_query_12', 'test_query_14']: + request.node.add_marker(pytest.mark.xfail(reason='function when not implemented')) + if originalname == 'test_query_13': + request.node.add_marker(pytest.mark.xfail(reason='function rlike not implemented')) + if originalname in ['test_query_16', 'test_query_19']: + request.node.add_marker(pytest.mark.xfail(reason='function in not implemented')) + if source == 'gateway-over-datafusion': + pytest.importorskip("datafusion.substrait") + request.node.add_marker(pytest.mark.xfail(reason='gateway internal error')) class TestTpchWithDataFrameAPI: From ba2ebedec23380805cece5083614ec37447f1b3c Mon Sep 17 00:00:00 2001 From: David Sisson Date: Wed, 24 Apr 2024 15:29:51 -0700 Subject: [PATCH 33/58] feat: improved arrow backend support (#58) - Adds workarounds to increase the likelihood that a function will be matched against the Acero engine. - Implements describe_table. - The tests are still not ready to be run continuously. --- src/gateway/backends/arrow_backend.py | 10 +++- src/gateway/backends/backend_options.py | 2 + src/gateway/backends/datafusion_backend.py | 4 +- src/gateway/converter/conversion_options.py | 11 ++-- src/gateway/converter/rename_functions.py | 60 ++++++++++++++++++++- 5 files changed, 75 insertions(+), 12 deletions(-) diff --git a/src/gateway/backends/arrow_backend.py b/src/gateway/backends/arrow_backend.py index eca41fb..a237c65 100644 --- a/src/gateway/backends/arrow_backend.py +++ b/src/gateway/backends/arrow_backend.py @@ -8,6 +8,7 @@ from substrait.gen.proto import plan_pb2 from gateway.backends.backend import Backend +from gateway.converter.rename_functions import RenameFunctionsForArrow class ArrowBackend(Backend): @@ -18,10 +19,13 @@ class ArrowBackend(Backend): def __init__(self, options): """Initialize the Datafusion backend.""" super().__init__(options) + self._use_uri_workaround = options.use_arrow_uri_workaround # pylint: disable=import-outside-toplevel def execute(self, plan: plan_pb2.Plan) -> pa.lib.Table: """Execute the given Substrait plan against Acero.""" + RenameFunctionsForArrow(use_uri_workaround=self._use_uri_workaround).visit_plan(plan) + plan_data = plan.SerializeToString() reader = pa.substrait.run_query(plan_data, table_provider=self._provide_tables) return reader.read_all() @@ -35,9 +39,13 @@ def drop_table(self, name: str) -> None: if self._registered_tables.get(name): del self._registered_tables[name] + def describe_table(self, name: str): + """Return the schema of the given table.""" + return pa.parquet.read_table(self._registered_tables[name]).schema + def _provide_tables(self, names: list[str], unused_schema) -> pyarrow.Table: """Provide the tables requested.""" for name in names: if name in self._registered_tables: - return pa.Table.from_pandas(pa.read_parquet(self._registered_tables[name])) + return pa.parquet.read_table(self._registered_tables[name]) raise ValueError(f'Table {names} not found in {self._registered_tables}') diff --git a/src/gateway/backends/backend_options.py b/src/gateway/backends/backend_options.py index 7039acb..5f0a578 100644 --- a/src/gateway/backends/backend_options.py +++ b/src/gateway/backends/backend_options.py @@ -23,3 +23,5 @@ def __init__(self, backend: Backend, use_adbc: bool = False): """Create a BackendOptions structure.""" self.backend = backend self.use_adbc = use_adbc + + self.use_arrow_uri_workaround = False diff --git a/src/gateway/backends/datafusion_backend.py b/src/gateway/backends/datafusion_backend.py index 3311cb6..c722198 100644 --- a/src/gateway/backends/datafusion_backend.py +++ b/src/gateway/backends/datafusion_backend.py @@ -6,7 +6,7 @@ from substrait.gen.proto import plan_pb2 from gateway.backends.backend import Backend -from gateway.converter.rename_functions import RenameFunctions +from gateway.converter.rename_functions import RenameFunctionsForDatafusion from gateway.converter.replace_local_files import ReplaceLocalFilesWithNamedTable @@ -38,7 +38,7 @@ def execute(self, plan: plan_pb2.Plan) -> pa.lib.Table: self.register_table(table_name, file) registered_tables.add(files[0]) - RenameFunctions().visit_plan(plan) + RenameFunctionsForDatafusion().visit_plan(plan) try: plan_data = plan.SerializeToString() diff --git a/src/gateway/converter/conversion_options.py b/src/gateway/converter/conversion_options.py index 0f87207..d3ad94e 100644 --- a/src/gateway/converter/conversion_options.py +++ b/src/gateway/converter/conversion_options.py @@ -10,14 +10,6 @@ class ConversionOptions: """Holds all the possible conversion options.""" - use_named_table_workaround: bool - needs_scheme_in_path_uris: bool - use_project_emit_workaround: bool - use_project_emit_workaround2: bool - use_emits_instead_of_direct: bool - - return_names_with_types: bool - def __init__(self, backend: BackendOptions = None): """Initialize the conversion options.""" self.use_named_table_workaround = False @@ -35,6 +27,9 @@ def arrow(): """Return standard options to connect to the Acero backend.""" options = ConversionOptions(backend=BackendOptions(Backend.ARROW)) options.needs_scheme_in_path_uris = True + options.return_names_with_types = True + options.implement_show_string = False + options.backend.use_arrow_uri_workaround = True return options diff --git a/src/gateway/converter/rename_functions.py b/src/gateway/converter/rename_functions.py index 9ed8bc9..34f1c9c 100644 --- a/src/gateway/converter/rename_functions.py +++ b/src/gateway/converter/rename_functions.py @@ -5,7 +5,7 @@ # pylint: disable=no-member,fixme -class RenameFunctions(SubstraitPlanVisitor): +class RenameFunctionsForDatafusion(SubstraitPlanVisitor): """Renames Substrait functions to match what Datafusion expects.""" def visit_plan(self, plan: plan_pb2.Plan) -> None: @@ -34,3 +34,61 @@ def visit_plan(self, plan: plan_pb2.Plan) -> None: extension.extension_function.name = 'instr' elif extension.extension_function.name == 'extract': extension.extension_function.name = 'date_part' + + +# pylint: disable=no-member,fixme +class RenameFunctionsForArrow(SubstraitPlanVisitor): + """Renames Substrait functions to match what Datafusion expects.""" + + def __init__(self, use_uri_workaround=False): + """Initialize the RenameFunctionsForArrow class.""" + self._extensions: dict[int, str] = {} + self._use_uri_workaround = use_uri_workaround + super().__init__() + + def normalize_extension_uris(self, plan: plan_pb2.Plan) -> None: + """Normalize the URI.""" + for extension in plan.extension_uris: + if self._use_uri_workaround: + extension.uri = 'urn:arrow:substrait_simple_extension_function' + else: + if extension.uri.startswith('/'): + extension.uri = extension.uri.replace( + '/', 'https://github.com/substrait-io/substrait/blob/main/extensions/') + + def index_extension_uris(self, plan: plan_pb2.Plan) -> None: + """Add the extension URIs into a dictionary.""" + self._extensions: dict[int, str] = {} + for extension in plan.extension_uris: + self._extensions[extension.extension_uri_anchor] = extension.uri + + def visit_plan(self, plan: plan_pb2.Plan) -> None: + """Modify the provided plan so that functions are Arrow compatible.""" + super().visit_plan(plan) + + self.normalize_extension_uris(plan) + self.index_extension_uris(plan) + + for extension in plan.extensions: + if extension.WhichOneof('mapping_type') != 'extension_function': + continue + + if ':' in extension.extension_function.name: + name, signature = extension.extension_function.name.split(':', 2) + else: + name = extension.extension_function.name + signature = None + + # TODO -- Take the URI references into account. + changed = False + if name == 'char_length': + changed = True + name = 'utf8_length' + + if not changed: + continue + + if signature: + extension.extension_function.name = f'{name}:{signature}' + else: + extension.extension_function.name = name From b887caee189e74f89fbd9595ff1bd07485d88c0f Mon Sep 17 00:00:00 2001 From: David Sisson Date: Wed, 24 Apr 2024 17:22:10 -0700 Subject: [PATCH 34/58] feat: split the tests by backend (#59) This will allow individual test suites to be more rapidly iterated on. To use, specify the desired source as an argument to pytest: `pytest -m gateway-over-duckdb` or `pytest -m general` --- .github/workflows/test.yml | 3 ++- pyproject.toml | 7 +++++++ src/gateway/tests/conftest.py | 10 ++++++++++ 3 files changed, 19 insertions(+), 1 deletion(-) diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index 9d34e23..085c68b 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -15,6 +15,7 @@ jobs: strategy: matrix: os: [macos-latest, ubuntu-latest] + source: ["general", "spark", "gateway-over-duckdb", "gateway-over-datafusion"] python: ["3.10"] runs-on: ${{ matrix.os }} steps: @@ -36,4 +37,4 @@ jobs: - name: Run tests shell: bash -el {0} run: | - pytest + pytest -m ${{ matrix.source }} diff --git a/pyproject.toml b/pyproject.toml index 78b1a05..5ca8d61 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -18,6 +18,13 @@ test = ["pytest >= 7.0.0"] [tool.pytest.ini_options] pythonpath = "src" addopts = "--ignore=third_party" +markers = [ + "spark: mark a test as running against Spark", + "gateway-over-arrow: mark a test as running against the gateway using Arrow", + "gateway-over-duckdb: mark a test as running against the gateway using DuckDB", + "gateway-over-datafusion: mark a test as running against the gateway using DataFusion", + "general: mark a test as not specific to a backend", +] [build-system] requires = ["setuptools>=61.0.0", "setuptools_scm[toml]>=6.2.0"] diff --git a/src/gateway/tests/conftest.py b/src/gateway/tests/conftest.py index a99d24c..fb305c8 100644 --- a/src/gateway/tests/conftest.py +++ b/src/gateway/tests/conftest.py @@ -1,5 +1,6 @@ # SPDX-License-Identifier: Apache-2.0 """Test fixtures for pytest of the gateway server.""" +import re from pathlib import Path import pytest @@ -14,6 +15,15 @@ from pyspark.sql.session import SparkSession +def pytest_collection_modifyitems(items): + for item in items: + if 'source' in getattr(item, 'fixturenames', ()): + source = re.search(r'\[([^,]+?)(-\d+)?]$', item.name).group(1) + item.add_marker(source) + continue + item.add_marker('general') + + # ruff: noqa: T201 def _create_local_spark_session() -> SparkSession: """Creates a local spark session for testing.""" From 8491ec206f5492ce98e34f3b90b2eaac0aa8014d Mon Sep 17 00:00:00 2001 From: David Sisson Date: Wed, 24 Apr 2024 22:09:31 -0700 Subject: [PATCH 35/58] feat: implement the SparkConnect in function (#60) There are two implementations here. For cases where the backend supports switch expressions we try that first (only works with literal values). Failing that a heavier weight if-then expression is crafted instead. --- src/gateway/converter/conversion_options.py | 2 + src/gateway/converter/spark_to_substrait.py | 49 +++++++++++++++++-- src/gateway/converter/substrait_builder.py | 17 +++++++ .../tests/test_tpch_with_dataframe_api.py | 8 +-- 4 files changed, 67 insertions(+), 9 deletions(-) diff --git a/src/gateway/converter/conversion_options.py b/src/gateway/converter/conversion_options.py index d3ad94e..5826564 100644 --- a/src/gateway/converter/conversion_options.py +++ b/src/gateway/converter/conversion_options.py @@ -15,6 +15,7 @@ def __init__(self, backend: BackendOptions = None): self.use_named_table_workaround = False self.needs_scheme_in_path_uris = False self.use_emits_instead_of_direct = False + self.use_switch_expressions_where_possible = True self.return_names_with_types = False @@ -42,4 +43,5 @@ def duck_db(): """Return standard options to connect to a DuckDB backend.""" options = ConversionOptions(backend=BackendOptions(Backend.DUCKDB)) options.return_names_with_types = True + options.use_switch_expressions_where_possible = False return options diff --git a/src/gateway/converter/spark_to_substrait.py b/src/gateway/converter/spark_to_substrait.py index dc6fc67..1b79009 100644 --- a/src/gateway/converter/spark_to_substrait.py +++ b/src/gateway/converter/spark_to_substrait.py @@ -18,8 +18,10 @@ from gateway.converter.substrait_builder import ( aggregate_relation, bigint_literal, + bool_literal, cast_operation, concat, + equal_function, fetch_relation, field_reference, flatten, @@ -236,14 +238,55 @@ def convert_when_function( return algebra_pb2.Expression(if_then=ifthen) + def convert_in_function( + self, in_: spark_exprs_pb2.Expression.UnresolvedFunction) -> algebra_pb2.Expression: + """Convert a Spark in function into a Substrait switch expression.""" + + def is_switch_expression_appropriate() -> bool: + """Determine if the IN function is appropriate for a switch expression.""" + if not self._conversion_options.use_switch_expressions_where_possible: + return False + return all(a.WhichOneof("expr_type") == "literal" for a in in_.arguments[1:]) + + if is_switch_expression_appropriate(): + switch = algebra_pb2.Expression.SwitchExpression( + match=self.convert_expression(in_.arguments[0])) + + for arg in in_.arguments[1:]: + ifvalue = algebra_pb2.Expression.SwitchExpression.IfValue(then=bool_literal(True)) + expr = self.convert_literal_expression(arg.literal) + getattr(ifvalue, 'if').CopyFrom(expr.literal) + switch.ifs.append(ifvalue) + + getattr(switch, 'else').CopyFrom(bool_literal(False)) + + return algebra_pb2.Expression(switch_expression=switch) + + equal_func = self.lookup_function_by_name('==') + + ifthen = algebra_pb2.Expression.IfThen() + + match = self.convert_expression(in_.arguments[0]) + for arg in in_.arguments[1:]: + clause = algebra_pb2.Expression.IfThen.IfClause(then=bool_literal(True)) + getattr(clause, 'if').CopyFrom( + equal_function(equal_func, match, self.convert_expression(arg))) + ifthen.ifs.append(clause) + + getattr(ifthen, 'else').CopyFrom(bool_literal(False)) + + return algebra_pb2.Expression(if_then=ifthen) + def convert_unresolved_function( self, - unresolved_function: - spark_exprs_pb2.Expression.UnresolvedFunction) -> algebra_pb2.Expression: + unresolved_function: spark_exprs_pb2.Expression.UnresolvedFunction + ) -> algebra_pb2.Expression: """Convert a Spark unresolved function into a Substrait scalar function.""" - func = algebra_pb2.Expression.ScalarFunction() if unresolved_function.function_name == 'when': return self.convert_when_function(unresolved_function) + if unresolved_function.function_name == 'in': + return self.convert_in_function(unresolved_function) + func = algebra_pb2.Expression.ScalarFunction() function_def = self.lookup_function_by_name(unresolved_function.function_name) func.function_reference = function_def.anchor for idx, arg in enumerate(unresolved_function.arguments): diff --git a/src/gateway/converter/substrait_builder.py b/src/gateway/converter/substrait_builder.py index 6ab9ae5..a5e4d37 100644 --- a/src/gateway/converter/substrait_builder.py +++ b/src/gateway/converter/substrait_builder.py @@ -141,6 +141,18 @@ def greatest_function(greater_function_info: ExtensionFunction, expr1: algebra_p ) +def equal_function(function_info: ExtensionFunction, + expr1: algebra_pb2.Expression, + expr2: algebra_pb2.Expression) -> algebra_pb2.Expression: + """Construct a Substrait min expression.""" + return algebra_pb2.Expression(scalar_function= + algebra_pb2.Expression.ScalarFunction( + function_reference=function_info.anchor, + output_type=function_info.output_type, + arguments=[algebra_pb2.FunctionArgument(value=expr1), + algebra_pb2.FunctionArgument(value=expr2)])) + + def greater_or_equal_function(function_info: ExtensionFunction, expr1: algebra_pb2.Expression, expr2: algebra_pb2.Expression) -> algebra_pb2.Expression: @@ -223,6 +235,11 @@ def rpad_function(function_info: ExtensionFunction, value=cast_operation(string_literal(pad_string), cast_type))])) +def bool_literal(val: bool) -> algebra_pb2.Expression: + """Construct a Substrait boolean literal expression.""" + return algebra_pb2.Expression(literal=algebra_pb2.Expression.Literal(boolean=val)) + + def string_literal(val: str) -> algebra_pb2.Expression: """Construct a Substrait string literal expression.""" return algebra_pb2.Expression(literal=algebra_pb2.Expression.Literal(string=val)) diff --git a/src/gateway/tests/test_tpch_with_dataframe_api.py b/src/gateway/tests/test_tpch_with_dataframe_api.py index 2f2793f..19c2e7a 100644 --- a/src/gateway/tests/test_tpch_with_dataframe_api.py +++ b/src/gateway/tests/test_tpch_with_dataframe_api.py @@ -18,17 +18,13 @@ def mark_tests_as_xfail(request): if originalname in [ 'test_query_02', 'test_query_03', 'test_query_05', 'test_query_07', 'test_query_08', 'test_query_09', 'test_query_10', 'test_query_11', 'test_query_12', 'test_query_14', - 'test_query_15', 'test_query_17', 'test_query_18', 'test_query_20', 'test_query_21', - 'test_query_22']: + 'test_query_15', 'test_query_16', 'test_query_17', 'test_query_18', 'test_query_19', + 'test_query_20', 'test_query_21', 'test_query_22']: request.node.add_marker(pytest.mark.xfail(reason='DuckDB binder error')) if originalname == 'test_query_04': request.node.add_marker(pytest.mark.xfail(reason='deduplicate not implemented')) - if originalname in ['test_query_12', 'test_query_14']: - request.node.add_marker(pytest.mark.xfail(reason='function when not implemented')) if originalname == 'test_query_13': request.node.add_marker(pytest.mark.xfail(reason='function rlike not implemented')) - if originalname in ['test_query_16', 'test_query_19']: - request.node.add_marker(pytest.mark.xfail(reason='function in not implemented')) if source == 'gateway-over-datafusion': pytest.importorskip("datafusion.substrait") request.node.add_marker(pytest.mark.xfail(reason='gateway internal error')) From 42c5825bc1ca6db32cdc2449d65c7220b8c107b2 Mon Sep 17 00:00:00 2001 From: David Sisson Date: Thu, 25 Apr 2024 09:33:00 -0700 Subject: [PATCH 36/58] feat: implement rlike function (#61) --- src/gateway/converter/conversion_options.py | 2 ++ src/gateway/converter/spark_functions.py | 7 +++++ src/gateway/converter/spark_to_substrait.py | 29 ++++++++++++++++++- src/gateway/converter/substrait_builder.py | 15 ++++++++++ .../tests/test_tpch_with_dataframe_api.py | 8 ++--- 5 files changed, 55 insertions(+), 6 deletions(-) diff --git a/src/gateway/converter/conversion_options.py b/src/gateway/converter/conversion_options.py index 5826564..3235720 100644 --- a/src/gateway/converter/conversion_options.py +++ b/src/gateway/converter/conversion_options.py @@ -16,6 +16,7 @@ def __init__(self, backend: BackendOptions = None): self.needs_scheme_in_path_uris = False self.use_emits_instead_of_direct = False self.use_switch_expressions_where_possible = True + self.use_duckdb_regexp_matches_function = False self.return_names_with_types = False @@ -44,4 +45,5 @@ def duck_db(): options = ConversionOptions(backend=BackendOptions(Backend.DUCKDB)) options.return_names_with_types = True options.use_switch_expressions_where_possible = False + options.use_duckdb_regexp_matches_function = True return options diff --git a/src/gateway/converter/spark_functions.py b/src/gateway/converter/spark_functions.py index 6a0666d..a284957 100644 --- a/src/gateway/converter/spark_functions.py +++ b/src/gateway/converter/spark_functions.py @@ -93,6 +93,13 @@ def __lt__(self, obj) -> bool: '/functions_string.yaml', 'regexp_match:str_binary_str', type_pb2.Type( list=type_pb2.Type.List(type=type_pb2.Type(string=type_pb2.Type.String( nullability=type_pb2.Type.Nullability.NULLABILITY_REQUIRED))))), + 'regexp_substring': ExtensionFunction( + '/functions_string.yaml', 'regexp_substring:str_str_i64_i64', type_pb2.Type( + i64=type_pb2.Type.I64(nullability=type_pb2.Type.Nullability.NULLABILITY_REQUIRED))), + 'DUCKDB_regexp_matches': ExtensionFunction( + '/functions_string.yaml', 'regexp_matches:str_str', type_pb2.Type( + bool=type_pb2.Type.Boolean( + nullability=type_pb2.Type.Nullability.NULLABILITY_REQUIRED))), 'substring': ExtensionFunction( '/functions_string.yaml', 'substring:str_int_int', type_pb2.Type( string=type_pb2.Type.String( diff --git a/src/gateway/converter/spark_to_substrait.py b/src/gateway/converter/spark_to_substrait.py index 1b79009..07b5949 100644 --- a/src/gateway/converter/spark_to_substrait.py +++ b/src/gateway/converter/spark_to_substrait.py @@ -34,6 +34,7 @@ max_agg_function, minus_function, project_relation, + regexp_strpos_function, repeat_function, string_concat_agg_function, string_literal, @@ -47,7 +48,6 @@ TABLE_NAME = "my_table" -# pylint: disable=E1101,fixme,too-many-public-methods # ruff: noqa: RUF005 class SparkSubstraitConverter: """Converts SparkConnect plans to Substrait plans.""" @@ -277,6 +277,31 @@ def is_switch_expression_appropriate() -> bool: return algebra_pb2.Expression(if_then=ifthen) + def convert_rlike_function( + self, in_: spark_exprs_pb2.Expression.UnresolvedFunction) -> algebra_pb2.Expression: + """Convert a Spark rlike function into a Substrait expression.""" + if self._conversion_options.use_duckdb_regexp_matches_function: + regexp_matches_func = self.lookup_function_by_name('DUCKDB_regexp_matches') + return algebra_pb2.Expression( + scalar_function=algebra_pb2.Expression.ScalarFunction( + function_reference=regexp_matches_func.anchor, + arguments=[ + algebra_pb2.FunctionArgument( + value=self.convert_expression(in_.arguments[0])), + algebra_pb2.FunctionArgument( + value=self.convert_expression(in_.arguments[1])) + ], + output_type=regexp_matches_func.output_type)) + + regexp_strpos_func = self.lookup_function_by_name('regexp_strpos') + greater_func = self.lookup_function_by_name('>') + + regexp_expr = regexp_strpos_function(regexp_strpos_func, + self.convert_expression(in_.arguments[1]), + self.convert_expression(in_.arguments[0]), + bigint_literal(1), bigint_literal(1)) + return greater_function(greater_func, regexp_expr, bigint_literal(0)) + def convert_unresolved_function( self, unresolved_function: spark_exprs_pb2.Expression.UnresolvedFunction @@ -286,6 +311,8 @@ def convert_unresolved_function( return self.convert_when_function(unresolved_function) if unresolved_function.function_name == 'in': return self.convert_in_function(unresolved_function) + if unresolved_function.function_name == 'rlike': + return self.convert_rlike_function(unresolved_function) func = algebra_pb2.Expression.ScalarFunction() function_def = self.lookup_function_by_name(unresolved_function.function_name) func.function_reference = function_def.anchor diff --git a/src/gateway/converter/substrait_builder.py b/src/gateway/converter/substrait_builder.py index a5e4d37..924b6bc 100644 --- a/src/gateway/converter/substrait_builder.py +++ b/src/gateway/converter/substrait_builder.py @@ -235,6 +235,21 @@ def rpad_function(function_info: ExtensionFunction, value=cast_operation(string_literal(pad_string), cast_type))])) +def regexp_strpos_function(function_info: ExtensionFunction, + input: algebra_pb2.Expression, pattern: algebra_pb2.Expression, + position: algebra_pb2.Expression, + occurrence: algebra_pb2.Expression) -> algebra_pb2.AggregateFunction: + """Construct a Substrait regex substring expression.""" + return algebra_pb2.Expression(scalar_function=algebra_pb2.Expression.ScalarFunction( + function_reference=function_info.anchor, + output_type=function_info.output_type, + arguments=[ + algebra_pb2.FunctionArgument(value=input), + algebra_pb2.FunctionArgument(value=pattern), + algebra_pb2.FunctionArgument(value=occurrence), + algebra_pb2.FunctionArgument(value=position)])) + + def bool_literal(val: bool) -> algebra_pb2.Expression: """Construct a Substrait boolean literal expression.""" return algebra_pb2.Expression(literal=algebra_pb2.Expression.Literal(boolean=val)) diff --git a/src/gateway/tests/test_tpch_with_dataframe_api.py b/src/gateway/tests/test_tpch_with_dataframe_api.py index 19c2e7a..5778543 100644 --- a/src/gateway/tests/test_tpch_with_dataframe_api.py +++ b/src/gateway/tests/test_tpch_with_dataframe_api.py @@ -17,14 +17,12 @@ def mark_tests_as_xfail(request): if source == 'gateway-over-duckdb': if originalname in [ 'test_query_02', 'test_query_03', 'test_query_05', 'test_query_07', 'test_query_08', - 'test_query_09', 'test_query_10', 'test_query_11', 'test_query_12', 'test_query_14', - 'test_query_15', 'test_query_16', 'test_query_17', 'test_query_18', 'test_query_19', - 'test_query_20', 'test_query_21', 'test_query_22']: + 'test_query_09', 'test_query_10', 'test_query_11', 'test_query_12', 'test_query_13', + 'test_query_14', 'test_query_15', 'test_query_16', 'test_query_17', 'test_query_18', + 'test_query_19', 'test_query_20', 'test_query_21', 'test_query_22']: request.node.add_marker(pytest.mark.xfail(reason='DuckDB binder error')) if originalname == 'test_query_04': request.node.add_marker(pytest.mark.xfail(reason='deduplicate not implemented')) - if originalname == 'test_query_13': - request.node.add_marker(pytest.mark.xfail(reason='function rlike not implemented')) if source == 'gateway-over-datafusion': pytest.importorskip("datafusion.substrait") request.node.add_marker(pytest.mark.xfail(reason='gateway internal error')) From 36de5950cd9c2639eb203b5bb55bec6834bc96ad Mon Sep 17 00:00:00 2001 From: David Sisson Date: Thu, 25 Apr 2024 17:02:14 -0700 Subject: [PATCH 37/58] feat: implement deduplicate (#62) --- src/gateway/converter/spark_functions.py | 4 ++++ src/gateway/converter/spark_to_substrait.py | 23 +++++++++++++++++++ src/gateway/tests/test_sql_api.py | 4 +++- .../tests/test_tpch_with_dataframe_api.py | 15 +++++------- 4 files changed, 36 insertions(+), 10 deletions(-) diff --git a/src/gateway/converter/spark_functions.py b/src/gateway/converter/spark_functions.py index a284957..bc78ded 100644 --- a/src/gateway/converter/spark_functions.py +++ b/src/gateway/converter/spark_functions.py @@ -156,6 +156,10 @@ def __lt__(self, obj) -> bool: '/functions_aggregate_generic.yaml', 'count:any', type_pb2.Type( i64=type_pb2.Type.I64( nullability=type_pb2.Type.Nullability.NULLABILITY_REQUIRED))), + 'any_value': ExtensionFunction( + '/functions_aggregate_generic.yaml', 'any_value:any', type_pb2.Type( + i64=type_pb2.Type.I64( + nullability=type_pb2.Type.Nullability.NULLABILITY_REQUIRED))), 'and': ExtensionFunction( '/functions_boolean.yaml', 'and:bool_bool', type_pb2.Type( bool=type_pb2.Type.Boolean( diff --git a/src/gateway/converter/spark_to_substrait.py b/src/gateway/converter/spark_to_substrait.py index 07b5949..f1ed07f 100644 --- a/src/gateway/converter/spark_to_substrait.py +++ b/src/gateway/converter/spark_to_substrait.py @@ -1053,6 +1053,27 @@ def convert_subquery_alias_relation(self, self.update_field_references(rel.input.common.plan_id) return result + def convert_deduplicate_relation(self, rel: spark_relations_pb2.Deduplicate) -> algebra_pb2.Rel: + """Convert a Spark deduplicate relation into a Substrait aggregation.""" + any_value_func = self.lookup_function_by_name('any_value') + + aggregate = algebra_pb2.AggregateRel(input=self.convert_relation(rel.input)) + self.update_field_references(rel.input.common.plan_id) + aggregate.common.CopyFrom(self.create_common_relation()) + symbol = self._symbol_table.get_symbol(self._current_plan_id) + grouping = aggregate.groupings.add() + for idx, field in enumerate(symbol.input_fields): + grouping.grouping_expressions.append(field_reference(idx)) + aggregate.measures.append( + algebra_pb2.AggregateRel.Measure( + measure=algebra_pb2.AggregateFunction( + function_reference=any_value_func.anchor, + arguments=[algebra_pb2.FunctionArgument(value=field_reference(idx))], + phase=algebra_pb2.AggregationPhase.AGGREGATION_PHASE_INITIAL_TO_RESULT, + output_type=type_pb2.Type(bool=type_pb2.Type.Boolean())))) + symbol.generated_fields.append(field) + return algebra_pb2.Rel(aggregate=aggregate) + def convert_relation(self, rel: spark_relations_pb2.Relation) -> algebra_pb2.Rel: """Convert a Spark relation into a Substrait one.""" self._symbol_table.add_symbol(rel.common.plan_id, parent=self._current_plan_id, @@ -1086,6 +1107,8 @@ def convert_relation(self, rel: spark_relations_pb2.Relation) -> algebra_pb2.Rel result = self.convert_project_relation(rel.project) case 'subquery_alias': result = self.convert_subquery_alias_relation(rel.subquery_alias) + case 'deduplicate': + result = self.convert_deduplicate_relation(rel.deduplicate) case _: raise ValueError( f'Unexpected Spark plan rel_type: {rel.WhichOneof("rel_type")}') diff --git a/src/gateway/tests/test_sql_api.py b/src/gateway/tests/test_sql_api.py index b6d4553..13d1b7f 100644 --- a/src/gateway/tests/test_sql_api.py +++ b/src/gateway/tests/test_sql_api.py @@ -37,10 +37,12 @@ def mark_tests_as_xfail(request): request.node.add_marker(pytest.mark.xfail(reason='aggregation error')) elif path.stem in ['09']: request.node.add_marker(pytest.mark.xfail(reason='instr not implemented')) - elif path.stem in ['11', '15']: + elif path.stem in ['11']: request.node.add_marker(pytest.mark.xfail(reason='first not implemented')) elif path.stem in ['13']: request.node.add_marker(pytest.mark.xfail(reason='not rlike not implemented')) + elif path.stem in ['15']: + request.node.add_marker(pytest.mark.xfail(reason='empty table error')) elif path.stem in ['16']: request.node.add_marker(pytest.mark.xfail(reason='mark join not implemented')) elif path.stem in ['18']: diff --git a/src/gateway/tests/test_tpch_with_dataframe_api.py b/src/gateway/tests/test_tpch_with_dataframe_api.py index 5778543..3da9fda 100644 --- a/src/gateway/tests/test_tpch_with_dataframe_api.py +++ b/src/gateway/tests/test_tpch_with_dataframe_api.py @@ -14,15 +14,12 @@ def mark_tests_as_xfail(request): """Marks a subset of tests as expected to be fail.""" source = request.getfixturevalue('source') originalname = request.keywords.node.originalname - if source == 'gateway-over-duckdb': - if originalname in [ - 'test_query_02', 'test_query_03', 'test_query_05', 'test_query_07', 'test_query_08', - 'test_query_09', 'test_query_10', 'test_query_11', 'test_query_12', 'test_query_13', - 'test_query_14', 'test_query_15', 'test_query_16', 'test_query_17', 'test_query_18', - 'test_query_19', 'test_query_20', 'test_query_21', 'test_query_22']: - request.node.add_marker(pytest.mark.xfail(reason='DuckDB binder error')) - if originalname == 'test_query_04': - request.node.add_marker(pytest.mark.xfail(reason='deduplicate not implemented')) + if source == 'gateway-over-duckdb' and originalname in [ + 'test_query_02', 'test_query_03', 'test_query_04', 'test_query_05', 'test_query_07', + 'test_query_08', 'test_query_09', 'test_query_10', 'test_query_11', 'test_query_12', + 'test_query_13', 'test_query_14', 'test_query_15', 'test_query_16', 'test_query_17', + 'test_query_18', 'test_query_19', 'test_query_20', 'test_query_21', 'test_query_22']: + request.node.add_marker(pytest.mark.xfail(reason='DuckDB binder error')) if source == 'gateway-over-datafusion': pytest.importorskip("datafusion.substrait") request.node.add_marker(pytest.mark.xfail(reason='gateway internal error')) From c5acd91db560589f4b87066ef6dd7c73fd1e32b6 Mon Sep 17 00:00:00 2001 From: David Sisson Date: Thu, 25 Apr 2024 18:04:24 -0700 Subject: [PATCH 38/58] chore: mark proper reason for failing SQL test (#63) --- src/gateway/tests/test_sql_api.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/src/gateway/tests/test_sql_api.py b/src/gateway/tests/test_sql_api.py index 13d1b7f..16f4510 100644 --- a/src/gateway/tests/test_sql_api.py +++ b/src/gateway/tests/test_sql_api.py @@ -21,8 +21,10 @@ def mark_tests_as_xfail(request): originalname = request.keywords.node.originalname if source == 'gateway-over-duckdb' and originalname == 'test_tpch': path = request.getfixturevalue('path') - if path.stem in ['02', '04', '15', '16', '17', '18', '20', '21', '22']: + if path.stem in ['02', '04', '16', '17', '18', '20', '21', '22']: request.node.add_marker(pytest.mark.xfail(reason='DuckDB needs Delim join')) + if path.stem in ['15']: + request.node.add_marker(pytest.mark.xfail(reason='Rounding inconsistency')) if source == 'gateway-over-datafusion': pytest.importorskip("datafusion.substrait") if originalname == 'test_count': From d0e9376a62252f509034e161327503acbfbd3350 Mon Sep 17 00:00:00 2001 From: Richard Tia Date: Mon, 6 May 2024 15:52:09 -0400 Subject: [PATCH 39/58] feat: add createTempView table registration (#64) --- src/gateway/backends/adbc_backend.py | 2 +- src/gateway/backends/datafusion_backend.py | 13 ++++- src/gateway/backends/duckdb_backend.py | 7 ++- src/gateway/converter/spark_to_substrait.py | 19 +++++-- src/gateway/converter/sql_to_substrait.py | 9 ++-- src/gateway/server.py | 59 +++++++++++++++++---- src/gateway/tests/conftest.py | 10 ++-- src/gateway/tests/test_dataframe_api.py | 29 ++++++++-- 8 files changed, 118 insertions(+), 30 deletions(-) diff --git a/src/gateway/backends/adbc_backend.py b/src/gateway/backends/adbc_backend.py index e89e7fb..57a40f1 100644 --- a/src/gateway/backends/adbc_backend.py +++ b/src/gateway/backends/adbc_backend.py @@ -62,7 +62,7 @@ def register_table(self, name: str, path: Path, file_format: str = 'parquet') -> file_paths = sorted([str(fp) for fp in file_paths]) # TODO: Support multiple paths. reader = pq.ParquetFile(file_paths[0]) - self._connection.cursor().adbc_ingest(name, reader.iter_batches(), mode="create") + self._connection.cursor().adbc_ingest(name, reader.iter_batches(), mode='create') def describe_table(self, table_name: str): """Asks the backend to describe the given table.""" diff --git a/src/gateway/backends/datafusion_backend.py b/src/gateway/backends/datafusion_backend.py index c722198..950af4d 100644 --- a/src/gateway/backends/datafusion_backend.py +++ b/src/gateway/backends/datafusion_backend.py @@ -23,6 +23,7 @@ def __init__(self, options): def create_connection(self) -> None: """Create a connection to the backend.""" import datafusion + self._connection = datafusion.SessionContext() def execute(self, plan: plan_pb2.Plan) -> pa.lib.Table: @@ -59,7 +60,15 @@ def execute(self, plan: plan_pb2.Plan) -> pa.lib.Table: for table_name in registered_tables: self._connection.deregister_table(table_name) - def register_table(self, name: str, path: Path, file_format: str = 'parquet') -> None: + def register_table( + self, name: str, location: Path, file_format: str = 'parquet' + ) -> None: """Register the given table with the backend.""" - files = Backend.expand_location(path) + files = Backend.expand_location(location) + if not files: + raise ValueError(f"No parquet files found at {location}") + # TODO: Add options to skip table registration if it already exists instead + # of deregistering it. + if self._connection.table_exist(name): + self._connection.deregister_table(name) self._connection.register_parquet(name, files[0]) diff --git a/src/gateway/backends/duckdb_backend.py b/src/gateway/backends/duckdb_backend.py index c35078c..af5e9b9 100644 --- a/src/gateway/backends/duckdb_backend.py +++ b/src/gateway/backends/duckdb_backend.py @@ -44,7 +44,12 @@ def execute(self, plan: plan_pb2.Plan) -> pa.lib.Table: df = query_result.df() return pa.Table.from_pandas(df=df) - def register_table(self, table_name: str, location: Path, file_format: str = 'parquet') -> None: + def register_table( + self, + table_name: str, + location: Path, + file_format: str = "parquet" + ) -> None: """Register the given table with the backend.""" files = Backend.expand_location(location) if not files: diff --git a/src/gateway/converter/spark_to_substrait.py b/src/gateway/converter/spark_to_substrait.py index f1ed07f..02e3493 100644 --- a/src/gateway/converter/spark_to_substrait.py +++ b/src/gateway/converter/spark_to_substrait.py @@ -62,6 +62,11 @@ def __init__(self, options: ConversionOptions): self._seen_generated_names = {} self._saved_extension_uris = {} self._saved_extensions = {} + self._backend_with_tempview = None + + def set_tempview_backend(self, backend) -> None: + """Save the backend being used to create the temporary dataframe.""" + self._backend_with_tempview = backend def lookup_function_by_name(self, name: str) -> ExtensionFunction: """Find the function reference for a given Spark function name.""" @@ -449,9 +454,14 @@ def convert_read_named_table_relation( """Convert a read named table relation to a Substrait relation.""" table_name = rel.unparsed_identifier - backend = find_backend(BackendOptions(self._conversion_options.backend.backend, True)) - tpch_location = backend.find_tpch() - backend.register_table(table_name, tpch_location / table_name) + if self._backend_with_tempview: + backend = self._backend_with_tempview + else: + # TODO -- Remove this once we have a persistent backend per session. + backend = find_backend(BackendOptions(self._conversion_options.backend.backend, + use_adbc=True)) + tpch_location = backend.find_tpch() + backend.register_table(table_name, tpch_location / table_name) arrow_schema = backend.describe_table(table_name) schema = self.convert_arrow_schema(arrow_schema) @@ -968,7 +978,8 @@ def convert_local_relation(self, rel: spark_relations_pb2.LocalRelation) -> alge def convert_sql_relation(self, rel: spark_relations_pb2.SQL) -> algebra_pb2.Rel: """Convert a Spark SQL relation into a Substrait relation.""" - plan = convert_sql(rel.query) + # TODO -- Handle multithreading in the case with a persistent backend. + plan = convert_sql(rel.query, self._backend_with_tempview) symbol = self._symbol_table.get_symbol(self._current_plan_id) for field_name in plan.relations[0].root.names: symbol.output_fields.append(field_name) diff --git a/src/gateway/converter/sql_to_substrait.py b/src/gateway/converter/sql_to_substrait.py index 0b12c2e..611b174 100644 --- a/src/gateway/converter/sql_to_substrait.py +++ b/src/gateway/converter/sql_to_substrait.py @@ -5,12 +5,15 @@ from substrait.gen.proto import plan_pb2 -def convert_sql(sql: str) -> plan_pb2.Plan: +def convert_sql(sql: str, backend=None) -> plan_pb2.Plan: """Convert SQL into a Substrait plan.""" result = plan_pb2.Plan() - backend = backend_selector.find_backend(BackendOptions(Backend.DUCKDB, False)) - backend.register_tpch() + # If backend is not provided or is not a DuckDBBackend, set one up. + # DuckDB is used as the SQL conversion engine. + if not isinstance(backend, backend_selector.DuckDBBackend): + backend = backend_selector.find_backend(BackendOptions(Backend.DUCKDB, False)) + backend.register_tpch() connection = backend.get_connection() proto_bytes = connection.get_substrait(query=sql).fetchone()[0] diff --git a/src/gateway/server.py b/src/gateway/server.py index b063ece..d6e1def 100644 --- a/src/gateway/server.py +++ b/src/gateway/server.py @@ -10,7 +10,9 @@ import pyspark.sql.connect.proto.base_pb2 as pb2 import pyspark.sql.connect.proto.base_pb2_grpc as pb2_grpc from pyspark.sql.connect.proto import types_pb2 +from substrait.gen.proto import algebra_pb2 +from gateway.backends.backend_options import BackendOptions from gateway.backends.backend_selector import find_backend from gateway.converter.conversion_options import arrow, datafusion, duck_db from gateway.converter.spark_to_substrait import SparkSubstraitConverter @@ -79,6 +81,20 @@ def convert_pyarrow_schema_to_spark(schema: pa.Schema) -> types_pb2.DataType: return types_pb2.DataType(struct=types_pb2.DataType.Struct(fields=fields)) +def create_dataframe_view(rel: pb2.Plan, conversion_options, backend) -> algebra_pb2.Rel: + """Register the temporary dataframe.""" + dataframe_view_name = rel.command.create_dataframe_view.name + read_data_source_relation = rel.command.create_dataframe_view.input.read.data_source + format = read_data_source_relation.format + path = read_data_source_relation.paths[0] + + if not backend: + backend = find_backend(BackendOptions(conversion_options.backend.backend, False)) + backend.register_table(dataframe_view_name, path, format) + + return backend + + # pylint: disable=E1101,fixme class SparkConnectService(pb2_grpc.SparkConnectServiceServicer): """Provides the SparkConnect service.""" @@ -88,6 +104,9 @@ def __init__(self, *args, **kwargs): """Initialize the SparkConnect service.""" # This is the central point for configuring the behavior of the service. self._options = duck_db() + self._backend_with_tempview = None + self._tempview_session_id = None + self._converter = None def ExecutePlan( self, request: pb2.ExecutePlanRequest, context: grpc.RpcContext) -> Generator[ @@ -96,20 +115,38 @@ def ExecutePlan( _LOGGER.info('ExecutePlan: %s', request) match request.plan.WhichOneof('op_type'): case 'root': - convert = SparkSubstraitConverter(self._options) - substrait = convert.convert_plan(request.plan) + if not self._converter and self._tempview_session_id != request.session_id: + self._converter = SparkSubstraitConverter(self._options) + substrait = self._converter.convert_plan(request.plan) case 'command': match request.plan.command.WhichOneof('command_type'): case 'sql_command': - substrait = convert_sql(request.plan.command.sql_command.sql) + if (self._backend_with_tempview and + self._tempview_session_id == request.session_id): + substrait = convert_sql(request.plan.command.sql_command.sql, + self._backend_with_tempview) + else: + substrait = convert_sql(request.plan.command.sql_command.sql) + case 'create_dataframe_view': + if not self._converter and self._tempview_session_id != request.session_id: + self._converter = SparkSubstraitConverter(self._options) + self._backend_with_tempview = create_dataframe_view( + request.plan, self._options, self._backend_with_tempview) + self._tempview_session_id = request.session_id + self._converter.set_tempview_backend(self._backend_with_tempview) + + return case _: type = request.plan.command.WhichOneof("command_type") raise NotImplementedError(f'Unsupported command type: {type}') case _: raise ValueError(f'Unknown plan type: {request.plan}') _LOGGER.debug(' as Substrait: %s', substrait) - backend = find_backend(self._options.backend) - backend.register_tpch() + if self._backend_with_tempview and self._tempview_session_id == request.session_id: + backend = self._backend_with_tempview + else: + backend = find_backend(self._options.backend) + backend.register_tpch() results = backend.execute(substrait) _LOGGER.debug(' results are: %s', results) @@ -149,10 +186,14 @@ def AnalyzePlan(self, request, context): """Analyze the given plan and return the results.""" _LOGGER.info('AnalyzePlan: %s', request) if request.schema: - convert = SparkSubstraitConverter(self._options) - substrait = convert.convert_plan(request.schema.plan) - backend = find_backend(self._options.backend) - backend.register_tpch() + if not self._converter: + self._converter = SparkSubstraitConverter(self._options) + substrait = self._converter.convert_plan(request.schema.plan) + if self._backend_with_tempview and self._tempview_session_id == request.session_id: + backend = self._backend_with_tempview + else: + backend = find_backend(self._options.backend) + backend.register_tpch() results = backend.execute(substrait) _LOGGER.debug(' results are: %s', results) return pb2.AnalyzePlanResponse( diff --git a/src/gateway/tests/conftest.py b/src/gateway/tests/conftest.py index fb305c8..5e2221b 100644 --- a/src/gateway/tests/conftest.py +++ b/src/gateway/tests/conftest.py @@ -62,7 +62,7 @@ def _create_gateway_session(backend: str) -> SparkSession: spark_gateway.stop() -@pytest.fixture(scope='session', autouse=True) +@pytest.fixture(scope='function', autouse=True) def manage_database() -> None: """Creates the mystream database for use throughout all the tests.""" create_mystream_database() @@ -70,7 +70,7 @@ def manage_database() -> None: delete_mystream_database() -@pytest.fixture(scope='session', autouse=True) +@pytest.fixture(scope='function', autouse=True) def gateway_server(): """Starts up a spark to substrait gateway service.""" server = serve(50052, wait=False) @@ -78,13 +78,13 @@ def gateway_server(): server.stop(None) -@pytest.fixture(scope='session') +@pytest.fixture(scope='function') def users_location() -> str: """Provides the location of the users database.""" return str(Path('users.parquet').resolve()) -@pytest.fixture(scope='session') +@pytest.fixture(scope='function') def schema_users(): """Provides the schema of the users database.""" return get_mystream_schema('users') @@ -100,7 +100,7 @@ def source(request) -> str: return request.param -@pytest.fixture(scope='session') +@pytest.fixture(scope='function') def spark_session(source): """Provides spark sessions connecting to various backends.""" match source: diff --git a/src/gateway/tests/test_dataframe_api.py b/src/gateway/tests/test_dataframe_api.py index d414ebc..a0f9dc3 100644 --- a/src/gateway/tests/test_dataframe_api.py +++ b/src/gateway/tests/test_dataframe_api.py @@ -13,14 +13,17 @@ def mark_dataframe_tests_as_xfail(request): """Marks a subset of tests as expected to be fail.""" source = request.getfixturevalue('source') originalname = request.keywords.node.originalname - if source == 'gateway-over-duckdb' and (originalname == 'test_with_column' or - originalname == 'test_cast'): - request.node.add_marker( - pytest.mark.xfail(reason='DuckDB column binding error')) + if source == 'gateway-over-duckdb': + if originalname == 'test_with_column' or originalname == 'test_cast': + request.node.add_marker(pytest.mark.xfail(reason='DuckDB column binding error')) + elif originalname in [ + 'test_create_or_replace_temp_view', 'test_create_or_replace_multiple_temp_views']: + request.node.add_marker(pytest.mark.xfail(reason='ADBC DuckDB from_substrait error')) elif source == 'gateway-over-datafusion': if originalname in [ 'test_data_source_schema', 'test_data_source_filter', 'test_table', 'test_table_schema', - 'test_table_filter']: + 'test_table_filter', 'test_create_or_replace_temp_view', + 'test_create_or_replace_multiple_temp_views',]: request.node.add_marker(pytest.mark.xfail(reason='Gateway internal iterating error')) else: pytest.importorskip("datafusion.substrait") @@ -133,3 +136,19 @@ def test_table_filter(self, spark_session_with_customer_dataset): customer_dataframe = spark_session_with_customer_dataset.table('customer') outcome = customer_dataframe.filter(col('c_mktsegment') == 'FURNITURE').collect() assert len(outcome) == 29968 + + def test_create_or_replace_temp_view(self, spark_session): + location_customer = str(Backend.find_tpch() / 'customer') + df_customer = spark_session.read.parquet(location_customer) + df_customer.createOrReplaceTempView("mytempview") + outcome = spark_session.table('mytempview').collect() + assert len(outcome) == 149999 + + def test_create_or_replace_multiple_temp_views(self, spark_session): + location_customer = str(Backend.find_tpch() / 'customer') + df_customer = spark_session.read.parquet(location_customer) + df_customer.createOrReplaceTempView("mytempview1") + df_customer.createOrReplaceTempView("mytempview2") + outcome1 = spark_session.table('mytempview1').collect() + outcome2 = spark_session.table('mytempview2').collect() + assert len(outcome1) == len(outcome2) == 149999 From df0cbe54deaa00c7681ed27b80ec047f6e4a4261 Mon Sep 17 00:00:00 2001 From: David Sisson Date: Tue, 7 May 2024 16:43:59 -0700 Subject: [PATCH 40/58] feat: add a safety project behind local reads to address an arrow bug (#66) --- src/gateway/converter/conversion_options.py | 2 ++ src/gateway/converter/spark_to_substrait.py | 21 +++++++++++++++++---- 2 files changed, 19 insertions(+), 4 deletions(-) diff --git a/src/gateway/converter/conversion_options.py b/src/gateway/converter/conversion_options.py index 3235720..f9d8b2a 100644 --- a/src/gateway/converter/conversion_options.py +++ b/src/gateway/converter/conversion_options.py @@ -17,6 +17,7 @@ def __init__(self, backend: BackendOptions = None): self.use_emits_instead_of_direct = False self.use_switch_expressions_where_possible = True self.use_duckdb_regexp_matches_function = False + self.safety_project_read_relations = False self.return_names_with_types = False @@ -32,6 +33,7 @@ def arrow(): options.return_names_with_types = True options.implement_show_string = False options.backend.use_arrow_uri_workaround = True + options.safety_project_read_relations = True return options diff --git a/src/gateway/converter/spark_to_substrait.py b/src/gateway/converter/spark_to_substrait.py index 02e3493..799d343 100644 --- a/src/gateway/converter/spark_to_substrait.py +++ b/src/gateway/converter/spark_to_substrait.py @@ -472,7 +472,8 @@ def convert_read_named_table_relation( return algebra_pb2.Rel( read=algebra_pb2.ReadRel( base_schema=schema, - named_table=algebra_pb2.ReadRel.NamedTable(names=[table_name]))) + named_table=algebra_pb2.ReadRel.NamedTable(names=[table_name]), + common=self.create_common_relation())) def convert_schema(self, schema_str: str) -> type_pb2.NamedStruct | None: """Convert the Spark JSON schema string into a Substrait named type structure.""" @@ -575,7 +576,8 @@ def convert_read_data_source_relation(self, rel: spark_relations_pb2.Read) -> al return algebra_pb2.Rel( read=algebra_pb2.ReadRel(base_schema=schema, named_table=algebra_pb2.ReadRel.NamedTable( - names=['demotable']))) + names=['demotable']), + common=self.create_common_relation())) if pathlib.Path(rel.paths[0]).is_dir(): file_paths = glob.glob(f'{rel.paths[0]}/*{rel.format}') else: @@ -610,7 +612,19 @@ def convert_read_data_source_relation(self, rel: spark_relations_pb2.Read) -> al case _: raise NotImplementedError(f'Unexpected file format: {rel.format}') local.items.append(file_or_files) - return algebra_pb2.Rel(read=algebra_pb2.ReadRel(base_schema=schema, local_files=local)) + result = algebra_pb2.Rel(read=algebra_pb2.ReadRel(base_schema=schema, local_files=local, + common=self.create_common_relation())) + if not self._conversion_options.safety_project_read_relations: + return result + + project = algebra_pb2.ProjectRel( + input=result, + common=algebra_pb2.RelCommon(direct=algebra_pb2.RelCommon.Direct())) + for field_number in range(len(symbol.output_fields)): + project.expressions.append(field_reference(field_number)) + project.common.emit.output_mapping.append(field_number) + + return algebra_pb2.Rel(project=project) def create_common_relation(self, emit_overrides=None) -> algebra_pb2.RelCommon: """Create the common metadata relation used by all relations.""" @@ -637,7 +651,6 @@ def convert_read_relation(self, rel: spark_relations_pb2.Read) -> algebra_pb2.Re result = self.convert_read_data_source_relation(rel.data_source) case _: raise ValueError(f'Unexpected read type: {rel.WhichOneof("read_type")}') - result.read.common.CopyFrom(self.create_common_relation()) return result def convert_filter_relation(self, rel: spark_relations_pb2.Filter) -> algebra_pb2.Rel: From 3725889ed607013e272751134f738528641e83a5 Mon Sep 17 00:00:00 2001 From: David Sisson Date: Wed, 8 May 2024 16:57:15 -0700 Subject: [PATCH 41/58] feat: add plan validation to all of the Substrait plans used in tests (#65) --- src/gateway/backends/backend_options.py | 4 + src/gateway/converter/add_extension_uris.py | 30 + src/gateway/converter/data/count.sql-splan | 3 + src/gateway/converter/spark_to_substrait.py | 9 +- src/gateway/converter/sql_to_substrait.py | 11 +- src/gateway/converter/substrait_builder.py | 3 +- src/gateway/server.py | 75 +- src/gateway/tests/plan_validator.py | 48 + src/gateway/tests/test_dataframe_api.py | 45 +- src/gateway/tests/test_sql_api.py | 37 +- .../tests/test_tpch_with_dataframe_api.py | 825 ++++++++++-------- 11 files changed, 671 insertions(+), 419 deletions(-) create mode 100644 src/gateway/converter/add_extension_uris.py create mode 100644 src/gateway/tests/plan_validator.py diff --git a/src/gateway/backends/backend_options.py b/src/gateway/backends/backend_options.py index 5f0a578..466e3fe 100644 --- a/src/gateway/backends/backend_options.py +++ b/src/gateway/backends/backend_options.py @@ -11,6 +11,10 @@ class Backend(Enum): DATAFUSION = 2 DUCKDB = 3 + def __str__(self): + """Return the string representation of the backend.""" + return self.name.lower() + @dataclasses.dataclass class BackendOptions: diff --git a/src/gateway/converter/add_extension_uris.py b/src/gateway/converter/add_extension_uris.py new file mode 100644 index 0000000..ae15053 --- /dev/null +++ b/src/gateway/converter/add_extension_uris.py @@ -0,0 +1,30 @@ +# SPDX-License-Identifier: Apache-2.0 +"""A library to search Substrait plan for local files.""" +from gateway.converter.substrait_plan_visitor import SubstraitPlanVisitor +from substrait.gen.proto import plan_pb2 +from substrait.gen.proto.extensions import extensions_pb2 + + +# pylint: disable=E1101,no-member +class AddExtensionUris(SubstraitPlanVisitor): + """Ensures that the plan has extension URI definitions for all references.""" + + def visit_plan(self, plan: plan_pb2.Plan) -> None: + """Modify the provided plan so that all functions have URI references.""" + super().visit_plan(plan) + + known_uris: list[int] = [] + for uri in plan.extension_uris: + known_uris.append(uri.extension_uri_anchor) + + for extension in plan.extensions: + if extension.WhichOneof('mapping_type') != 'extension_function': + continue + + if extension.extension_function.extension_uri_reference not in known_uris: + # TODO -- Make sure this hack occurs at most once. + uri = extensions_pb2.SimpleExtensionURI( + uri='urn:arrow:substrait_simple_extension_function', + extension_uri_anchor=extension.extension_function.extension_uri_reference) + plan.extension_uris.append(uri) + known_uris.append(extension.extension_function.extension_uri_reference) diff --git a/src/gateway/converter/data/count.sql-splan b/src/gateway/converter/data/count.sql-splan index 49eb7f5..58ade53 100644 --- a/src/gateway/converter/data/count.sql-splan +++ b/src/gateway/converter/data/count.sql-splan @@ -1,3 +1,6 @@ +extension_uris { + uri: "urn:arrow:substrait_simple_extension_function" +} extensions { extension_function { function_anchor: 1 diff --git a/src/gateway/converter/spark_to_substrait.py b/src/gateway/converter/spark_to_substrait.py index 799d343..4bec1eb 100644 --- a/src/gateway/converter/spark_to_substrait.py +++ b/src/gateway/converter/spark_to_substrait.py @@ -190,7 +190,7 @@ def convert_unresolved_attribute( root_reference=algebra_pb2.Expression.FieldReference.RootReference())) def determine_type_of_expression(self, expr: algebra_pb2.Expression) -> type_pb2.Type: - """Determine the type of a Substrait expression.""" + """Determine the type of the Substrait expression.""" if expr.WhichOneof('rex_type') == 'literal': match expr.literal.WhichOneof('literal_type'): case 'boolean': @@ -723,6 +723,10 @@ def convert_aggregate_relation(self, rel: spark_relations_pb2.Aggregate) -> alge symbol.generated_fields.append(self.determine_expression_name(expr)) symbol.output_fields.clear() symbol.output_fields.extend(symbol.generated_fields) + if len(rel.grouping_expressions) > 1: + # Hide the grouping source from the downstream relations. + for i in range(len(rel.grouping_expressions) + len(rel.aggregate_expressions)): + aggregate.common.emit.output_mapping.append(i) return algebra_pb2.Rel(aggregate=aggregate) # pylint: disable=too-many-locals,pointless-string-statement @@ -1094,7 +1098,8 @@ def convert_deduplicate_relation(self, rel: spark_relations_pb2.Deduplicate) -> function_reference=any_value_func.anchor, arguments=[algebra_pb2.FunctionArgument(value=field_reference(idx))], phase=algebra_pb2.AggregationPhase.AGGREGATION_PHASE_INITIAL_TO_RESULT, - output_type=type_pb2.Type(bool=type_pb2.Type.Boolean())))) + output_type=type_pb2.Type(bool=type_pb2.Type.Boolean( + nullability=type_pb2.Type.NULLABILITY_REQUIRED))))) symbol.generated_fields.append(field) return algebra_pb2.Rel(aggregate=aggregate) diff --git a/src/gateway/converter/sql_to_substrait.py b/src/gateway/converter/sql_to_substrait.py index 611b174..a5d4d14 100644 --- a/src/gateway/converter/sql_to_substrait.py +++ b/src/gateway/converter/sql_to_substrait.py @@ -2,12 +2,13 @@ """Routines to convert SparkConnect plans to Substrait plans.""" from gateway.backends import backend_selector from gateway.backends.backend_options import Backend, BackendOptions +from gateway.converter.add_extension_uris import AddExtensionUris from substrait.gen.proto import plan_pb2 def convert_sql(sql: str, backend=None) -> plan_pb2.Plan: """Convert SQL into a Substrait plan.""" - result = plan_pb2.Plan() + plan = plan_pb2.Plan() # If backend is not provided or is not a DuckDBBackend, set one up. # DuckDB is used as the SQL conversion engine. @@ -17,5 +18,9 @@ def convert_sql(sql: str, backend=None) -> plan_pb2.Plan: connection = backend.get_connection() proto_bytes = connection.get_substrait(query=sql).fetchone()[0] - result.ParseFromString(proto_bytes) - return result + plan.ParseFromString(proto_bytes) + + # TODO -- Remove this after the SQL converter is fixed. + AddExtensionUris().visit_plan(plan) + + return plan diff --git a/src/gateway/converter/substrait_builder.py b/src/gateway/converter/substrait_builder.py index 924b6bc..6eca52d 100644 --- a/src/gateway/converter/substrait_builder.py +++ b/src/gateway/converter/substrait_builder.py @@ -97,7 +97,8 @@ def field_reference(field_number: int) -> algebra_pb2.Expression: selection=algebra_pb2.Expression.FieldReference( direct_reference=algebra_pb2.Expression.ReferenceSegment( struct_field=algebra_pb2.Expression.ReferenceSegment.StructField( - field=field_number)))) + field=field_number)), + root_reference=algebra_pb2.Expression.FieldReference.RootReference())) def max_agg_function(function_info: ExtensionFunction, diff --git a/src/gateway/server.py b/src/gateway/server.py index d6e1def..9b287dc 100644 --- a/src/gateway/server.py +++ b/src/gateway/server.py @@ -9,8 +9,9 @@ import pyarrow as pa import pyspark.sql.connect.proto.base_pb2 as pb2 import pyspark.sql.connect.proto.base_pb2_grpc as pb2_grpc +from google.protobuf.json_format import MessageToJson from pyspark.sql.connect.proto import types_pb2 -from substrait.gen.proto import algebra_pb2 +from substrait.gen.proto import algebra_pb2, plan_pb2 from gateway.backends.backend_options import BackendOptions from gateway.backends.backend_selector import find_backend @@ -94,6 +95,45 @@ def create_dataframe_view(rel: pb2.Plan, conversion_options, backend) -> algebra return backend +class Statistics: + """Statistics about the requests made to the server.""" + + def __init__(self): + """Initialize the statistics.""" + self.config_requests: int = 0 + self.analyze_requests: int = 0 + self.execute_requests: int = 0 + self.add_artifacts_requests: int = 0 + self.artifact_status_requests: int = 0 + self.interrupt_requests: int = 0 + self.reattach_requests: int = 0 + self.release_requests: int = 0 + + self.requests: list[str] = [] + self.plans: list[str] = [] + + def add_request(self, request): + """Remember a request for later introspection.""" + self.requests.append(str(request)) + + def add_plan(self, plan: plan_pb2.Plan): + """Remember a plan for later introspection.""" + self.plans.append(MessageToJson(plan)) + + def reset(self): + """Reset the statistics.""" + self.config_requests = 0 + self.analyze_requests = 0 + self.execute_requests = 0 + self.add_artifacts_requests = 0 + self.artifact_status_requests = 0 + self.interrupt_requests = 0 + self.reattach_requests = 0 + self.release_requests = 0 + + self.requests = [] + self.plans = [] + # pylint: disable=E1101,fixme class SparkConnectService(pb2_grpc.SparkConnectServiceServicer): @@ -107,11 +147,14 @@ def __init__(self, *args, **kwargs): self._backend_with_tempview = None self._tempview_session_id = None self._converter = None + self._statistics = Statistics() def ExecutePlan( self, request: pb2.ExecutePlanRequest, context: grpc.RpcContext) -> Generator[ pb2.ExecutePlanResponse, None, None]: """Execute the given plan and return the results.""" + self._statistics.execute_requests += 1 + self._statistics.add_request(request) _LOGGER.info('ExecutePlan: %s', request) match request.plan.WhichOneof('op_type'): case 'root': @@ -147,6 +190,7 @@ def ExecutePlan( else: backend = find_backend(self._options.backend) backend.register_tpch() + self._statistics.add_plan(substrait) results = backend.execute(substrait) _LOGGER.debug(' results are: %s', results) @@ -184,6 +228,8 @@ def ExecutePlan( def AnalyzePlan(self, request, context): """Analyze the given plan and return the results.""" + self._statistics.analyze_requests += 1 + self._statistics.add_request(request) _LOGGER.info('AnalyzePlan: %s', request) if request.schema: if not self._converter: @@ -194,6 +240,7 @@ def AnalyzePlan(self, request, context): else: backend = find_backend(self._options.backend) backend.register_tpch() + self._statistics.add_plan(substrait) results = backend.execute(substrait) _LOGGER.debug(' results are: %s', results) return pb2.AnalyzePlanResponse( @@ -204,6 +251,7 @@ def AnalyzePlan(self, request, context): def Config(self, request, context): """Get or set the configuration of the server.""" + self._statistics.config_requests += 1 _LOGGER.info('Config: %s', request) response = pb2.ConfigResponse(session_id=request.session_id) match request.operation.WhichOneof('op_type'): @@ -220,23 +268,44 @@ def Config(self, request, context): self._options = datafusion() case _: raise ValueError(f'Unknown backend: {pair.value}') + elif pair.key == 'spark-substrait-gateway.reset_statistics': + self._statistics.reset() response.pairs.extend(request.operation.set.pairs) + case 'get': + for key in request.operation.get.keys: + if key == 'spark-substrait-gateway.backend': + response.pairs.add(key=key, value=str(self._options.backend.backend)) + elif key == 'spark-substrait-gateway.plan_count': + response.pairs.add(key=key, value=str(len(self._statistics.plans))) + elif key.startswith('spark-substrait-gateway.plan.'): + index = int(key[len('spark-substrait-gateway.plan.'):]) + if 0 <= index - 1 < len(self._statistics.plans): + response.pairs.add(key=key, value=self._statistics.plans[index - 1]) + else: + raise NotImplementedError(f'Unknown config item: {key}') case 'get_with_default': - response.pairs.extend(request.operation.get_with_default.pairs) + for pair in request.operation.get_with_default.pairs: + if pair.key == 'spark-substrait-gateway.backend': + response.pairs.add(key=pair.key, value=str(self._options.backend.backend)) + else: + response.pairs.append(pair) return response def AddArtifacts(self, request_iterator, context): """Add the given artifacts to the server.""" + self._statistics.add_artifacts_requests += 1 _LOGGER.info('AddArtifacts') return pb2.AddArtifactsResponse() def ArtifactStatus(self, request, context): """Get the status of the given artifact.""" + self._statistics.artifact_status_requests += 1 _LOGGER.info('ArtifactStatus') return pb2.ArtifactStatusesResponse() def Interrupt(self, request, context): """Interrupt the execution of the given plan.""" + self._statistics.interrupt_requests += 1 _LOGGER.info('Interrupt') return pb2.InterruptResponse() @@ -244,6 +313,7 @@ def ReattachExecute( self, request: pb2.ReattachExecuteRequest, context: grpc.RpcContext) -> Generator[ pb2.ExecutePlanResponse, None, None]: """Reattach the execution of the given plan.""" + self._statistics.reattach_requests += 1 _LOGGER.info('ReattachExecute') yield pb2.ExecutePlanResponse( session_id=request.session_id, @@ -251,6 +321,7 @@ def ReattachExecute( def ReleaseExecute(self, request, context): """Release the execution of the given plan.""" + self._statistics.release_requests += 1 _LOGGER.info('ReleaseExecute') return pb2.ReleaseExecuteResponse() diff --git a/src/gateway/tests/plan_validator.py b/src/gateway/tests/plan_validator.py new file mode 100644 index 0000000..3030385 --- /dev/null +++ b/src/gateway/tests/plan_validator.py @@ -0,0 +1,48 @@ +# SPDX-License-Identifier: Apache-2.0 +from contextlib import contextmanager + +import pytest +import substrait_validator +from google.protobuf import json_format +from pyspark.errors.exceptions.connect import SparkConnectGrpcException +from substrait.gen.proto import plan_pb2 + + +def validate_plan(json_plan: str): + substrait_plan = json_format.Parse(json_plan, plan_pb2.Plan()) + diagnostics = substrait_validator.plan_to_diagnostics(substrait_plan.SerializeToString()) + issues = [] + for issue in diagnostics: + if issue.adjusted_level >= substrait_validator.Diagnostic.LEVEL_ERROR: + issues.append(issue.msg) + if issues: + issues_as_text = '\n'.join(f' → {issue}' for issue in issues) + pytest.fail(f'Validation failed. Issues:\n{issues_as_text}\n\nPlan:\n{substrait_plan}\n', + pytrace=False) + + +@contextmanager +def utilizes_valid_plans(session): + """Validates that the plans used by the gateway backend pass validation.""" + if hasattr(session, 'sparkSession'): + session = session.sparkSession + # Reset the statistics, so we only see the plans that were created during our lifetime. + if session.conf.get('spark-substrait-gateway.backend', 'spark') != 'spark': + session.conf.set('spark-substrait-gateway.reset_statistics', None) + try: + exception = None + yield + except SparkConnectGrpcException as e: + exception = e + if session.conf.get('spark-substrait-gateway.backend', 'spark') == 'spark': + return + plan_count = int(session.conf.get('spark-substrait-gateway.plan_count')) + first_plan = None + for i in range(plan_count): + plan = session.conf.get(f'spark-substrait-gateway.plan.{i + 1}') + if first_plan is None: + first_plan = plan + validate_plan(plan) + if exception: + pytest.fail(f'Exception raised during plan validation: {exception.message}\n\n' + f'First Plan:\n{first_plan}\n', pytrace=False) diff --git a/src/gateway/tests/test_dataframe_api.py b/src/gateway/tests/test_dataframe_api.py index a0f9dc3..9a72843 100644 --- a/src/gateway/tests/test_dataframe_api.py +++ b/src/gateway/tests/test_dataframe_api.py @@ -2,6 +2,7 @@ """Tests for the Spark to Substrait Gateway server.""" import pytest from gateway.backends.backend import Backend +from gateway.tests.plan_validator import utilizes_valid_plans from hamcrest import assert_that, equal_to from pyspark import Row from pyspark.sql.functions import col, substring @@ -35,12 +36,16 @@ class TestDataFrameAPI: """Tests of the dataframe side of SparkConnect.""" def test_collect(self, users_dataframe): - outcome = users_dataframe.collect() + with utilizes_valid_plans(users_dataframe): + outcome = users_dataframe.collect() + assert len(outcome) == 100 # pylint: disable=singleton-comparison def test_filter(self, users_dataframe): - outcome = users_dataframe.filter(col('paid_for_service') == True).collect() + with utilizes_valid_plans(users_dataframe): + outcome = users_dataframe.filter(col('paid_for_service') == True).collect() + assert len(outcome) == 29 # pylint: disable=singleton-comparison @@ -53,8 +58,10 @@ def test_filter_with_show(self, users_dataframe, capsys): +-------------+---------------+----------------+ ''' - users_dataframe.filter(col('paid_for_service') == True).limit(2).show() - outcome = capsys.readouterr().out + with utilizes_valid_plans(users_dataframe): + users_dataframe.filter(col('paid_for_service') == True).limit(2).show() + outcome = capsys.readouterr().out + assert_that(outcome, equal_to(expected)) # pylint: disable=singleton-comparison @@ -67,8 +74,10 @@ def test_filter_with_show_with_limit(self, users_dataframe, capsys): only showing top 1 row ''' - users_dataframe.filter(col('paid_for_service') == True).show(1) - outcome = capsys.readouterr().out + with utilizes_valid_plans(users_dataframe): + users_dataframe.filter(col('paid_for_service') == True).show(1) + outcome = capsys.readouterr().out + assert_that(outcome, equal_to(expected)) # pylint: disable=singleton-comparison @@ -85,7 +94,9 @@ def test_filter_with_show_and_truncate(self, users_dataframe, capsys): assert_that(outcome, equal_to(expected)) def test_count(self, users_dataframe): - outcome = users_dataframe.count() + with utilizes_valid_plans(users_dataframe): + outcome = users_dataframe.count() + assert outcome == 100 def test_limit(self, users_dataframe): @@ -93,24 +104,32 @@ def test_limit(self, users_dataframe): Row(user_id='user849118289', name='Brooke Jones', paid_for_service=False), Row(user_id='user954079192', name='Collin Frank', paid_for_service=False), ] - outcome = users_dataframe.limit(2).collect() + + with utilizes_valid_plans(users_dataframe): + outcome = users_dataframe.limit(2).collect() + assertDataFrameEqual(outcome, expected) def test_with_column(self, users_dataframe): expected = [ Row(user_id='user849118289', name='Brooke Jones', paid_for_service=False), ] - outcome = users_dataframe.withColumn( - 'user_id', col('user_id')).limit(1).collect() + + with utilizes_valid_plans(users_dataframe): + outcome = users_dataframe.withColumn('user_id', col('user_id')).limit(1).collect() + assertDataFrameEqual(outcome, expected) def test_cast(self, users_dataframe): expected = [ Row(user_id=849, name='Brooke Jones', paid_for_service=False), ] - outcome = users_dataframe.withColumn( - 'user_id', - substring(col('user_id'), 5, 3).cast('integer')).limit(1).collect() + + with utilizes_valid_plans(users_dataframe): + outcome = users_dataframe.withColumn( + 'user_id', + substring(col('user_id'), 5, 3).cast('integer')).limit(1).collect() + assertDataFrameEqual(outcome, expected) def test_data_source_schema(self, spark_session): diff --git a/src/gateway/tests/test_sql_api.py b/src/gateway/tests/test_sql_api.py index 16f4510..4199770 100644 --- a/src/gateway/tests/test_sql_api.py +++ b/src/gateway/tests/test_sql_api.py @@ -3,6 +3,7 @@ from pathlib import Path import pytest +from gateway.tests.plan_validator import utilizes_valid_plans from hamcrest import assert_that, equal_to from pyspark import Row from pyspark.testing import assertDataFrameEqual @@ -19,16 +20,23 @@ def mark_tests_as_xfail(request): """Marks a subset of tests as expected to be fail.""" source = request.getfixturevalue('source') originalname = request.keywords.node.originalname - if source == 'gateway-over-duckdb' and originalname == 'test_tpch': - path = request.getfixturevalue('path') - if path.stem in ['02', '04', '16', '17', '18', '20', '21', '22']: - request.node.add_marker(pytest.mark.xfail(reason='DuckDB needs Delim join')) - if path.stem in ['15']: - request.node.add_marker(pytest.mark.xfail(reason='Rounding inconsistency')) + if source == 'gateway-over-duckdb': + if originalname == 'test_tpch': + path = request.getfixturevalue('path') + if path.stem in ['02', '04', '16', '17', '18', '20', '21', '22']: + request.node.add_marker(pytest.mark.xfail(reason='DuckDB needs Delim join')) + if path.stem in ['15']: + request.node.add_marker(pytest.mark.xfail(reason='Rounding inconsistency')) + else: + request.node.add_marker(pytest.mark.xfail(reason='Too few names returned')) + else: + request.node.add_marker(pytest.mark.xfail(reason='Too few names returned')) if source == 'gateway-over-datafusion': pytest.importorskip("datafusion.substrait") if originalname == 'test_count': request.node.add_marker(pytest.mark.xfail(reason='COUNT() not implemented')) + if originalname == 'test_limit': + request.node.add_marker(pytest.mark.xfail(reason='Too few names returned')) if originalname in ['test_tpch']: path = request.getfixturevalue('path') if path.stem in ['01']: @@ -62,8 +70,10 @@ class TestSqlAPI: """Tests of the SQL side of SparkConnect.""" def test_count(self, spark_session_with_tpch_dataset): - outcome = spark_session_with_tpch_dataset.sql( - 'SELECT COUNT(*) FROM customer').collect() + with utilizes_valid_plans(spark_session_with_tpch_dataset): + outcome = spark_session_with_tpch_dataset.sql( + 'SELECT COUNT(*) FROM customer').collect() + assert_that(outcome[0][0], equal_to(149999)) def test_limit(self, spark_session_with_tpch_dataset): @@ -74,8 +84,11 @@ def test_limit(self, spark_session_with_tpch_dataset): Row(c_custkey=5, c_phone='13-750-942-6364', c_mktsegment='HOUSEHOLD'), Row(c_custkey=6, c_phone='30-114-968-4951', c_mktsegment='AUTOMOBILE'), ] - outcome = spark_session_with_tpch_dataset.sql( - 'SELECT c_custkey, c_phone, c_mktsegment FROM customer LIMIT 5').collect() + + with utilizes_valid_plans(spark_session_with_tpch_dataset): + outcome = spark_session_with_tpch_dataset.sql( + 'SELECT c_custkey, c_phone, c_mktsegment FROM customer LIMIT 5').collect() + assertDataFrameEqual(outcome, expected) @pytest.mark.timeout(60) @@ -90,4 +103,6 @@ def test_tpch(self, spark_session_with_tpch_dataset, path): with open(path, "rb") as file: sql_bytes = file.read() sql = sql_bytes.decode('utf-8') - spark_session_with_tpch_dataset.sql(sql).collect() + + with utilizes_valid_plans(spark_session_with_tpch_dataset): + spark_session_with_tpch_dataset.sql(sql).collect() diff --git a/src/gateway/tests/test_tpch_with_dataframe_api.py b/src/gateway/tests/test_tpch_with_dataframe_api.py index 3da9fda..6b338dc 100644 --- a/src/gateway/tests/test_tpch_with_dataframe_api.py +++ b/src/gateway/tests/test_tpch_with_dataframe_api.py @@ -4,6 +4,7 @@ import pyspark import pytest +from gateway.tests.plan_validator import utilizes_valid_plans from pyspark import Row from pyspark.sql.functions import avg, col, count, countDistinct, desc, try_sum, when from pyspark.testing import assertDataFrameEqual @@ -15,11 +16,11 @@ def mark_tests_as_xfail(request): source = request.getfixturevalue('source') originalname = request.keywords.node.originalname if source == 'gateway-over-duckdb' and originalname in [ - 'test_query_02', 'test_query_03', 'test_query_04', 'test_query_05', 'test_query_07', - 'test_query_08', 'test_query_09', 'test_query_10', 'test_query_11', 'test_query_12', - 'test_query_13', 'test_query_14', 'test_query_15', 'test_query_16', 'test_query_17', - 'test_query_18', 'test_query_19', 'test_query_20', 'test_query_21', 'test_query_22']: - request.node.add_marker(pytest.mark.xfail(reason='DuckDB binder error')) + 'test_query_02', 'test_query_03', 'test_query_04', 'test_query_05', 'test_query_07', + 'test_query_08', 'test_query_09', 'test_query_10', 'test_query_11', 'test_query_12', + 'test_query_13', 'test_query_14', 'test_query_15', 'test_query_16', 'test_query_17', + 'test_query_18', 'test_query_19', 'test_query_20', 'test_query_21', 'test_query_22']: + request.node.add_marker(pytest.mark.xfail(reason='DuckDB binder error')) if source == 'gateway-over-datafusion': pytest.importorskip("datafusion.substrait") request.node.add_marker(pytest.mark.xfail(reason='gateway internal error')) @@ -37,20 +38,23 @@ def test_query_01(self, spark_session_with_tpch_dataset): avg_price=38273.13, avg_disc=0.05, count_order=1478493), ] - lineitem = spark_session_with_tpch_dataset.table('lineitem') - outcome = lineitem.filter(col('l_shipdate') <= '1998-09-02').groupBy('l_returnflag', - 'l_linestatus').agg( - try_sum('l_quantity').alias('sum_qty'), - try_sum('l_extendedprice').alias('sum_base_price'), - try_sum(col('l_extendedprice') * (1 - col('l_discount'))).alias('sum_disc_price'), - try_sum(col('l_extendedprice') * (1 - col('l_discount')) * (1 + col('l_tax'))).alias( - 'sum_charge'), - avg('l_quantity').alias('avg_qty'), - avg('l_extendedprice').alias('avg_price'), - avg('l_discount').alias('avg_disc'), - count('*').alias('count_order')) - - sorted_outcome = outcome.sort('l_returnflag', 'l_linestatus').limit(1).collect() + with utilizes_valid_plans(spark_session_with_tpch_dataset): + lineitem = spark_session_with_tpch_dataset.table('lineitem') + outcome = lineitem.filter(col('l_shipdate') <= '1998-09-02').groupBy( + 'l_returnflag', 'l_linestatus').agg( + try_sum('l_quantity').alias('sum_qty'), + try_sum('l_extendedprice').alias('sum_base_price'), + try_sum(col('l_extendedprice') * (1 - col('l_discount'))).alias('sum_disc_price'), + try_sum( + col('l_extendedprice') * (1 - col('l_discount')) * (1 + col('l_tax'))).alias( + 'sum_charge'), + avg('l_quantity').alias('avg_qty'), + avg('l_extendedprice').alias('avg_price'), + avg('l_discount').alias('avg_disc'), + count('*').alias('count_order')) + + sorted_outcome = outcome.sort('l_returnflag', 'l_linestatus').limit(1).collect() + assertDataFrameEqual(sorted_outcome, expected, atol=1e-2) def test_query_02(self, spark_session_with_tpch_dataset): @@ -65,30 +69,33 @@ def test_query_02(self, spark_session_with_tpch_dataset): s_comment='efully express instructions. regular requests against the slyly fin'), ] - part = spark_session_with_tpch_dataset.table('part') - supplier = spark_session_with_tpch_dataset.table('supplier') - partsupp = spark_session_with_tpch_dataset.table('partsupp') - nation = spark_session_with_tpch_dataset.table('nation') - region = spark_session_with_tpch_dataset.table('region') + with utilizes_valid_plans(spark_session_with_tpch_dataset): + part = spark_session_with_tpch_dataset.table('part') + supplier = spark_session_with_tpch_dataset.table('supplier') + partsupp = spark_session_with_tpch_dataset.table('partsupp') + nation = spark_session_with_tpch_dataset.table('nation') + region = spark_session_with_tpch_dataset.table('region') + + europe = region.filter(col('r_name') == 'EUROPE').join( + nation, col('r_regionkey') == col('n_regionkey')).join( + supplier, col('n_nationkey') == col('s_nationkey')).join( + partsupp, col('s_suppkey') == col('ps_suppkey')) - europe = region.filter(col('r_name') == 'EUROPE').join( - nation, col('r_regionkey') == col('n_regionkey')).join( - supplier, col('n_nationkey') == col('s_nationkey')).join( - partsupp, col('s_suppkey') == col('ps_suppkey')) + brass = part.filter((col('p_size') == 15) & (col('p_type').endswith('BRASS'))).join( + europe, col('ps_partkey') == col('p_partkey')) - brass = part.filter((col('p_size') == 15) & (col('p_type').endswith('BRASS'))).join( - europe, col('ps_partkey') == col('p_partkey')) + minCost = brass.groupBy(col('ps_partkey')).agg( + pyspark.sql.functions.min('ps_supplycost').alias('min')) - minCost = brass.groupBy(col('ps_partkey')).agg( - pyspark.sql.functions.min('ps_supplycost').alias('min')) + outcome = brass.join(minCost, brass.ps_partkey == minCost.ps_partkey).filter( + col('ps_supplycost') == col('min')).select('s_acctbal', 's_name', 'n_name', + 'p_partkey', + 'p_mfgr', 's_address', 's_phone', + 's_comment') - outcome = brass.join(minCost, brass.ps_partkey == minCost.ps_partkey).filter( - col('ps_supplycost') == col('min')).select('s_acctbal', 's_name', 'n_name', 'p_partkey', - 'p_mfgr', 's_address', 's_phone', - 's_comment') + sorted_outcome = outcome.sort( + desc('s_acctbal'), 'n_name', 's_name', 'p_partkey').limit(2).collect() - sorted_outcome = outcome.sort( - desc('s_acctbal'), 'n_name', 's_name', 'p_partkey').limit(2).collect() assertDataFrameEqual(sorted_outcome, expected, atol=1e-2) def test_query_03(self, spark_session_with_tpch_dataset): @@ -105,25 +112,27 @@ def test_query_03(self, spark_session_with_tpch_dataset): o_shippriority=0), ] - customer = spark_session_with_tpch_dataset.table('customer') - lineitem = spark_session_with_tpch_dataset.table('lineitem') - orders = spark_session_with_tpch_dataset.table('orders') - - fcust = customer.filter(col('c_mktsegment') == 'BUILDING') - forders = orders.filter(col('o_orderdate') < '1995-03-15') - flineitems = lineitem.filter(lineitem.l_shipdate > '1995-03-15') - - outcome = fcust.join(forders, col('c_custkey') == forders.o_custkey).select( - 'o_orderkey', 'o_orderdate', 'o_shippriority').join( - flineitems, col('o_orderkey') == flineitems.l_orderkey).select( - 'l_orderkey', - (col('l_extendedprice') * (1 - col('l_discount'))).alias('volume'), - 'o_orderdate', - 'o_shippriority').groupBy('l_orderkey', 'o_orderdate', 'o_shippriority').agg( - try_sum('volume').alias('revenue')).select( - 'l_orderkey', 'revenue', 'o_orderdate', 'o_shippriority') - - sorted_outcome = outcome.sort(desc('revenue'), 'o_orderdate').limit(5).collect() + with utilizes_valid_plans(spark_session_with_tpch_dataset): + customer = spark_session_with_tpch_dataset.table('customer') + lineitem = spark_session_with_tpch_dataset.table('lineitem') + orders = spark_session_with_tpch_dataset.table('orders') + + fcust = customer.filter(col('c_mktsegment') == 'BUILDING') + forders = orders.filter(col('o_orderdate') < '1995-03-15') + flineitems = lineitem.filter(lineitem.l_shipdate > '1995-03-15') + + outcome = fcust.join(forders, col('c_custkey') == forders.o_custkey).select( + 'o_orderkey', 'o_orderdate', 'o_shippriority').join( + flineitems, col('o_orderkey') == flineitems.l_orderkey).select( + 'l_orderkey', + (col('l_extendedprice') * (1 - col('l_discount'))).alias('volume'), + 'o_orderdate', + 'o_shippriority').groupBy('l_orderkey', 'o_orderdate', 'o_shippriority').agg( + try_sum('volume').alias('revenue')).select( + 'l_orderkey', 'revenue', 'o_orderdate', 'o_shippriority') + + sorted_outcome = outcome.sort(desc('revenue'), 'o_orderdate').limit(5).collect() + assertDataFrameEqual(sorted_outcome, expected, atol=1e-2) def test_query_04(self, spark_session_with_tpch_dataset): @@ -135,20 +144,22 @@ def test_query_04(self, spark_session_with_tpch_dataset): Row(o_orderpriority='5-LOW', order_count=10487), ] - orders = spark_session_with_tpch_dataset.table('orders') - lineitem = spark_session_with_tpch_dataset.table('lineitem') + with utilizes_valid_plans(spark_session_with_tpch_dataset): + orders = spark_session_with_tpch_dataset.table('orders') + lineitem = spark_session_with_tpch_dataset.table('lineitem') - forders = orders.filter( - (col('o_orderdate') >= '1993-07-01') & (col('o_orderdate') < '1993-10-01')) - flineitems = lineitem.filter(col('l_commitdate') < col('l_receiptdate')).select( - 'l_orderkey').distinct() + forders = orders.filter( + (col('o_orderdate') >= '1993-07-01') & (col('o_orderdate') < '1993-10-01')) + flineitems = lineitem.filter(col('l_commitdate') < col('l_receiptdate')).select( + 'l_orderkey').distinct() - outcome = flineitems.join( - forders, - col('l_orderkey') == col('o_orderkey')).groupBy('o_orderpriority').agg( - count('o_orderpriority').alias('order_count')) + outcome = flineitems.join( + forders, + col('l_orderkey') == col('o_orderkey')).groupBy('o_orderpriority').agg( + count('o_orderpriority').alias('order_count')) + + sorted_outcome = outcome.sort('o_orderpriority').collect() - sorted_outcome = outcome.sort('o_orderpriority').collect() assertDataFrameEqual(sorted_outcome, expected, atol=1e-2) def test_query_05(self, spark_session_with_tpch_dataset): @@ -160,29 +171,30 @@ def test_query_05(self, spark_session_with_tpch_dataset): Row(n_name='JAPAN', revenue=45410175.70), ] - customer = spark_session_with_tpch_dataset.table('customer') - orders = spark_session_with_tpch_dataset.table('orders') - lineitem = spark_session_with_tpch_dataset.table('lineitem') - nation = spark_session_with_tpch_dataset.table('nation') - region = spark_session_with_tpch_dataset.table('region') - supplier = spark_session_with_tpch_dataset.table('supplier') - - forders = orders.filter(col('o_orderdate') >= '1994-01-01').filter( - col('o_orderdate') < '1995-01-01') - - outcome = region.filter(col('r_name') == 'ASIA').join( # r_name = 'ASIA' - nation, col('r_regionkey') == col('n_regionkey')).join( - supplier, col('n_nationkey') == col('s_nationkey')).join( - lineitem, col('s_suppkey') == col('l_suppkey')).select( - 'n_name', 'l_extendedprice', 'l_discount', 'l_quantity', 'l_orderkey', - 's_nationkey').join(forders, col('l_orderkey') == forders.o_orderkey).join( - customer, (col('o_custkey') == col('c_custkey')) & ( - col('s_nationkey') == col('c_nationkey'))).select( - 'n_name', - (col('l_extendedprice') * (1 - col('l_discount'))).alias('volume')).groupBy( - 'n_name').agg(try_sum('volume').alias('revenue')) - - sorted_outcome = outcome.sort('revenue').collect() + with utilizes_valid_plans(spark_session_with_tpch_dataset): + customer = spark_session_with_tpch_dataset.table('customer') + orders = spark_session_with_tpch_dataset.table('orders') + lineitem = spark_session_with_tpch_dataset.table('lineitem') + nation = spark_session_with_tpch_dataset.table('nation') + region = spark_session_with_tpch_dataset.table('region') + supplier = spark_session_with_tpch_dataset.table('supplier') + + forders = orders.filter(col('o_orderdate') >= '1994-01-01').filter( + col('o_orderdate') < '1995-01-01') + + outcome = region.filter(col('r_name') == 'ASIA').join( # r_name = 'ASIA' + nation, col('r_regionkey') == col('n_regionkey')).join( + supplier, col('n_nationkey') == col('s_nationkey')).join( + lineitem, col('s_suppkey') == col('l_suppkey')).select( + 'n_name', 'l_extendedprice', 'l_discount', 'l_quantity', 'l_orderkey', + 's_nationkey').join(forders, col('l_orderkey') == forders.o_orderkey).join( + customer, (col('o_custkey') == col('c_custkey')) & ( + col('s_nationkey') == col('c_nationkey'))).select( + 'n_name', + (col('l_extendedprice') * (1 - col('l_discount'))).alias('volume')).groupBy( + 'n_name').agg(try_sum('volume').alias('revenue')) + + sorted_outcome = outcome.sort('revenue').collect() assertDataFrameEqual(sorted_outcome, expected, atol=1e-2) @@ -191,14 +203,15 @@ def test_query_06(self, spark_session_with_tpch_dataset): Row(revenue=123141078.23), ] - lineitem = spark_session_with_tpch_dataset.table('lineitem') + with utilizes_valid_plans(spark_session_with_tpch_dataset): + lineitem = spark_session_with_tpch_dataset.table('lineitem') - outcome = lineitem.filter((col('l_shipdate') >= '1994-01-01') & - (col('l_shipdate') < '1995-01-01') & - (col('l_discount') >= 0.05) & - (col('l_discount') <= 0.07) & - (col('l_quantity') < 24)).agg( - try_sum(col('l_extendedprice') * col('l_discount'))).alias('revenue') + outcome = lineitem.filter((col('l_shipdate') >= '1994-01-01') & + (col('l_shipdate') < '1995-01-01') & + (col('l_discount') >= 0.05) & + (col('l_discount') <= 0.07) & + (col('l_quantity') < 24)).agg( + try_sum(col('l_extendedprice') * col('l_discount'))).alias('revenue') assertDataFrameEqual(outcome, expected, atol=1e-2) @@ -210,33 +223,35 @@ def test_query_07(self, spark_session_with_tpch_dataset): Row(supp_nation='GERMANY', cust_nation='FRANCE', l_year='1996', revenue=52520549.02), ] - customer = spark_session_with_tpch_dataset.table('customer') - orders = spark_session_with_tpch_dataset.table('orders') - lineitem = spark_session_with_tpch_dataset.table('lineitem') - supplier = spark_session_with_tpch_dataset.table('supplier') - nation = spark_session_with_tpch_dataset.table('nation') - - fnation = nation.filter((nation.n_name == 'FRANCE') | (nation.n_name == 'GERMANY')) - fline = lineitem.filter( - (col('l_shipdate') >= '1995-01-01') & (col('l_shipdate') <= '1996-12-31')) - - suppNation = fnation.join(supplier, col('n_nationkey') == col('s_nationkey')).join( - fline, col('s_suppkey') == col('l_suppkey')).select( - col('n_name').alias('supp_nation'), 'l_orderkey', 'l_extendedprice', 'l_discount', - 'l_shipdate') - - outcome = fnation.join(customer, col('n_nationkey') == col('c_nationkey')).join( - orders, col('c_custkey') == col('o_custkey')).select( - col('n_name').alias('cust_nation'), 'o_orderkey').join( - suppNation, col('o_orderkey') == suppNation.l_orderkey).filter( - (col('supp_nation') == 'FRANCE') & (col('cust_nation') == 'GERMANY') | ( - col('supp_nation') == 'GERMANY') & (col('cust_nation') == 'FRANCE')).select( - 'supp_nation', 'cust_nation', col('l_shipdate').substr(0, 4).alias('l_year'), - (col('l_extendedprice') * (1 - col('l_discount'))).alias('volume')).groupBy( - 'supp_nation', 'cust_nation', 'l_year').agg( - try_sum('volume').alias('revenue')) - - sorted_outcome = outcome.sort('supp_nation', 'cust_nation', 'l_year').collect() + with utilizes_valid_plans(spark_session_with_tpch_dataset): + customer = spark_session_with_tpch_dataset.table('customer') + orders = spark_session_with_tpch_dataset.table('orders') + lineitem = spark_session_with_tpch_dataset.table('lineitem') + supplier = spark_session_with_tpch_dataset.table('supplier') + nation = spark_session_with_tpch_dataset.table('nation') + + fnation = nation.filter((nation.n_name == 'FRANCE') | (nation.n_name == 'GERMANY')) + fline = lineitem.filter( + (col('l_shipdate') >= '1995-01-01') & (col('l_shipdate') <= '1996-12-31')) + + suppNation = fnation.join(supplier, col('n_nationkey') == col('s_nationkey')).join( + fline, col('s_suppkey') == col('l_suppkey')).select( + col('n_name').alias('supp_nation'), 'l_orderkey', 'l_extendedprice', 'l_discount', + 'l_shipdate') + + outcome = fnation.join(customer, col('n_nationkey') == col('c_nationkey')).join( + orders, col('c_custkey') == col('o_custkey')).select( + col('n_name').alias('cust_nation'), 'o_orderkey').join( + suppNation, col('o_orderkey') == suppNation.l_orderkey).filter( + (col('supp_nation') == 'FRANCE') & (col('cust_nation') == 'GERMANY') | ( + col('supp_nation') == 'GERMANY') & (col('cust_nation') == 'FRANCE')).select( + 'supp_nation', 'cust_nation', col('l_shipdate').substr(0, 4).alias('l_year'), + (col('l_extendedprice') * (1 - col('l_discount'))).alias('volume')).groupBy( + 'supp_nation', 'cust_nation', 'l_year').agg( + try_sum('volume').alias('revenue')) + + sorted_outcome = outcome.sort('supp_nation', 'cust_nation', 'l_year').collect() + assertDataFrameEqual(sorted_outcome, expected, atol=1e-2) def test_query_08(self, spark_session_with_tpch_dataset): @@ -245,40 +260,43 @@ def test_query_08(self, spark_session_with_tpch_dataset): Row(o_year='1996', mkt_share=0.04), ] - customer = spark_session_with_tpch_dataset.table('customer') - orders = spark_session_with_tpch_dataset.table('orders') - lineitem = spark_session_with_tpch_dataset.table('lineitem') - nation = spark_session_with_tpch_dataset.table('nation') - part = spark_session_with_tpch_dataset.table('part') - region = spark_session_with_tpch_dataset.table('region') - supplier = spark_session_with_tpch_dataset.table('supplier') - - fregion = region.filter(col('r_name') == 'AMERICA') - forder = orders.filter((col('o_orderdate') >= '1995-01-01') & ( - col('o_orderdate') <= '1996-12-31')) - fpart = part.filter(col('p_type') == 'ECONOMY ANODIZED STEEL') - - nat = nation.join(supplier, col('n_nationkey') == col('s_nationkey')) - - line = lineitem.select( - 'l_partkey', 'l_suppkey', 'l_orderkey', - (col('l_extendedprice') * (1 - col('l_discount'))).alias( - 'volume')).join( - fpart, col('l_partkey') == fpart.p_partkey).join( - nat, col('l_suppkey') == nat.s_suppkey) - - outcome = nation.join(fregion, col('n_regionkey') == fregion.r_regionkey).select( - 'n_nationkey', 'n_name').join(customer, - col('n_nationkey') == col('c_nationkey')).select( - 'c_custkey').join(forder, col('c_custkey') == col('o_custkey')).select( - 'o_orderkey', 'o_orderdate').join(line, col('o_orderkey') == line.l_orderkey).select( - col('n_name'), col('o_orderdate').substr(0, 4).alias('o_year'), - col('volume')).withColumn('case_volume', - when(col('n_name') == 'BRAZIL', col('volume')).otherwise( - 0)).groupBy('o_year').agg( - (try_sum('case_volume') / try_sum('volume')).alias('mkt_share')) - - sorted_outcome = outcome.sort('o_year').collect() + with utilizes_valid_plans(spark_session_with_tpch_dataset): + customer = spark_session_with_tpch_dataset.table('customer') + orders = spark_session_with_tpch_dataset.table('orders') + lineitem = spark_session_with_tpch_dataset.table('lineitem') + nation = spark_session_with_tpch_dataset.table('nation') + part = spark_session_with_tpch_dataset.table('part') + region = spark_session_with_tpch_dataset.table('region') + supplier = spark_session_with_tpch_dataset.table('supplier') + + fregion = region.filter(col('r_name') == 'AMERICA') + forder = orders.filter((col('o_orderdate') >= '1995-01-01') & ( + col('o_orderdate') <= '1996-12-31')) + fpart = part.filter(col('p_type') == 'ECONOMY ANODIZED STEEL') + + nat = nation.join(supplier, col('n_nationkey') == col('s_nationkey')) + + line = lineitem.select( + 'l_partkey', 'l_suppkey', 'l_orderkey', + (col('l_extendedprice') * (1 - col('l_discount'))).alias( + 'volume')).join( + fpart, col('l_partkey') == fpart.p_partkey).join( + nat, col('l_suppkey') == nat.s_suppkey) + + outcome = nation.join(fregion, col('n_regionkey') == fregion.r_regionkey).select( + 'n_nationkey', 'n_name').join(customer, + col('n_nationkey') == col('c_nationkey')).select( + 'c_custkey').join(forder, col('c_custkey') == col('o_custkey')).select( + 'o_orderkey', 'o_orderdate').join(line, + col('o_orderkey') == line.l_orderkey).select( + col('n_name'), col('o_orderdate').substr(0, 4).alias('o_year'), + col('volume')).withColumn('case_volume', + when(col('n_name') == 'BRAZIL', col('volume')).otherwise( + 0)).groupBy('o_year').agg( + (try_sum('case_volume') / try_sum('volume')).alias('mkt_share')) + + sorted_outcome = outcome.sort('o_year').collect() + assertDataFrameEqual(sorted_outcome, expected, atol=1e-2) def test_query_09(self, spark_session_with_tpch_dataset): @@ -291,27 +309,29 @@ def test_query_09(self, spark_session_with_tpch_dataset): Row(n_name='ARGENTINA', o_year='1994', sum_profit=48268856.35), ] - orders = spark_session_with_tpch_dataset.table('orders') - lineitem = spark_session_with_tpch_dataset.table('lineitem') - nation = spark_session_with_tpch_dataset.table('nation') - part = spark_session_with_tpch_dataset.table('part') - partsupp = spark_session_with_tpch_dataset.table('partsupp') - supplier = spark_session_with_tpch_dataset.table('supplier') - - linePart = part.filter(col('p_name').contains('green')).join( - lineitem, col('p_partkey') == lineitem.l_partkey) - natSup = nation.join(supplier, col('n_nationkey') == supplier.s_nationkey) - - outcome = linePart.join(natSup, col('l_suppkey') == natSup.s_suppkey).join( - partsupp, (col('l_suppkey') == partsupp.ps_suppkey) & ( - col('l_partkey') == partsupp.ps_partkey)).join( - orders, col('l_orderkey') == orders.o_orderkey).select( - 'n_name', col('o_orderdate').substr(0, 4).alias('o_year'), - (col('l_extendedprice') * (1 - col('l_discount')) - ( - col('ps_supplycost') * col('l_quantity'))).alias('amount')).groupBy( - 'n_name', 'o_year').agg(try_sum('amount').alias('sum_profit')) - - sorted_outcome = outcome.sort('n_name', desc('o_year')).limit(5).collect() + with utilizes_valid_plans(spark_session_with_tpch_dataset): + orders = spark_session_with_tpch_dataset.table('orders') + lineitem = spark_session_with_tpch_dataset.table('lineitem') + nation = spark_session_with_tpch_dataset.table('nation') + part = spark_session_with_tpch_dataset.table('part') + partsupp = spark_session_with_tpch_dataset.table('partsupp') + supplier = spark_session_with_tpch_dataset.table('supplier') + + linePart = part.filter(col('p_name').contains('green')).join( + lineitem, col('p_partkey') == lineitem.l_partkey) + natSup = nation.join(supplier, col('n_nationkey') == supplier.s_nationkey) + + outcome = linePart.join(natSup, col('l_suppkey') == natSup.s_suppkey).join( + partsupp, (col('l_suppkey') == partsupp.ps_suppkey) & ( + col('l_partkey') == partsupp.ps_partkey)).join( + orders, col('l_orderkey') == orders.o_orderkey).select( + 'n_name', col('o_orderdate').substr(0, 4).alias('o_year'), + (col('l_extendedprice') * (1 - col('l_discount')) - ( + col('ps_supplycost') * col('l_quantity'))).alias('amount')).groupBy( + 'n_name', 'o_year').agg(try_sum('amount').alias('sum_profit')) + + sorted_outcome = outcome.sort('n_name', desc('o_year')).limit(5).collect() + assertDataFrameEqual(sorted_outcome, expected, atol=1e-2) def test_query_10(self, spark_session_with_tpch_dataset): @@ -327,27 +347,30 @@ def test_query_10(self, spark_session_with_tpch_dataset): 'pinto beans. ironic, idle re'), ] - customer = spark_session_with_tpch_dataset.table('customer') - lineitem = spark_session_with_tpch_dataset.table('lineitem') - nation = spark_session_with_tpch_dataset.table('nation') - orders = spark_session_with_tpch_dataset.table('orders') - - flineitem = lineitem.filter(col('l_returnflag') == 'R') - - outcome = orders.filter( - (col('o_orderdate') >= '1993-10-01') & (col('o_orderdate') < '1994-01-01')).join( - customer, col('o_custkey') == customer.c_custkey).join( - nation, col('c_nationkey') == nation.n_nationkey).join( - flineitem, col('o_orderkey') == flineitem.l_orderkey).select( - 'c_custkey', 'c_name', - (col('l_extendedprice') * (1 - col('l_discount'))).alias('volume'), - 'c_acctbal', 'n_name', 'c_address', 'c_phone', 'c_comment').groupBy( - 'c_custkey', 'c_name', 'c_acctbal', 'c_phone', 'n_name', 'c_address', 'c_comment').agg( - try_sum('volume').alias('revenue')).select( - 'c_custkey', 'c_name', 'revenue', 'c_acctbal', 'n_name', 'c_address', 'c_phone', - 'c_comment') - - sorted_outcome = outcome.sort(desc('revenue')).limit(2).collect() + with utilizes_valid_plans(spark_session_with_tpch_dataset): + customer = spark_session_with_tpch_dataset.table('customer') + lineitem = spark_session_with_tpch_dataset.table('lineitem') + nation = spark_session_with_tpch_dataset.table('nation') + orders = spark_session_with_tpch_dataset.table('orders') + + flineitem = lineitem.filter(col('l_returnflag') == 'R') + + outcome = orders.filter( + (col('o_orderdate') >= '1993-10-01') & (col('o_orderdate') < '1994-01-01')).join( + customer, col('o_custkey') == customer.c_custkey).join( + nation, col('c_nationkey') == nation.n_nationkey).join( + flineitem, col('o_orderkey') == flineitem.l_orderkey).select( + 'c_custkey', 'c_name', + (col('l_extendedprice') * (1 - col('l_discount'))).alias('volume'), + 'c_acctbal', 'n_name', 'c_address', 'c_phone', 'c_comment').groupBy( + 'c_custkey', 'c_name', 'c_acctbal', 'c_phone', 'n_name', 'c_address', + 'c_comment').agg( + try_sum('volume').alias('revenue')).select( + 'c_custkey', 'c_name', 'revenue', 'c_acctbal', 'n_name', 'c_address', 'c_phone', + 'c_comment') + + sorted_outcome = outcome.sort(desc('revenue')).limit(2).collect() + assertDataFrameEqual(sorted_outcome, expected, atol=1e-2) def test_query_11(self, spark_session_with_tpch_dataset): @@ -359,22 +382,24 @@ def test_query_11(self, spark_session_with_tpch_dataset): Row(ps_partkey=34452, value=15983844.72), ] - nation = spark_session_with_tpch_dataset.table('nation') - partsupp = spark_session_with_tpch_dataset.table('partsupp') - supplier = spark_session_with_tpch_dataset.table('supplier') + with utilizes_valid_plans(spark_session_with_tpch_dataset): + nation = spark_session_with_tpch_dataset.table('nation') + partsupp = spark_session_with_tpch_dataset.table('partsupp') + supplier = spark_session_with_tpch_dataset.table('supplier') + + tmp = nation.filter(col('n_name') == 'GERMANY').join( + supplier, col('n_nationkey') == supplier.s_nationkey).select( + 's_suppkey').join(partsupp, col('s_suppkey') == partsupp.ps_suppkey).select( + 'ps_partkey', (col('ps_supplycost') * col('ps_availqty')).alias('value')) - tmp = nation.filter(col('n_name') == 'GERMANY').join( - supplier, col('n_nationkey') == supplier.s_nationkey).select( - 's_suppkey').join(partsupp, col('s_suppkey') == partsupp.ps_suppkey).select( - 'ps_partkey', (col('ps_supplycost') * col('ps_availqty')).alias('value')) + sumRes = tmp.agg(try_sum('value').alias('total_value')) - sumRes = tmp.agg(try_sum('value').alias('total_value')) + outcome = tmp.groupBy('ps_partkey').agg( + (try_sum('value')).alias('part_value')).join( + sumRes, col('part_value') > col('total_value') * 0.0001) - outcome = tmp.groupBy('ps_partkey').agg( - (try_sum('value')).alias('part_value')).join( - sumRes, col('part_value') > col('total_value') * 0.0001) + sorted_outcome = outcome.sort(desc('part_value')).limit(5).collect() - sorted_outcome = outcome.sort(desc('part_value')).limit(5).collect() assertDataFrameEqual(sorted_outcome, expected, atol=1e-2) def test_query_12(self, spark_session_with_tpch_dataset): @@ -383,25 +408,30 @@ def test_query_12(self, spark_session_with_tpch_dataset): Row(l_shipmode='SHIP', high_line_count=6200, low_line_count=9262), ] - lineitem = spark_session_with_tpch_dataset.table('lineitem') - orders = spark_session_with_tpch_dataset.table('orders') - - outcome = lineitem.filter( - (col('l_shipmode') == 'MAIL') | (col('l_shipmode') == 'SHIP')).filter( - (col('l_commitdate') < col('l_receiptdate')) & - (col('l_shipdate') < col('l_commitdate')) & - (col('l_receiptdate') >= '1994-01-01') & (col('l_receiptdate') < '1995-01-01')).join( - orders, - col('l_orderkey') == orders.o_orderkey).select( - 'l_shipmode', 'o_orderpriority').groupBy('l_shipmode').agg( - count( - when((col('o_orderpriority') == '1-URGENT') | (col('o_orderpriority') == '2-HIGH'), - True)).alias('high_line_count'), - count( - when((col('o_orderpriority') != '1-URGENT') & (col('o_orderpriority') != '2-HIGH'), - True)).alias('low_line_count')) - - sorted_outcome = outcome.sort('l_shipmode').collect() + with utilizes_valid_plans(spark_session_with_tpch_dataset): + lineitem = spark_session_with_tpch_dataset.table('lineitem') + orders = spark_session_with_tpch_dataset.table('orders') + + outcome = lineitem.filter( + (col('l_shipmode') == 'MAIL') | (col('l_shipmode') == 'SHIP')).filter( + (col('l_commitdate') < col('l_receiptdate')) & + (col('l_shipdate') < col('l_commitdate')) & + (col('l_receiptdate') >= '1994-01-01') & ( + col('l_receiptdate') < '1995-01-01')).join( + orders, + col('l_orderkey') == orders.o_orderkey).select( + 'l_shipmode', 'o_orderpriority').groupBy('l_shipmode').agg( + count( + when((col('o_orderpriority') == '1-URGENT') | ( + col('o_orderpriority') == '2-HIGH'), + True)).alias('high_line_count'), + count( + when((col('o_orderpriority') != '1-URGENT') & ( + col('o_orderpriority') != '2-HIGH'), + True)).alias('low_line_count')) + + sorted_outcome = outcome.sort('l_shipmode').collect() + assertDataFrameEqual(sorted_outcome, expected, atol=1e-2) def test_query_13(self, spark_session_with_tpch_dataset): @@ -412,16 +442,18 @@ def test_query_13(self, spark_session_with_tpch_dataset): Row(c_count=11, custdist=6014), ] - customer = spark_session_with_tpch_dataset.table('customer') - orders = spark_session_with_tpch_dataset.table('orders') + with utilizes_valid_plans(spark_session_with_tpch_dataset): + customer = spark_session_with_tpch_dataset.table('customer') + orders = spark_session_with_tpch_dataset.table('orders') + + outcome = customer.join( + orders, (col('c_custkey') == orders.o_custkey) & ( + ~col('o_comment').rlike('.*special.*requests.*')), 'left_outer').groupBy( + 'o_custkey').agg(count('o_orderkey').alias('c_count')).groupBy( + 'c_count').agg(count('o_custkey').alias('custdist')) - outcome = customer.join( - orders, (col('c_custkey') == orders.o_custkey) & ( - ~col('o_comment').rlike('.*special.*requests.*')), 'left_outer').groupBy( - 'o_custkey').agg(count('o_orderkey').alias('c_count')).groupBy( - 'c_count').agg(count('o_custkey').alias('custdist')) + sorted_outcome = outcome.sort(desc('custdist'), desc('c_count')).limit(3).collect() - sorted_outcome = outcome.sort(desc('custdist'), desc('c_count')).limit(3).collect() assertDataFrameEqual(sorted_outcome, expected, atol=1e-2) def test_query_14(self, spark_session_with_tpch_dataset): @@ -429,16 +461,17 @@ def test_query_14(self, spark_session_with_tpch_dataset): Row(promo_revenue=16.38), ] - lineitem = spark_session_with_tpch_dataset.table('lineitem') - part = spark_session_with_tpch_dataset.table('part') + with utilizes_valid_plans(spark_session_with_tpch_dataset): + lineitem = spark_session_with_tpch_dataset.table('lineitem') + part = spark_session_with_tpch_dataset.table('part') - outcome = part.join(lineitem, (col('l_partkey') == col('p_partkey')) & - (col('l_shipdate') >= '1995-09-01') & - (col('l_shipdate') < '1995-10-01')).select( - 'p_type', (col('l_extendedprice') * (1 - col('l_discount'))).alias('value')).agg( - try_sum(when(col('p_type').contains('PROMO'), col('value'))) * 100 / try_sum( - col('value')) - ).alias('promo_revenue') + outcome = part.join(lineitem, (col('l_partkey') == col('p_partkey')) & + (col('l_shipdate') >= '1995-09-01') & + (col('l_shipdate') < '1995-10-01')).select( + 'p_type', (col('l_extendedprice') * (1 - col('l_discount'))).alias('value')).agg( + try_sum(when(col('p_type').contains('PROMO'), col('value'))) * 100 / try_sum( + col('value')) + ).alias('promo_revenue') assertDataFrameEqual(outcome, expected, atol=1e-2) @@ -447,20 +480,23 @@ def test_query_15(self, spark_session_with_tpch_dataset): Row(s_suppkey=8449, s_name='Supplier#000008449', s_address='Wp34zim9qYFbVctdW'), ] - lineitem = spark_session_with_tpch_dataset.table('lineitem') - supplier = spark_session_with_tpch_dataset.table('supplier') + with utilizes_valid_plans(spark_session_with_tpch_dataset): + lineitem = spark_session_with_tpch_dataset.table('lineitem') + supplier = spark_session_with_tpch_dataset.table('supplier') + + revenue = lineitem.filter((col('l_shipdate') >= '1996-01-01') & + (col('l_shipdate') < '1996-04-01')).select( + 'l_suppkey', + (col('l_extendedprice') * (1 - col('l_discount'))).alias('value')).groupBy( + 'l_suppkey').agg(try_sum('value').alias('total')) - revenue = lineitem.filter((col('l_shipdate') >= '1996-01-01') & - (col('l_shipdate') < '1996-04-01')).select( - 'l_suppkey', (col('l_extendedprice') * (1 - col('l_discount'))).alias('value')).groupBy( - 'l_suppkey').agg(try_sum('value').alias('total')) + outcome = revenue.agg(pyspark.sql.functions.max(col('total')).alias('max_total')).join( + revenue, col('max_total') == revenue.total).join( + supplier, col('l_suppkey') == supplier.s_suppkey).select( + 's_suppkey', 's_name', 's_address', 's_phone', 'total') - outcome = revenue.agg(pyspark.sql.functions.max(col('total')).alias('max_total')).join( - revenue, col('max_total') == revenue.total).join( - supplier, col('l_suppkey') == supplier.s_suppkey).select( - 's_suppkey', 's_name', 's_address', 's_phone', 'total') + sorted_outcome = outcome.sort('s_suppkey').limit(1).collect() - sorted_outcome = outcome.sort('s_suppkey').limit(1).collect() assertDataFrameEqual(sorted_outcome, expected, atol=1e-2) def test_query_16(self, spark_session_with_tpch_dataset): @@ -470,23 +506,26 @@ def test_query_16(self, spark_session_with_tpch_dataset): Row(p_brand='Brand#11', p_type='STANDARD BRUSHED TIN', p_size=23, supplier_cnt=24), ] - part = spark_session_with_tpch_dataset.table('part') - partsupp = spark_session_with_tpch_dataset.table('partsupp') - supplier = spark_session_with_tpch_dataset.table('supplier') + with utilizes_valid_plans(spark_session_with_tpch_dataset): + part = spark_session_with_tpch_dataset.table('part') + partsupp = spark_session_with_tpch_dataset.table('partsupp') + supplier = spark_session_with_tpch_dataset.table('supplier') - fparts = part.filter((col('p_brand') != 'Brand#45') & - (~col('p_type').startswith('MEDIUM POLISHED')) & - (col('p_size').isin([3, 14, 23, 45, 49, 9, 19, 36]))).select( - 'p_partkey', 'p_brand', 'p_type', 'p_size') + fparts = part.filter((col('p_brand') != 'Brand#45') & + (~col('p_type').startswith('MEDIUM POLISHED')) & + (col('p_size').isin([3, 14, 23, 45, 49, 9, 19, 36]))).select( + 'p_partkey', 'p_brand', 'p_type', 'p_size') - outcome = supplier.filter(~col('s_comment').rlike('.*Customer.*Complaints.*')).join( - partsupp, col('s_suppkey') == partsupp.ps_suppkey).select( - 'ps_partkey', 'ps_suppkey').join( - fparts, col('ps_partkey') == fparts.p_partkey).groupBy( - 'p_brand', 'p_type', 'p_size').agg(countDistinct('ps_suppkey').alias('supplier_cnt')) + outcome = supplier.filter(~col('s_comment').rlike('.*Customer.*Complaints.*')).join( + partsupp, col('s_suppkey') == partsupp.ps_suppkey).select( + 'ps_partkey', 'ps_suppkey').join( + fparts, col('ps_partkey') == fparts.p_partkey).groupBy( + 'p_brand', 'p_type', 'p_size').agg( + countDistinct('ps_suppkey').alias('supplier_cnt')) + + sorted_outcome = outcome.sort( + desc('supplier_cnt'), 'p_brand', 'p_type', 'p_size').limit(3).collect() - sorted_outcome = outcome.sort( - desc('supplier_cnt'), 'p_brand', 'p_type', 'p_size').limit(3).collect() assertDataFrameEqual(sorted_outcome, expected, atol=1e-2) def test_query_17(self, spark_session_with_tpch_dataset): @@ -494,19 +533,20 @@ def test_query_17(self, spark_session_with_tpch_dataset): Row(avg_yearly=348406.02), ] - lineitem = spark_session_with_tpch_dataset.table('lineitem') - part = spark_session_with_tpch_dataset.table('part') + with utilizes_valid_plans(spark_session_with_tpch_dataset): + lineitem = spark_session_with_tpch_dataset.table('lineitem') + part = spark_session_with_tpch_dataset.table('part') - fpart = part.filter( - (col('p_brand') == 'Brand#23') & (col('p_container') == 'MED BOX')).select( - 'p_partkey').join(lineitem, col('p_partkey') == lineitem.l_partkey, 'left_outer') + fpart = part.filter( + (col('p_brand') == 'Brand#23') & (col('p_container') == 'MED BOX')).select( + 'p_partkey').join(lineitem, col('p_partkey') == lineitem.l_partkey, 'left_outer') - outcome = fpart.groupBy('p_partkey').agg( - (avg('l_quantity') * 0.2).alias('avg_quantity')).select( - col('p_partkey').alias('key'), 'avg_quantity').join( - fpart, col('key') == fpart.p_partkey).filter( - col('l_quantity') < col('avg_quantity')).agg( - try_sum('l_extendedprice') / 7).alias('avg_yearly') + outcome = fpart.groupBy('p_partkey').agg( + (avg('l_quantity') * 0.2).alias('avg_quantity')).select( + col('p_partkey').alias('key'), 'avg_quantity').join( + fpart, col('key') == fpart.p_partkey).filter( + col('l_quantity') < col('avg_quantity')).agg( + try_sum('l_extendedprice') / 7).alias('avg_yearly') assertDataFrameEqual(outcome, expected, atol=1e-2) @@ -520,22 +560,25 @@ def test_query_18(self, spark_session_with_tpch_dataset): o_totalprice=530604.44, sum_l_quantity=317.00), ] - customer = spark_session_with_tpch_dataset.table('customer') - lineitem = spark_session_with_tpch_dataset.table('lineitem') - orders = spark_session_with_tpch_dataset.table('orders') - - outcome = lineitem.groupBy('l_orderkey').agg( - try_sum('l_quantity').alias('sum_quantity')).filter( - col('sum_quantity') > 300).select(col('l_orderkey').alias('key'), 'sum_quantity').join( - orders, orders.o_orderkey == col('key')).join( - lineitem, col('o_orderkey') == lineitem.l_orderkey).join( - customer, col('o_custkey') == customer.c_custkey).select( - 'l_quantity', 'c_name', 'c_custkey', 'o_orderkey', 'o_orderdate', - 'o_totalprice').groupBy( - 'c_name', 'c_custkey', 'o_orderkey', 'o_orderdate', 'o_totalprice').agg( - try_sum('l_quantity')) - - sorted_outcome = outcome.sort(desc('o_totalprice'), 'o_orderdate').limit(2).collect() + with utilizes_valid_plans(spark_session_with_tpch_dataset): + customer = spark_session_with_tpch_dataset.table('customer') + lineitem = spark_session_with_tpch_dataset.table('lineitem') + orders = spark_session_with_tpch_dataset.table('orders') + + outcome = lineitem.groupBy('l_orderkey').agg( + try_sum('l_quantity').alias('sum_quantity')).filter( + col('sum_quantity') > 300).select(col('l_orderkey').alias('key'), + 'sum_quantity').join( + orders, orders.o_orderkey == col('key')).join( + lineitem, col('o_orderkey') == lineitem.l_orderkey).join( + customer, col('o_custkey') == customer.c_custkey).select( + 'l_quantity', 'c_name', 'c_custkey', 'o_orderkey', 'o_orderdate', + 'o_totalprice').groupBy( + 'c_name', 'c_custkey', 'o_orderkey', 'o_orderdate', 'o_totalprice').agg( + try_sum('l_quantity')) + + sorted_outcome = outcome.sort(desc('o_totalprice'), 'o_orderdate').limit(2).collect() + assertDataFrameEqual(sorted_outcome, expected, atol=1e-2) def test_query_19(self, spark_session_with_tpch_dataset): @@ -543,26 +586,27 @@ def test_query_19(self, spark_session_with_tpch_dataset): Row(revenue=3083843.06), ] - lineitem = spark_session_with_tpch_dataset.table('lineitem') - part = spark_session_with_tpch_dataset.table('part') - - outcome = part.join(lineitem, col('l_partkey') == col('p_partkey')).filter( - col('l_shipmode').isin(['AIR', 'AIR REG']) & ( - col('l_shipinstruct') == 'DELIVER IN PERSON')).filter( - ((col('p_brand') == 'Brand#12') & ( - col('p_container').isin(['SM CASE', 'SM BOX', 'SM PACK', 'SM PKG'])) & - (col('l_quantity') >= 1) & (col('l_quantity') <= 11) & - (col('p_size') >= 1) & (col('p_size') <= 5)) | - ((col('p_brand') == 'Brand#23') & ( - col('p_container').isin(['MED BAG', 'MED BOX', 'MED PKG', 'MED PACK'])) & - (col('l_quantity') >= 10) & (col('l_quantity') <= 20) & - (col('p_size') >= 1) & (col('p_size') <= 10)) | - ((col('p_brand') == 'Brand#34') & ( - col('p_container').isin(['LG CASE', 'LG BOX', 'LG PACK', 'LG PKG'])) & - (col('l_quantity') >= 20) & (col('l_quantity') <= 30) & - (col('p_size') >= 1) & (col('p_size') <= 15))).select( - (col('l_extendedprice') * (1 - col('l_discount'))).alias('volume')).agg( - try_sum('volume').alias('revenue')) + with utilizes_valid_plans(spark_session_with_tpch_dataset): + lineitem = spark_session_with_tpch_dataset.table('lineitem') + part = spark_session_with_tpch_dataset.table('part') + + outcome = part.join(lineitem, col('l_partkey') == col('p_partkey')).filter( + col('l_shipmode').isin(['AIR', 'AIR REG']) & ( + col('l_shipinstruct') == 'DELIVER IN PERSON')).filter( + ((col('p_brand') == 'Brand#12') & ( + col('p_container').isin(['SM CASE', 'SM BOX', 'SM PACK', 'SM PKG'])) & + (col('l_quantity') >= 1) & (col('l_quantity') <= 11) & + (col('p_size') >= 1) & (col('p_size') <= 5)) | + ((col('p_brand') == 'Brand#23') & ( + col('p_container').isin(['MED BAG', 'MED BOX', 'MED PKG', 'MED PACK'])) & + (col('l_quantity') >= 10) & (col('l_quantity') <= 20) & + (col('p_size') >= 1) & (col('p_size') <= 10)) | + ((col('p_brand') == 'Brand#34') & ( + col('p_container').isin(['LG CASE', 'LG BOX', 'LG PACK', 'LG PKG'])) & + (col('l_quantity') >= 20) & (col('l_quantity') <= 30) & + (col('p_size') >= 1) & (col('p_size') <= 15))).select( + (col('l_extendedprice') * (1 - col('l_discount'))).alias('volume')).agg( + try_sum('volume').alias('revenue')) assertDataFrameEqual(outcome, expected, atol=1e-2) @@ -573,29 +617,31 @@ def test_query_20(self, spark_session_with_tpch_dataset): Row(s_name='Supplier#000000205', s_address='rF uV8d0JNEk'), ] - lineitem = spark_session_with_tpch_dataset.table('lineitem') - nation = spark_session_with_tpch_dataset.table('nation') - part = spark_session_with_tpch_dataset.table('part') - partsupp = spark_session_with_tpch_dataset.table('partsupp') - supplier = spark_session_with_tpch_dataset.table('supplier') - - flineitem = lineitem.filter( - (col('l_shipdate') >= '1994-01-01') & (col('l_shipdate') < '1995-01-01')).groupBy( - 'l_partkey', 'l_suppkey').agg( - try_sum(col('l_quantity') * 0.5).alias('sum_quantity')) - - fnation = nation.filter(col('n_name') == 'CANADA') - nat_supp = supplier.select('s_suppkey', 's_name', 's_nationkey', 's_address').join( - fnation, col('s_nationkey') == fnation.n_nationkey) - - outcome = part.filter(col('p_name').startswith('forest')).select('p_partkey').join( - partsupp, col('p_partkey') == partsupp.ps_partkey).join( - flineitem, (col('ps_suppkey') == flineitem.l_suppkey) & ( - col('ps_partkey') == flineitem.l_partkey)).filter( - col('ps_availqty') > col('sum_quantity')).select('ps_suppkey').distinct().join( - nat_supp, col('ps_suppkey') == nat_supp.s_suppkey).select('s_name', 's_address') - - sorted_outcome = outcome.sort('s_name').limit(3).collect() + with utilizes_valid_plans(spark_session_with_tpch_dataset): + lineitem = spark_session_with_tpch_dataset.table('lineitem') + nation = spark_session_with_tpch_dataset.table('nation') + part = spark_session_with_tpch_dataset.table('part') + partsupp = spark_session_with_tpch_dataset.table('partsupp') + supplier = spark_session_with_tpch_dataset.table('supplier') + + flineitem = lineitem.filter( + (col('l_shipdate') >= '1994-01-01') & (col('l_shipdate') < '1995-01-01')).groupBy( + 'l_partkey', 'l_suppkey').agg( + try_sum(col('l_quantity') * 0.5).alias('sum_quantity')) + + fnation = nation.filter(col('n_name') == 'CANADA') + nat_supp = supplier.select('s_suppkey', 's_name', 's_nationkey', 's_address').join( + fnation, col('s_nationkey') == fnation.n_nationkey) + + outcome = part.filter(col('p_name').startswith('forest')).select('p_partkey').join( + partsupp, col('p_partkey') == partsupp.ps_partkey).join( + flineitem, (col('ps_suppkey') == flineitem.l_suppkey) & ( + col('ps_partkey') == flineitem.l_partkey)).filter( + col('ps_availqty') > col('sum_quantity')).select('ps_suppkey').distinct().join( + nat_supp, col('ps_suppkey') == nat_supp.s_suppkey).select('s_name', 's_address') + + sorted_outcome = outcome.sort('s_name').limit(3).collect() + assertDataFrameEqual(sorted_outcome, expected, atol=1e-2) def test_query_21(self, spark_session_with_tpch_dataset): @@ -608,43 +654,46 @@ def test_query_21(self, spark_session_with_tpch_dataset): Row(s_name='Supplier#000000486', numwait=25), ] - lineitem = spark_session_with_tpch_dataset.table('lineitem') - nation = spark_session_with_tpch_dataset.table('nation') - orders = spark_session_with_tpch_dataset.table('orders') - supplier = spark_session_with_tpch_dataset.table('supplier') + with utilizes_valid_plans(spark_session_with_tpch_dataset): + lineitem = spark_session_with_tpch_dataset.table('lineitem') + nation = spark_session_with_tpch_dataset.table('nation') + orders = spark_session_with_tpch_dataset.table('orders') + supplier = spark_session_with_tpch_dataset.table('supplier') + + fsupplier = supplier.select('s_suppkey', 's_nationkey', 's_name') - fsupplier = supplier.select('s_suppkey', 's_nationkey', 's_name') + plineitem = lineitem.select('l_suppkey', 'l_orderkey', 'l_receiptdate', 'l_commitdate') - plineitem = lineitem.select('l_suppkey', 'l_orderkey', 'l_receiptdate', 'l_commitdate') + flineitem = plineitem.filter(col('l_receiptdate') > col('l_commitdate')) - flineitem = plineitem.filter(col('l_receiptdate') > col('l_commitdate')) + line1 = plineitem.groupBy('l_orderkey').agg( + countDistinct('l_suppkey').alias('suppkey_count'), + pyspark.sql.functions.max(col('l_suppkey')).alias('suppkey_max')).select( + col('l_orderkey').alias('key'), 'suppkey_count', 'suppkey_max') - line1 = plineitem.groupBy('l_orderkey').agg( - countDistinct('l_suppkey').alias('suppkey_count'), - pyspark.sql.functions.max(col('l_suppkey')).alias('suppkey_max')).select( - col('l_orderkey').alias('key'), 'suppkey_count', 'suppkey_max') + line2 = flineitem.groupBy('l_orderkey').agg( + countDistinct('l_suppkey').alias('suppkey_count'), + pyspark.sql.functions.max(col('l_suppkey')).alias('suppkey_max')).select( + col('l_orderkey').alias('key'), 'suppkey_count', 'suppkey_max') - line2 = flineitem.groupBy('l_orderkey').agg( - countDistinct('l_suppkey').alias('suppkey_count'), - pyspark.sql.functions.max(col('l_suppkey')).alias('suppkey_max')).select( - col('l_orderkey').alias('key'), 'suppkey_count', 'suppkey_max') + forder = orders.select('o_orderkey', 'o_orderstatus').filter( + col('o_orderstatus') == 'F') - forder = orders.select('o_orderkey', 'o_orderstatus').filter(col('o_orderstatus') == 'F') + outcome = nation.filter(col('n_name') == 'SAUDI ARABIA').join( + fsupplier, col('n_nationkey') == fsupplier.s_nationkey).join( + flineitem, col('s_suppkey') == flineitem.l_suppkey).join( + forder, col('l_orderkey') == forder.o_orderkey).join( + line1, col('l_orderkey') == line1.key).filter( + (col('suppkey_count') > 1) | + ((col('suppkey_count') == 1) & (col('l_suppkey') == col('suppkey_max')))).select( + 's_name', 'l_orderkey', 'l_suppkey').join( + line2, col('l_orderkey') == line2.key, 'left_outer').select( + 's_name', 'l_orderkey', 'l_suppkey', 'suppkey_count', 'suppkey_max').filter( + (col('suppkey_count') == 1) & (col('l_suppkey') == col('suppkey_max'))).groupBy( + 's_name').agg(count(col('l_suppkey')).alias('numwait')) - outcome = nation.filter(col('n_name') == 'SAUDI ARABIA').join( - fsupplier, col('n_nationkey') == fsupplier.s_nationkey).join( - flineitem, col('s_suppkey') == flineitem.l_suppkey).join( - forder, col('l_orderkey') == forder.o_orderkey).join( - line1, col('l_orderkey') == line1.key).filter( - (col('suppkey_count') > 1) | - ((col('suppkey_count') == 1) & (col('l_suppkey') == col('suppkey_max')))).select( - 's_name', 'l_orderkey', 'l_suppkey').join( - line2, col('l_orderkey') == line2.key, 'left_outer').select( - 's_name', 'l_orderkey', 'l_suppkey', 'suppkey_count', 'suppkey_max').filter( - (col('suppkey_count') == 1) & (col('l_suppkey') == col('suppkey_max'))).groupBy( - 's_name').agg(count(col('l_suppkey')).alias('numwait')) + sorted_outcome = outcome.sort(desc('numwait'), 's_name').limit(5).collect() - sorted_outcome = outcome.sort(desc('numwait'), 's_name').limit(5).collect() assertDataFrameEqual(sorted_outcome, expected, atol=1e-2) def test_query_22(self, spark_session_with_tpch_dataset): @@ -658,22 +707,24 @@ def test_query_22(self, spark_session_with_tpch_dataset): Row(cntrycode='31', numcust=922, totacctbal=6806670.18), ] - customer = spark_session_with_tpch_dataset.table('customer') - orders = spark_session_with_tpch_dataset.table('orders') + with utilizes_valid_plans(spark_session_with_tpch_dataset): + customer = spark_session_with_tpch_dataset.table('customer') + orders = spark_session_with_tpch_dataset.table('orders') + + fcustomer = customer.select( + 'c_acctbal', 'c_custkey', (col('c_phone').substr(0, 2)).alias('cntrycode')).filter( + col('cntrycode').isin(['13', '31', '23', '29', '30', '18', '17'])) - fcustomer = customer.select( - 'c_acctbal', 'c_custkey', (col('c_phone').substr(0, 2)).alias('cntrycode')).filter( - col('cntrycode').isin(['13', '31', '23', '29', '30', '18', '17'])) + avg_customer = fcustomer.filter(col('c_acctbal') > 0.00).agg( + avg('c_acctbal').alias('avg_acctbal')) - avg_customer = fcustomer.filter(col('c_acctbal') > 0.00).agg( - avg('c_acctbal').alias('avg_acctbal')) + outcome = orders.groupBy('o_custkey').agg( + count('o_custkey')).select('o_custkey').join( + fcustomer, col('o_custkey') == fcustomer.c_custkey, 'right_outer').filter( + col('o_custkey').isNull()).join(avg_customer).filter( + col('c_acctbal') > col('avg_acctbal')).groupBy('cntrycode').agg( + count('c_custkey').alias('numcust'), try_sum('c_acctbal')) - outcome = orders.groupBy('o_custkey').agg( - count('o_custkey')).select('o_custkey').join( - fcustomer, col('o_custkey') == fcustomer.c_custkey, 'right_outer').filter( - col('o_custkey').isNull()).join(avg_customer).filter( - col('c_acctbal') > col('avg_acctbal')).groupBy('cntrycode').agg( - count('c_custkey').alias('numcust'), try_sum('c_acctbal')) + sorted_outcome = outcome.sort('cntrycode').collect() - sorted_outcome = outcome.sort('cntrycode').collect() assertDataFrameEqual(sorted_outcome, expected, atol=1e-2) From f3d12403ba5ff1ac129675f70d4a8d498322ed33 Mon Sep 17 00:00:00 2001 From: Richard Tia Date: Thu, 9 May 2024 11:14:31 -0400 Subject: [PATCH 42/58] feat: use a persistent backend throughout the test cases (#67) --- src/gateway/converter/spark_to_substrait.py | 23 +++---- src/gateway/converter/sql_to_substrait.py | 10 +-- src/gateway/server.py | 68 ++++++++++----------- src/gateway/tests/conftest.py | 42 ++++++------- src/gateway/tests/test_dataframe_api.py | 12 ++-- 5 files changed, 73 insertions(+), 82 deletions(-) diff --git a/src/gateway/converter/spark_to_substrait.py b/src/gateway/converter/spark_to_substrait.py index 4bec1eb..82c6352 100644 --- a/src/gateway/converter/spark_to_substrait.py +++ b/src/gateway/converter/spark_to_substrait.py @@ -62,11 +62,11 @@ def __init__(self, options: ConversionOptions): self._seen_generated_names = {} self._saved_extension_uris = {} self._saved_extensions = {} - self._backend_with_tempview = None + self._backend = None - def set_tempview_backend(self, backend) -> None: + def set_backend(self, backend) -> None: """Save the backend being used to create the temporary dataframe.""" - self._backend_with_tempview = backend + self._backend = backend def lookup_function_by_name(self, name: str) -> ExtensionFunction: """Find the function reference for a given Spark function name.""" @@ -454,15 +454,12 @@ def convert_read_named_table_relation( """Convert a read named table relation to a Substrait relation.""" table_name = rel.unparsed_identifier - if self._backend_with_tempview: - backend = self._backend_with_tempview - else: - # TODO -- Remove this once we have a persistent backend per session. - backend = find_backend(BackendOptions(self._conversion_options.backend.backend, - use_adbc=True)) - tpch_location = backend.find_tpch() - backend.register_table(table_name, tpch_location / table_name) - arrow_schema = backend.describe_table(table_name) + # An ADBC backend is required in order to get the arrow schema + temp_backend = find_backend(BackendOptions(self._conversion_options.backend.backend, + use_adbc=True)) + tpch_location = temp_backend.find_tpch() + temp_backend.register_table(table_name, tpch_location / table_name) + arrow_schema = temp_backend.describe_table(table_name) schema = self.convert_arrow_schema(arrow_schema) symbol = self._symbol_table.get_symbol(self._current_plan_id) @@ -996,7 +993,7 @@ def convert_local_relation(self, rel: spark_relations_pb2.LocalRelation) -> alge def convert_sql_relation(self, rel: spark_relations_pb2.SQL) -> algebra_pb2.Rel: """Convert a Spark SQL relation into a Substrait relation.""" # TODO -- Handle multithreading in the case with a persistent backend. - plan = convert_sql(rel.query, self._backend_with_tempview) + plan = convert_sql(rel.query) symbol = self._symbol_table.get_symbol(self._current_plan_id) for field_name in plan.relations[0].root.names: symbol.output_fields.append(field_name) diff --git a/src/gateway/converter/sql_to_substrait.py b/src/gateway/converter/sql_to_substrait.py index a5d4d14..939644c 100644 --- a/src/gateway/converter/sql_to_substrait.py +++ b/src/gateway/converter/sql_to_substrait.py @@ -6,16 +6,12 @@ from substrait.gen.proto import plan_pb2 -def convert_sql(sql: str, backend=None) -> plan_pb2.Plan: +def convert_sql(sql: str) -> plan_pb2.Plan: """Convert SQL into a Substrait plan.""" plan = plan_pb2.Plan() - # If backend is not provided or is not a DuckDBBackend, set one up. - # DuckDB is used as the SQL conversion engine. - if not isinstance(backend, backend_selector.DuckDBBackend): - backend = backend_selector.find_backend(BackendOptions(Backend.DUCKDB, False)) - backend.register_tpch() - + backend = backend_selector.find_backend(BackendOptions(Backend.DUCKDB, False)) + backend.register_tpch() connection = backend.get_connection() proto_bytes = connection.get_substrait(query=sql).fetchone()[0] plan.ParseFromString(proto_bytes) diff --git a/src/gateway/server.py b/src/gateway/server.py index 9b287dc..e6b14b4 100644 --- a/src/gateway/server.py +++ b/src/gateway/server.py @@ -13,7 +13,7 @@ from pyspark.sql.connect.proto import types_pb2 from substrait.gen.proto import algebra_pb2, plan_pb2 -from gateway.backends.backend_options import BackendOptions +from gateway.backends import backend_selector from gateway.backends.backend_selector import find_backend from gateway.converter.conversion_options import arrow, datafusion, duck_db from gateway.converter.spark_to_substrait import SparkSubstraitConverter @@ -82,18 +82,16 @@ def convert_pyarrow_schema_to_spark(schema: pa.Schema) -> types_pb2.DataType: return types_pb2.DataType(struct=types_pb2.DataType.Struct(fields=fields)) -def create_dataframe_view(rel: pb2.Plan, conversion_options, backend) -> algebra_pb2.Rel: +def create_dataframe_view(rel: pb2.Plan, backend) -> algebra_pb2.Rel: """Register the temporary dataframe.""" dataframe_view_name = rel.command.create_dataframe_view.name read_data_source_relation = rel.command.create_dataframe_view.input.read.data_source format = read_data_source_relation.format path = read_data_source_relation.paths[0] - - if not backend: - backend = find_backend(BackendOptions(conversion_options.backend.backend, False)) backend.register_table(dataframe_view_name, path, format) - return backend + return None + class Statistics: """Statistics about the requests made to the server.""" @@ -144,11 +142,15 @@ def __init__(self, *args, **kwargs): """Initialize the SparkConnect service.""" # This is the central point for configuring the behavior of the service. self._options = duck_db() - self._backend_with_tempview = None - self._tempview_session_id = None + self._backend = None self._converter = None self._statistics = Statistics() + def _InitializeExecution(self): + """Initialize the execution of the Plan by setting the backend.""" + if not self._backend: + self._backend = find_backend(self._options.backend) + def ExecutePlan( self, request: pb2.ExecutePlanRequest, context: grpc.RpcContext) -> Generator[ pb2.ExecutePlanResponse, None, None]: @@ -156,28 +158,26 @@ def ExecutePlan( self._statistics.execute_requests += 1 self._statistics.add_request(request) _LOGGER.info('ExecutePlan: %s', request) + self._InitializeExecution() + if not self._converter: + self._converter = SparkSubstraitConverter(self._options) + self._converter.set_backend(self._backend) match request.plan.WhichOneof('op_type'): case 'root': - if not self._converter and self._tempview_session_id != request.session_id: - self._converter = SparkSubstraitConverter(self._options) substrait = self._converter.convert_plan(request.plan) case 'command': match request.plan.command.WhichOneof('command_type'): case 'sql_command': - if (self._backend_with_tempview and - self._tempview_session_id == request.session_id): - substrait = convert_sql(request.plan.command.sql_command.sql, - self._backend_with_tempview) - else: - substrait = convert_sql(request.plan.command.sql_command.sql) + if "CREATE" in request.plan.command.sql_command.sql: + connection = self._backend.get_connection() + connection.execute(request.plan.command.sql_command.sql) + return + substrait = convert_sql(request.plan.command.sql_command.sql) case 'create_dataframe_view': - if not self._converter and self._tempview_session_id != request.session_id: - self._converter = SparkSubstraitConverter(self._options) - self._backend_with_tempview = create_dataframe_view( - request.plan, self._options, self._backend_with_tempview) - self._tempview_session_id = request.session_id - self._converter.set_tempview_backend(self._backend_with_tempview) - + create_dataframe_view(request.plan, self._backend) + yield pb2.ExecutePlanResponse( + session_id=request.session_id, + result_complete=pb2.ExecutePlanResponse.ResultComplete()) return case _: type = request.plan.command.WhichOneof("command_type") @@ -185,13 +185,11 @@ def ExecutePlan( case _: raise ValueError(f'Unknown plan type: {request.plan}') _LOGGER.debug(' as Substrait: %s', substrait) - if self._backend_with_tempview and self._tempview_session_id == request.session_id: - backend = self._backend_with_tempview - else: - backend = find_backend(self._options.backend) - backend.register_tpch() + # TODO: Register the TPCH data for datafusion through the fixture. + if isinstance(self._backend, backend_selector.DatafusionBackend): + self._backend.register_tpch() self._statistics.add_plan(substrait) - results = backend.execute(substrait) + results = self._backend.execute(substrait) _LOGGER.debug(' results are: %s', results) if not self._options.implement_show_string and request.plan.WhichOneof( @@ -231,17 +229,17 @@ def AnalyzePlan(self, request, context): self._statistics.analyze_requests += 1 self._statistics.add_request(request) _LOGGER.info('AnalyzePlan: %s', request) + self._InitializeExecution() if request.schema: if not self._converter: self._converter = SparkSubstraitConverter(self._options) + self._converter.set_backend(self._backend) substrait = self._converter.convert_plan(request.schema.plan) - if self._backend_with_tempview and self._tempview_session_id == request.session_id: - backend = self._backend_with_tempview - else: - backend = find_backend(self._options.backend) - backend.register_tpch() + # TODO: Register the TPCH data for datafusion through the fixture. + if isinstance(self._backend, backend_selector.DatafusionBackend): + self._backend.register_tpch() self._statistics.add_plan(substrait) - results = backend.execute(substrait) + results = self._backend.execute(substrait) _LOGGER.debug(' results are: %s', results) return pb2.AnalyzePlanResponse( session_id=request.session_id, diff --git a/src/gateway/tests/conftest.py b/src/gateway/tests/conftest.py index 5e2221b..cd5648c 100644 --- a/src/gateway/tests/conftest.py +++ b/src/gateway/tests/conftest.py @@ -126,31 +126,31 @@ def users_dataframe(spark_session, schema_users, users_location): .parquet(users_location) -def _register_table(spark_session: SparkSession, name: str) -> None: +def _register_table(spark_session: SparkSession, source: str, name: str) -> None: location = Backend.find_tpch() / name - spark_session.sql( - f'CREATE OR REPLACE TEMPORARY VIEW {name} USING org.apache.spark.sql.parquet ' - f'OPTIONS ( path "{location}" )') + match source: + case 'spark': + spark_session.sql( + f'CREATE OR REPLACE TEMPORARY VIEW {name} USING org.apache.spark.sql.parquet ' + f'OPTIONS ( path "{location}" )') + case 'gateway-over-duckdb': + files = Backend.expand_location(location) + if not files: + raise ValueError(f"No parquet files found at {location}") + files_str = ', '.join([f"'{f}'" for f in files]) + files_sql = f"CREATE OR REPLACE TABLE {name} AS FROM read_parquet([{files_str}])" + spark_session.sql(files_sql) @pytest.fixture(scope='function') def spark_session_with_tpch_dataset(spark_session: SparkSession, source: str) -> SparkSession: """Add the TPC-H dataset to the current spark session.""" - if source == 'spark': - _register_table(spark_session, 'customer') - _register_table(spark_session, 'lineitem') - _register_table(spark_session, 'nation') - _register_table(spark_session, 'orders') - _register_table(spark_session, 'part') - _register_table(spark_session, 'partsupp') - _register_table(spark_session, 'region') - _register_table(spark_session, 'supplier') - return spark_session - - -@pytest.fixture(scope='function') -def spark_session_with_customer_dataset(spark_session: SparkSession, source: str) -> SparkSession: - """Add the TPC-H dataset to the current spark session.""" - if source == 'spark': - _register_table(spark_session, 'customer') + _register_table(spark_session, source, 'customer') + _register_table(spark_session, source, 'lineitem') + _register_table(spark_session, source, 'nation') + _register_table(spark_session, source, 'orders') + _register_table(spark_session, source, 'part') + _register_table(spark_session, source, 'partsupp') + _register_table(spark_session, source, 'region') + _register_table(spark_session, source, 'supplier') return spark_session diff --git a/src/gateway/tests/test_dataframe_api.py b/src/gateway/tests/test_dataframe_api.py index 9a72843..7331b0e 100644 --- a/src/gateway/tests/test_dataframe_api.py +++ b/src/gateway/tests/test_dataframe_api.py @@ -143,16 +143,16 @@ def test_data_source_filter(self, spark_session): outcome = customer_dataframe.filter(col('c_mktsegment') == 'FURNITURE').collect() assert len(outcome) == 29968 - def test_table(self, spark_session_with_customer_dataset): - outcome = spark_session_with_customer_dataset.table('customer').collect() + def test_table(self, spark_session_with_tpch_dataset): + outcome = spark_session_with_tpch_dataset.table('customer').collect() assert len(outcome) == 149999 - def test_table_schema(self, spark_session_with_customer_dataset): - schema = spark_session_with_customer_dataset.table('customer').schema + def test_table_schema(self, spark_session_with_tpch_dataset): + schema = spark_session_with_tpch_dataset.table('customer').schema assert len(schema) == 8 - def test_table_filter(self, spark_session_with_customer_dataset): - customer_dataframe = spark_session_with_customer_dataset.table('customer') + def test_table_filter(self, spark_session_with_tpch_dataset): + customer_dataframe = spark_session_with_tpch_dataset.table('customer') outcome = customer_dataframe.filter(col('c_mktsegment') == 'FURNITURE').collect() assert len(outcome) == 29968 From a511a72e7e1c001b7aee9620192c46e3915ddebb Mon Sep 17 00:00:00 2001 From: David Sisson Date: Thu, 9 May 2024 14:10:49 -0700 Subject: [PATCH 43/58] feat: add duckdb table describe method (#70) --- src/gateway/backends/duckdb_backend.py | 25 ++++++++++++++++ src/gateway/converter/spark_to_substrait.py | 7 +---- src/gateway/server.py | 3 ++ src/gateway/tests/test_dataframe_api.py | 32 ++++++++++++++------- 4 files changed, 50 insertions(+), 17 deletions(-) diff --git a/src/gateway/backends/duckdb_backend.py b/src/gateway/backends/duckdb_backend.py index af5e9b9..3993575 100644 --- a/src/gateway/backends/duckdb_backend.py +++ b/src/gateway/backends/duckdb_backend.py @@ -8,6 +8,19 @@ from gateway.backends.backend import Backend +_DUCKDB_TO_ARROW = { + 'BOOLEAN': pa.bool_(), + 'TINYINT': pa.int8(), + 'SMALLINT': pa.int16(), + 'INTEGER': pa.int32(), + 'BIGINT': pa.int64(), + 'FLOAT': pa.float32(), + 'DOUBLE': pa.float64(), + 'DATE': pa.date32(), + 'TIMESTAMP': pa.timestamp('ns'), + 'VARCHAR': pa.string(), +} + # pylint: disable=fixme class DuckDBBackend(Backend): @@ -58,3 +71,15 @@ def register_table( files_sql = f"CREATE OR REPLACE TABLE {table_name} AS FROM read_parquet([{files_str}])" self._connection.execute(files_sql) + + def describe_table(self, name: str): + """Asks the backend to describe the given table.""" + result = self._connection.table(name).describe() + + fields = [] + for name, field_type in zip(result.columns, result.types, strict=False): + if name == 'aggr': + # This isn't a real column. + continue + fields.append(pa.field(name, _DUCKDB_TO_ARROW[str(field_type)])) + return pa.schema(fields) diff --git a/src/gateway/converter/spark_to_substrait.py b/src/gateway/converter/spark_to_substrait.py index 82c6352..892d808 100644 --- a/src/gateway/converter/spark_to_substrait.py +++ b/src/gateway/converter/spark_to_substrait.py @@ -454,12 +454,7 @@ def convert_read_named_table_relation( """Convert a read named table relation to a Substrait relation.""" table_name = rel.unparsed_identifier - # An ADBC backend is required in order to get the arrow schema - temp_backend = find_backend(BackendOptions(self._conversion_options.backend.backend, - use_adbc=True)) - tpch_location = temp_backend.find_tpch() - temp_backend.register_table(table_name, tpch_location / table_name) - arrow_schema = temp_backend.describe_table(table_name) + arrow_schema = self._backend.describe_table(table_name) schema = self.convert_arrow_schema(arrow_schema) symbol = self._symbol_table.get_symbol(self._current_plan_id) diff --git a/src/gateway/server.py b/src/gateway/server.py index e6b14b4..9f4a687 100644 --- a/src/gateway/server.py +++ b/src/gateway/server.py @@ -171,6 +171,9 @@ def ExecutePlan( if "CREATE" in request.plan.command.sql_command.sql: connection = self._backend.get_connection() connection.execute(request.plan.command.sql_command.sql) + yield pb2.ExecutePlanResponse( + session_id=request.session_id, + result_complete=pb2.ExecutePlanResponse.ResultComplete()) return substrait = convert_sql(request.plan.command.sql_command.sql) case 'create_dataframe_view': diff --git a/src/gateway/tests/test_dataframe_api.py b/src/gateway/tests/test_dataframe_api.py index 7331b0e..2a2955a 100644 --- a/src/gateway/tests/test_dataframe_api.py +++ b/src/gateway/tests/test_dataframe_api.py @@ -15,11 +15,7 @@ def mark_dataframe_tests_as_xfail(request): source = request.getfixturevalue('source') originalname = request.keywords.node.originalname if source == 'gateway-over-duckdb': - if originalname == 'test_with_column' or originalname == 'test_cast': - request.node.add_marker(pytest.mark.xfail(reason='DuckDB column binding error')) - elif originalname in [ - 'test_create_or_replace_temp_view', 'test_create_or_replace_multiple_temp_views']: - request.node.add_marker(pytest.mark.xfail(reason='ADBC DuckDB from_substrait error')) + request.node.add_marker(pytest.mark.xfail(reason='DuckDB column binding error')) elif source == 'gateway-over-datafusion': if originalname in [ 'test_data_source_schema', 'test_data_source_filter', 'test_table', 'test_table_schema', @@ -140,11 +136,16 @@ def test_data_source_schema(self, spark_session): def test_data_source_filter(self, spark_session): location_customer = str(Backend.find_tpch() / 'customer') customer_dataframe = spark_session.read.parquet(location_customer) - outcome = customer_dataframe.filter(col('c_mktsegment') == 'FURNITURE').collect() + + with utilizes_valid_plans(spark_session): + outcome = customer_dataframe.filter(col('c_mktsegment') == 'FURNITURE').collect() + assert len(outcome) == 29968 def test_table(self, spark_session_with_tpch_dataset): - outcome = spark_session_with_tpch_dataset.table('customer').collect() + with utilizes_valid_plans(spark_session_with_tpch_dataset): + outcome = spark_session_with_tpch_dataset.table('customer').collect() + assert len(outcome) == 149999 def test_table_schema(self, spark_session_with_tpch_dataset): @@ -153,14 +154,20 @@ def test_table_schema(self, spark_session_with_tpch_dataset): def test_table_filter(self, spark_session_with_tpch_dataset): customer_dataframe = spark_session_with_tpch_dataset.table('customer') - outcome = customer_dataframe.filter(col('c_mktsegment') == 'FURNITURE').collect() + + with utilizes_valid_plans(spark_session_with_tpch_dataset): + outcome = customer_dataframe.filter(col('c_mktsegment') == 'FURNITURE').collect() + assert len(outcome) == 29968 def test_create_or_replace_temp_view(self, spark_session): location_customer = str(Backend.find_tpch() / 'customer') df_customer = spark_session.read.parquet(location_customer) df_customer.createOrReplaceTempView("mytempview") - outcome = spark_session.table('mytempview').collect() + + with utilizes_valid_plans(spark_session): + outcome = spark_session.table('mytempview').collect() + assert len(outcome) == 149999 def test_create_or_replace_multiple_temp_views(self, spark_session): @@ -168,6 +175,9 @@ def test_create_or_replace_multiple_temp_views(self, spark_session): df_customer = spark_session.read.parquet(location_customer) df_customer.createOrReplaceTempView("mytempview1") df_customer.createOrReplaceTempView("mytempview2") - outcome1 = spark_session.table('mytempview1').collect() - outcome2 = spark_session.table('mytempview2').collect() + + with utilizes_valid_plans(spark_session): + outcome1 = spark_session.table('mytempview1').collect() + outcome2 = spark_session.table('mytempview2').collect() + assert len(outcome1) == len(outcome2) == 149999 From a1c920fdceaaf4ac14e9600f1e63719deac1abed Mon Sep 17 00:00:00 2001 From: David Sisson Date: Thu, 9 May 2024 21:19:44 -0700 Subject: [PATCH 44/58] feat: fix with_columns for DuckDB (#71) This adds a workaround that projects and emits all columns instead of relying on the pass through fields. --- src/gateway/converter/conversion_options.py | 2 + src/gateway/converter/data/00001.splan | 348 ++++++++++++++++++-- src/gateway/converter/spark_to_substrait.py | 6 + src/gateway/tests/test_dataframe_api.py | 4 +- 4 files changed, 330 insertions(+), 30 deletions(-) diff --git a/src/gateway/converter/conversion_options.py b/src/gateway/converter/conversion_options.py index f9d8b2a..5ac46bd 100644 --- a/src/gateway/converter/conversion_options.py +++ b/src/gateway/converter/conversion_options.py @@ -17,6 +17,7 @@ def __init__(self, backend: BackendOptions = None): self.use_emits_instead_of_direct = False self.use_switch_expressions_where_possible = True self.use_duckdb_regexp_matches_function = False + self.duckdb_project_emit_workaround = False self.safety_project_read_relations = False self.return_names_with_types = False @@ -48,4 +49,5 @@ def duck_db(): options.return_names_with_types = True options.use_switch_expressions_where_possible = False options.use_duckdb_regexp_matches_function = True + options.duckdb_project_emit_workaround = True return options diff --git a/src/gateway/converter/data/00001.splan b/src/gateway/converter/data/00001.splan index eeac539..b9ef5e9 100644 --- a/src/gateway/converter/data/00001.splan +++ b/src/gateway/converter/data/00001.splan @@ -78,15 +78,15 @@ relations { project { common { emit { - output_mapping: 0 - output_mapping: 1 - output_mapping: 2 - output_mapping: 3 - output_mapping: 4 - output_mapping: 5 - output_mapping: 6 - output_mapping: 7 - output_mapping: 8 + output_mapping: 11 + output_mapping: 12 + output_mapping: 13 + output_mapping: 14 + output_mapping: 15 + output_mapping: 16 + output_mapping: 17 + output_mapping: 18 + output_mapping: 19 output_mapping: 10 } } @@ -94,32 +94,32 @@ relations { project { common { emit { - output_mapping: 0 - output_mapping: 1 - output_mapping: 2 - output_mapping: 3 - output_mapping: 4 - output_mapping: 5 - output_mapping: 6 + output_mapping: 11 + output_mapping: 12 + output_mapping: 13 + output_mapping: 14 + output_mapping: 15 + output_mapping: 16 + output_mapping: 17 output_mapping: 10 - output_mapping: 8 - output_mapping: 9 + output_mapping: 18 + output_mapping: 19 } } input { project { common { emit { - output_mapping: 0 - output_mapping: 1 - output_mapping: 2 - output_mapping: 3 - output_mapping: 4 - output_mapping: 5 + output_mapping: 11 + output_mapping: 12 + output_mapping: 13 + output_mapping: 14 + output_mapping: 15 + output_mapping: 16 output_mapping: 10 - output_mapping: 7 - output_mapping: 8 - output_mapping: 9 + output_mapping: 17 + output_mapping: 18 + output_mapping: 19 } } input { @@ -232,6 +232,104 @@ relations { } } } + expressions { + selection { + direct_reference { + struct_field { + } + } + root_reference { + } + } + } + expressions { + selection { + direct_reference { + struct_field { + field: 1 + } + } + root_reference { + } + } + } + expressions { + selection { + direct_reference { + struct_field { + field: 2 + } + } + root_reference { + } + } + } + expressions { + selection { + direct_reference { + struct_field { + field: 3 + } + } + root_reference { + } + } + } + expressions { + selection { + direct_reference { + struct_field { + field: 4 + } + } + root_reference { + } + } + } + expressions { + selection { + direct_reference { + struct_field { + field: 5 + } + } + root_reference { + } + } + } + expressions { + selection { + direct_reference { + struct_field { + field: 7 + } + } + root_reference { + } + } + } + expressions { + selection { + direct_reference { + struct_field { + field: 8 + } + } + root_reference { + } + } + } + expressions { + selection { + direct_reference { + struct_field { + field: 9 + } + } + root_reference { + } + } + } } } expressions { @@ -254,6 +352,104 @@ relations { } } } + expressions { + selection { + direct_reference { + struct_field { + } + } + root_reference { + } + } + } + expressions { + selection { + direct_reference { + struct_field { + field: 1 + } + } + root_reference { + } + } + } + expressions { + selection { + direct_reference { + struct_field { + field: 2 + } + } + root_reference { + } + } + } + expressions { + selection { + direct_reference { + struct_field { + field: 3 + } + } + root_reference { + } + } + } + expressions { + selection { + direct_reference { + struct_field { + field: 4 + } + } + root_reference { + } + } + } + expressions { + selection { + direct_reference { + struct_field { + field: 5 + } + } + root_reference { + } + } + } + expressions { + selection { + direct_reference { + struct_field { + field: 6 + } + } + root_reference { + } + } + } + expressions { + selection { + direct_reference { + struct_field { + field: 8 + } + } + root_reference { + } + } + } + expressions { + selection { + direct_reference { + struct_field { + field: 9 + } + } + root_reference { + } + } + } } } expressions { @@ -276,6 +472,104 @@ relations { } } } + expressions { + selection { + direct_reference { + struct_field { + } + } + root_reference { + } + } + } + expressions { + selection { + direct_reference { + struct_field { + field: 1 + } + } + root_reference { + } + } + } + expressions { + selection { + direct_reference { + struct_field { + field: 2 + } + } + root_reference { + } + } + } + expressions { + selection { + direct_reference { + struct_field { + field: 3 + } + } + root_reference { + } + } + } + expressions { + selection { + direct_reference { + struct_field { + field: 4 + } + } + root_reference { + } + } + } + expressions { + selection { + direct_reference { + struct_field { + field: 5 + } + } + root_reference { + } + } + } + expressions { + selection { + direct_reference { + struct_field { + field: 6 + } + } + root_reference { + } + } + } + expressions { + selection { + direct_reference { + struct_field { + field: 7 + } + } + root_reference { + } + } + } + expressions { + selection { + direct_reference { + struct_field { + field: 8 + } + } + root_reference { + } + } + } } } condition { diff --git a/src/gateway/converter/spark_to_substrait.py b/src/gateway/converter/spark_to_substrait.py index 892d808..eb63417 100644 --- a/src/gateway/converter/spark_to_substrait.py +++ b/src/gateway/converter/spark_to_substrait.py @@ -928,6 +928,12 @@ def convert_with_columns_relation( symbol.output_fields.append(name) project.common.CopyFrom(self.create_common_relation()) if remapped: + if self._conversion_options.duckdb_project_emit_workaround: + for field_number in range(len(symbol.input_fields)): + if field_number == mapping[field_number]: + project.expressions.append(field_reference(field_number)) + mapping[field_number] = len(symbol.input_fields) + ( + len(project.expressions)) - 1 for item in mapping: project.common.emit.output_mapping.append(item) return algebra_pb2.Rel(project=project) diff --git a/src/gateway/tests/test_dataframe_api.py b/src/gateway/tests/test_dataframe_api.py index 2a2955a..b41c610 100644 --- a/src/gateway/tests/test_dataframe_api.py +++ b/src/gateway/tests/test_dataframe_api.py @@ -14,9 +14,7 @@ def mark_dataframe_tests_as_xfail(request): """Marks a subset of tests as expected to be fail.""" source = request.getfixturevalue('source') originalname = request.keywords.node.originalname - if source == 'gateway-over-duckdb': - request.node.add_marker(pytest.mark.xfail(reason='DuckDB column binding error')) - elif source == 'gateway-over-datafusion': + if source == 'gateway-over-datafusion': if originalname in [ 'test_data_source_schema', 'test_data_source_filter', 'test_table', 'test_table_schema', 'test_table_filter', 'test_create_or_replace_temp_view', From 76d18437d82dda025a96afaf0afb5f31cc915ea3 Mon Sep 17 00:00:00 2001 From: Richard Tia Date: Fri, 10 May 2024 11:54:20 -0400 Subject: [PATCH 45/58] feat: add datafusion describe table (#69) --- src/gateway/backends/datafusion_backend.py | 25 ++++++++++++++++++++- src/gateway/converter/spark_to_substrait.py | 12 +++------- src/gateway/server.py | 9 +++++--- src/gateway/tests/test_dataframe_api.py | 5 +---- 4 files changed, 34 insertions(+), 17 deletions(-) diff --git a/src/gateway/backends/datafusion_backend.py b/src/gateway/backends/datafusion_backend.py index 950af4d..b350af8 100644 --- a/src/gateway/backends/datafusion_backend.py +++ b/src/gateway/backends/datafusion_backend.py @@ -9,6 +9,19 @@ from gateway.converter.rename_functions import RenameFunctionsForDatafusion from gateway.converter.replace_local_files import ReplaceLocalFilesWithNamedTable +_DATAFUSION_TO_ARROW = { + 'Boolean': pa.bool_(), + 'Int8': pa.int8(), + 'Int16': pa.int16(), + 'Int32': pa.int32(), + 'Int64': pa.int64(), + 'Float32': pa.float32(), + 'Float64': pa.float64(), + 'Date32': pa.date32(), + 'Timestamp(Nanosecond, None)': pa.timestamp('ns'), + 'Utf8': pa.string(), +} + # pylint: disable=import-outside-toplevel class DatafusionBackend(Backend): @@ -71,4 +84,14 @@ def register_table( # of deregistering it. if self._connection.table_exist(name): self._connection.deregister_table(name) - self._connection.register_parquet(name, files[0]) + self._connection.register_parquet(name, str(location)) + + def describe_table(self, table_name: str): + """Asks the backend to describe the given table.""" + result = self._connection.sql(f"describe {table_name}").to_arrow_table().to_pylist() + + fields = [] + for index in range(len(result)): + fields.append(pa.field(result[index]['column_name'], + _DATAFUSION_TO_ARROW[result[index]['data_type']])) + return pa.schema(fields) diff --git a/src/gateway/converter/spark_to_substrait.py b/src/gateway/converter/spark_to_substrait.py index eb63417..e9357cf 100644 --- a/src/gateway/converter/spark_to_substrait.py +++ b/src/gateway/converter/spark_to_substrait.py @@ -10,8 +10,6 @@ import pyspark.sql.connect.proto.expressions_pb2 as spark_exprs_pb2 import pyspark.sql.connect.proto.relations_pb2 as spark_relations_pb2 import pyspark.sql.connect.proto.types_pb2 as spark_types_pb2 -from gateway.backends.backend_options import BackendOptions -from gateway.backends.backend_selector import find_backend from gateway.converter.conversion_options import ConversionOptions from gateway.converter.spark_functions import ExtensionFunction, lookup_spark_function from gateway.converter.sql_to_substrait import convert_sql @@ -554,13 +552,9 @@ def convert_read_data_source_relation(self, rel: spark_relations_pb2.Read) -> al local = algebra_pb2.ReadRel.LocalFiles() schema = self.convert_schema(rel.schema) if not schema: - backend = find_backend(BackendOptions(self._conversion_options.backend.backend, True)) - try: - backend.register_table(TABLE_NAME, rel.paths[0], rel.format) - arrow_schema = backend.describe_table(TABLE_NAME) - schema = self.convert_arrow_schema(arrow_schema) - finally: - backend.drop_table(TABLE_NAME) + self._backend.register_table(TABLE_NAME, rel.paths[0], rel.format) + arrow_schema = self._backend.describe_table(TABLE_NAME) + schema = self.convert_arrow_schema(arrow_schema) symbol = self._symbol_table.get_symbol(self._current_plan_id) for field_name in schema.names: symbol.output_fields.append(field_name) diff --git a/src/gateway/server.py b/src/gateway/server.py index 9f4a687..d3d0bf4 100644 --- a/src/gateway/server.py +++ b/src/gateway/server.py @@ -159,6 +159,9 @@ def ExecutePlan( self._statistics.add_request(request) _LOGGER.info('ExecutePlan: %s', request) self._InitializeExecution() + # TODO: Register the TPCH data for datafusion through the fixture. + if isinstance(self._backend, backend_selector.DatafusionBackend): + self._backend.register_tpch() if not self._converter: self._converter = SparkSubstraitConverter(self._options) self._converter.set_backend(self._backend) @@ -233,14 +236,14 @@ def AnalyzePlan(self, request, context): self._statistics.add_request(request) _LOGGER.info('AnalyzePlan: %s', request) self._InitializeExecution() + # TODO: Register the TPCH data for datafusion through the fixture. + if isinstance(self._backend, backend_selector.DatafusionBackend): + self._backend.register_tpch() if request.schema: if not self._converter: self._converter = SparkSubstraitConverter(self._options) self._converter.set_backend(self._backend) substrait = self._converter.convert_plan(request.schema.plan) - # TODO: Register the TPCH data for datafusion through the fixture. - if isinstance(self._backend, backend_selector.DatafusionBackend): - self._backend.register_tpch() self._statistics.add_plan(substrait) results = self._backend.execute(substrait) _LOGGER.debug(' results are: %s', results) diff --git a/src/gateway/tests/test_dataframe_api.py b/src/gateway/tests/test_dataframe_api.py index b41c610..eba2a67 100644 --- a/src/gateway/tests/test_dataframe_api.py +++ b/src/gateway/tests/test_dataframe_api.py @@ -15,10 +15,7 @@ def mark_dataframe_tests_as_xfail(request): source = request.getfixturevalue('source') originalname = request.keywords.node.originalname if source == 'gateway-over-datafusion': - if originalname in [ - 'test_data_source_schema', 'test_data_source_filter', 'test_table', 'test_table_schema', - 'test_table_filter', 'test_create_or_replace_temp_view', - 'test_create_or_replace_multiple_temp_views',]: + if originalname in ['test_data_source_filter']: request.node.add_marker(pytest.mark.xfail(reason='Gateway internal iterating error')) else: pytest.importorskip("datafusion.substrait") From 74b36aef8b21be7fad9cdd9e320568bced8fc270 Mon Sep 17 00:00:00 2001 From: David Sisson Date: Fri, 10 May 2024 12:15:59 -0700 Subject: [PATCH 46/58] feat: add location to the substrait validator error messages (#73) --- src/gateway/tests/plan_validator.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/gateway/tests/plan_validator.py b/src/gateway/tests/plan_validator.py index 3030385..d3dfeed 100644 --- a/src/gateway/tests/plan_validator.py +++ b/src/gateway/tests/plan_validator.py @@ -14,9 +14,9 @@ def validate_plan(json_plan: str): issues = [] for issue in diagnostics: if issue.adjusted_level >= substrait_validator.Diagnostic.LEVEL_ERROR: - issues.append(issue.msg) + issues.append([issue.msg, substrait_validator.path_to_string(issue.path)]) if issues: - issues_as_text = '\n'.join(f' → {issue}' for issue in issues) + issues_as_text = '\n'.join(f' → {issue[0]}\n at {issue[1]}' for issue in issues) pytest.fail(f'Validation failed. Issues:\n{issues_as_text}\n\nPlan:\n{substrait_plan}\n', pytrace=False) From 5abafb15f68d6680e226c71f017fe4613f813e3f Mon Sep 17 00:00:00 2001 From: David Sisson Date: Fri, 10 May 2024 12:56:35 -0700 Subject: [PATCH 47/58] feat: remove the server side table registration workaround (#72) --- src/gateway/backends/adbc_backend.py | 5 +- src/gateway/backends/backend.py | 31 +++--------- src/gateway/backends/backend_options.py | 6 +-- src/gateway/backends/backend_selector.py | 8 +-- src/gateway/backends/datafusion_backend.py | 8 ++- src/gateway/backends/duckdb_backend.py | 26 +++++++++- src/gateway/converter/conversion_options.py | 8 +-- src/gateway/converter/data/count.sql-splan | 3 -- src/gateway/converter/spark_to_substrait.py | 14 +++--- .../converter/spark_to_substrait_test.py | 10 +++- src/gateway/converter/sql_to_substrait.py | 14 ++---- src/gateway/server.py | 28 ++++------- src/gateway/tests/conftest.py | 49 ++++++++++--------- src/gateway/tests/test_dataframe_api.py | 10 ++-- 14 files changed, 110 insertions(+), 110 deletions(-) diff --git a/src/gateway/backends/adbc_backend.py b/src/gateway/backends/adbc_backend.py index 57a40f1..5085240 100644 --- a/src/gateway/backends/adbc_backend.py +++ b/src/gateway/backends/adbc_backend.py @@ -9,8 +9,7 @@ from substrait.gen.proto import plan_pb2 from gateway.backends.backend import Backend -from gateway.backends.backend_options import Backend as backend_engine -from gateway.backends.backend_options import BackendOptions +from gateway.backends.backend_options import BackendEngine, BackendOptions def _import(handle): @@ -20,7 +19,7 @@ def _import(handle): def _get_backend_driver(options: BackendOptions) -> tuple[str, str]: """Get the driver and entry point for the specified backend.""" match options.backend: - case backend_engine.DUCKDB: + case BackendEngine.DUCKDB: driver = duckdb.duckdb.__file__ entry_point = "duckdb_adbc_init" case _: diff --git a/src/gateway/backends/backend.py b/src/gateway/backends/backend.py index eb3c4ab..c6f8dba 100644 --- a/src/gateway/backends/backend.py +++ b/src/gateway/backends/backend.py @@ -35,6 +35,10 @@ def register_table(self, name: str, path: Path, file_format: str = 'parquet') -> """Register the given table with the backend.""" raise NotImplementedError() + def describe_files(self, paths: list[str]): + """Asks the backend to describe the given files.""" + raise NotImplementedError() + def describe_table(self, name: str): """Asks the backend to describe the given table.""" raise NotImplementedError() @@ -43,6 +47,10 @@ def drop_table(self, name: str) -> None: """Asks the backend to drop the given table.""" raise NotImplementedError() + def convert_sql(self, sql: str) -> plan_pb2.Plan: + """Convert SQL into a Substrait plan.""" + raise NotImplementedError() + @staticmethod def expand_location(location: Path | str) -> list[str]: """Expand the location of a file or directory into a list of files.""" @@ -50,26 +58,3 @@ def expand_location(location: Path | str) -> list[str]: path = Path(location) files = Path(location).resolve().glob('*.parquet') if path.is_dir() else [path] return sorted(str(f) for f in files) - - @staticmethod - def find_tpch() -> Path: - """Find the location of the TPCH dataset.""" - current_location = Path('.').resolve() - while current_location != Path('/'): - location = current_location / 'third_party' / 'tpch' / 'parquet' - if location.exists(): - return location.resolve() - current_location = current_location.parent - raise ValueError('TPCH dataset not found') - - def register_tpch(self): - """Register the entire TPC-H dataset.""" - tpch_location = Backend.find_tpch() - self.register_table('customer', tpch_location / 'customer') - self.register_table('lineitem', tpch_location / 'lineitem') - self.register_table('nation', tpch_location / 'nation') - self.register_table('orders', tpch_location / 'orders') - self.register_table('part', tpch_location / 'part') - self.register_table('partsupp', tpch_location / 'partsupp') - self.register_table('region', tpch_location / 'region') - self.register_table('supplier', tpch_location / 'supplier') diff --git a/src/gateway/backends/backend_options.py b/src/gateway/backends/backend_options.py index 466e3fe..61133e3 100644 --- a/src/gateway/backends/backend_options.py +++ b/src/gateway/backends/backend_options.py @@ -4,7 +4,7 @@ from enum import Enum -class Backend(Enum): +class BackendEngine(Enum): """Represents the different backends we have support for.""" ARROW = 1 @@ -20,10 +20,10 @@ def __str__(self): class BackendOptions: """Holds all the possible backend options.""" - backend: Backend + backend: BackendEngine use_adbc: bool - def __init__(self, backend: Backend, use_adbc: bool = False): + def __init__(self, backend: BackendEngine, use_adbc: bool = False): """Create a BackendOptions structure.""" self.backend = backend self.use_adbc = use_adbc diff --git a/src/gateway/backends/backend_selector.py b/src/gateway/backends/backend_selector.py index fa88625..c6b6696 100644 --- a/src/gateway/backends/backend_selector.py +++ b/src/gateway/backends/backend_selector.py @@ -3,7 +3,7 @@ from gateway.backends import backend from gateway.backends.adbc_backend import AdbcBackend from gateway.backends.arrow_backend import ArrowBackend -from gateway.backends.backend_options import Backend, BackendOptions +from gateway.backends.backend_options import BackendEngine, BackendOptions from gateway.backends.datafusion_backend import DatafusionBackend from gateway.backends.duckdb_backend import DuckDBBackend @@ -11,11 +11,11 @@ def find_backend(options: BackendOptions) -> backend.Backend: """Given a backend enum, returns an instance of the correct Backend descendant.""" match options.backend: - case Backend.ARROW: + case BackendEngine.ARROW: return ArrowBackend(options) - case Backend.DATAFUSION: + case BackendEngine.DATAFUSION: return DatafusionBackend(options) - case Backend.DUCKDB: + case BackendEngine.DUCKDB: if options.use_adbc: return AdbcBackend(options) return DuckDBBackend(options) diff --git a/src/gateway/backends/datafusion_backend.py b/src/gateway/backends/datafusion_backend.py index b350af8..fdbb8eb 100644 --- a/src/gateway/backends/datafusion_backend.py +++ b/src/gateway/backends/datafusion_backend.py @@ -74,7 +74,7 @@ def execute(self, plan: plan_pb2.Plan) -> pa.lib.Table: self._connection.deregister_table(table_name) def register_table( - self, name: str, location: Path, file_format: str = 'parquet' + self, name: str, location: Path, file_format: str = 'parquet' ) -> None: """Register the given table with the backend.""" files = Backend.expand_location(location) @@ -86,6 +86,12 @@ def register_table( self._connection.deregister_table(name) self._connection.register_parquet(name, str(location)) + def describe_files(self, paths: list[str]): + """Asks the backend to describe the given files.""" + # TODO -- Use the ListingTable API to resolve the combined schema. + df = self._connection.read_parquet(paths[0]) + return df.schema() + def describe_table(self, table_name: str): """Asks the backend to describe the given table.""" result = self._connection.sql(f"describe {table_name}").to_arrow_table().to_pylist() diff --git a/src/gateway/backends/duckdb_backend.py b/src/gateway/backends/duckdb_backend.py index 3993575..88dc966 100644 --- a/src/gateway/backends/duckdb_backend.py +++ b/src/gateway/backends/duckdb_backend.py @@ -72,14 +72,36 @@ def register_table( self._connection.execute(files_sql) + def describe_files(self, paths: list[str]): + """Asks the backend to describe the given files.""" + files = paths + if len(paths) == 1: + files = self.expand_location(paths[0]) + df = self._connection.read_parquet(files) + + fields = [] + for name, field_type in zip(df.columns, df.types, strict=False): + if name == 'aggr': + # This isn't a real column. + continue + fields.append(pa.field(name, _DUCKDB_TO_ARROW[str(field_type)])) + return pa.schema(fields) + def describe_table(self, name: str): """Asks the backend to describe the given table.""" - result = self._connection.table(name).describe() + df = self._connection.table(name).describe() fields = [] - for name, field_type in zip(result.columns, result.types, strict=False): + for name, field_type in zip(df.columns, df.types, strict=False): if name == 'aggr': # This isn't a real column. continue fields.append(pa.field(name, _DUCKDB_TO_ARROW[str(field_type)])) return pa.schema(fields) + + def convert_sql(self, sql: str) -> plan_pb2.Plan: + """Convert SQL into a Substrait plan.""" + plan = plan_pb2.Plan() + proto_bytes = self._connection.get_substrait(query=sql).fetchone()[0] + plan.ParseFromString(proto_bytes) + return plan diff --git a/src/gateway/converter/conversion_options.py b/src/gateway/converter/conversion_options.py index 5ac46bd..0901ee8 100644 --- a/src/gateway/converter/conversion_options.py +++ b/src/gateway/converter/conversion_options.py @@ -2,7 +2,7 @@ """Tracks conversion related options.""" import dataclasses -from gateway.backends.backend_options import Backend, BackendOptions +from gateway.backends.backend_options import BackendEngine, BackendOptions # pylint: disable=too-many-instance-attributes @@ -29,7 +29,7 @@ def __init__(self, backend: BackendOptions = None): def arrow(): """Return standard options to connect to the Acero backend.""" - options = ConversionOptions(backend=BackendOptions(Backend.ARROW)) + options = ConversionOptions(backend=BackendOptions(BackendEngine.ARROW)) options.needs_scheme_in_path_uris = True options.return_names_with_types = True options.implement_show_string = False @@ -40,12 +40,12 @@ def arrow(): def datafusion(): """Return standard options to connect to a Datafusion backend.""" - return ConversionOptions(backend=BackendOptions(Backend.DATAFUSION)) + return ConversionOptions(backend=BackendOptions(BackendEngine.DATAFUSION)) def duck_db(): """Return standard options to connect to a DuckDB backend.""" - options = ConversionOptions(backend=BackendOptions(Backend.DUCKDB)) + options = ConversionOptions(backend=BackendOptions(BackendEngine.DUCKDB)) options.return_names_with_types = True options.use_switch_expressions_where_possible = False options.use_duckdb_regexp_matches_function = True diff --git a/src/gateway/converter/data/count.sql-splan b/src/gateway/converter/data/count.sql-splan index 58ade53..49eb7f5 100644 --- a/src/gateway/converter/data/count.sql-splan +++ b/src/gateway/converter/data/count.sql-splan @@ -1,6 +1,3 @@ -extension_uris { - uri: "urn:arrow:substrait_simple_extension_function" -} extensions { extension_function { function_anchor: 1 diff --git a/src/gateway/converter/spark_to_substrait.py b/src/gateway/converter/spark_to_substrait.py index e9357cf..4731def 100644 --- a/src/gateway/converter/spark_to_substrait.py +++ b/src/gateway/converter/spark_to_substrait.py @@ -12,7 +12,6 @@ import pyspark.sql.connect.proto.types_pb2 as spark_types_pb2 from gateway.converter.conversion_options import ConversionOptions from gateway.converter.spark_functions import ExtensionFunction, lookup_spark_function -from gateway.converter.sql_to_substrait import convert_sql from gateway.converter.substrait_builder import ( aggregate_relation, bigint_literal, @@ -43,8 +42,6 @@ from substrait.gen.proto import algebra_pb2, plan_pb2, type_pb2 from substrait.gen.proto.extensions import extensions_pb2 -TABLE_NAME = "my_table" - # ruff: noqa: RUF005 class SparkSubstraitConverter: @@ -61,10 +58,12 @@ def __init__(self, options: ConversionOptions): self._saved_extension_uris = {} self._saved_extensions = {} self._backend = None + self._sql_backend = None - def set_backend(self, backend) -> None: - """Save the backend being used to create the temporary dataframe.""" + def set_backends(self, backend, sql_backend) -> None: + """Save the backends being used to resolve tables and convert to SQL.""" self._backend = backend + self._sql_backend = sql_backend def lookup_function_by_name(self, name: str) -> ExtensionFunction: """Find the function reference for a given Spark function name.""" @@ -552,8 +551,7 @@ def convert_read_data_source_relation(self, rel: spark_relations_pb2.Read) -> al local = algebra_pb2.ReadRel.LocalFiles() schema = self.convert_schema(rel.schema) if not schema: - self._backend.register_table(TABLE_NAME, rel.paths[0], rel.format) - arrow_schema = self._backend.describe_table(TABLE_NAME) + arrow_schema = self._backend.describe_files([str(path) for path in rel.paths]) schema = self.convert_arrow_schema(arrow_schema) symbol = self._symbol_table.get_symbol(self._current_plan_id) for field_name in schema.names: @@ -988,7 +986,7 @@ def convert_local_relation(self, rel: spark_relations_pb2.LocalRelation) -> alge def convert_sql_relation(self, rel: spark_relations_pb2.SQL) -> algebra_pb2.Rel: """Convert a Spark SQL relation into a Substrait relation.""" # TODO -- Handle multithreading in the case with a persistent backend. - plan = convert_sql(rel.query) + plan = self._sql_backend.convert_sql(rel.query) symbol = self._symbol_table.get_symbol(self._current_plan_id) for field_name in plan.relations[0].root.names: symbol.output_fields.append(field_name) diff --git a/src/gateway/converter/spark_to_substrait_test.py b/src/gateway/converter/spark_to_substrait_test.py index 685c71d..17909da 100644 --- a/src/gateway/converter/spark_to_substrait_test.py +++ b/src/gateway/converter/spark_to_substrait_test.py @@ -3,10 +3,11 @@ from pathlib import Path import pytest +from gateway.backends.backend_selector import find_backend from gateway.converter.conversion_options import duck_db from gateway.converter.spark_to_substrait import SparkSubstraitConverter -from gateway.converter.sql_to_substrait import convert_sql from gateway.demo.mystream_database import create_mystream_database, delete_mystream_database +from gateway.tests.conftest import find_tpch from google.protobuf import text_format from pyspark.sql.connect.proto import base_pb2 as spark_base_pb2 from substrait.gen.proto import plan_pb2 @@ -41,8 +42,10 @@ def test_plan_conversion(request, path): substrait_plan = text_format.Parse(splan_prototext, plan_pb2.Plan()) options = duck_db() + backend = find_backend(options.backend) options.implement_show_string = False convert = SparkSubstraitConverter(options) + convert.set_backends(backend, backend) substrait = convert.convert_plan(spark_plan) if request.config.getoption('rebuild_goldens'): @@ -80,7 +83,10 @@ def test_sql_conversion(request, path): splan_prototext = file.read() substrait_plan = text_format.Parse(splan_prototext, plan_pb2.Plan()) - substrait = convert_sql(str(sql)) + options = duck_db() + backend = find_backend(options.backend) + backend.register_table('customer', find_tpch() / 'customer') + substrait = backend.convert_sql(str(sql)) if request.config.getoption('rebuild_goldens'): if substrait != substrait_plan: diff --git a/src/gateway/converter/sql_to_substrait.py b/src/gateway/converter/sql_to_substrait.py index 939644c..5894398 100644 --- a/src/gateway/converter/sql_to_substrait.py +++ b/src/gateway/converter/sql_to_substrait.py @@ -1,21 +1,15 @@ # SPDX-License-Identifier: Apache-2.0 """Routines to convert SparkConnect plans to Substrait plans.""" -from gateway.backends import backend_selector -from gateway.backends.backend_options import Backend, BackendOptions +from gateway.backends.backend import Backend from gateway.converter.add_extension_uris import AddExtensionUris from substrait.gen.proto import plan_pb2 -def convert_sql(sql: str) -> plan_pb2.Plan: +def convert_sql(backend: Backend, sql: str) -> plan_pb2.Plan: """Convert SQL into a Substrait plan.""" - plan = plan_pb2.Plan() - - backend = backend_selector.find_backend(BackendOptions(Backend.DUCKDB, False)) - backend.register_tpch() - connection = backend.get_connection() - proto_bytes = connection.get_substrait(query=sql).fetchone()[0] - plan.ParseFromString(proto_bytes) + plan = backend.convert_sql(sql) + # Perform various fixes to make the plan more compatible. # TODO -- Remove this after the SQL converter is fixed. AddExtensionUris().visit_plan(plan) diff --git a/src/gateway/server.py b/src/gateway/server.py index d3d0bf4..3e61b21 100644 --- a/src/gateway/server.py +++ b/src/gateway/server.py @@ -13,11 +13,11 @@ from pyspark.sql.connect.proto import types_pb2 from substrait.gen.proto import algebra_pb2, plan_pb2 -from gateway.backends import backend_selector +from gateway.backends.backend import Backend +from gateway.backends.backend_options import BackendEngine, BackendOptions from gateway.backends.backend_selector import find_backend from gateway.converter.conversion_options import arrow, datafusion, duck_db from gateway.converter.spark_to_substrait import SparkSubstraitConverter -from gateway.converter.sql_to_substrait import convert_sql _LOGGER = logging.getLogger(__name__) @@ -142,7 +142,8 @@ def __init__(self, *args, **kwargs): """Initialize the SparkConnect service.""" # This is the central point for configuring the behavior of the service. self._options = duck_db() - self._backend = None + self._backend: Backend | None = None + self._sql_backend: Backend | None = None self._converter = None self._statistics = Statistics() @@ -150,6 +151,9 @@ def _InitializeExecution(self): """Initialize the execution of the Plan by setting the backend.""" if not self._backend: self._backend = find_backend(self._options.backend) + self._sql_backend = find_backend(BackendOptions(BackendEngine.DUCKDB, False)) + self._converter = SparkSubstraitConverter(self._options) + self._converter.set_backends(self._backend, self._sql_backend) def ExecutePlan( self, request: pb2.ExecutePlanRequest, context: grpc.RpcContext) -> Generator[ @@ -159,12 +163,6 @@ def ExecutePlan( self._statistics.add_request(request) _LOGGER.info('ExecutePlan: %s', request) self._InitializeExecution() - # TODO: Register the TPCH data for datafusion through the fixture. - if isinstance(self._backend, backend_selector.DatafusionBackend): - self._backend.register_tpch() - if not self._converter: - self._converter = SparkSubstraitConverter(self._options) - self._converter.set_backend(self._backend) match request.plan.WhichOneof('op_type'): case 'root': substrait = self._converter.convert_plan(request.plan) @@ -178,9 +176,11 @@ def ExecutePlan( session_id=request.session_id, result_complete=pb2.ExecutePlanResponse.ResultComplete()) return - substrait = convert_sql(request.plan.command.sql_command.sql) + substrait = self._sql_backend.convert_sql( + request.plan.command.sql_command.sql) case 'create_dataframe_view': create_dataframe_view(request.plan, self._backend) + create_dataframe_view(request.plan, self._sql_backend) yield pb2.ExecutePlanResponse( session_id=request.session_id, result_complete=pb2.ExecutePlanResponse.ResultComplete()) @@ -192,8 +192,6 @@ def ExecutePlan( raise ValueError(f'Unknown plan type: {request.plan}') _LOGGER.debug(' as Substrait: %s', substrait) # TODO: Register the TPCH data for datafusion through the fixture. - if isinstance(self._backend, backend_selector.DatafusionBackend): - self._backend.register_tpch() self._statistics.add_plan(substrait) results = self._backend.execute(substrait) _LOGGER.debug(' results are: %s', results) @@ -236,13 +234,7 @@ def AnalyzePlan(self, request, context): self._statistics.add_request(request) _LOGGER.info('AnalyzePlan: %s', request) self._InitializeExecution() - # TODO: Register the TPCH data for datafusion through the fixture. - if isinstance(self._backend, backend_selector.DatafusionBackend): - self._backend.register_tpch() if request.schema: - if not self._converter: - self._converter = SparkSubstraitConverter(self._options) - self._converter.set_backend(self._backend) substrait = self._converter.convert_plan(request.schema.plan) self._statistics.add_plan(substrait) results = self._backend.execute(substrait) diff --git a/src/gateway/tests/conftest.py b/src/gateway/tests/conftest.py index cd5648c..21360f5 100644 --- a/src/gateway/tests/conftest.py +++ b/src/gateway/tests/conftest.py @@ -4,7 +4,6 @@ from pathlib import Path import pytest -from gateway.backends.backend import Backend from gateway.demo.mystream_database import ( create_mystream_database, delete_mystream_database, @@ -126,31 +125,33 @@ def users_dataframe(spark_session, schema_users, users_location): .parquet(users_location) -def _register_table(spark_session: SparkSession, source: str, name: str) -> None: - location = Backend.find_tpch() / name - match source: - case 'spark': - spark_session.sql( - f'CREATE OR REPLACE TEMPORARY VIEW {name} USING org.apache.spark.sql.parquet ' - f'OPTIONS ( path "{location}" )') - case 'gateway-over-duckdb': - files = Backend.expand_location(location) - if not files: - raise ValueError(f"No parquet files found at {location}") - files_str = ', '.join([f"'{f}'" for f in files]) - files_sql = f"CREATE OR REPLACE TABLE {name} AS FROM read_parquet([{files_str}])" - spark_session.sql(files_sql) +def find_tpch() -> Path: + """Find the location of the TPC-H dataset.""" + current_location = Path('.').resolve() + while current_location != Path('/'): + location = current_location / 'third_party' / 'tpch' / 'parquet' + if location.exists(): + return location.resolve() + current_location = current_location.parent + raise ValueError('TPC-H dataset not found') + + +def _register_table(spark_session: SparkSession, name: str) -> None: + """Registers a TPC-H table with the given name into spark_session.""" + location = find_tpch() / name + df = spark_session.read.parquet(str(location)) + df.createOrReplaceTempView(name) @pytest.fixture(scope='function') -def spark_session_with_tpch_dataset(spark_session: SparkSession, source: str) -> SparkSession: +def spark_session_with_tpch_dataset(spark_session: SparkSession) -> SparkSession: """Add the TPC-H dataset to the current spark session.""" - _register_table(spark_session, source, 'customer') - _register_table(spark_session, source, 'lineitem') - _register_table(spark_session, source, 'nation') - _register_table(spark_session, source, 'orders') - _register_table(spark_session, source, 'part') - _register_table(spark_session, source, 'partsupp') - _register_table(spark_session, source, 'region') - _register_table(spark_session, source, 'supplier') + _register_table(spark_session, 'customer') + _register_table(spark_session, 'lineitem') + _register_table(spark_session, 'nation') + _register_table(spark_session, 'orders') + _register_table(spark_session, 'part') + _register_table(spark_session, 'partsupp') + _register_table(spark_session, 'region') + _register_table(spark_session, 'supplier') return spark_session diff --git a/src/gateway/tests/test_dataframe_api.py b/src/gateway/tests/test_dataframe_api.py index eba2a67..236a210 100644 --- a/src/gateway/tests/test_dataframe_api.py +++ b/src/gateway/tests/test_dataframe_api.py @@ -1,7 +1,7 @@ # SPDX-License-Identifier: Apache-2.0 """Tests for the Spark to Substrait Gateway server.""" import pytest -from gateway.backends.backend import Backend +from gateway.tests.conftest import find_tpch from gateway.tests.plan_validator import utilizes_valid_plans from hamcrest import assert_that, equal_to from pyspark import Row @@ -124,12 +124,12 @@ def test_cast(self, users_dataframe): assertDataFrameEqual(outcome, expected) def test_data_source_schema(self, spark_session): - location_customer = str(Backend.find_tpch() / 'customer') + location_customer = str(find_tpch() / 'customer') schema = spark_session.read.parquet(location_customer).schema assert len(schema) == 8 def test_data_source_filter(self, spark_session): - location_customer = str(Backend.find_tpch() / 'customer') + location_customer = str(find_tpch() / 'customer') customer_dataframe = spark_session.read.parquet(location_customer) with utilizes_valid_plans(spark_session): @@ -156,7 +156,7 @@ def test_table_filter(self, spark_session_with_tpch_dataset): assert len(outcome) == 29968 def test_create_or_replace_temp_view(self, spark_session): - location_customer = str(Backend.find_tpch() / 'customer') + location_customer = str(find_tpch() / 'customer') df_customer = spark_session.read.parquet(location_customer) df_customer.createOrReplaceTempView("mytempview") @@ -166,7 +166,7 @@ def test_create_or_replace_temp_view(self, spark_session): assert len(outcome) == 149999 def test_create_or_replace_multiple_temp_views(self, spark_session): - location_customer = str(Backend.find_tpch() / 'customer') + location_customer = str(find_tpch() / 'customer') df_customer = spark_session.read.parquet(location_customer) df_customer.createOrReplaceTempView("mytempview1") df_customer.createOrReplaceTempView("mytempview2") From 15bd09d5e3a20f691100a0dcbfe5c5c7bb2c18e5 Mon Sep 17 00:00:00 2001 From: Richard Tia Date: Fri, 10 May 2024 20:08:03 -0400 Subject: [PATCH 48/58] fix: update datafusion execute to register all files (#74) --- src/gateway/backends/datafusion_backend.py | 7 +++---- src/gateway/tests/test_dataframe_api.py | 6 +----- 2 files changed, 4 insertions(+), 9 deletions(-) diff --git a/src/gateway/backends/datafusion_backend.py b/src/gateway/backends/datafusion_backend.py index fdbb8eb..88baa1d 100644 --- a/src/gateway/backends/datafusion_backend.py +++ b/src/gateway/backends/datafusion_backend.py @@ -47,10 +47,9 @@ def execute(self, plan: plan_pb2.Plan) -> pa.lib.Table: registered_tables = set() for files in file_groups: table_name = files[0] - for file in files[1]: - if table_name not in registered_tables: - self.register_table(table_name, file) - registered_tables.add(files[0]) + location = Path(files[1][0]).parent + self.register_table(table_name, location) + registered_tables.add(table_name) RenameFunctionsForDatafusion().visit_plan(plan) diff --git a/src/gateway/tests/test_dataframe_api.py b/src/gateway/tests/test_dataframe_api.py index 236a210..dc38392 100644 --- a/src/gateway/tests/test_dataframe_api.py +++ b/src/gateway/tests/test_dataframe_api.py @@ -13,12 +13,8 @@ def mark_dataframe_tests_as_xfail(request): """Marks a subset of tests as expected to be fail.""" source = request.getfixturevalue('source') - originalname = request.keywords.node.originalname if source == 'gateway-over-datafusion': - if originalname in ['test_data_source_filter']: - request.node.add_marker(pytest.mark.xfail(reason='Gateway internal iterating error')) - else: - pytest.importorskip("datafusion.substrait") + pytest.importorskip("datafusion.substrait") # pylint: disable=missing-function-docstring From 6878a40fc6fcf0216519a2373aca43ccb6a381a5 Mon Sep 17 00:00:00 2001 From: David Sisson Date: Sat, 11 May 2024 16:25:02 -0700 Subject: [PATCH 49/58] feat: fix output fields expected coming out of join relations (#75) --- src/gateway/converter/spark_to_substrait.py | 2 +- src/gateway/tests/test_dataframe_api.py | 18 ++++++++++++++++++ .../tests/test_tpch_with_dataframe_api.py | 10 +++++----- 3 files changed, 24 insertions(+), 6 deletions(-) diff --git a/src/gateway/converter/spark_to_substrait.py b/src/gateway/converter/spark_to_substrait.py index 4731def..05b5263 100644 --- a/src/gateway/converter/spark_to_substrait.py +++ b/src/gateway/converter/spark_to_substrait.py @@ -84,7 +84,7 @@ def update_field_references(self, plan_id: int) -> None: source_symbol = self._symbol_table.get_symbol(plan_id) current_symbol = self._symbol_table.get_symbol(self._current_plan_id) current_symbol.input_fields.extend(source_symbol.output_fields) - current_symbol.output_fields.extend(current_symbol.input_fields) + current_symbol.output_fields.extend(source_symbol.output_fields) def find_field_by_name(self, field_name: str) -> int | None: """Look up the field name in the current set of field references.""" diff --git a/src/gateway/tests/test_dataframe_api.py b/src/gateway/tests/test_dataframe_api.py index dc38392..d8901e1 100644 --- a/src/gateway/tests/test_dataframe_api.py +++ b/src/gateway/tests/test_dataframe_api.py @@ -119,6 +119,24 @@ def test_cast(self, users_dataframe): assertDataFrameEqual(outcome, expected) + def test_join(self, spark_session_with_tpch_dataset): + expected = [ + Row(n_nationkey=5, n_name='ETHIOPIA', n_regionkey=0, + n_comment='ven packages wake quickly. regu', s_suppkey=2, + s_name='Supplier#000000002', s_address='89eJ5ksX3ImxJQBvxObC,', s_nationkey=5, + s_phone='15-679-861-2259', s_acctbal=4032.68, + s_comment=' slyly bold instructions. idle dependen'), + ] + + with utilizes_valid_plans(spark_session_with_tpch_dataset): + nation = spark_session_with_tpch_dataset.table('nation') + supplier = spark_session_with_tpch_dataset.table('supplier') + + nat = nation.join(supplier, col('n_nationkey') == col('s_nationkey')) + outcome = nat.filter(col('s_suppkey') == 2).limit(1).collect() + + assertDataFrameEqual(outcome, expected) + def test_data_source_schema(self, spark_session): location_customer = str(find_tpch() / 'customer') schema = spark_session.read.parquet(location_customer).schema diff --git a/src/gateway/tests/test_tpch_with_dataframe_api.py b/src/gateway/tests/test_tpch_with_dataframe_api.py index 6b338dc..deb49f2 100644 --- a/src/gateway/tests/test_tpch_with_dataframe_api.py +++ b/src/gateway/tests/test_tpch_with_dataframe_api.py @@ -16,11 +16,11 @@ def mark_tests_as_xfail(request): source = request.getfixturevalue('source') originalname = request.keywords.node.originalname if source == 'gateway-over-duckdb' and originalname in [ - 'test_query_02', 'test_query_03', 'test_query_04', 'test_query_05', 'test_query_07', - 'test_query_08', 'test_query_09', 'test_query_10', 'test_query_11', 'test_query_12', - 'test_query_13', 'test_query_14', 'test_query_15', 'test_query_16', 'test_query_17', - 'test_query_18', 'test_query_19', 'test_query_20', 'test_query_21', 'test_query_22']: - request.node.add_marker(pytest.mark.xfail(reason='DuckDB binder error')) + 'test_query_02', 'test_query_03', 'test_query_04', 'test_query_05', 'test_query_07', + 'test_query_08', 'test_query_09', 'test_query_10', 'test_query_11', 'test_query_12', + 'test_query_14', 'test_query_15', 'test_query_16', 'test_query_17', 'test_query_18', + 'test_query_19', 'test_query_20', 'test_query_21', 'test_query_22']: + request.node.add_marker(pytest.mark.xfail(reason='DuckDB binder error')) if source == 'gateway-over-datafusion': pytest.importorskip("datafusion.substrait") request.node.add_marker(pytest.mark.xfail(reason='gateway internal error')) From d18c4d0dd97816851d0b695befe1bd9c2434ffaf Mon Sep 17 00:00:00 2001 From: David Sisson Date: Tue, 14 May 2024 17:47:41 -0700 Subject: [PATCH 50/58] feat: fix conversion of spark project to substrait (#76) The SparkConnect project relation does not pass any fields through except the fields it generates. The previous behavior mimicked that of Substrait which was to pass all of the input fields through as well. The SparkConnect behavior is now used. This PR also now emits all plans emitted in failed tests, avoids duplicate names in joins (usually the join keys), and retains the name of projected field references. --- src/gateway/converter/spark_functions.py | 6 ++++- src/gateway/converter/spark_to_substrait.py | 16 +++++++++--- src/gateway/tests/plan_validator.py | 11 ++++---- .../tests/test_tpch_with_dataframe_api.py | 25 +++++++++++++------ 4 files changed, 42 insertions(+), 16 deletions(-) diff --git a/src/gateway/converter/spark_functions.py b/src/gateway/converter/spark_functions.py index bc78ded..3790c82 100644 --- a/src/gateway/converter/spark_functions.py +++ b/src/gateway/converter/spark_functions.py @@ -121,7 +121,11 @@ def __lt__(self, obj) -> bool: i64=type_pb2.Type.I64( nullability=type_pb2.Type.Nullability.NULLABILITY_REQUIRED))), 'max': ExtensionFunction( - '/functions_aggregate.yaml', 'max:i64', type_pb2.Type( + '/unknown.yaml', 'max:i64', type_pb2.Type( + i64=type_pb2.Type.I64( + nullability=type_pb2.Type.Nullability.NULLABILITY_REQUIRED))), + 'min': ExtensionFunction( + '/unknown.yaml', 'min:i64', type_pb2.Type( i64=type_pb2.Type.I64( nullability=type_pb2.Type.Nullability.NULLABILITY_REQUIRED))), 'string_agg': ExtensionFunction( diff --git a/src/gateway/converter/spark_to_substrait.py b/src/gateway/converter/spark_to_substrait.py index 05b5263..f712ad5 100644 --- a/src/gateway/converter/spark_to_substrait.py +++ b/src/gateway/converter/spark_to_substrait.py @@ -83,14 +83,19 @@ def update_field_references(self, plan_id: int) -> None: """Use the field references using the specified portion of the plan.""" source_symbol = self._symbol_table.get_symbol(plan_id) current_symbol = self._symbol_table.get_symbol(self._current_plan_id) - current_symbol.input_fields.extend(source_symbol.output_fields) - current_symbol.output_fields.extend(source_symbol.output_fields) + original_output_fields = current_symbol.output_fields + for symbol in source_symbol.output_fields: + new_name = symbol + while new_name in original_output_fields: + new_name = new_name + '_dup' + current_symbol.input_fields.append(new_name) + current_symbol.output_fields.append(new_name) def find_field_by_name(self, field_name: str) -> int | None: """Look up the field name in the current set of field references.""" current_symbol = self._symbol_table.get_symbol(self._current_plan_id) try: - return current_symbol.output_fields.index(field_name) + return current_symbol.input_fields.index(field_name) except ValueError: return None @@ -1056,11 +1061,16 @@ def convert_project_relation( project.expressions.append(self.convert_expression(expr)) if expr.HasField('alias'): name = expr.alias.name[0] + elif expr.WhichOneof('expr_type') == 'unresolved_attribute': + name = expr.unresolved_attribute.unparsed_identifier else: name = f'generated_field_{field_number}' symbol.generated_fields.append(name) symbol.output_fields.append(name) project.common.CopyFrom(self.create_common_relation()) + symbol.output_fields = symbol.generated_fields + for field_number in range(len(rel.expressions)): + project.common.emit.output_mapping.append(field_number + len(symbol.input_fields)) return algebra_pb2.Rel(project=project) def convert_subquery_alias_relation(self, diff --git a/src/gateway/tests/plan_validator.py b/src/gateway/tests/plan_validator.py index d3dfeed..eb505ee 100644 --- a/src/gateway/tests/plan_validator.py +++ b/src/gateway/tests/plan_validator.py @@ -35,14 +35,15 @@ def utilizes_valid_plans(session): except SparkConnectGrpcException as e: exception = e if session.conf.get('spark-substrait-gateway.backend', 'spark') == 'spark': + if exception: + raise exception return plan_count = int(session.conf.get('spark-substrait-gateway.plan_count')) - first_plan = None + plans_as_text = [] for i in range(plan_count): plan = session.conf.get(f'spark-substrait-gateway.plan.{i + 1}') - if first_plan is None: - first_plan = plan + plans_as_text.append( f'Plan #{i+1}:\n{plan}\n') validate_plan(plan) if exception: - pytest.fail(f'Exception raised during plan validation: {exception.message}\n\n' - f'First Plan:\n{first_plan}\n', pytrace=False) + pytest.fail(f'Exception raised during execution: {exception.message}\n\n' + + '\n\n'.join(plans_as_text), pytrace=False) diff --git a/src/gateway/tests/test_tpch_with_dataframe_api.py b/src/gateway/tests/test_tpch_with_dataframe_api.py index deb49f2..c61ee02 100644 --- a/src/gateway/tests/test_tpch_with_dataframe_api.py +++ b/src/gateway/tests/test_tpch_with_dataframe_api.py @@ -15,13 +15,24 @@ def mark_tests_as_xfail(request): """Marks a subset of tests as expected to be fail.""" source = request.getfixturevalue('source') originalname = request.keywords.node.originalname - if source == 'gateway-over-duckdb' and originalname in [ - 'test_query_02', 'test_query_03', 'test_query_04', 'test_query_05', 'test_query_07', - 'test_query_08', 'test_query_09', 'test_query_10', 'test_query_11', 'test_query_12', - 'test_query_14', 'test_query_15', 'test_query_16', 'test_query_17', 'test_query_18', - 'test_query_19', 'test_query_20', 'test_query_21', 'test_query_22']: - request.node.add_marker(pytest.mark.xfail(reason='DuckDB binder error')) - if source == 'gateway-over-datafusion': + if source == 'gateway-over-duckdb': + if originalname in [ 'test_query_03', 'test_query_18']: + request.node.add_marker(pytest.mark.xfail(reason='Date time type mismatch')) + elif originalname in ['test_query_04']: + request.node.add_marker(pytest.mark.xfail(reason='Incorrect calculation')) + elif originalname in[ 'test_query_07', 'test_query_08', 'test_query_09']: + request.node.add_marker(pytest.mark.xfail(reason='Substring argument mismatch')) + elif originalname in ['test_query_12']: + request.node.add_marker(pytest.mark.xfail(reason='Missing nullability information')) + elif originalname in ['test_query_15']: + request.node.add_marker(pytest.mark.xfail(reason='No results (float vs decimal)')) + elif originalname in ['test_query_16', 'test_query_21']: + request.node.add_marker(pytest.mark.xfail(reason='Disctinct argument behavior')) + elif originalname in ['test_query_20']: + request.node.add_marker(pytest.mark.xfail(reason='Unknown validation error')) + elif originalname in ['test_query_22']: + request.node.add_marker(pytest.mark.xfail(reason='Schema determination for null')) + elif source == 'gateway-over-datafusion': pytest.importorskip("datafusion.substrait") request.node.add_marker(pytest.mark.xfail(reason='gateway internal error')) From ac0f97b86daa3d0de923dcf0a9b11147602323e3 Mon Sep 17 00:00:00 2001 From: David Sisson Date: Wed, 15 May 2024 09:20:49 -0700 Subject: [PATCH 51/58] feat: various fixes for arrow backend support (#77) This PR does not yet add arrow to CI because there aren't enough passing tests to make doing so worthwhile. --- environment.yml | 2 +- src/gateway/converter/conversion_options.py | 1 - src/gateway/converter/data/00001.splan | 2 ++ src/gateway/converter/rename_functions.py | 31 +++++++++++++++++++++ src/gateway/converter/spark_functions.py | 8 +++--- src/gateway/converter/spark_to_substrait.py | 5 +++- src/gateway/converter/substrait_builder.py | 19 +++++++------ src/gateway/tests/conftest.py | 22 +++++++++------ 8 files changed, 67 insertions(+), 23 deletions(-) diff --git a/environment.yml b/environment.yml index 9d04cd4..254b805 100644 --- a/environment.yml +++ b/environment.yml @@ -18,7 +18,7 @@ dependencies: - pip: - adbc_driver_manager - cargo - - pyarrow >= 13.0.0 + - pyarrow >= 16.0.0 - duckdb == 0.10.1 - datafusion >= 36.0.0 - pyspark diff --git a/src/gateway/converter/conversion_options.py b/src/gateway/converter/conversion_options.py index 0901ee8..df516d7 100644 --- a/src/gateway/converter/conversion_options.py +++ b/src/gateway/converter/conversion_options.py @@ -32,7 +32,6 @@ def arrow(): options = ConversionOptions(backend=BackendOptions(BackendEngine.ARROW)) options.needs_scheme_in_path_uris = True options.return_names_with_types = True - options.implement_show_string = False options.backend.use_arrow_uri_workaround = True options.safety_project_read_relations = True return options diff --git a/src/gateway/converter/data/00001.splan b/src/gateway/converter/data/00001.splan index b9ef5e9..1dad38f 100644 --- a/src/gateway/converter/data/00001.splan +++ b/src/gateway/converter/data/00001.splan @@ -350,6 +350,7 @@ relations { } } } + failure_behavior: FAILURE_BEHAVIOR_THROW_EXCEPTION } } expressions { @@ -470,6 +471,7 @@ relations { } } } + failure_behavior: FAILURE_BEHAVIOR_THROW_EXCEPTION } } expressions { diff --git a/src/gateway/converter/rename_functions.py b/src/gateway/converter/rename_functions.py index 34f1c9c..71ac19b 100644 --- a/src/gateway/converter/rename_functions.py +++ b/src/gateway/converter/rename_functions.py @@ -2,6 +2,7 @@ """A library to search Substrait plan for local files.""" from gateway.converter.substrait_plan_visitor import SubstraitPlanVisitor from substrait.gen.proto import plan_pb2 +from substrait.gen.proto.extensions import extensions_pb2 # pylint: disable=no-member,fixme @@ -46,6 +47,20 @@ def __init__(self, use_uri_workaround=False): self._use_uri_workaround = use_uri_workaround super().__init__() + def _find_arrow_uri_reference(self, plan: plan_pb2.Plan) -> int: + """Find the URI reference for the Arrow workaround.""" + biggest_reference = -1 + for extension in plan.extension_uris: + if extension.uri == 'urn:arrow:substrait_simple_extension_function': + return extension.extension_uri_anchor + if extension.extension_uri_anchor > biggest_reference: + biggest_reference = extension.extension_uri_anchor + plan.extension_uris.append(extensions_pb2.SimpleExtensionURI( + extension_uri_anchor=biggest_reference + 1, + uri='urn:arrow:substrait_simple_extension_function')) + self._extensions[biggest_reference + 1] = 'urn:arrow:substrait_simple_extension_function' + return biggest_reference + 1 + def normalize_extension_uris(self, plan: plan_pb2.Plan) -> None: """Normalize the URI.""" for extension in plan.extension_uris: @@ -83,7 +98,23 @@ def visit_plan(self, plan: plan_pb2.Plan) -> None: changed = False if name == 'char_length': changed = True + extension.extension_function.extension_uri_reference = ( + self._find_arrow_uri_reference(plan)) name = 'utf8_length' + elif name == 'max': + changed = True + extension.extension_function.extension_uri_reference = ( + self._find_arrow_uri_reference(plan)) + elif name == 'gt': + changed = True + extension.extension_function.extension_uri_reference = ( + self._find_arrow_uri_reference(plan)) + name = 'greater' + elif name == 'lt': + changed = True + extension.extension_function.extension_uri_reference = ( + self._find_arrow_uri_reference(plan)) + name = 'less' if not changed: continue diff --git a/src/gateway/converter/spark_functions.py b/src/gateway/converter/spark_functions.py index 3790c82..261cb58 100644 --- a/src/gateway/converter/spark_functions.py +++ b/src/gateway/converter/spark_functions.py @@ -42,19 +42,19 @@ def __lt__(self, obj) -> bool: bool=type_pb2.Type.Boolean( nullability=type_pb2.Type.Nullability.NULLABILITY_REQUIRED))), '<=': ExtensionFunction( - '/functions_comparison.yaml', 'lte:str_str', type_pb2.Type( + '/functions_comparison.yaml', 'lte:i64_i64', type_pb2.Type( bool=type_pb2.Type.Boolean( nullability=type_pb2.Type.Nullability.NULLABILITY_REQUIRED))), '>=': ExtensionFunction( - '/functions_comparison.yaml', 'gte:str_str', type_pb2.Type( + '/functions_comparison.yaml', 'gte:i64_i64', type_pb2.Type( bool=type_pb2.Type.Boolean( nullability=type_pb2.Type.Nullability.NULLABILITY_REQUIRED))), '<': ExtensionFunction( - '/functions_comparison.yaml', 'lt:str_str', type_pb2.Type( + '/functions_comparison.yaml', 'lt:i64_i64', type_pb2.Type( bool=type_pb2.Type.Boolean( nullability=type_pb2.Type.Nullability.NULLABILITY_REQUIRED))), '>': ExtensionFunction( - '/functions_comparison.yaml', 'gt:str_str', type_pb2.Type( + '/functions_comparison.yaml', 'gt:i64_i64', type_pb2.Type( bool=type_pb2.Type.Boolean( nullability=type_pb2.Type.Nullability.NULLABILITY_REQUIRED))), '+': ExtensionFunction( diff --git a/src/gateway/converter/spark_to_substrait.py b/src/gateway/converter/spark_to_substrait.py index f712ad5..90d7507 100644 --- a/src/gateway/converter/spark_to_substrait.py +++ b/src/gateway/converter/spark_to_substrait.py @@ -367,7 +367,9 @@ def convert_type(self, spark_type: spark_types_pb2.DataType) -> type_pb2.Type: def convert_cast_expression( self, cast: spark_exprs_pb2.Expression.Cast) -> algebra_pb2.Expression: """Convert a Spark cast expression into a Substrait cast expression.""" - cast_rel = algebra_pb2.Expression.Cast(input=self.convert_expression(cast.expr)) + cast_rel = algebra_pb2.Expression.Cast( + input=self.convert_expression(cast.expr), + failure_behavior=algebra_pb2.Expression.Cast.FAILURE_BEHAVIOR_THROW_EXCEPTION) match cast.WhichOneof('cast_to_type'): case 'type': cast_rel.type.CopyFrom(self.convert_type(cast.type)) @@ -457,6 +459,7 @@ def convert_read_named_table_relation( table_name = rel.unparsed_identifier arrow_schema = self._backend.describe_table(table_name) + schema = self.convert_arrow_schema(arrow_schema) symbol = self._symbol_table.get_symbol(self._current_plan_id) diff --git a/src/gateway/converter/substrait_builder.py b/src/gateway/converter/substrait_builder.py index 6eca52d..01bc3c8 100644 --- a/src/gateway/converter/substrait_builder.py +++ b/src/gateway/converter/substrait_builder.py @@ -76,8 +76,9 @@ def cast_operation(expression: algebra_pb2.Expression, output_type: type_pb2.Type) -> algebra_pb2.Expression: """Construct a Substrait cast expression.""" return algebra_pb2.Expression( - cast=algebra_pb2.Expression.Cast(input=expression, type=output_type) - ) + cast=algebra_pb2.Expression.Cast( + input=expression, type=output_type, + failure_behavior=algebra_pb2.Expression.Cast.FAILURE_BEHAVIOR_THROW_EXCEPTION)) def if_then_else_operation(if_expr: algebra_pb2.Expression, then_expr: algebra_pb2.Expression, @@ -108,7 +109,8 @@ def max_agg_function(function_info: ExtensionFunction, return algebra_pb2.AggregateFunction( function_reference=function_info.anchor, output_type=function_info.output_type, - arguments=[algebra_pb2.FunctionArgument(value=field_reference(field_number))]) + arguments=[algebra_pb2.FunctionArgument(value=field_reference(field_number))], + phase=algebra_pb2.AggregationPhase.AGGREGATION_PHASE_INITIAL_TO_RESULT) def string_concat_agg_function(function_info: ExtensionFunction, @@ -119,7 +121,8 @@ def string_concat_agg_function(function_info: ExtensionFunction, function_reference=function_info.anchor, output_type=function_info.output_type, arguments=[algebra_pb2.FunctionArgument(value=field_reference(field_number)), - algebra_pb2.FunctionArgument(value=string_literal(separator))]) + algebra_pb2.FunctionArgument(value=string_literal(separator))], + phase=algebra_pb2.AggregationPhase.AGGREGATION_PHASE_INITIAL_TO_RESULT) def least_function(greater_function_info: ExtensionFunction, expr1: algebra_pb2.Expression, @@ -192,7 +195,7 @@ def minus_function(function_info: ExtensionFunction, def repeat_function(function_info: ExtensionFunction, string: str, - count: algebra_pb2.Expression) -> algebra_pb2.AggregateFunction: + count: algebra_pb2.Expression) -> algebra_pb2.Expression: """Construct a Substrait concat expression.""" return algebra_pb2.Expression(scalar_function= algebra_pb2.Expression.ScalarFunction( @@ -204,7 +207,7 @@ def repeat_function(function_info: ExtensionFunction, def lpad_function(function_info: ExtensionFunction, expression: algebra_pb2.Expression, count: algebra_pb2.Expression, - pad_string: str = ' ') -> algebra_pb2.AggregateFunction: + pad_string: str = ' ') -> algebra_pb2.Expression: """Construct a Substrait concat expression.""" # TODO -- Avoid a cast if we don't need it. cast_type = string_type() @@ -221,7 +224,7 @@ def lpad_function(function_info: ExtensionFunction, def rpad_function(function_info: ExtensionFunction, expression: algebra_pb2.Expression, count: algebra_pb2.Expression, - pad_string: str = ' ') -> algebra_pb2.AggregateFunction: + pad_string: str = ' ') -> algebra_pb2.Expression: """Construct a Substrait concat expression.""" # TODO -- Avoid a cast if we don't need it. cast_type = string_type() @@ -239,7 +242,7 @@ def rpad_function(function_info: ExtensionFunction, def regexp_strpos_function(function_info: ExtensionFunction, input: algebra_pb2.Expression, pattern: algebra_pb2.Expression, position: algebra_pb2.Expression, - occurrence: algebra_pb2.Expression) -> algebra_pb2.AggregateFunction: + occurrence: algebra_pb2.Expression) -> algebra_pb2.Expression: """Construct a Substrait regex substring expression.""" return algebra_pb2.Expression(scalar_function=algebra_pb2.Expression.ScalarFunction( function_reference=function_info.anchor, diff --git a/src/gateway/tests/conftest.py b/src/gateway/tests/conftest.py index 21360f5..a34267f 100644 --- a/src/gateway/tests/conftest.py +++ b/src/gateway/tests/conftest.py @@ -10,7 +10,6 @@ get_mystream_schema, ) from gateway.server import serve -from pyspark.sql.pandas.types import from_arrow_schema from pyspark.sql.session import SparkSession @@ -78,19 +77,20 @@ def gateway_server(): @pytest.fixture(scope='function') -def users_location() -> str: +def users_location(manage_database) -> str: """Provides the location of the users database.""" return str(Path('users.parquet').resolve()) @pytest.fixture(scope='function') -def schema_users(): +def schema_users(manage_database): """Provides the schema of the users database.""" return get_mystream_schema('users') @pytest.fixture(scope='session', params=['spark', + 'gateway-over-arrow', 'gateway-over-duckdb', 'gateway-over-datafusion', ]) @@ -118,11 +118,17 @@ def spark_session(source): # pylint: disable=redefined-outer-name @pytest.fixture(scope='function') -def users_dataframe(spark_session, schema_users, users_location): - """Provides a ready to go dataframe over the users database.""" - return spark_session.read.format('parquet') \ - .schema(from_arrow_schema(schema_users)) \ - .parquet(users_location) +def spark_session_with_users_dataset(spark_session, schema_users, users_location): + """Provides the spark session with the users database already loaded.""" + df = spark_session.read.parquet(users_location) + df.createOrReplaceTempView('users') + return spark_session + + +@pytest.fixture(scope='function') +def users_dataframe(spark_session_with_users_dataset): + """Provides a ready to go users dataframe.""" + return spark_session_with_users_dataset.table('users') def find_tpch() -> Path: From bf4ecefa29b68df13defcdc32546edaeef44e5a1 Mon Sep 17 00:00:00 2001 From: David Sisson Date: Wed, 15 May 2024 18:38:14 -0700 Subject: [PATCH 52/58] feat: fix various name and datetime issues (#79) Fixes field name handling (aggregate and project expression aliases now retain their names on output). Fixes issues with datetime/date type comparison differences when comparing dataframes at test time. Corrects deduplicate behavior to physically return one field instead of two. --- src/gateway/converter/spark_to_substrait.py | 13 ++- src/gateway/tests/compare_dataframes.py | 41 ++++++++++ .../tests/test_tpch_with_dataframe_api.py | 81 +++++++++---------- 3 files changed, 90 insertions(+), 45 deletions(-) create mode 100644 src/gateway/tests/compare_dataframes.py diff --git a/src/gateway/converter/spark_to_substrait.py b/src/gateway/converter/spark_to_substrait.py index 90d7507..48ed259 100644 --- a/src/gateway/converter/spark_to_substrait.py +++ b/src/gateway/converter/spark_to_substrait.py @@ -341,6 +341,7 @@ def convert_unresolved_function( def convert_alias_expression( self, alias: spark_exprs_pb2.Expression.Alias) -> algebra_pb2.Expression: """Convert a Spark alias into a Substrait expression.""" + # We do nothing here and let the magic happen in the calling project relation. return self.convert_expression(alias.expr) def convert_type_str(self, spark_type_str: str | None) -> type_pb2.Type: @@ -1062,7 +1063,7 @@ def convert_project_relation( symbol = self._symbol_table.get_symbol(self._current_plan_id) for field_number, expr in enumerate(rel.expressions): project.expressions.append(self.convert_expression(expr)) - if expr.HasField('alias'): + if expr.WhichOneof('expr_type') == 'alias': name = expr.alias.name[0] elif expr.WhichOneof('expr_type') == 'unresolved_attribute': name = expr.unresolved_attribute.unparsed_identifier @@ -1079,9 +1080,10 @@ def convert_project_relation( def convert_subquery_alias_relation(self, rel: spark_relations_pb2.SubqueryAlias) -> algebra_pb2.Rel: """Convert a Spark subquery alias relation into a Substrait relation.""" - # TODO -- Utilize rel.alias somehow. result = self.convert_relation(rel.input) self.update_field_references(rel.input.common.plan_id) + symbol = self._symbol_table.get_symbol(self._current_plan_id) + symbol.output_fields[-1] = rel.alias return result def convert_deduplicate_relation(self, rel: spark_relations_pb2.Deduplicate) -> algebra_pb2.Rel: @@ -1104,7 +1106,12 @@ def convert_deduplicate_relation(self, rel: spark_relations_pb2.Deduplicate) -> output_type=type_pb2.Type(bool=type_pb2.Type.Boolean( nullability=type_pb2.Type.NULLABILITY_REQUIRED))))) symbol.generated_fields.append(field) - return algebra_pb2.Rel(aggregate=aggregate) + aggr = algebra_pb2.Rel(aggregate=aggregate) + project = project_relation( + aggr, [field_reference(idx) for idx in range(len(symbol.input_fields))]) + for idx in range(len(symbol.input_fields)): + project.project.common.emit.output_mapping.append(idx) + return project def convert_relation(self, rel: spark_relations_pb2.Relation) -> algebra_pb2.Rel: """Convert a Spark relation into a Substrait one.""" diff --git a/src/gateway/tests/compare_dataframes.py b/src/gateway/tests/compare_dataframes.py new file mode 100644 index 0000000..ef8a33d --- /dev/null +++ b/src/gateway/tests/compare_dataframes.py @@ -0,0 +1,41 @@ +# SPDX-License-Identifier: Apache-2.0 +"""Routines for comparing dataframes.""" +import datetime + +from pyspark import Row +from pyspark.testing import assertDataFrameEqual + + +def have_same_schema(outcome: list[Row], expected: list[Row]): + """Returns True if the two dataframes have the same schema.""" + return all(type(a) is type(b) for a, b in zip(outcome[0], expected[0], strict=False)) + + +def align_schema(source_df: list[Row], schema_df: list[Row]): + """Returns a copy of source_df with the fields changed to match schema_df.""" + schema = schema_df[0] + + if have_same_schema(source_df, schema_df): + return source_df + + new_source_df = [] + for row in source_df: + new_row = {} + for field_name, field_value in schema.asDict().items(): + if (type(row[field_name] is not type(field_value)) and + isinstance(field_value, datetime.date)): + new_row[field_name] = row[field_name].date() + else: + new_row[field_name] = row[field_name] + + new_source_df.append(Row(**new_row)) + + return new_source_df + + +def assert_dataframes_equal(outcome: list[Row], expected: list[Row]): + """Asserts that two dataframes are equal ignoring column names and date formats.""" + # Create a copy of the dataframes to avoid modifying the original ones + modified_outcome = align_schema(outcome, expected) + + assertDataFrameEqual(modified_outcome, expected, atol=1e-2) diff --git a/src/gateway/tests/test_tpch_with_dataframe_api.py b/src/gateway/tests/test_tpch_with_dataframe_api.py index c61ee02..3b96435 100644 --- a/src/gateway/tests/test_tpch_with_dataframe_api.py +++ b/src/gateway/tests/test_tpch_with_dataframe_api.py @@ -4,10 +4,10 @@ import pyspark import pytest +from gateway.tests.compare_dataframes import assert_dataframes_equal from gateway.tests.plan_validator import utilizes_valid_plans from pyspark import Row from pyspark.sql.functions import avg, col, count, countDistinct, desc, try_sum, when -from pyspark.testing import assertDataFrameEqual @pytest.fixture(autouse=True) @@ -16,19 +16,15 @@ def mark_tests_as_xfail(request): source = request.getfixturevalue('source') originalname = request.keywords.node.originalname if source == 'gateway-over-duckdb': - if originalname in [ 'test_query_03', 'test_query_18']: - request.node.add_marker(pytest.mark.xfail(reason='Date time type mismatch')) - elif originalname in ['test_query_04']: - request.node.add_marker(pytest.mark.xfail(reason='Incorrect calculation')) - elif originalname in[ 'test_query_07', 'test_query_08', 'test_query_09']: + if originalname in[ 'test_query_07', 'test_query_08', 'test_query_09']: request.node.add_marker(pytest.mark.xfail(reason='Substring argument mismatch')) - elif originalname in ['test_query_12']: + elif originalname in ['test_query_12', 'test_query_14']: request.node.add_marker(pytest.mark.xfail(reason='Missing nullability information')) elif originalname in ['test_query_15']: request.node.add_marker(pytest.mark.xfail(reason='No results (float vs decimal)')) elif originalname in ['test_query_16', 'test_query_21']: - request.node.add_marker(pytest.mark.xfail(reason='Disctinct argument behavior')) - elif originalname in ['test_query_20']: + request.node.add_marker(pytest.mark.xfail(reason='Distinct argument behavior')) + elif originalname in ['test_query_19', 'test_query_20']: request.node.add_marker(pytest.mark.xfail(reason='Unknown validation error')) elif originalname in ['test_query_22']: request.node.add_marker(pytest.mark.xfail(reason='Schema determination for null')) @@ -66,7 +62,7 @@ def test_query_01(self, spark_session_with_tpch_dataset): sorted_outcome = outcome.sort('l_returnflag', 'l_linestatus').limit(1).collect() - assertDataFrameEqual(sorted_outcome, expected, atol=1e-2) + assert_dataframes_equal(sorted_outcome, expected) def test_query_02(self, spark_session_with_tpch_dataset): expected = [ @@ -107,7 +103,7 @@ def test_query_02(self, spark_session_with_tpch_dataset): sorted_outcome = outcome.sort( desc('s_acctbal'), 'n_name', 's_name', 'p_partkey').limit(2).collect() - assertDataFrameEqual(sorted_outcome, expected, atol=1e-2) + assert_dataframes_equal(sorted_outcome, expected) def test_query_03(self, spark_session_with_tpch_dataset): expected = [ @@ -144,7 +140,7 @@ def test_query_03(self, spark_session_with_tpch_dataset): sorted_outcome = outcome.sort(desc('revenue'), 'o_orderdate').limit(5).collect() - assertDataFrameEqual(sorted_outcome, expected, atol=1e-2) + assert_dataframes_equal(sorted_outcome, expected) def test_query_04(self, spark_session_with_tpch_dataset): expected = [ @@ -171,7 +167,7 @@ def test_query_04(self, spark_session_with_tpch_dataset): sorted_outcome = outcome.sort('o_orderpriority').collect() - assertDataFrameEqual(sorted_outcome, expected, atol=1e-2) + assert_dataframes_equal(sorted_outcome, expected) def test_query_05(self, spark_session_with_tpch_dataset): expected = [ @@ -207,7 +203,7 @@ def test_query_05(self, spark_session_with_tpch_dataset): sorted_outcome = outcome.sort('revenue').collect() - assertDataFrameEqual(sorted_outcome, expected, atol=1e-2) + assert_dataframes_equal(sorted_outcome, expected) def test_query_06(self, spark_session_with_tpch_dataset): expected = [ @@ -222,9 +218,9 @@ def test_query_06(self, spark_session_with_tpch_dataset): (col('l_discount') >= 0.05) & (col('l_discount') <= 0.07) & (col('l_quantity') < 24)).agg( - try_sum(col('l_extendedprice') * col('l_discount'))).alias('revenue') + try_sum(col('l_extendedprice') * col('l_discount'))).alias('revenue').collect() - assertDataFrameEqual(outcome, expected, atol=1e-2) + assert_dataframes_equal(outcome, expected) def test_query_07(self, spark_session_with_tpch_dataset): expected = [ @@ -263,7 +259,7 @@ def test_query_07(self, spark_session_with_tpch_dataset): sorted_outcome = outcome.sort('supp_nation', 'cust_nation', 'l_year').collect() - assertDataFrameEqual(sorted_outcome, expected, atol=1e-2) + assert_dataframes_equal(sorted_outcome, expected) def test_query_08(self, spark_session_with_tpch_dataset): expected = [ @@ -308,7 +304,7 @@ def test_query_08(self, spark_session_with_tpch_dataset): sorted_outcome = outcome.sort('o_year').collect() - assertDataFrameEqual(sorted_outcome, expected, atol=1e-2) + assert_dataframes_equal(sorted_outcome, expected) def test_query_09(self, spark_session_with_tpch_dataset): # TODO -- Verify the corretness of these results against another version of the dataset. @@ -343,7 +339,7 @@ def test_query_09(self, spark_session_with_tpch_dataset): sorted_outcome = outcome.sort('n_name', desc('o_year')).limit(5).collect() - assertDataFrameEqual(sorted_outcome, expected, atol=1e-2) + assert_dataframes_equal(sorted_outcome, expected) def test_query_10(self, spark_session_with_tpch_dataset): expected = [ @@ -382,15 +378,15 @@ def test_query_10(self, spark_session_with_tpch_dataset): sorted_outcome = outcome.sort(desc('revenue')).limit(2).collect() - assertDataFrameEqual(sorted_outcome, expected, atol=1e-2) + assert_dataframes_equal(sorted_outcome, expected) def test_query_11(self, spark_session_with_tpch_dataset): expected = [ - Row(ps_partkey=129760, value=17538456.86), - Row(ps_partkey=166726, value=16503353.92), - Row(ps_partkey=191287, value=16474801.97), - Row(ps_partkey=161758, value=16101755.54), - Row(ps_partkey=34452, value=15983844.72), + Row(ps_partkey=129760, part_value=17538456.86), + Row(ps_partkey=166726, part_value=16503353.92), + Row(ps_partkey=191287, part_value=16474801.97), + Row(ps_partkey=161758, part_value=16101755.54), + Row(ps_partkey=34452, part_value=15983844.72), ] with utilizes_valid_plans(spark_session_with_tpch_dataset): @@ -411,7 +407,7 @@ def test_query_11(self, spark_session_with_tpch_dataset): sorted_outcome = outcome.sort(desc('part_value')).limit(5).collect() - assertDataFrameEqual(sorted_outcome, expected, atol=1e-2) + assert_dataframes_equal(sorted_outcome, expected) def test_query_12(self, spark_session_with_tpch_dataset): expected = [ @@ -443,7 +439,7 @@ def test_query_12(self, spark_session_with_tpch_dataset): sorted_outcome = outcome.sort('l_shipmode').collect() - assertDataFrameEqual(sorted_outcome, expected, atol=1e-2) + assert_dataframes_equal(sorted_outcome, expected) def test_query_13(self, spark_session_with_tpch_dataset): # TODO -- Verify the corretness of these results against another version of the dataset. @@ -465,7 +461,7 @@ def test_query_13(self, spark_session_with_tpch_dataset): sorted_outcome = outcome.sort(desc('custdist'), desc('c_count')).limit(3).collect() - assertDataFrameEqual(sorted_outcome, expected, atol=1e-2) + assert_dataframes_equal(sorted_outcome, expected) def test_query_14(self, spark_session_with_tpch_dataset): expected = [ @@ -482,13 +478,14 @@ def test_query_14(self, spark_session_with_tpch_dataset): 'p_type', (col('l_extendedprice') * (1 - col('l_discount'))).alias('value')).agg( try_sum(when(col('p_type').contains('PROMO'), col('value'))) * 100 / try_sum( col('value')) - ).alias('promo_revenue') + ).alias('promo_revenue').collect() - assertDataFrameEqual(outcome, expected, atol=1e-2) + assert_dataframes_equal(outcome, expected) def test_query_15(self, spark_session_with_tpch_dataset): expected = [ - Row(s_suppkey=8449, s_name='Supplier#000008449', s_address='Wp34zim9qYFbVctdW'), + Row(s_suppkey=8449, s_name='Supplier#000008449', s_address='Wp34zim9qYFbVctdW', + s_phone='20-469-856-8873', total=1772627.21), ] with utilizes_valid_plans(spark_session_with_tpch_dataset): @@ -508,7 +505,7 @@ def test_query_15(self, spark_session_with_tpch_dataset): sorted_outcome = outcome.sort('s_suppkey').limit(1).collect() - assertDataFrameEqual(sorted_outcome, expected, atol=1e-2) + assert_dataframes_equal(sorted_outcome, expected) def test_query_16(self, spark_session_with_tpch_dataset): expected = [ @@ -537,7 +534,7 @@ def test_query_16(self, spark_session_with_tpch_dataset): sorted_outcome = outcome.sort( desc('supplier_cnt'), 'p_brand', 'p_type', 'p_size').limit(3).collect() - assertDataFrameEqual(sorted_outcome, expected, atol=1e-2) + assert_dataframes_equal(sorted_outcome, expected) def test_query_17(self, spark_session_with_tpch_dataset): expected = [ @@ -557,9 +554,9 @@ def test_query_17(self, spark_session_with_tpch_dataset): col('p_partkey').alias('key'), 'avg_quantity').join( fpart, col('key') == fpart.p_partkey).filter( col('l_quantity') < col('avg_quantity')).agg( - try_sum('l_extendedprice') / 7).alias('avg_yearly') + try_sum('l_extendedprice') / 7).alias('avg_yearly').collect() - assertDataFrameEqual(outcome, expected, atol=1e-2) + assert_dataframes_equal(outcome, expected) def test_query_18(self, spark_session_with_tpch_dataset): expected = [ @@ -586,11 +583,11 @@ def test_query_18(self, spark_session_with_tpch_dataset): 'l_quantity', 'c_name', 'c_custkey', 'o_orderkey', 'o_orderdate', 'o_totalprice').groupBy( 'c_name', 'c_custkey', 'o_orderkey', 'o_orderdate', 'o_totalprice').agg( - try_sum('l_quantity')) + try_sum('l_quantity')).alias('sum_l_quantity') sorted_outcome = outcome.sort(desc('o_totalprice'), 'o_orderdate').limit(2).collect() - assertDataFrameEqual(sorted_outcome, expected, atol=1e-2) + assert_dataframes_equal(sorted_outcome, expected) def test_query_19(self, spark_session_with_tpch_dataset): expected = [ @@ -617,9 +614,9 @@ def test_query_19(self, spark_session_with_tpch_dataset): (col('l_quantity') >= 20) & (col('l_quantity') <= 30) & (col('p_size') >= 1) & (col('p_size') <= 15))).select( (col('l_extendedprice') * (1 - col('l_discount'))).alias('volume')).agg( - try_sum('volume').alias('revenue')) + try_sum('volume').alias('revenue')).collect() - assertDataFrameEqual(outcome, expected, atol=1e-2) + assert_dataframes_equal(outcome, expected) def test_query_20(self, spark_session_with_tpch_dataset): expected = [ @@ -653,7 +650,7 @@ def test_query_20(self, spark_session_with_tpch_dataset): sorted_outcome = outcome.sort('s_name').limit(3).collect() - assertDataFrameEqual(sorted_outcome, expected, atol=1e-2) + assert_dataframes_equal(sorted_outcome, expected) def test_query_21(self, spark_session_with_tpch_dataset): # TODO -- Verify the corretness of these results against another version of the dataset. @@ -705,7 +702,7 @@ def test_query_21(self, spark_session_with_tpch_dataset): sorted_outcome = outcome.sort(desc('numwait'), 's_name').limit(5).collect() - assertDataFrameEqual(sorted_outcome, expected, atol=1e-2) + assert_dataframes_equal(sorted_outcome, expected) def test_query_22(self, spark_session_with_tpch_dataset): expected = [ @@ -738,4 +735,4 @@ def test_query_22(self, spark_session_with_tpch_dataset): sorted_outcome = outcome.sort('cntrycode').collect() - assertDataFrameEqual(sorted_outcome, expected, atol=1e-2) + assert_dataframes_equal(sorted_outcome, expected) From 085745531e33d1ba28c40901ea3736186795d83c Mon Sep 17 00:00:00 2001 From: Richard Tia Date: Thu, 16 May 2024 00:42:19 -0400 Subject: [PATCH 53/58] fix: query 12 and get query 22 further along (#80) --- src/gateway/converter/spark_functions.py | 4 ++++ src/gateway/converter/spark_to_substrait.py | 7 +++++-- src/gateway/converter/substrait_builder.py | 9 +++++++++ src/gateway/server.py | 2 ++ src/gateway/tests/test_tpch_with_dataframe_api.py | 4 ++-- 5 files changed, 22 insertions(+), 4 deletions(-) diff --git a/src/gateway/converter/spark_functions.py b/src/gateway/converter/spark_functions.py index 261cb58..afffb0d 100644 --- a/src/gateway/converter/spark_functions.py +++ b/src/gateway/converter/spark_functions.py @@ -57,6 +57,10 @@ def __lt__(self, obj) -> bool: '/functions_comparison.yaml', 'gt:i64_i64', type_pb2.Type( bool=type_pb2.Type.Boolean( nullability=type_pb2.Type.Nullability.NULLABILITY_REQUIRED))), + 'isnull': ExtensionFunction( + '/functions_comparison.yaml', 'is_null:int', type_pb2.Type( + bool=type_pb2.Type.Boolean( + nullability=type_pb2.Type.Nullability.NULLABILITY_REQUIRED))), '+': ExtensionFunction( '/functions_arithmetic.yaml', 'add:i64_i64', type_pb2.Type( i64=type_pb2.Type.I64( diff --git a/src/gateway/converter/spark_to_substrait.py b/src/gateway/converter/spark_to_substrait.py index 48ed259..a003d60 100644 --- a/src/gateway/converter/spark_to_substrait.py +++ b/src/gateway/converter/spark_to_substrait.py @@ -16,6 +16,7 @@ aggregate_relation, bigint_literal, bool_literal, + bool_type, cast_operation, concat, equal_function, @@ -196,7 +197,7 @@ def determine_type_of_expression(self, expr: algebra_pb2.Expression) -> type_pb2 if expr.WhichOneof('rex_type') == 'literal': match expr.literal.WhichOneof('literal_type'): case 'boolean': - return type_pb2.Type(bool=type_pb2.Type.Boolean()) + return bool_type() case 'i8': return type_pb2.Type(i8=type_pb2.Type.I8()) case 'i16': @@ -238,10 +239,12 @@ def convert_when_function( getattr(ifthen, 'else').CopyFrom( self.convert_expression(when.arguments[len(when.arguments) - 1])) else: + nullable_literal = self.determine_type_of_expression(ifthen.ifs[-1].then) + nullable_literal.bool.nullability = type_pb2.Type.Nullability.NULLABILITY_NULLABLE getattr(ifthen, 'else').CopyFrom( algebra_pb2.Expression( literal=algebra_pb2.Expression.Literal( - null=self.determine_type_of_expression(ifthen.ifs[-1].then)))) + null=nullable_literal))) return algebra_pb2.Expression(if_then=ifthen) diff --git a/src/gateway/converter/substrait_builder.py b/src/gateway/converter/substrait_builder.py index 01bc3c8..276f9ba 100644 --- a/src/gateway/converter/substrait_builder.py +++ b/src/gateway/converter/substrait_builder.py @@ -259,6 +259,15 @@ def bool_literal(val: bool) -> algebra_pb2.Expression: return algebra_pb2.Expression(literal=algebra_pb2.Expression.Literal(boolean=val)) +def bool_type(required: bool = True) -> algebra_pb2.Expression: + """Construct a Substrait boolean type.""" + if required: + nullability = type_pb2.Type.Nullability.NULLABILITY_REQUIRED + else: + nullability = type_pb2.Type.Nullability.NULLABILITY_NULLABLE + return type_pb2.Type(bool=type_pb2.Type.Boolean(nullability=nullability)) + + def string_literal(val: str) -> algebra_pb2.Expression: """Construct a Substrait string literal expression.""" return algebra_pb2.Expression(literal=algebra_pb2.Expression.Literal(string=val)) diff --git a/src/gateway/server.py b/src/gateway/server.py index 3e61b21..512a082 100644 --- a/src/gateway/server.py +++ b/src/gateway/server.py @@ -71,6 +71,8 @@ def convert_pyarrow_schema_to_spark(schema: pa.Schema) -> types_pb2.DataType: data_type = types_pb2.DataType(timestamp=types_pb2.DataType.Timestamp()) elif field.type == pa.date32(): data_type = types_pb2.DataType(date=types_pb2.DataType.Date()) + elif field.type == pa.null(): + data_type = types_pb2.DataType(null=types_pb2.DataType.NULL()) else: raise NotImplementedError( 'Conversion from Arrow schema to Spark schema not yet implemented ' diff --git a/src/gateway/tests/test_tpch_with_dataframe_api.py b/src/gateway/tests/test_tpch_with_dataframe_api.py index 3b96435..f4ce4b3 100644 --- a/src/gateway/tests/test_tpch_with_dataframe_api.py +++ b/src/gateway/tests/test_tpch_with_dataframe_api.py @@ -18,7 +18,7 @@ def mark_tests_as_xfail(request): if source == 'gateway-over-duckdb': if originalname in[ 'test_query_07', 'test_query_08', 'test_query_09']: request.node.add_marker(pytest.mark.xfail(reason='Substring argument mismatch')) - elif originalname in ['test_query_12', 'test_query_14']: + elif originalname in ['test_query_14']: request.node.add_marker(pytest.mark.xfail(reason='Missing nullability information')) elif originalname in ['test_query_15']: request.node.add_marker(pytest.mark.xfail(reason='No results (float vs decimal)')) @@ -27,7 +27,7 @@ def mark_tests_as_xfail(request): elif originalname in ['test_query_19', 'test_query_20']: request.node.add_marker(pytest.mark.xfail(reason='Unknown validation error')) elif originalname in ['test_query_22']: - request.node.add_marker(pytest.mark.xfail(reason='Schema determination for null')) + request.node.add_marker(pytest.mark.xfail(reason='Unsupported expression type 0')) elif source == 'gateway-over-datafusion': pytest.importorskip("datafusion.substrait") request.node.add_marker(pytest.mark.xfail(reason='gateway internal error')) From 91bae8c527a87ed0e30bbb17b878b1b1e90a2c9e Mon Sep 17 00:00:00 2001 From: David Sisson Date: Thu, 16 May 2024 08:05:08 -0700 Subject: [PATCH 54/58] feat: fix handling of joins without join conditions (#81) Also fixes the way tables are described in DuckDB (this allows date types to show up as dates). --- src/gateway/backends/backend_options.py | 1 + src/gateway/backends/duckdb_backend.py | 25 ++++++++++--------- src/gateway/converter/conversion_options.py | 1 + src/gateway/converter/spark_to_substrait.py | 8 ++++-- src/gateway/converter/substrait_builder.py | 2 +- src/gateway/server.py | 12 ++++----- .../tests/test_tpch_with_dataframe_api.py | 8 +++--- 7 files changed, 30 insertions(+), 27 deletions(-) diff --git a/src/gateway/backends/backend_options.py b/src/gateway/backends/backend_options.py index 61133e3..d1c05e3 100644 --- a/src/gateway/backends/backend_options.py +++ b/src/gateway/backends/backend_options.py @@ -29,3 +29,4 @@ def __init__(self, backend: BackendEngine, use_adbc: bool = False): self.use_adbc = use_adbc self.use_arrow_uri_workaround = False + self.use_duckdb_python_api = False diff --git a/src/gateway/backends/duckdb_backend.py b/src/gateway/backends/duckdb_backend.py index 88dc966..586a013 100644 --- a/src/gateway/backends/duckdb_backend.py +++ b/src/gateway/backends/duckdb_backend.py @@ -31,6 +31,7 @@ def __init__(self, options): self._connection = None super().__init__(options) self.create_connection() + self._use_duckdb_python_api = options.use_duckdb_python_api def create_connection(self): """Create a connection to the backend.""" @@ -58,19 +59,22 @@ def execute(self, plan: plan_pb2.Plan) -> pa.lib.Table: return pa.Table.from_pandas(df=df) def register_table( - self, - table_name: str, - location: Path, - file_format: str = "parquet" + self, + table_name: str, + location: Path, + file_format: str = "parquet" ) -> None: """Register the given table with the backend.""" files = Backend.expand_location(location) if not files: raise ValueError(f"No parquet files found at {location}") - files_str = ', '.join([f"'{f}'" for f in files]) - files_sql = f"CREATE OR REPLACE TABLE {table_name} AS FROM read_parquet([{files_str}])" - self._connection.execute(files_sql) + if self._use_duckdb_python_api: + self._connection.register(table_name, self._connection.read_parquet(files)) + else: + files_str = ', '.join([f"'{f}'" for f in files]) + files_sql = f"CREATE OR REPLACE TABLE {table_name} AS FROM read_parquet([{files_str}])" + self._connection.execute(files_sql) def describe_files(self, paths: list[str]): """Asks the backend to describe the given files.""" @@ -89,13 +93,10 @@ def describe_files(self, paths: list[str]): def describe_table(self, name: str): """Asks the backend to describe the given table.""" - df = self._connection.table(name).describe() + df = self._connection.execute(f'DESCRIBE {name}').fetchdf() fields = [] - for name, field_type in zip(df.columns, df.types, strict=False): - if name == 'aggr': - # This isn't a real column. - continue + for name, field_type in zip(df.column_name, df.column_type, strict=False): fields.append(pa.field(name, _DUCKDB_TO_ARROW[str(field_type)])) return pa.schema(fields) diff --git a/src/gateway/converter/conversion_options.py b/src/gateway/converter/conversion_options.py index df516d7..1ae902f 100644 --- a/src/gateway/converter/conversion_options.py +++ b/src/gateway/converter/conversion_options.py @@ -49,4 +49,5 @@ def duck_db(): options.use_switch_expressions_where_possible = False options.use_duckdb_regexp_matches_function = True options.duckdb_project_emit_workaround = True + options.backend.use_duckdb_python_api = False return options diff --git a/src/gateway/converter/spark_to_substrait.py b/src/gateway/converter/spark_to_substrait.py index a003d60..e1998ae 100644 --- a/src/gateway/converter/spark_to_substrait.py +++ b/src/gateway/converter/spark_to_substrait.py @@ -240,7 +240,9 @@ def convert_when_function( self.convert_expression(when.arguments[len(when.arguments) - 1])) else: nullable_literal = self.determine_type_of_expression(ifthen.ifs[-1].then) - nullable_literal.bool.nullability = type_pb2.Type.Nullability.NULLABILITY_NULLABLE + kind = nullable_literal.WhichOneof('kind') + getattr(nullable_literal, kind).nullability = ( + type_pb2.Type.Nullability.NULLABILITY_NULLABLE) getattr(ifthen, 'else').CopyFrom( algebra_pb2.Expression( literal=algebra_pb2.Expression.Literal( @@ -1041,12 +1043,14 @@ def convert_cross_join_relation(self, rel: spark_relations_pb2.Join) -> algebra_ if rel.HasField('join_condition'): raise ValueError('Cross joins do not support having a join condition.') join.common.CopyFrom(self.create_common_relation()) - return algebra_pb2.Rel(join=join) + return algebra_pb2.Rel(cross=join) def convert_join_relation(self, rel: spark_relations_pb2.Join) -> algebra_pb2.Rel: """Convert a Spark join relation into a Substrait join relation.""" if rel.join_type == spark_relations_pb2.Join.JOIN_TYPE_CROSS: return self.convert_cross_join_relation(rel) + if not rel.HasField('join_condition') and not rel.using_columns: + return self.convert_cross_join_relation(rel) join = algebra_pb2.JoinRel(left=self.convert_relation(rel.left), right=self.convert_relation(rel.right)) self.update_field_references(rel.left.common.plan_id) diff --git a/src/gateway/converter/substrait_builder.py b/src/gateway/converter/substrait_builder.py index 276f9ba..9d3469b 100644 --- a/src/gateway/converter/substrait_builder.py +++ b/src/gateway/converter/substrait_builder.py @@ -259,7 +259,7 @@ def bool_literal(val: bool) -> algebra_pb2.Expression: return algebra_pb2.Expression(literal=algebra_pb2.Expression.Literal(boolean=val)) -def bool_type(required: bool = True) -> algebra_pb2.Expression: +def bool_type(required: bool = True) -> type_pb2.Type: """Construct a Substrait boolean type.""" if required: nullability = type_pb2.Type.Nullability.NULLABILITY_REQUIRED diff --git a/src/gateway/server.py b/src/gateway/server.py index 512a082..644af15 100644 --- a/src/gateway/server.py +++ b/src/gateway/server.py @@ -11,7 +11,7 @@ import pyspark.sql.connect.proto.base_pb2_grpc as pb2_grpc from google.protobuf.json_format import MessageToJson from pyspark.sql.connect.proto import types_pb2 -from substrait.gen.proto import algebra_pb2, plan_pb2 +from substrait.gen.proto import plan_pb2 from gateway.backends.backend import Backend from gateway.backends.backend_options import BackendEngine, BackendOptions @@ -56,7 +56,7 @@ def convert_pyarrow_schema_to_spark(schema: pa.Schema) -> types_pb2.DataType: elif field.type == pa.int8(): data_type = types_pb2.DataType(byte=types_pb2.DataType.Byte()) elif field.type == pa.int16(): - data_type = types_pb2.DataType(integer=types_pb2.DataType.Short()) + data_type = types_pb2.DataType(short=types_pb2.DataType.Short()) elif field.type == pa.int32(): data_type = types_pb2.DataType(integer=types_pb2.DataType.Integer()) elif field.type == pa.int64(): @@ -84,15 +84,13 @@ def convert_pyarrow_schema_to_spark(schema: pa.Schema) -> types_pb2.DataType: return types_pb2.DataType(struct=types_pb2.DataType.Struct(fields=fields)) -def create_dataframe_view(rel: pb2.Plan, backend) -> algebra_pb2.Rel: +def create_dataframe_view(rel: pb2.Plan, backend) -> None: """Register the temporary dataframe.""" dataframe_view_name = rel.command.create_dataframe_view.name read_data_source_relation = rel.command.create_dataframe_view.input.read.data_source - format = read_data_source_relation.format + fmt = read_data_source_relation.format path = read_data_source_relation.paths[0] - backend.register_table(dataframe_view_name, path, format) - - return None + backend.register_table(dataframe_view_name, path, fmt) class Statistics: diff --git a/src/gateway/tests/test_tpch_with_dataframe_api.py b/src/gateway/tests/test_tpch_with_dataframe_api.py index f4ce4b3..fcf0371 100644 --- a/src/gateway/tests/test_tpch_with_dataframe_api.py +++ b/src/gateway/tests/test_tpch_with_dataframe_api.py @@ -16,18 +16,16 @@ def mark_tests_as_xfail(request): source = request.getfixturevalue('source') originalname = request.keywords.node.originalname if source == 'gateway-over-duckdb': - if originalname in[ 'test_query_07', 'test_query_08', 'test_query_09']: + if originalname in ['test_query_07', 'test_query_08', 'test_query_09']: request.node.add_marker(pytest.mark.xfail(reason='Substring argument mismatch')) elif originalname in ['test_query_14']: - request.node.add_marker(pytest.mark.xfail(reason='Missing nullability information')) + request.node.add_marker(pytest.mark.xfail(reason='If/then branches w/ different types')) elif originalname in ['test_query_15']: request.node.add_marker(pytest.mark.xfail(reason='No results (float vs decimal)')) elif originalname in ['test_query_16', 'test_query_21']: request.node.add_marker(pytest.mark.xfail(reason='Distinct argument behavior')) elif originalname in ['test_query_19', 'test_query_20']: request.node.add_marker(pytest.mark.xfail(reason='Unknown validation error')) - elif originalname in ['test_query_22']: - request.node.add_marker(pytest.mark.xfail(reason='Unsupported expression type 0')) elif source == 'gateway-over-datafusion': pytest.importorskip("datafusion.substrait") request.node.add_marker(pytest.mark.xfail(reason='gateway internal error')) @@ -720,7 +718,7 @@ def test_query_22(self, spark_session_with_tpch_dataset): orders = spark_session_with_tpch_dataset.table('orders') fcustomer = customer.select( - 'c_acctbal', 'c_custkey', (col('c_phone').substr(0, 2)).alias('cntrycode')).filter( + 'c_acctbal', 'c_custkey', (col('c_phone').substr(1, 2)).alias('cntrycode')).filter( col('cntrycode').isin(['13', '31', '23', '29', '30', '18', '17'])) avg_customer = fcustomer.filter(col('c_acctbal') > 0.00).agg( From 25001ea47ddec6776fb0863bf748f8eb0886328a Mon Sep 17 00:00:00 2001 From: David Sisson Date: Thu, 16 May 2024 10:00:46 -0700 Subject: [PATCH 55/58] feat: address substr type issues by always casting its input (#82) --- src/gateway/converter/spark_to_substrait.py | 4 ++++ src/gateway/tests/test_tpch_with_dataframe_api.py | 12 +++++------- 2 files changed, 9 insertions(+), 7 deletions(-) diff --git a/src/gateway/converter/spark_to_substrait.py b/src/gateway/converter/spark_to_substrait.py index e1998ae..f662ed6 100644 --- a/src/gateway/converter/spark_to_substrait.py +++ b/src/gateway/converter/spark_to_substrait.py @@ -341,6 +341,10 @@ def convert_unresolved_function( raise NotImplementedError( 'Treating arguments as distinct is not supported for unresolved functions.') func.output_type.CopyFrom(function_def.output_type) + if unresolved_function.function_name == 'substring': + original_argument = func.arguments[0] + func.arguments[0].CopyFrom(algebra_pb2.FunctionArgument( + value=cast_operation(original_argument.value, string_type()))) return algebra_pb2.Expression(scalar_function=func) def convert_alias_expression( diff --git a/src/gateway/tests/test_tpch_with_dataframe_api.py b/src/gateway/tests/test_tpch_with_dataframe_api.py index fcf0371..558ac39 100644 --- a/src/gateway/tests/test_tpch_with_dataframe_api.py +++ b/src/gateway/tests/test_tpch_with_dataframe_api.py @@ -16,15 +16,13 @@ def mark_tests_as_xfail(request): source = request.getfixturevalue('source') originalname = request.keywords.node.originalname if source == 'gateway-over-duckdb': - if originalname in ['test_query_07', 'test_query_08', 'test_query_09']: - request.node.add_marker(pytest.mark.xfail(reason='Substring argument mismatch')) - elif originalname in ['test_query_14']: + if originalname in ['test_query_14']: request.node.add_marker(pytest.mark.xfail(reason='If/then branches w/ different types')) elif originalname in ['test_query_15']: request.node.add_marker(pytest.mark.xfail(reason='No results (float vs decimal)')) elif originalname in ['test_query_16', 'test_query_21']: request.node.add_marker(pytest.mark.xfail(reason='Distinct argument behavior')) - elif originalname in ['test_query_19', 'test_query_20']: + elif originalname in ['test_query_08', 'test_query_19', 'test_query_20']: request.node.add_marker(pytest.mark.xfail(reason='Unknown validation error')) elif source == 'gateway-over-datafusion': pytest.importorskip("datafusion.substrait") @@ -250,7 +248,7 @@ def test_query_07(self, spark_session_with_tpch_dataset): suppNation, col('o_orderkey') == suppNation.l_orderkey).filter( (col('supp_nation') == 'FRANCE') & (col('cust_nation') == 'GERMANY') | ( col('supp_nation') == 'GERMANY') & (col('cust_nation') == 'FRANCE')).select( - 'supp_nation', 'cust_nation', col('l_shipdate').substr(0, 4).alias('l_year'), + 'supp_nation', 'cust_nation', col('l_shipdate').substr(1, 4).alias('l_year'), (col('l_extendedprice') * (1 - col('l_discount'))).alias('volume')).groupBy( 'supp_nation', 'cust_nation', 'l_year').agg( try_sum('volume').alias('revenue')) @@ -294,7 +292,7 @@ def test_query_08(self, spark_session_with_tpch_dataset): 'c_custkey').join(forder, col('c_custkey') == col('o_custkey')).select( 'o_orderkey', 'o_orderdate').join(line, col('o_orderkey') == line.l_orderkey).select( - col('n_name'), col('o_orderdate').substr(0, 4).alias('o_year'), + col('n_name'), col('o_orderdate').substr(1, 4).alias('o_year'), col('volume')).withColumn('case_volume', when(col('n_name') == 'BRAZIL', col('volume')).otherwise( 0)).groupBy('o_year').agg( @@ -330,7 +328,7 @@ def test_query_09(self, spark_session_with_tpch_dataset): partsupp, (col('l_suppkey') == partsupp.ps_suppkey) & ( col('l_partkey') == partsupp.ps_partkey)).join( orders, col('l_orderkey') == orders.o_orderkey).select( - 'n_name', col('o_orderdate').substr(0, 4).alias('o_year'), + 'n_name', col('o_orderdate').substr(1, 4).alias('o_year'), (col('l_extendedprice') * (1 - col('l_discount')) - ( col('ps_supplycost') * col('l_quantity'))).alias('amount')).groupBy( 'n_name', 'o_year').agg(try_sum('amount').alias('sum_profit')) From cfeb6aea4e6ecdff6cd8df41b061f27359d44ab6 Mon Sep 17 00:00:00 2001 From: Richard Tia Date: Thu, 16 May 2024 15:09:58 -0400 Subject: [PATCH 56/58] fix: query 14 (#83) --- src/gateway/converter/spark_to_substrait.py | 2 +- src/gateway/tests/test_tpch_with_dataframe_api.py | 4 +--- 2 files changed, 2 insertions(+), 4 deletions(-) diff --git a/src/gateway/converter/spark_to_substrait.py b/src/gateway/converter/spark_to_substrait.py index f662ed6..6bb8276 100644 --- a/src/gateway/converter/spark_to_substrait.py +++ b/src/gateway/converter/spark_to_substrait.py @@ -220,7 +220,7 @@ def determine_type_of_expression(self, expr: algebra_pb2.Expression) -> type_pb2 return expr.scalar_function.output_type if expr.WhichOneof('rex_type') == 'selection': # TODO -- Figure out how to determine the type of a field reference. - return type_pb2.Type(i32=type_pb2.Type.I32()) + return type_pb2.Type(i64=type_pb2.Type.I64()) raise NotImplementedError( 'Type determination not implemented for expressions of type ' f'{expr.WhichOneof("rex_type")}.') diff --git a/src/gateway/tests/test_tpch_with_dataframe_api.py b/src/gateway/tests/test_tpch_with_dataframe_api.py index 558ac39..b5e038c 100644 --- a/src/gateway/tests/test_tpch_with_dataframe_api.py +++ b/src/gateway/tests/test_tpch_with_dataframe_api.py @@ -16,9 +16,7 @@ def mark_tests_as_xfail(request): source = request.getfixturevalue('source') originalname = request.keywords.node.originalname if source == 'gateway-over-duckdb': - if originalname in ['test_query_14']: - request.node.add_marker(pytest.mark.xfail(reason='If/then branches w/ different types')) - elif originalname in ['test_query_15']: + if originalname in ['test_query_15']: request.node.add_marker(pytest.mark.xfail(reason='No results (float vs decimal)')) elif originalname in ['test_query_16', 'test_query_21']: request.node.add_marker(pytest.mark.xfail(reason='Distinct argument behavior')) From 2611f1686d03e7a7666252c06b8257db701506ae Mon Sep 17 00:00:00 2001 From: David Sisson Date: Fri, 17 May 2024 07:01:24 -0700 Subject: [PATCH 57/58] feat: ignore internal substrait validator errors (#84) I hand decoded one of the failures and it didn't contain any validation issues so it is relatively safe to ignore these failures for now. --- src/gateway/tests/plan_validator.py | 9 +++++++-- src/gateway/tests/test_tpch_with_dataframe_api.py | 4 ++-- 2 files changed, 9 insertions(+), 4 deletions(-) diff --git a/src/gateway/tests/plan_validator.py b/src/gateway/tests/plan_validator.py index eb505ee..787f21e 100644 --- a/src/gateway/tests/plan_validator.py +++ b/src/gateway/tests/plan_validator.py @@ -1,16 +1,21 @@ # SPDX-License-Identifier: Apache-2.0 from contextlib import contextmanager +import google.protobuf.message import pytest import substrait_validator from google.protobuf import json_format from pyspark.errors.exceptions.connect import SparkConnectGrpcException -from substrait.gen.proto import plan_pb2 +from substrait_validator.substrait import plan_pb2 def validate_plan(json_plan: str): substrait_plan = json_format.Parse(json_plan, plan_pb2.Plan()) - diagnostics = substrait_validator.plan_to_diagnostics(substrait_plan.SerializeToString()) + try: + diagnostics = substrait_validator.plan_to_diagnostics(substrait_plan.SerializeToString()) + except google.protobuf.message.DecodeError: + # Probable protobuf mismatch internal to Substrait Validator, ignore for now. + return issues = [] for issue in diagnostics: if issue.adjusted_level >= substrait_validator.Diagnostic.LEVEL_ERROR: diff --git a/src/gateway/tests/test_tpch_with_dataframe_api.py b/src/gateway/tests/test_tpch_with_dataframe_api.py index b5e038c..b3c8854 100644 --- a/src/gateway/tests/test_tpch_with_dataframe_api.py +++ b/src/gateway/tests/test_tpch_with_dataframe_api.py @@ -20,8 +20,8 @@ def mark_tests_as_xfail(request): request.node.add_marker(pytest.mark.xfail(reason='No results (float vs decimal)')) elif originalname in ['test_query_16', 'test_query_21']: request.node.add_marker(pytest.mark.xfail(reason='Distinct argument behavior')) - elif originalname in ['test_query_08', 'test_query_19', 'test_query_20']: - request.node.add_marker(pytest.mark.xfail(reason='Unknown validation error')) + elif originalname in ['test_query_08']: + request.node.add_marker(pytest.mark.xfail(reason='DuckDB binder error')) elif source == 'gateway-over-datafusion': pytest.importorskip("datafusion.substrait") request.node.add_marker(pytest.mark.xfail(reason='gateway internal error')) From 311679a9a95582a7aa1c9827ebb7920818a21a92 Mon Sep 17 00:00:00 2001 From: David Sisson Date: Fri, 17 May 2024 17:59:22 -0700 Subject: [PATCH 58/58] feat: implement distinct aggregate functions (#85) --- src/gateway/converter/spark_functions.py | 4 ++++ src/gateway/converter/spark_to_substrait.py | 14 +++++++++++--- src/gateway/tests/test_tpch_with_dataframe_api.py | 4 ++-- 3 files changed, 17 insertions(+), 5 deletions(-) diff --git a/src/gateway/converter/spark_functions.py b/src/gateway/converter/spark_functions.py index afffb0d..721cf50 100644 --- a/src/gateway/converter/spark_functions.py +++ b/src/gateway/converter/spark_functions.py @@ -164,6 +164,10 @@ def __lt__(self, obj) -> bool: '/functions_aggregate_generic.yaml', 'count:any', type_pb2.Type( i64=type_pb2.Type.I64( nullability=type_pb2.Type.Nullability.NULLABILITY_REQUIRED))), + 'approx_count_distinct': ExtensionFunction( + '/functions_aggregate_approx.yaml', 'approx_count_distinct:any', + type_pb2.Type(i64=type_pb2.Type.I64( + nullability=type_pb2.Type.Nullability.NULLABILITY_REQUIRED))), 'any_value': ExtensionFunction( '/functions_aggregate_generic.yaml', 'any_value:any', type_pb2.Type( i64=type_pb2.Type.I64( diff --git a/src/gateway/converter/spark_to_substrait.py b/src/gateway/converter/spark_to_substrait.py index 6bb8276..7960c36 100644 --- a/src/gateway/converter/spark_to_substrait.py +++ b/src/gateway/converter/spark_to_substrait.py @@ -337,9 +337,6 @@ def convert_unresolved_function( break func.arguments.append( algebra_pb2.FunctionArgument(value=self.convert_expression(arg))) - if unresolved_function.is_distinct: - raise NotImplementedError( - 'Treating arguments as distinct is not supported for unresolved functions.') func.output_type.CopyFrom(function_def.output_type) if unresolved_function.function_name == 'substring': original_argument = func.arguments[0] @@ -439,12 +436,23 @@ def convert_expression(self, expr: spark_exprs_pb2.Expression) -> algebra_pb2.Ex f'Unexpected expression type: {expr.WhichOneof("expr_type")}') return result + def is_distinct(self, expr: spark_exprs_pb2.Expression) -> bool: + """Determine if the expression is distinct.""" + if expr.WhichOneof( + 'expr_type') == 'unresolved_function' and expr.unresolved_function.is_distinct: + return True + if expr.WhichOneof('expr_type') == 'alias': + return self.is_distinct(expr.alias.expr) + return False + def convert_expression_to_aggregate_function( self, expr: spark_exprs_pb2.Expression) -> algebra_pb2.AggregateFunction: """Convert a SparkConnect expression to a Substrait expression.""" func = algebra_pb2.AggregateFunction( phase=algebra_pb2.AggregationPhase.AGGREGATION_PHASE_INITIAL_TO_RESULT) + if self.is_distinct(expr): + func.invocation = algebra_pb2.AggregateFunction.AGGREGATION_INVOCATION_DISTINCT expression = self.convert_expression(expr) match expression.WhichOneof('rex_type'): case 'scalar_function': diff --git a/src/gateway/tests/test_tpch_with_dataframe_api.py b/src/gateway/tests/test_tpch_with_dataframe_api.py index b3c8854..3ee0fcb 100644 --- a/src/gateway/tests/test_tpch_with_dataframe_api.py +++ b/src/gateway/tests/test_tpch_with_dataframe_api.py @@ -18,10 +18,10 @@ def mark_tests_as_xfail(request): if source == 'gateway-over-duckdb': if originalname in ['test_query_15']: request.node.add_marker(pytest.mark.xfail(reason='No results (float vs decimal)')) - elif originalname in ['test_query_16', 'test_query_21']: - request.node.add_marker(pytest.mark.xfail(reason='Distinct argument behavior')) elif originalname in ['test_query_08']: request.node.add_marker(pytest.mark.xfail(reason='DuckDB binder error')) + elif originalname == 'test_query_16': + request.node.add_marker(pytest.mark.xfail(reason='results differ')) elif source == 'gateway-over-datafusion': pytest.importorskip("datafusion.substrait") request.node.add_marker(pytest.mark.xfail(reason='gateway internal error'))