diff --git a/environment.yml b/environment.yml index 9d04cd4..254b805 100644 --- a/environment.yml +++ b/environment.yml @@ -18,7 +18,7 @@ dependencies: - pip: - adbc_driver_manager - cargo - - pyarrow >= 13.0.0 + - pyarrow >= 16.0.0 - duckdb == 0.10.1 - datafusion >= 36.0.0 - pyspark diff --git a/src/gateway/backends/adbc_backend.py b/src/gateway/backends/adbc_backend.py index e89e7fb..5085240 100644 --- a/src/gateway/backends/adbc_backend.py +++ b/src/gateway/backends/adbc_backend.py @@ -9,8 +9,7 @@ from substrait.gen.proto import plan_pb2 from gateway.backends.backend import Backend -from gateway.backends.backend_options import Backend as backend_engine -from gateway.backends.backend_options import BackendOptions +from gateway.backends.backend_options import BackendEngine, BackendOptions def _import(handle): @@ -20,7 +19,7 @@ def _import(handle): def _get_backend_driver(options: BackendOptions) -> tuple[str, str]: """Get the driver and entry point for the specified backend.""" match options.backend: - case backend_engine.DUCKDB: + case BackendEngine.DUCKDB: driver = duckdb.duckdb.__file__ entry_point = "duckdb_adbc_init" case _: @@ -62,7 +61,7 @@ def register_table(self, name: str, path: Path, file_format: str = 'parquet') -> file_paths = sorted([str(fp) for fp in file_paths]) # TODO: Support multiple paths. reader = pq.ParquetFile(file_paths[0]) - self._connection.cursor().adbc_ingest(name, reader.iter_batches(), mode="create") + self._connection.cursor().adbc_ingest(name, reader.iter_batches(), mode='create') def describe_table(self, table_name: str): """Asks the backend to describe the given table.""" diff --git a/src/gateway/backends/backend.py b/src/gateway/backends/backend.py index eb3c4ab..9f74ac1 100644 --- a/src/gateway/backends/backend.py +++ b/src/gateway/backends/backend.py @@ -35,6 +35,10 @@ def register_table(self, name: str, path: Path, file_format: str = 'parquet') -> """Register the given table with the backend.""" raise NotImplementedError() + def describe_files(self, paths: list[str]): + """Asks the backend to describe the given files.""" + raise NotImplementedError() + def describe_table(self, name: str): """Asks the backend to describe the given table.""" raise NotImplementedError() @@ -43,6 +47,10 @@ def drop_table(self, name: str) -> None: """Asks the backend to drop the given table.""" raise NotImplementedError() + def convert_sql(self, sql: str) -> plan_pb2.Plan: + """Convert SQL into a Substrait plan.""" + raise NotImplementedError() + @staticmethod def expand_location(location: Path | str) -> list[str]: """Expand the location of a file or directory into a list of files.""" diff --git a/src/gateway/backends/backend_options.py b/src/gateway/backends/backend_options.py index 5f0a578..d1c05e3 100644 --- a/src/gateway/backends/backend_options.py +++ b/src/gateway/backends/backend_options.py @@ -4,24 +4,29 @@ from enum import Enum -class Backend(Enum): +class BackendEngine(Enum): """Represents the different backends we have support for.""" ARROW = 1 DATAFUSION = 2 DUCKDB = 3 + def __str__(self): + """Return the string representation of the backend.""" + return self.name.lower() + @dataclasses.dataclass class BackendOptions: """Holds all the possible backend options.""" - backend: Backend + backend: BackendEngine use_adbc: bool - def __init__(self, backend: Backend, use_adbc: bool = False): + def __init__(self, backend: BackendEngine, use_adbc: bool = False): """Create a BackendOptions structure.""" self.backend = backend self.use_adbc = use_adbc self.use_arrow_uri_workaround = False + self.use_duckdb_python_api = False diff --git a/src/gateway/backends/backend_selector.py b/src/gateway/backends/backend_selector.py index fa88625..c6b6696 100644 --- a/src/gateway/backends/backend_selector.py +++ b/src/gateway/backends/backend_selector.py @@ -3,7 +3,7 @@ from gateway.backends import backend from gateway.backends.adbc_backend import AdbcBackend from gateway.backends.arrow_backend import ArrowBackend -from gateway.backends.backend_options import Backend, BackendOptions +from gateway.backends.backend_options import BackendEngine, BackendOptions from gateway.backends.datafusion_backend import DatafusionBackend from gateway.backends.duckdb_backend import DuckDBBackend @@ -11,11 +11,11 @@ def find_backend(options: BackendOptions) -> backend.Backend: """Given a backend enum, returns an instance of the correct Backend descendant.""" match options.backend: - case Backend.ARROW: + case BackendEngine.ARROW: return ArrowBackend(options) - case Backend.DATAFUSION: + case BackendEngine.DATAFUSION: return DatafusionBackend(options) - case Backend.DUCKDB: + case BackendEngine.DUCKDB: if options.use_adbc: return AdbcBackend(options) return DuckDBBackend(options) diff --git a/src/gateway/backends/datafusion_backend.py b/src/gateway/backends/datafusion_backend.py index c722198..88baa1d 100644 --- a/src/gateway/backends/datafusion_backend.py +++ b/src/gateway/backends/datafusion_backend.py @@ -9,6 +9,19 @@ from gateway.converter.rename_functions import RenameFunctionsForDatafusion from gateway.converter.replace_local_files import ReplaceLocalFilesWithNamedTable +_DATAFUSION_TO_ARROW = { + 'Boolean': pa.bool_(), + 'Int8': pa.int8(), + 'Int16': pa.int16(), + 'Int32': pa.int32(), + 'Int64': pa.int64(), + 'Float32': pa.float32(), + 'Float64': pa.float64(), + 'Date32': pa.date32(), + 'Timestamp(Nanosecond, None)': pa.timestamp('ns'), + 'Utf8': pa.string(), +} + # pylint: disable=import-outside-toplevel class DatafusionBackend(Backend): @@ -23,6 +36,7 @@ def __init__(self, options): def create_connection(self) -> None: """Create a connection to the backend.""" import datafusion + self._connection = datafusion.SessionContext() def execute(self, plan: plan_pb2.Plan) -> pa.lib.Table: @@ -33,10 +47,9 @@ def execute(self, plan: plan_pb2.Plan) -> pa.lib.Table: registered_tables = set() for files in file_groups: table_name = files[0] - for file in files[1]: - if table_name not in registered_tables: - self.register_table(table_name, file) - registered_tables.add(files[0]) + location = Path(files[1][0]).parent + self.register_table(table_name, location) + registered_tables.add(table_name) RenameFunctionsForDatafusion().visit_plan(plan) @@ -59,7 +72,31 @@ def execute(self, plan: plan_pb2.Plan) -> pa.lib.Table: for table_name in registered_tables: self._connection.deregister_table(table_name) - def register_table(self, name: str, path: Path, file_format: str = 'parquet') -> None: + def register_table( + self, name: str, location: Path, file_format: str = 'parquet' + ) -> None: """Register the given table with the backend.""" - files = Backend.expand_location(path) - self._connection.register_parquet(name, files[0]) + files = Backend.expand_location(location) + if not files: + raise ValueError(f"No parquet files found at {location}") + # TODO: Add options to skip table registration if it already exists instead + # of deregistering it. + if self._connection.table_exist(name): + self._connection.deregister_table(name) + self._connection.register_parquet(name, str(location)) + + def describe_files(self, paths: list[str]): + """Asks the backend to describe the given files.""" + # TODO -- Use the ListingTable API to resolve the combined schema. + df = self._connection.read_parquet(paths[0]) + return df.schema() + + def describe_table(self, table_name: str): + """Asks the backend to describe the given table.""" + result = self._connection.sql(f"describe {table_name}").to_arrow_table().to_pylist() + + fields = [] + for index in range(len(result)): + fields.append(pa.field(result[index]['column_name'], + _DATAFUSION_TO_ARROW[result[index]['data_type']])) + return pa.schema(fields) diff --git a/src/gateway/backends/duckdb_backend.py b/src/gateway/backends/duckdb_backend.py index c35078c..586a013 100644 --- a/src/gateway/backends/duckdb_backend.py +++ b/src/gateway/backends/duckdb_backend.py @@ -8,6 +8,19 @@ from gateway.backends.backend import Backend +_DUCKDB_TO_ARROW = { + 'BOOLEAN': pa.bool_(), + 'TINYINT': pa.int8(), + 'SMALLINT': pa.int16(), + 'INTEGER': pa.int32(), + 'BIGINT': pa.int64(), + 'FLOAT': pa.float32(), + 'DOUBLE': pa.float64(), + 'DATE': pa.date32(), + 'TIMESTAMP': pa.timestamp('ns'), + 'VARCHAR': pa.string(), +} + # pylint: disable=fixme class DuckDBBackend(Backend): @@ -18,6 +31,7 @@ def __init__(self, options): self._connection = None super().__init__(options) self.create_connection() + self._use_duckdb_python_api = options.use_duckdb_python_api def create_connection(self): """Create a connection to the backend.""" @@ -44,12 +58,51 @@ def execute(self, plan: plan_pb2.Plan) -> pa.lib.Table: df = query_result.df() return pa.Table.from_pandas(df=df) - def register_table(self, table_name: str, location: Path, file_format: str = 'parquet') -> None: + def register_table( + self, + table_name: str, + location: Path, + file_format: str = "parquet" + ) -> None: """Register the given table with the backend.""" files = Backend.expand_location(location) if not files: raise ValueError(f"No parquet files found at {location}") - files_str = ', '.join([f"'{f}'" for f in files]) - files_sql = f"CREATE OR REPLACE TABLE {table_name} AS FROM read_parquet([{files_str}])" - self._connection.execute(files_sql) + if self._use_duckdb_python_api: + self._connection.register(table_name, self._connection.read_parquet(files)) + else: + files_str = ', '.join([f"'{f}'" for f in files]) + files_sql = f"CREATE OR REPLACE TABLE {table_name} AS FROM read_parquet([{files_str}])" + self._connection.execute(files_sql) + + def describe_files(self, paths: list[str]): + """Asks the backend to describe the given files.""" + files = paths + if len(paths) == 1: + files = self.expand_location(paths[0]) + df = self._connection.read_parquet(files) + + fields = [] + for name, field_type in zip(df.columns, df.types, strict=False): + if name == 'aggr': + # This isn't a real column. + continue + fields.append(pa.field(name, _DUCKDB_TO_ARROW[str(field_type)])) + return pa.schema(fields) + + def describe_table(self, name: str): + """Asks the backend to describe the given table.""" + df = self._connection.execute(f'DESCRIBE {name}').fetchdf() + + fields = [] + for name, field_type in zip(df.column_name, df.column_type, strict=False): + fields.append(pa.field(name, _DUCKDB_TO_ARROW[str(field_type)])) + return pa.schema(fields) + + def convert_sql(self, sql: str) -> plan_pb2.Plan: + """Convert SQL into a Substrait plan.""" + plan = plan_pb2.Plan() + proto_bytes = self._connection.get_substrait(query=sql).fetchone()[0] + plan.ParseFromString(proto_bytes) + return plan diff --git a/src/gateway/converter/add_extension_uris.py b/src/gateway/converter/add_extension_uris.py new file mode 100644 index 0000000..ae15053 --- /dev/null +++ b/src/gateway/converter/add_extension_uris.py @@ -0,0 +1,30 @@ +# SPDX-License-Identifier: Apache-2.0 +"""A library to search Substrait plan for local files.""" +from gateway.converter.substrait_plan_visitor import SubstraitPlanVisitor +from substrait.gen.proto import plan_pb2 +from substrait.gen.proto.extensions import extensions_pb2 + + +# pylint: disable=E1101,no-member +class AddExtensionUris(SubstraitPlanVisitor): + """Ensures that the plan has extension URI definitions for all references.""" + + def visit_plan(self, plan: plan_pb2.Plan) -> None: + """Modify the provided plan so that all functions have URI references.""" + super().visit_plan(plan) + + known_uris: list[int] = [] + for uri in plan.extension_uris: + known_uris.append(uri.extension_uri_anchor) + + for extension in plan.extensions: + if extension.WhichOneof('mapping_type') != 'extension_function': + continue + + if extension.extension_function.extension_uri_reference not in known_uris: + # TODO -- Make sure this hack occurs at most once. + uri = extensions_pb2.SimpleExtensionURI( + uri='urn:arrow:substrait_simple_extension_function', + extension_uri_anchor=extension.extension_function.extension_uri_reference) + plan.extension_uris.append(uri) + known_uris.append(extension.extension_function.extension_uri_reference) diff --git a/src/gateway/converter/conversion_options.py b/src/gateway/converter/conversion_options.py index 3235720..1ae902f 100644 --- a/src/gateway/converter/conversion_options.py +++ b/src/gateway/converter/conversion_options.py @@ -2,7 +2,7 @@ """Tracks conversion related options.""" import dataclasses -from gateway.backends.backend_options import Backend, BackendOptions +from gateway.backends.backend_options import BackendEngine, BackendOptions # pylint: disable=too-many-instance-attributes @@ -17,6 +17,8 @@ def __init__(self, backend: BackendOptions = None): self.use_emits_instead_of_direct = False self.use_switch_expressions_where_possible = True self.use_duckdb_regexp_matches_function = False + self.duckdb_project_emit_workaround = False + self.safety_project_read_relations = False self.return_names_with_types = False @@ -27,23 +29,25 @@ def __init__(self, backend: BackendOptions = None): def arrow(): """Return standard options to connect to the Acero backend.""" - options = ConversionOptions(backend=BackendOptions(Backend.ARROW)) + options = ConversionOptions(backend=BackendOptions(BackendEngine.ARROW)) options.needs_scheme_in_path_uris = True options.return_names_with_types = True - options.implement_show_string = False options.backend.use_arrow_uri_workaround = True + options.safety_project_read_relations = True return options def datafusion(): """Return standard options to connect to a Datafusion backend.""" - return ConversionOptions(backend=BackendOptions(Backend.DATAFUSION)) + return ConversionOptions(backend=BackendOptions(BackendEngine.DATAFUSION)) def duck_db(): """Return standard options to connect to a DuckDB backend.""" - options = ConversionOptions(backend=BackendOptions(Backend.DUCKDB)) + options = ConversionOptions(backend=BackendOptions(BackendEngine.DUCKDB)) options.return_names_with_types = True options.use_switch_expressions_where_possible = False options.use_duckdb_regexp_matches_function = True + options.duckdb_project_emit_workaround = True + options.backend.use_duckdb_python_api = False return options diff --git a/src/gateway/converter/data/00001.splan b/src/gateway/converter/data/00001.splan index eeac539..1dad38f 100644 --- a/src/gateway/converter/data/00001.splan +++ b/src/gateway/converter/data/00001.splan @@ -78,15 +78,15 @@ relations { project { common { emit { - output_mapping: 0 - output_mapping: 1 - output_mapping: 2 - output_mapping: 3 - output_mapping: 4 - output_mapping: 5 - output_mapping: 6 - output_mapping: 7 - output_mapping: 8 + output_mapping: 11 + output_mapping: 12 + output_mapping: 13 + output_mapping: 14 + output_mapping: 15 + output_mapping: 16 + output_mapping: 17 + output_mapping: 18 + output_mapping: 19 output_mapping: 10 } } @@ -94,32 +94,32 @@ relations { project { common { emit { - output_mapping: 0 - output_mapping: 1 - output_mapping: 2 - output_mapping: 3 - output_mapping: 4 - output_mapping: 5 - output_mapping: 6 + output_mapping: 11 + output_mapping: 12 + output_mapping: 13 + output_mapping: 14 + output_mapping: 15 + output_mapping: 16 + output_mapping: 17 output_mapping: 10 - output_mapping: 8 - output_mapping: 9 + output_mapping: 18 + output_mapping: 19 } } input { project { common { emit { - output_mapping: 0 - output_mapping: 1 - output_mapping: 2 - output_mapping: 3 - output_mapping: 4 - output_mapping: 5 + output_mapping: 11 + output_mapping: 12 + output_mapping: 13 + output_mapping: 14 + output_mapping: 15 + output_mapping: 16 output_mapping: 10 - output_mapping: 7 - output_mapping: 8 - output_mapping: 9 + output_mapping: 17 + output_mapping: 18 + output_mapping: 19 } } input { @@ -232,6 +232,104 @@ relations { } } } + expressions { + selection { + direct_reference { + struct_field { + } + } + root_reference { + } + } + } + expressions { + selection { + direct_reference { + struct_field { + field: 1 + } + } + root_reference { + } + } + } + expressions { + selection { + direct_reference { + struct_field { + field: 2 + } + } + root_reference { + } + } + } + expressions { + selection { + direct_reference { + struct_field { + field: 3 + } + } + root_reference { + } + } + } + expressions { + selection { + direct_reference { + struct_field { + field: 4 + } + } + root_reference { + } + } + } + expressions { + selection { + direct_reference { + struct_field { + field: 5 + } + } + root_reference { + } + } + } + expressions { + selection { + direct_reference { + struct_field { + field: 7 + } + } + root_reference { + } + } + } + expressions { + selection { + direct_reference { + struct_field { + field: 8 + } + } + root_reference { + } + } + } + expressions { + selection { + direct_reference { + struct_field { + field: 9 + } + } + root_reference { + } + } + } } } expressions { @@ -252,6 +350,105 @@ relations { } } } + failure_behavior: FAILURE_BEHAVIOR_THROW_EXCEPTION + } + } + expressions { + selection { + direct_reference { + struct_field { + } + } + root_reference { + } + } + } + expressions { + selection { + direct_reference { + struct_field { + field: 1 + } + } + root_reference { + } + } + } + expressions { + selection { + direct_reference { + struct_field { + field: 2 + } + } + root_reference { + } + } + } + expressions { + selection { + direct_reference { + struct_field { + field: 3 + } + } + root_reference { + } + } + } + expressions { + selection { + direct_reference { + struct_field { + field: 4 + } + } + root_reference { + } + } + } + expressions { + selection { + direct_reference { + struct_field { + field: 5 + } + } + root_reference { + } + } + } + expressions { + selection { + direct_reference { + struct_field { + field: 6 + } + } + root_reference { + } + } + } + expressions { + selection { + direct_reference { + struct_field { + field: 8 + } + } + root_reference { + } + } + } + expressions { + selection { + direct_reference { + struct_field { + field: 9 + } + } + root_reference { + } } } } @@ -274,6 +471,105 @@ relations { } } } + failure_behavior: FAILURE_BEHAVIOR_THROW_EXCEPTION + } + } + expressions { + selection { + direct_reference { + struct_field { + } + } + root_reference { + } + } + } + expressions { + selection { + direct_reference { + struct_field { + field: 1 + } + } + root_reference { + } + } + } + expressions { + selection { + direct_reference { + struct_field { + field: 2 + } + } + root_reference { + } + } + } + expressions { + selection { + direct_reference { + struct_field { + field: 3 + } + } + root_reference { + } + } + } + expressions { + selection { + direct_reference { + struct_field { + field: 4 + } + } + root_reference { + } + } + } + expressions { + selection { + direct_reference { + struct_field { + field: 5 + } + } + root_reference { + } + } + } + expressions { + selection { + direct_reference { + struct_field { + field: 6 + } + } + root_reference { + } + } + } + expressions { + selection { + direct_reference { + struct_field { + field: 7 + } + } + root_reference { + } + } + } + expressions { + selection { + direct_reference { + struct_field { + field: 8 + } + } + root_reference { + } } } } diff --git a/src/gateway/converter/rename_functions.py b/src/gateway/converter/rename_functions.py index 34f1c9c..71ac19b 100644 --- a/src/gateway/converter/rename_functions.py +++ b/src/gateway/converter/rename_functions.py @@ -2,6 +2,7 @@ """A library to search Substrait plan for local files.""" from gateway.converter.substrait_plan_visitor import SubstraitPlanVisitor from substrait.gen.proto import plan_pb2 +from substrait.gen.proto.extensions import extensions_pb2 # pylint: disable=no-member,fixme @@ -46,6 +47,20 @@ def __init__(self, use_uri_workaround=False): self._use_uri_workaround = use_uri_workaround super().__init__() + def _find_arrow_uri_reference(self, plan: plan_pb2.Plan) -> int: + """Find the URI reference for the Arrow workaround.""" + biggest_reference = -1 + for extension in plan.extension_uris: + if extension.uri == 'urn:arrow:substrait_simple_extension_function': + return extension.extension_uri_anchor + if extension.extension_uri_anchor > biggest_reference: + biggest_reference = extension.extension_uri_anchor + plan.extension_uris.append(extensions_pb2.SimpleExtensionURI( + extension_uri_anchor=biggest_reference + 1, + uri='urn:arrow:substrait_simple_extension_function')) + self._extensions[biggest_reference + 1] = 'urn:arrow:substrait_simple_extension_function' + return biggest_reference + 1 + def normalize_extension_uris(self, plan: plan_pb2.Plan) -> None: """Normalize the URI.""" for extension in plan.extension_uris: @@ -83,7 +98,23 @@ def visit_plan(self, plan: plan_pb2.Plan) -> None: changed = False if name == 'char_length': changed = True + extension.extension_function.extension_uri_reference = ( + self._find_arrow_uri_reference(plan)) name = 'utf8_length' + elif name == 'max': + changed = True + extension.extension_function.extension_uri_reference = ( + self._find_arrow_uri_reference(plan)) + elif name == 'gt': + changed = True + extension.extension_function.extension_uri_reference = ( + self._find_arrow_uri_reference(plan)) + name = 'greater' + elif name == 'lt': + changed = True + extension.extension_function.extension_uri_reference = ( + self._find_arrow_uri_reference(plan)) + name = 'less' if not changed: continue diff --git a/src/gateway/converter/spark_functions.py b/src/gateway/converter/spark_functions.py index bc78ded..721cf50 100644 --- a/src/gateway/converter/spark_functions.py +++ b/src/gateway/converter/spark_functions.py @@ -42,19 +42,23 @@ def __lt__(self, obj) -> bool: bool=type_pb2.Type.Boolean( nullability=type_pb2.Type.Nullability.NULLABILITY_REQUIRED))), '<=': ExtensionFunction( - '/functions_comparison.yaml', 'lte:str_str', type_pb2.Type( + '/functions_comparison.yaml', 'lte:i64_i64', type_pb2.Type( bool=type_pb2.Type.Boolean( nullability=type_pb2.Type.Nullability.NULLABILITY_REQUIRED))), '>=': ExtensionFunction( - '/functions_comparison.yaml', 'gte:str_str', type_pb2.Type( + '/functions_comparison.yaml', 'gte:i64_i64', type_pb2.Type( bool=type_pb2.Type.Boolean( nullability=type_pb2.Type.Nullability.NULLABILITY_REQUIRED))), '<': ExtensionFunction( - '/functions_comparison.yaml', 'lt:str_str', type_pb2.Type( + '/functions_comparison.yaml', 'lt:i64_i64', type_pb2.Type( bool=type_pb2.Type.Boolean( nullability=type_pb2.Type.Nullability.NULLABILITY_REQUIRED))), '>': ExtensionFunction( - '/functions_comparison.yaml', 'gt:str_str', type_pb2.Type( + '/functions_comparison.yaml', 'gt:i64_i64', type_pb2.Type( + bool=type_pb2.Type.Boolean( + nullability=type_pb2.Type.Nullability.NULLABILITY_REQUIRED))), + 'isnull': ExtensionFunction( + '/functions_comparison.yaml', 'is_null:int', type_pb2.Type( bool=type_pb2.Type.Boolean( nullability=type_pb2.Type.Nullability.NULLABILITY_REQUIRED))), '+': ExtensionFunction( @@ -121,7 +125,11 @@ def __lt__(self, obj) -> bool: i64=type_pb2.Type.I64( nullability=type_pb2.Type.Nullability.NULLABILITY_REQUIRED))), 'max': ExtensionFunction( - '/functions_aggregate.yaml', 'max:i64', type_pb2.Type( + '/unknown.yaml', 'max:i64', type_pb2.Type( + i64=type_pb2.Type.I64( + nullability=type_pb2.Type.Nullability.NULLABILITY_REQUIRED))), + 'min': ExtensionFunction( + '/unknown.yaml', 'min:i64', type_pb2.Type( i64=type_pb2.Type.I64( nullability=type_pb2.Type.Nullability.NULLABILITY_REQUIRED))), 'string_agg': ExtensionFunction( @@ -156,6 +164,10 @@ def __lt__(self, obj) -> bool: '/functions_aggregate_generic.yaml', 'count:any', type_pb2.Type( i64=type_pb2.Type.I64( nullability=type_pb2.Type.Nullability.NULLABILITY_REQUIRED))), + 'approx_count_distinct': ExtensionFunction( + '/functions_aggregate_approx.yaml', 'approx_count_distinct:any', + type_pb2.Type(i64=type_pb2.Type.I64( + nullability=type_pb2.Type.Nullability.NULLABILITY_REQUIRED))), 'any_value': ExtensionFunction( '/functions_aggregate_generic.yaml', 'any_value:any', type_pb2.Type( i64=type_pb2.Type.I64( diff --git a/src/gateway/converter/spark_to_substrait.py b/src/gateway/converter/spark_to_substrait.py index f1ed07f..7960c36 100644 --- a/src/gateway/converter/spark_to_substrait.py +++ b/src/gateway/converter/spark_to_substrait.py @@ -10,15 +10,13 @@ import pyspark.sql.connect.proto.expressions_pb2 as spark_exprs_pb2 import pyspark.sql.connect.proto.relations_pb2 as spark_relations_pb2 import pyspark.sql.connect.proto.types_pb2 as spark_types_pb2 -from gateway.backends.backend_options import BackendOptions -from gateway.backends.backend_selector import find_backend from gateway.converter.conversion_options import ConversionOptions from gateway.converter.spark_functions import ExtensionFunction, lookup_spark_function -from gateway.converter.sql_to_substrait import convert_sql from gateway.converter.substrait_builder import ( aggregate_relation, bigint_literal, bool_literal, + bool_type, cast_operation, concat, equal_function, @@ -45,8 +43,6 @@ from substrait.gen.proto import algebra_pb2, plan_pb2, type_pb2 from substrait.gen.proto.extensions import extensions_pb2 -TABLE_NAME = "my_table" - # ruff: noqa: RUF005 class SparkSubstraitConverter: @@ -62,6 +58,13 @@ def __init__(self, options: ConversionOptions): self._seen_generated_names = {} self._saved_extension_uris = {} self._saved_extensions = {} + self._backend = None + self._sql_backend = None + + def set_backends(self, backend, sql_backend) -> None: + """Save the backends being used to resolve tables and convert to SQL.""" + self._backend = backend + self._sql_backend = sql_backend def lookup_function_by_name(self, name: str) -> ExtensionFunction: """Find the function reference for a given Spark function name.""" @@ -81,14 +84,19 @@ def update_field_references(self, plan_id: int) -> None: """Use the field references using the specified portion of the plan.""" source_symbol = self._symbol_table.get_symbol(plan_id) current_symbol = self._symbol_table.get_symbol(self._current_plan_id) - current_symbol.input_fields.extend(source_symbol.output_fields) - current_symbol.output_fields.extend(current_symbol.input_fields) + original_output_fields = current_symbol.output_fields + for symbol in source_symbol.output_fields: + new_name = symbol + while new_name in original_output_fields: + new_name = new_name + '_dup' + current_symbol.input_fields.append(new_name) + current_symbol.output_fields.append(new_name) def find_field_by_name(self, field_name: str) -> int | None: """Look up the field name in the current set of field references.""" current_symbol = self._symbol_table.get_symbol(self._current_plan_id) try: - return current_symbol.output_fields.index(field_name) + return current_symbol.input_fields.index(field_name) except ValueError: return None @@ -185,11 +193,11 @@ def convert_unresolved_attribute( root_reference=algebra_pb2.Expression.FieldReference.RootReference())) def determine_type_of_expression(self, expr: algebra_pb2.Expression) -> type_pb2.Type: - """Determine the type of a Substrait expression.""" + """Determine the type of the Substrait expression.""" if expr.WhichOneof('rex_type') == 'literal': match expr.literal.WhichOneof('literal_type'): case 'boolean': - return type_pb2.Type(bool=type_pb2.Type.Boolean()) + return bool_type() case 'i8': return type_pb2.Type(i8=type_pb2.Type.I8()) case 'i16': @@ -212,7 +220,7 @@ def determine_type_of_expression(self, expr: algebra_pb2.Expression) -> type_pb2 return expr.scalar_function.output_type if expr.WhichOneof('rex_type') == 'selection': # TODO -- Figure out how to determine the type of a field reference. - return type_pb2.Type(i32=type_pb2.Type.I32()) + return type_pb2.Type(i64=type_pb2.Type.I64()) raise NotImplementedError( 'Type determination not implemented for expressions of type ' f'{expr.WhichOneof("rex_type")}.') @@ -231,10 +239,14 @@ def convert_when_function( getattr(ifthen, 'else').CopyFrom( self.convert_expression(when.arguments[len(when.arguments) - 1])) else: + nullable_literal = self.determine_type_of_expression(ifthen.ifs[-1].then) + kind = nullable_literal.WhichOneof('kind') + getattr(nullable_literal, kind).nullability = ( + type_pb2.Type.Nullability.NULLABILITY_NULLABLE) getattr(ifthen, 'else').CopyFrom( algebra_pb2.Expression( literal=algebra_pb2.Expression.Literal( - null=self.determine_type_of_expression(ifthen.ifs[-1].then)))) + null=nullable_literal))) return algebra_pb2.Expression(if_then=ifthen) @@ -325,15 +337,17 @@ def convert_unresolved_function( break func.arguments.append( algebra_pb2.FunctionArgument(value=self.convert_expression(arg))) - if unresolved_function.is_distinct: - raise NotImplementedError( - 'Treating arguments as distinct is not supported for unresolved functions.') func.output_type.CopyFrom(function_def.output_type) + if unresolved_function.function_name == 'substring': + original_argument = func.arguments[0] + func.arguments[0].CopyFrom(algebra_pb2.FunctionArgument( + value=cast_operation(original_argument.value, string_type()))) return algebra_pb2.Expression(scalar_function=func) def convert_alias_expression( self, alias: spark_exprs_pb2.Expression.Alias) -> algebra_pb2.Expression: """Convert a Spark alias into a Substrait expression.""" + # We do nothing here and let the magic happen in the calling project relation. return self.convert_expression(alias.expr) def convert_type_str(self, spark_type_str: str | None) -> type_pb2.Type: @@ -360,7 +374,9 @@ def convert_type(self, spark_type: spark_types_pb2.DataType) -> type_pb2.Type: def convert_cast_expression( self, cast: spark_exprs_pb2.Expression.Cast) -> algebra_pb2.Expression: """Convert a Spark cast expression into a Substrait cast expression.""" - cast_rel = algebra_pb2.Expression.Cast(input=self.convert_expression(cast.expr)) + cast_rel = algebra_pb2.Expression.Cast( + input=self.convert_expression(cast.expr), + failure_behavior=algebra_pb2.Expression.Cast.FAILURE_BEHAVIOR_THROW_EXCEPTION) match cast.WhichOneof('cast_to_type'): case 'type': cast_rel.type.CopyFrom(self.convert_type(cast.type)) @@ -420,12 +436,23 @@ def convert_expression(self, expr: spark_exprs_pb2.Expression) -> algebra_pb2.Ex f'Unexpected expression type: {expr.WhichOneof("expr_type")}') return result + def is_distinct(self, expr: spark_exprs_pb2.Expression) -> bool: + """Determine if the expression is distinct.""" + if expr.WhichOneof( + 'expr_type') == 'unresolved_function' and expr.unresolved_function.is_distinct: + return True + if expr.WhichOneof('expr_type') == 'alias': + return self.is_distinct(expr.alias.expr) + return False + def convert_expression_to_aggregate_function( self, expr: spark_exprs_pb2.Expression) -> algebra_pb2.AggregateFunction: """Convert a SparkConnect expression to a Substrait expression.""" func = algebra_pb2.AggregateFunction( phase=algebra_pb2.AggregationPhase.AGGREGATION_PHASE_INITIAL_TO_RESULT) + if self.is_distinct(expr): + func.invocation = algebra_pb2.AggregateFunction.AGGREGATION_INVOCATION_DISTINCT expression = self.convert_expression(expr) match expression.WhichOneof('rex_type'): case 'scalar_function': @@ -449,10 +476,8 @@ def convert_read_named_table_relation( """Convert a read named table relation to a Substrait relation.""" table_name = rel.unparsed_identifier - backend = find_backend(BackendOptions(self._conversion_options.backend.backend, True)) - tpch_location = backend.find_tpch() - backend.register_table(table_name, tpch_location / table_name) - arrow_schema = backend.describe_table(table_name) + arrow_schema = self._backend.describe_table(table_name) + schema = self.convert_arrow_schema(arrow_schema) symbol = self._symbol_table.get_symbol(self._current_plan_id) @@ -462,7 +487,8 @@ def convert_read_named_table_relation( return algebra_pb2.Rel( read=algebra_pb2.ReadRel( base_schema=schema, - named_table=algebra_pb2.ReadRel.NamedTable(names=[table_name]))) + named_table=algebra_pb2.ReadRel.NamedTable(names=[table_name]), + common=self.create_common_relation())) def convert_schema(self, schema_str: str) -> type_pb2.NamedStruct | None: """Convert the Spark JSON schema string into a Substrait named type structure.""" @@ -551,13 +577,8 @@ def convert_read_data_source_relation(self, rel: spark_relations_pb2.Read) -> al local = algebra_pb2.ReadRel.LocalFiles() schema = self.convert_schema(rel.schema) if not schema: - backend = find_backend(BackendOptions(self._conversion_options.backend.backend, True)) - try: - backend.register_table(TABLE_NAME, rel.paths[0], rel.format) - arrow_schema = backend.describe_table(TABLE_NAME) - schema = self.convert_arrow_schema(arrow_schema) - finally: - backend.drop_table(TABLE_NAME) + arrow_schema = self._backend.describe_files([str(path) for path in rel.paths]) + schema = self.convert_arrow_schema(arrow_schema) symbol = self._symbol_table.get_symbol(self._current_plan_id) for field_name in schema.names: symbol.output_fields.append(field_name) @@ -565,7 +586,8 @@ def convert_read_data_source_relation(self, rel: spark_relations_pb2.Read) -> al return algebra_pb2.Rel( read=algebra_pb2.ReadRel(base_schema=schema, named_table=algebra_pb2.ReadRel.NamedTable( - names=['demotable']))) + names=['demotable']), + common=self.create_common_relation())) if pathlib.Path(rel.paths[0]).is_dir(): file_paths = glob.glob(f'{rel.paths[0]}/*{rel.format}') else: @@ -600,7 +622,19 @@ def convert_read_data_source_relation(self, rel: spark_relations_pb2.Read) -> al case _: raise NotImplementedError(f'Unexpected file format: {rel.format}') local.items.append(file_or_files) - return algebra_pb2.Rel(read=algebra_pb2.ReadRel(base_schema=schema, local_files=local)) + result = algebra_pb2.Rel(read=algebra_pb2.ReadRel(base_schema=schema, local_files=local, + common=self.create_common_relation())) + if not self._conversion_options.safety_project_read_relations: + return result + + project = algebra_pb2.ProjectRel( + input=result, + common=algebra_pb2.RelCommon(direct=algebra_pb2.RelCommon.Direct())) + for field_number in range(len(symbol.output_fields)): + project.expressions.append(field_reference(field_number)) + project.common.emit.output_mapping.append(field_number) + + return algebra_pb2.Rel(project=project) def create_common_relation(self, emit_overrides=None) -> algebra_pb2.RelCommon: """Create the common metadata relation used by all relations.""" @@ -627,7 +661,6 @@ def convert_read_relation(self, rel: spark_relations_pb2.Read) -> algebra_pb2.Re result = self.convert_read_data_source_relation(rel.data_source) case _: raise ValueError(f'Unexpected read type: {rel.WhichOneof("read_type")}') - result.read.common.CopyFrom(self.create_common_relation()) return result def convert_filter_relation(self, rel: spark_relations_pb2.Filter) -> algebra_pb2.Rel: @@ -700,6 +733,10 @@ def convert_aggregate_relation(self, rel: spark_relations_pb2.Aggregate) -> alge symbol.generated_fields.append(self.determine_expression_name(expr)) symbol.output_fields.clear() symbol.output_fields.extend(symbol.generated_fields) + if len(rel.grouping_expressions) > 1: + # Hide the grouping source from the downstream relations. + for i in range(len(rel.grouping_expressions) + len(rel.aggregate_expressions)): + aggregate.common.emit.output_mapping.append(i) return algebra_pb2.Rel(aggregate=aggregate) # pylint: disable=too-many-locals,pointless-string-statement @@ -909,6 +946,12 @@ def convert_with_columns_relation( symbol.output_fields.append(name) project.common.CopyFrom(self.create_common_relation()) if remapped: + if self._conversion_options.duckdb_project_emit_workaround: + for field_number in range(len(symbol.input_fields)): + if field_number == mapping[field_number]: + project.expressions.append(field_reference(field_number)) + mapping[field_number] = len(symbol.input_fields) + ( + len(project.expressions)) - 1 for item in mapping: project.common.emit.output_mapping.append(item) return algebra_pb2.Rel(project=project) @@ -968,7 +1011,8 @@ def convert_local_relation(self, rel: spark_relations_pb2.LocalRelation) -> alge def convert_sql_relation(self, rel: spark_relations_pb2.SQL) -> algebra_pb2.Rel: """Convert a Spark SQL relation into a Substrait relation.""" - plan = convert_sql(rel.query) + # TODO -- Handle multithreading in the case with a persistent backend. + plan = self._sql_backend.convert_sql(rel.query) symbol = self._symbol_table.get_symbol(self._current_plan_id) for field_name in plan.relations[0].root.names: symbol.output_fields.append(field_name) @@ -1011,12 +1055,14 @@ def convert_cross_join_relation(self, rel: spark_relations_pb2.Join) -> algebra_ if rel.HasField('join_condition'): raise ValueError('Cross joins do not support having a join condition.') join.common.CopyFrom(self.create_common_relation()) - return algebra_pb2.Rel(join=join) + return algebra_pb2.Rel(cross=join) def convert_join_relation(self, rel: spark_relations_pb2.Join) -> algebra_pb2.Rel: """Convert a Spark join relation into a Substrait join relation.""" if rel.join_type == spark_relations_pb2.Join.JOIN_TYPE_CROSS: return self.convert_cross_join_relation(rel) + if not rel.HasField('join_condition') and not rel.using_columns: + return self.convert_cross_join_relation(rel) join = algebra_pb2.JoinRel(left=self.convert_relation(rel.left), right=self.convert_relation(rel.right)) self.update_field_references(rel.left.common.plan_id) @@ -1036,21 +1082,27 @@ def convert_project_relation( symbol = self._symbol_table.get_symbol(self._current_plan_id) for field_number, expr in enumerate(rel.expressions): project.expressions.append(self.convert_expression(expr)) - if expr.HasField('alias'): + if expr.WhichOneof('expr_type') == 'alias': name = expr.alias.name[0] + elif expr.WhichOneof('expr_type') == 'unresolved_attribute': + name = expr.unresolved_attribute.unparsed_identifier else: name = f'generated_field_{field_number}' symbol.generated_fields.append(name) symbol.output_fields.append(name) project.common.CopyFrom(self.create_common_relation()) + symbol.output_fields = symbol.generated_fields + for field_number in range(len(rel.expressions)): + project.common.emit.output_mapping.append(field_number + len(symbol.input_fields)) return algebra_pb2.Rel(project=project) def convert_subquery_alias_relation(self, rel: spark_relations_pb2.SubqueryAlias) -> algebra_pb2.Rel: """Convert a Spark subquery alias relation into a Substrait relation.""" - # TODO -- Utilize rel.alias somehow. result = self.convert_relation(rel.input) self.update_field_references(rel.input.common.plan_id) + symbol = self._symbol_table.get_symbol(self._current_plan_id) + symbol.output_fields[-1] = rel.alias return result def convert_deduplicate_relation(self, rel: spark_relations_pb2.Deduplicate) -> algebra_pb2.Rel: @@ -1070,9 +1122,15 @@ def convert_deduplicate_relation(self, rel: spark_relations_pb2.Deduplicate) -> function_reference=any_value_func.anchor, arguments=[algebra_pb2.FunctionArgument(value=field_reference(idx))], phase=algebra_pb2.AggregationPhase.AGGREGATION_PHASE_INITIAL_TO_RESULT, - output_type=type_pb2.Type(bool=type_pb2.Type.Boolean())))) + output_type=type_pb2.Type(bool=type_pb2.Type.Boolean( + nullability=type_pb2.Type.NULLABILITY_REQUIRED))))) symbol.generated_fields.append(field) - return algebra_pb2.Rel(aggregate=aggregate) + aggr = algebra_pb2.Rel(aggregate=aggregate) + project = project_relation( + aggr, [field_reference(idx) for idx in range(len(symbol.input_fields))]) + for idx in range(len(symbol.input_fields)): + project.project.common.emit.output_mapping.append(idx) + return project def convert_relation(self, rel: spark_relations_pb2.Relation) -> algebra_pb2.Rel: """Convert a Spark relation into a Substrait one.""" diff --git a/src/gateway/converter/spark_to_substrait_test.py b/src/gateway/converter/spark_to_substrait_test.py index 685c71d..17909da 100644 --- a/src/gateway/converter/spark_to_substrait_test.py +++ b/src/gateway/converter/spark_to_substrait_test.py @@ -3,10 +3,11 @@ from pathlib import Path import pytest +from gateway.backends.backend_selector import find_backend from gateway.converter.conversion_options import duck_db from gateway.converter.spark_to_substrait import SparkSubstraitConverter -from gateway.converter.sql_to_substrait import convert_sql from gateway.demo.mystream_database import create_mystream_database, delete_mystream_database +from gateway.tests.conftest import find_tpch from google.protobuf import text_format from pyspark.sql.connect.proto import base_pb2 as spark_base_pb2 from substrait.gen.proto import plan_pb2 @@ -41,8 +42,10 @@ def test_plan_conversion(request, path): substrait_plan = text_format.Parse(splan_prototext, plan_pb2.Plan()) options = duck_db() + backend = find_backend(options.backend) options.implement_show_string = False convert = SparkSubstraitConverter(options) + convert.set_backends(backend, backend) substrait = convert.convert_plan(spark_plan) if request.config.getoption('rebuild_goldens'): @@ -80,7 +83,10 @@ def test_sql_conversion(request, path): splan_prototext = file.read() substrait_plan = text_format.Parse(splan_prototext, plan_pb2.Plan()) - substrait = convert_sql(str(sql)) + options = duck_db() + backend = find_backend(options.backend) + backend.register_table('customer', find_tpch() / 'customer') + substrait = backend.convert_sql(str(sql)) if request.config.getoption('rebuild_goldens'): if substrait != substrait_plan: diff --git a/src/gateway/converter/sql_to_substrait.py b/src/gateway/converter/sql_to_substrait.py index 0b12c2e..5894398 100644 --- a/src/gateway/converter/sql_to_substrait.py +++ b/src/gateway/converter/sql_to_substrait.py @@ -1,18 +1,16 @@ # SPDX-License-Identifier: Apache-2.0 """Routines to convert SparkConnect plans to Substrait plans.""" -from gateway.backends import backend_selector -from gateway.backends.backend_options import Backend, BackendOptions +from gateway.backends.backend import Backend +from gateway.converter.add_extension_uris import AddExtensionUris from substrait.gen.proto import plan_pb2 -def convert_sql(sql: str) -> plan_pb2.Plan: +def convert_sql(backend: Backend, sql: str) -> plan_pb2.Plan: """Convert SQL into a Substrait plan.""" - result = plan_pb2.Plan() + plan = backend.convert_sql(sql) - backend = backend_selector.find_backend(BackendOptions(Backend.DUCKDB, False)) - backend.register_tpch() + # Perform various fixes to make the plan more compatible. + # TODO -- Remove this after the SQL converter is fixed. + AddExtensionUris().visit_plan(plan) - connection = backend.get_connection() - proto_bytes = connection.get_substrait(query=sql).fetchone()[0] - result.ParseFromString(proto_bytes) - return result + return plan diff --git a/src/gateway/converter/substrait_builder.py b/src/gateway/converter/substrait_builder.py index 924b6bc..9d3469b 100644 --- a/src/gateway/converter/substrait_builder.py +++ b/src/gateway/converter/substrait_builder.py @@ -76,8 +76,9 @@ def cast_operation(expression: algebra_pb2.Expression, output_type: type_pb2.Type) -> algebra_pb2.Expression: """Construct a Substrait cast expression.""" return algebra_pb2.Expression( - cast=algebra_pb2.Expression.Cast(input=expression, type=output_type) - ) + cast=algebra_pb2.Expression.Cast( + input=expression, type=output_type, + failure_behavior=algebra_pb2.Expression.Cast.FAILURE_BEHAVIOR_THROW_EXCEPTION)) def if_then_else_operation(if_expr: algebra_pb2.Expression, then_expr: algebra_pb2.Expression, @@ -97,7 +98,8 @@ def field_reference(field_number: int) -> algebra_pb2.Expression: selection=algebra_pb2.Expression.FieldReference( direct_reference=algebra_pb2.Expression.ReferenceSegment( struct_field=algebra_pb2.Expression.ReferenceSegment.StructField( - field=field_number)))) + field=field_number)), + root_reference=algebra_pb2.Expression.FieldReference.RootReference())) def max_agg_function(function_info: ExtensionFunction, @@ -107,7 +109,8 @@ def max_agg_function(function_info: ExtensionFunction, return algebra_pb2.AggregateFunction( function_reference=function_info.anchor, output_type=function_info.output_type, - arguments=[algebra_pb2.FunctionArgument(value=field_reference(field_number))]) + arguments=[algebra_pb2.FunctionArgument(value=field_reference(field_number))], + phase=algebra_pb2.AggregationPhase.AGGREGATION_PHASE_INITIAL_TO_RESULT) def string_concat_agg_function(function_info: ExtensionFunction, @@ -118,7 +121,8 @@ def string_concat_agg_function(function_info: ExtensionFunction, function_reference=function_info.anchor, output_type=function_info.output_type, arguments=[algebra_pb2.FunctionArgument(value=field_reference(field_number)), - algebra_pb2.FunctionArgument(value=string_literal(separator))]) + algebra_pb2.FunctionArgument(value=string_literal(separator))], + phase=algebra_pb2.AggregationPhase.AGGREGATION_PHASE_INITIAL_TO_RESULT) def least_function(greater_function_info: ExtensionFunction, expr1: algebra_pb2.Expression, @@ -191,7 +195,7 @@ def minus_function(function_info: ExtensionFunction, def repeat_function(function_info: ExtensionFunction, string: str, - count: algebra_pb2.Expression) -> algebra_pb2.AggregateFunction: + count: algebra_pb2.Expression) -> algebra_pb2.Expression: """Construct a Substrait concat expression.""" return algebra_pb2.Expression(scalar_function= algebra_pb2.Expression.ScalarFunction( @@ -203,7 +207,7 @@ def repeat_function(function_info: ExtensionFunction, def lpad_function(function_info: ExtensionFunction, expression: algebra_pb2.Expression, count: algebra_pb2.Expression, - pad_string: str = ' ') -> algebra_pb2.AggregateFunction: + pad_string: str = ' ') -> algebra_pb2.Expression: """Construct a Substrait concat expression.""" # TODO -- Avoid a cast if we don't need it. cast_type = string_type() @@ -220,7 +224,7 @@ def lpad_function(function_info: ExtensionFunction, def rpad_function(function_info: ExtensionFunction, expression: algebra_pb2.Expression, count: algebra_pb2.Expression, - pad_string: str = ' ') -> algebra_pb2.AggregateFunction: + pad_string: str = ' ') -> algebra_pb2.Expression: """Construct a Substrait concat expression.""" # TODO -- Avoid a cast if we don't need it. cast_type = string_type() @@ -238,7 +242,7 @@ def rpad_function(function_info: ExtensionFunction, def regexp_strpos_function(function_info: ExtensionFunction, input: algebra_pb2.Expression, pattern: algebra_pb2.Expression, position: algebra_pb2.Expression, - occurrence: algebra_pb2.Expression) -> algebra_pb2.AggregateFunction: + occurrence: algebra_pb2.Expression) -> algebra_pb2.Expression: """Construct a Substrait regex substring expression.""" return algebra_pb2.Expression(scalar_function=algebra_pb2.Expression.ScalarFunction( function_reference=function_info.anchor, @@ -255,6 +259,15 @@ def bool_literal(val: bool) -> algebra_pb2.Expression: return algebra_pb2.Expression(literal=algebra_pb2.Expression.Literal(boolean=val)) +def bool_type(required: bool = True) -> type_pb2.Type: + """Construct a Substrait boolean type.""" + if required: + nullability = type_pb2.Type.Nullability.NULLABILITY_REQUIRED + else: + nullability = type_pb2.Type.Nullability.NULLABILITY_NULLABLE + return type_pb2.Type(bool=type_pb2.Type.Boolean(nullability=nullability)) + + def string_literal(val: str) -> algebra_pb2.Expression: """Construct a Substrait string literal expression.""" return algebra_pb2.Expression(literal=algebra_pb2.Expression.Literal(string=val)) diff --git a/src/gateway/server.py b/src/gateway/server.py index b063ece..b2028f2 100644 --- a/src/gateway/server.py +++ b/src/gateway/server.py @@ -9,12 +9,15 @@ import pyarrow as pa import pyspark.sql.connect.proto.base_pb2 as pb2 import pyspark.sql.connect.proto.base_pb2_grpc as pb2_grpc +from google.protobuf.json_format import MessageToJson from pyspark.sql.connect.proto import types_pb2 +from substrait.gen.proto import plan_pb2 +from gateway.backends.backend import Backend +from gateway.backends.backend_options import BackendEngine, BackendOptions from gateway.backends.backend_selector import find_backend from gateway.converter.conversion_options import arrow, datafusion, duck_db from gateway.converter.spark_to_substrait import SparkSubstraitConverter -from gateway.converter.sql_to_substrait import convert_sql _LOGGER = logging.getLogger(__name__) @@ -53,7 +56,7 @@ def convert_pyarrow_schema_to_spark(schema: pa.Schema) -> types_pb2.DataType: elif field.type == pa.int8(): data_type = types_pb2.DataType(byte=types_pb2.DataType.Byte()) elif field.type == pa.int16(): - data_type = types_pb2.DataType(integer=types_pb2.DataType.Short()) + data_type = types_pb2.DataType(short=types_pb2.DataType.Short()) elif field.type == pa.int32(): data_type = types_pb2.DataType(integer=types_pb2.DataType.Integer()) elif field.type == pa.int64(): @@ -68,6 +71,8 @@ def convert_pyarrow_schema_to_spark(schema: pa.Schema) -> types_pb2.DataType: data_type = types_pb2.DataType(timestamp=types_pb2.DataType.Timestamp()) elif field.type == pa.date32(): data_type = types_pb2.DataType(date=types_pb2.DataType.Date()) + elif field.type == pa.null(): + data_type = types_pb2.DataType(null=types_pb2.DataType.NULL()) else: raise NotImplementedError( 'Conversion from Arrow schema to Spark schema not yet implemented ' @@ -79,6 +84,55 @@ def convert_pyarrow_schema_to_spark(schema: pa.Schema) -> types_pb2.DataType: return types_pb2.DataType(struct=types_pb2.DataType.Struct(fields=fields)) +def create_dataframe_view(rel: pb2.Plan, backend) -> None: + """Register the temporary dataframe.""" + dataframe_view_name = rel.command.create_dataframe_view.name + read_data_source_relation = rel.command.create_dataframe_view.input.read.data_source + fmt = read_data_source_relation.format + path = read_data_source_relation.paths[0] + backend.register_table(dataframe_view_name, path, fmt) + + +class Statistics: + """Statistics about the requests made to the server.""" + + def __init__(self): + """Initialize the statistics.""" + self.config_requests: int = 0 + self.analyze_requests: int = 0 + self.execute_requests: int = 0 + self.add_artifacts_requests: int = 0 + self.artifact_status_requests: int = 0 + self.interrupt_requests: int = 0 + self.reattach_requests: int = 0 + self.release_requests: int = 0 + + self.requests: list[str] = [] + self.plans: list[str] = [] + + def add_request(self, request): + """Remember a request for later introspection.""" + self.requests.append(str(request)) + + def add_plan(self, plan: plan_pb2.Plan): + """Remember a plan for later introspection.""" + self.plans.append(MessageToJson(plan)) + + def reset(self): + """Reset the statistics.""" + self.config_requests = 0 + self.analyze_requests = 0 + self.execute_requests = 0 + self.add_artifacts_requests = 0 + self.artifact_status_requests = 0 + self.interrupt_requests = 0 + self.reattach_requests = 0 + self.release_requests = 0 + + self.requests = [] + self.plans = [] + + # pylint: disable=E1101,fixme class SparkConnectService(pb2_grpc.SparkConnectServiceServicer): """Provides the SparkConnect service.""" @@ -88,29 +142,57 @@ def __init__(self, *args, **kwargs): """Initialize the SparkConnect service.""" # This is the central point for configuring the behavior of the service. self._options = duck_db() + self._backend: Backend | None = None + self._sql_backend: Backend | None = None + self._converter = None + self._statistics = Statistics() + + def _InitializeExecution(self): + """Initialize the execution of the Plan by setting the backend.""" + if not self._backend: + self._backend = find_backend(self._options.backend) + self._sql_backend = find_backend(BackendOptions(BackendEngine.DUCKDB, False)) + self._converter = SparkSubstraitConverter(self._options) + self._converter.set_backends(self._backend, self._sql_backend) def ExecutePlan( self, request: pb2.ExecutePlanRequest, context: grpc.RpcContext) -> Generator[ pb2.ExecutePlanResponse, None, None]: """Execute the given plan and return the results.""" + self._statistics.execute_requests += 1 + self._statistics.add_request(request) _LOGGER.info('ExecutePlan: %s', request) + self._InitializeExecution() match request.plan.WhichOneof('op_type'): case 'root': - convert = SparkSubstraitConverter(self._options) - substrait = convert.convert_plan(request.plan) + substrait = self._converter.convert_plan(request.plan) case 'command': match request.plan.command.WhichOneof('command_type'): case 'sql_command': - substrait = convert_sql(request.plan.command.sql_command.sql) + if "CREATE" in request.plan.command.sql_command.sql: + connection = self._backend.get_connection() + connection.execute(request.plan.command.sql_command.sql) + yield pb2.ExecutePlanResponse( + session_id=request.session_id, + result_complete=pb2.ExecutePlanResponse.ResultComplete()) + return + substrait = self._sql_backend.convert_sql( + request.plan.command.sql_command.sql) + case 'create_dataframe_view': + create_dataframe_view(request.plan, self._backend) + create_dataframe_view(request.plan, self._sql_backend) + yield pb2.ExecutePlanResponse( + session_id=request.session_id, + result_complete=pb2.ExecutePlanResponse.ResultComplete()) + return case _: type = request.plan.command.WhichOneof("command_type") raise NotImplementedError(f'Unsupported command type: {type}') case _: raise ValueError(f'Unknown plan type: {request.plan}') _LOGGER.debug(' as Substrait: %s', substrait) - backend = find_backend(self._options.backend) - backend.register_tpch() - results = backend.execute(substrait) + self._statistics.add_plan(substrait) + results = self._backend.execute(substrait) _LOGGER.debug(' results are: %s', results) if not self._options.implement_show_string and request.plan.WhichOneof( @@ -147,13 +229,14 @@ def ExecutePlan( def AnalyzePlan(self, request, context): """Analyze the given plan and return the results.""" + self._statistics.analyze_requests += 1 + self._statistics.add_request(request) _LOGGER.info('AnalyzePlan: %s', request) + self._InitializeExecution() if request.schema: - convert = SparkSubstraitConverter(self._options) - substrait = convert.convert_plan(request.schema.plan) - backend = find_backend(self._options.backend) - backend.register_tpch() - results = backend.execute(substrait) + substrait = self._converter.convert_plan(request.schema.plan) + self._statistics.add_plan(substrait) + results = self._backend.execute(substrait) _LOGGER.debug(' results are: %s', results) return pb2.AnalyzePlanResponse( session_id=request.session_id, @@ -163,6 +246,7 @@ def AnalyzePlan(self, request, context): def Config(self, request, context): """Get or set the configuration of the server.""" + self._statistics.config_requests += 1 _LOGGER.info('Config: %s', request) response = pb2.ConfigResponse(session_id=request.session_id) match request.operation.WhichOneof('op_type'): @@ -179,23 +263,44 @@ def Config(self, request, context): self._options = datafusion() case _: raise ValueError(f'Unknown backend: {pair.value}') + elif pair.key == 'spark-substrait-gateway.reset_statistics': + self._statistics.reset() response.pairs.extend(request.operation.set.pairs) + case 'get': + for key in request.operation.get.keys: + if key == 'spark-substrait-gateway.backend': + response.pairs.add(key=key, value=str(self._options.backend.backend)) + elif key == 'spark-substrait-gateway.plan_count': + response.pairs.add(key=key, value=str(len(self._statistics.plans))) + elif key.startswith('spark-substrait-gateway.plan.'): + index = int(key[len('spark-substrait-gateway.plan.'):]) + if 0 <= index - 1 < len(self._statistics.plans): + response.pairs.add(key=key, value=self._statistics.plans[index - 1]) + else: + raise NotImplementedError(f'Unknown config item: {key}') case 'get_with_default': - response.pairs.extend(request.operation.get_with_default.pairs) + for pair in request.operation.get_with_default.pairs: + if pair.key == 'spark-substrait-gateway.backend': + response.pairs.add(key=pair.key, value=str(self._options.backend.backend)) + else: + response.pairs.append(pair) return response def AddArtifacts(self, request_iterator, context): """Add the given artifacts to the server.""" + self._statistics.add_artifacts_requests += 1 _LOGGER.info('AddArtifacts') return pb2.AddArtifactsResponse() def ArtifactStatus(self, request, context): """Get the status of the given artifact.""" + self._statistics.artifact_status_requests += 1 _LOGGER.info('ArtifactStatus') return pb2.ArtifactStatusesResponse() def Interrupt(self, request, context): """Interrupt the execution of the given plan.""" + self._statistics.interrupt_requests += 1 _LOGGER.info('Interrupt') return pb2.InterruptResponse() @@ -203,6 +308,7 @@ def ReattachExecute( self, request: pb2.ReattachExecuteRequest, context: grpc.RpcContext) -> Generator[ pb2.ExecutePlanResponse, None, None]: """Reattach the execution of the given plan.""" + self._statistics.reattach_requests += 1 _LOGGER.info('ReattachExecute') yield pb2.ExecutePlanResponse( session_id=request.session_id, @@ -210,6 +316,7 @@ def ReattachExecute( def ReleaseExecute(self, request, context): """Release the execution of the given plan.""" + self._statistics.release_requests += 1 _LOGGER.info('ReleaseExecute') return pb2.ReleaseExecuteResponse() diff --git a/src/gateway/tests/compare_dataframes.py b/src/gateway/tests/compare_dataframes.py new file mode 100644 index 0000000..ef8a33d --- /dev/null +++ b/src/gateway/tests/compare_dataframes.py @@ -0,0 +1,41 @@ +# SPDX-License-Identifier: Apache-2.0 +"""Routines for comparing dataframes.""" +import datetime + +from pyspark import Row +from pyspark.testing import assertDataFrameEqual + + +def have_same_schema(outcome: list[Row], expected: list[Row]): + """Returns True if the two dataframes have the same schema.""" + return all(type(a) is type(b) for a, b in zip(outcome[0], expected[0], strict=False)) + + +def align_schema(source_df: list[Row], schema_df: list[Row]): + """Returns a copy of source_df with the fields changed to match schema_df.""" + schema = schema_df[0] + + if have_same_schema(source_df, schema_df): + return source_df + + new_source_df = [] + for row in source_df: + new_row = {} + for field_name, field_value in schema.asDict().items(): + if (type(row[field_name] is not type(field_value)) and + isinstance(field_value, datetime.date)): + new_row[field_name] = row[field_name].date() + else: + new_row[field_name] = row[field_name] + + new_source_df.append(Row(**new_row)) + + return new_source_df + + +def assert_dataframes_equal(outcome: list[Row], expected: list[Row]): + """Asserts that two dataframes are equal ignoring column names and date formats.""" + # Create a copy of the dataframes to avoid modifying the original ones + modified_outcome = align_schema(outcome, expected) + + assertDataFrameEqual(modified_outcome, expected, atol=1e-2) diff --git a/src/gateway/tests/conftest.py b/src/gateway/tests/conftest.py index fb305c8..a34267f 100644 --- a/src/gateway/tests/conftest.py +++ b/src/gateway/tests/conftest.py @@ -4,14 +4,12 @@ from pathlib import Path import pytest -from gateway.backends.backend import Backend from gateway.demo.mystream_database import ( create_mystream_database, delete_mystream_database, get_mystream_schema, ) from gateway.server import serve -from pyspark.sql.pandas.types import from_arrow_schema from pyspark.sql.session import SparkSession @@ -62,7 +60,7 @@ def _create_gateway_session(backend: str) -> SparkSession: spark_gateway.stop() -@pytest.fixture(scope='session', autouse=True) +@pytest.fixture(scope='function', autouse=True) def manage_database() -> None: """Creates the mystream database for use throughout all the tests.""" create_mystream_database() @@ -70,7 +68,7 @@ def manage_database() -> None: delete_mystream_database() -@pytest.fixture(scope='session', autouse=True) +@pytest.fixture(scope='function', autouse=True) def gateway_server(): """Starts up a spark to substrait gateway service.""" server = serve(50052, wait=False) @@ -78,20 +76,21 @@ def gateway_server(): server.stop(None) -@pytest.fixture(scope='session') -def users_location() -> str: +@pytest.fixture(scope='function') +def users_location(manage_database) -> str: """Provides the location of the users database.""" return str(Path('users.parquet').resolve()) -@pytest.fixture(scope='session') -def schema_users(): +@pytest.fixture(scope='function') +def schema_users(manage_database): """Provides the schema of the users database.""" return get_mystream_schema('users') @pytest.fixture(scope='session', params=['spark', + 'gateway-over-arrow', 'gateway-over-duckdb', 'gateway-over-datafusion', ]) @@ -100,7 +99,7 @@ def source(request) -> str: return request.param -@pytest.fixture(scope='session') +@pytest.fixture(scope='function') def spark_session(source): """Provides spark sessions connecting to various backends.""" match source: @@ -119,38 +118,46 @@ def spark_session(source): # pylint: disable=redefined-outer-name @pytest.fixture(scope='function') -def users_dataframe(spark_session, schema_users, users_location): - """Provides a ready to go dataframe over the users database.""" - return spark_session.read.format('parquet') \ - .schema(from_arrow_schema(schema_users)) \ - .parquet(users_location) +def spark_session_with_users_dataset(spark_session, schema_users, users_location): + """Provides the spark session with the users database already loaded.""" + df = spark_session.read.parquet(users_location) + df.createOrReplaceTempView('users') + return spark_session -def _register_table(spark_session: SparkSession, name: str) -> None: - location = Backend.find_tpch() / name - spark_session.sql( - f'CREATE OR REPLACE TEMPORARY VIEW {name} USING org.apache.spark.sql.parquet ' - f'OPTIONS ( path "{location}" )') +@pytest.fixture(scope='function') +def users_dataframe(spark_session_with_users_dataset): + """Provides a ready to go users dataframe.""" + return spark_session_with_users_dataset.table('users') -@pytest.fixture(scope='function') -def spark_session_with_tpch_dataset(spark_session: SparkSession, source: str) -> SparkSession: - """Add the TPC-H dataset to the current spark session.""" - if source == 'spark': - _register_table(spark_session, 'customer') - _register_table(spark_session, 'lineitem') - _register_table(spark_session, 'nation') - _register_table(spark_session, 'orders') - _register_table(spark_session, 'part') - _register_table(spark_session, 'partsupp') - _register_table(spark_session, 'region') - _register_table(spark_session, 'supplier') - return spark_session +def find_tpch() -> Path: + """Find the location of the TPC-H dataset.""" + current_location = Path('.').resolve() + while current_location != Path('/'): + location = current_location / 'third_party' / 'tpch' / 'parquet' + if location.exists(): + return location.resolve() + current_location = current_location.parent + raise ValueError('TPC-H dataset not found') + + +def _register_table(spark_session: SparkSession, name: str) -> None: + """Registers a TPC-H table with the given name into spark_session.""" + location = find_tpch() / name + df = spark_session.read.parquet(str(location)) + df.createOrReplaceTempView(name) @pytest.fixture(scope='function') -def spark_session_with_customer_dataset(spark_session: SparkSession, source: str) -> SparkSession: +def spark_session_with_tpch_dataset(spark_session: SparkSession) -> SparkSession: """Add the TPC-H dataset to the current spark session.""" - if source == 'spark': - _register_table(spark_session, 'customer') + _register_table(spark_session, 'customer') + _register_table(spark_session, 'lineitem') + _register_table(spark_session, 'nation') + _register_table(spark_session, 'orders') + _register_table(spark_session, 'part') + _register_table(spark_session, 'partsupp') + _register_table(spark_session, 'region') + _register_table(spark_session, 'supplier') return spark_session diff --git a/src/gateway/tests/plan_validator.py b/src/gateway/tests/plan_validator.py new file mode 100644 index 0000000..787f21e --- /dev/null +++ b/src/gateway/tests/plan_validator.py @@ -0,0 +1,54 @@ +# SPDX-License-Identifier: Apache-2.0 +from contextlib import contextmanager + +import google.protobuf.message +import pytest +import substrait_validator +from google.protobuf import json_format +from pyspark.errors.exceptions.connect import SparkConnectGrpcException +from substrait_validator.substrait import plan_pb2 + + +def validate_plan(json_plan: str): + substrait_plan = json_format.Parse(json_plan, plan_pb2.Plan()) + try: + diagnostics = substrait_validator.plan_to_diagnostics(substrait_plan.SerializeToString()) + except google.protobuf.message.DecodeError: + # Probable protobuf mismatch internal to Substrait Validator, ignore for now. + return + issues = [] + for issue in diagnostics: + if issue.adjusted_level >= substrait_validator.Diagnostic.LEVEL_ERROR: + issues.append([issue.msg, substrait_validator.path_to_string(issue.path)]) + if issues: + issues_as_text = '\n'.join(f' → {issue[0]}\n at {issue[1]}' for issue in issues) + pytest.fail(f'Validation failed. Issues:\n{issues_as_text}\n\nPlan:\n{substrait_plan}\n', + pytrace=False) + + +@contextmanager +def utilizes_valid_plans(session): + """Validates that the plans used by the gateway backend pass validation.""" + if hasattr(session, 'sparkSession'): + session = session.sparkSession + # Reset the statistics, so we only see the plans that were created during our lifetime. + if session.conf.get('spark-substrait-gateway.backend', 'spark') != 'spark': + session.conf.set('spark-substrait-gateway.reset_statistics', None) + try: + exception = None + yield + except SparkConnectGrpcException as e: + exception = e + if session.conf.get('spark-substrait-gateway.backend', 'spark') == 'spark': + if exception: + raise exception + return + plan_count = int(session.conf.get('spark-substrait-gateway.plan_count')) + plans_as_text = [] + for i in range(plan_count): + plan = session.conf.get(f'spark-substrait-gateway.plan.{i + 1}') + plans_as_text.append( f'Plan #{i+1}:\n{plan}\n') + validate_plan(plan) + if exception: + pytest.fail(f'Exception raised during execution: {exception.message}\n\n' + + '\n\n'.join(plans_as_text), pytrace=False) diff --git a/src/gateway/tests/test_dataframe_api.py b/src/gateway/tests/test_dataframe_api.py index d414ebc..d8901e1 100644 --- a/src/gateway/tests/test_dataframe_api.py +++ b/src/gateway/tests/test_dataframe_api.py @@ -1,7 +1,8 @@ # SPDX-License-Identifier: Apache-2.0 """Tests for the Spark to Substrait Gateway server.""" import pytest -from gateway.backends.backend import Backend +from gateway.tests.conftest import find_tpch +from gateway.tests.plan_validator import utilizes_valid_plans from hamcrest import assert_that, equal_to from pyspark import Row from pyspark.sql.functions import col, substring @@ -12,18 +13,8 @@ def mark_dataframe_tests_as_xfail(request): """Marks a subset of tests as expected to be fail.""" source = request.getfixturevalue('source') - originalname = request.keywords.node.originalname - if source == 'gateway-over-duckdb' and (originalname == 'test_with_column' or - originalname == 'test_cast'): - request.node.add_marker( - pytest.mark.xfail(reason='DuckDB column binding error')) - elif source == 'gateway-over-datafusion': - if originalname in [ - 'test_data_source_schema', 'test_data_source_filter', 'test_table', 'test_table_schema', - 'test_table_filter']: - request.node.add_marker(pytest.mark.xfail(reason='Gateway internal iterating error')) - else: - pytest.importorskip("datafusion.substrait") + if source == 'gateway-over-datafusion': + pytest.importorskip("datafusion.substrait") # pylint: disable=missing-function-docstring @@ -32,12 +23,16 @@ class TestDataFrameAPI: """Tests of the dataframe side of SparkConnect.""" def test_collect(self, users_dataframe): - outcome = users_dataframe.collect() + with utilizes_valid_plans(users_dataframe): + outcome = users_dataframe.collect() + assert len(outcome) == 100 # pylint: disable=singleton-comparison def test_filter(self, users_dataframe): - outcome = users_dataframe.filter(col('paid_for_service') == True).collect() + with utilizes_valid_plans(users_dataframe): + outcome = users_dataframe.filter(col('paid_for_service') == True).collect() + assert len(outcome) == 29 # pylint: disable=singleton-comparison @@ -50,8 +45,10 @@ def test_filter_with_show(self, users_dataframe, capsys): +-------------+---------------+----------------+ ''' - users_dataframe.filter(col('paid_for_service') == True).limit(2).show() - outcome = capsys.readouterr().out + with utilizes_valid_plans(users_dataframe): + users_dataframe.filter(col('paid_for_service') == True).limit(2).show() + outcome = capsys.readouterr().out + assert_that(outcome, equal_to(expected)) # pylint: disable=singleton-comparison @@ -64,8 +61,10 @@ def test_filter_with_show_with_limit(self, users_dataframe, capsys): only showing top 1 row ''' - users_dataframe.filter(col('paid_for_service') == True).show(1) - outcome = capsys.readouterr().out + with utilizes_valid_plans(users_dataframe): + users_dataframe.filter(col('paid_for_service') == True).show(1) + outcome = capsys.readouterr().out + assert_that(outcome, equal_to(expected)) # pylint: disable=singleton-comparison @@ -82,7 +81,9 @@ def test_filter_with_show_and_truncate(self, users_dataframe, capsys): assert_that(outcome, equal_to(expected)) def test_count(self, users_dataframe): - outcome = users_dataframe.count() + with utilizes_valid_plans(users_dataframe): + outcome = users_dataframe.count() + assert outcome == 100 def test_limit(self, users_dataframe): @@ -90,46 +91,102 @@ def test_limit(self, users_dataframe): Row(user_id='user849118289', name='Brooke Jones', paid_for_service=False), Row(user_id='user954079192', name='Collin Frank', paid_for_service=False), ] - outcome = users_dataframe.limit(2).collect() + + with utilizes_valid_plans(users_dataframe): + outcome = users_dataframe.limit(2).collect() + assertDataFrameEqual(outcome, expected) def test_with_column(self, users_dataframe): expected = [ Row(user_id='user849118289', name='Brooke Jones', paid_for_service=False), ] - outcome = users_dataframe.withColumn( - 'user_id', col('user_id')).limit(1).collect() + + with utilizes_valid_plans(users_dataframe): + outcome = users_dataframe.withColumn('user_id', col('user_id')).limit(1).collect() + assertDataFrameEqual(outcome, expected) def test_cast(self, users_dataframe): expected = [ Row(user_id=849, name='Brooke Jones', paid_for_service=False), ] - outcome = users_dataframe.withColumn( - 'user_id', - substring(col('user_id'), 5, 3).cast('integer')).limit(1).collect() + + with utilizes_valid_plans(users_dataframe): + outcome = users_dataframe.withColumn( + 'user_id', + substring(col('user_id'), 5, 3).cast('integer')).limit(1).collect() + + assertDataFrameEqual(outcome, expected) + + def test_join(self, spark_session_with_tpch_dataset): + expected = [ + Row(n_nationkey=5, n_name='ETHIOPIA', n_regionkey=0, + n_comment='ven packages wake quickly. regu', s_suppkey=2, + s_name='Supplier#000000002', s_address='89eJ5ksX3ImxJQBvxObC,', s_nationkey=5, + s_phone='15-679-861-2259', s_acctbal=4032.68, + s_comment=' slyly bold instructions. idle dependen'), + ] + + with utilizes_valid_plans(spark_session_with_tpch_dataset): + nation = spark_session_with_tpch_dataset.table('nation') + supplier = spark_session_with_tpch_dataset.table('supplier') + + nat = nation.join(supplier, col('n_nationkey') == col('s_nationkey')) + outcome = nat.filter(col('s_suppkey') == 2).limit(1).collect() + assertDataFrameEqual(outcome, expected) def test_data_source_schema(self, spark_session): - location_customer = str(Backend.find_tpch() / 'customer') + location_customer = str(find_tpch() / 'customer') schema = spark_session.read.parquet(location_customer).schema assert len(schema) == 8 def test_data_source_filter(self, spark_session): - location_customer = str(Backend.find_tpch() / 'customer') + location_customer = str(find_tpch() / 'customer') customer_dataframe = spark_session.read.parquet(location_customer) - outcome = customer_dataframe.filter(col('c_mktsegment') == 'FURNITURE').collect() + + with utilizes_valid_plans(spark_session): + outcome = customer_dataframe.filter(col('c_mktsegment') == 'FURNITURE').collect() + assert len(outcome) == 29968 - def test_table(self, spark_session_with_customer_dataset): - outcome = spark_session_with_customer_dataset.table('customer').collect() + def test_table(self, spark_session_with_tpch_dataset): + with utilizes_valid_plans(spark_session_with_tpch_dataset): + outcome = spark_session_with_tpch_dataset.table('customer').collect() + assert len(outcome) == 149999 - def test_table_schema(self, spark_session_with_customer_dataset): - schema = spark_session_with_customer_dataset.table('customer').schema + def test_table_schema(self, spark_session_with_tpch_dataset): + schema = spark_session_with_tpch_dataset.table('customer').schema assert len(schema) == 8 - def test_table_filter(self, spark_session_with_customer_dataset): - customer_dataframe = spark_session_with_customer_dataset.table('customer') - outcome = customer_dataframe.filter(col('c_mktsegment') == 'FURNITURE').collect() + def test_table_filter(self, spark_session_with_tpch_dataset): + customer_dataframe = spark_session_with_tpch_dataset.table('customer') + + with utilizes_valid_plans(spark_session_with_tpch_dataset): + outcome = customer_dataframe.filter(col('c_mktsegment') == 'FURNITURE').collect() + assert len(outcome) == 29968 + + def test_create_or_replace_temp_view(self, spark_session): + location_customer = str(find_tpch() / 'customer') + df_customer = spark_session.read.parquet(location_customer) + df_customer.createOrReplaceTempView("mytempview") + + with utilizes_valid_plans(spark_session): + outcome = spark_session.table('mytempview').collect() + + assert len(outcome) == 149999 + + def test_create_or_replace_multiple_temp_views(self, spark_session): + location_customer = str(find_tpch() / 'customer') + df_customer = spark_session.read.parquet(location_customer) + df_customer.createOrReplaceTempView("mytempview1") + df_customer.createOrReplaceTempView("mytempview2") + + with utilizes_valid_plans(spark_session): + outcome1 = spark_session.table('mytempview1').collect() + outcome2 = spark_session.table('mytempview2').collect() + + assert len(outcome1) == len(outcome2) == 149999 diff --git a/src/gateway/tests/test_sql_api.py b/src/gateway/tests/test_sql_api.py index 16f4510..4199770 100644 --- a/src/gateway/tests/test_sql_api.py +++ b/src/gateway/tests/test_sql_api.py @@ -3,6 +3,7 @@ from pathlib import Path import pytest +from gateway.tests.plan_validator import utilizes_valid_plans from hamcrest import assert_that, equal_to from pyspark import Row from pyspark.testing import assertDataFrameEqual @@ -19,16 +20,23 @@ def mark_tests_as_xfail(request): """Marks a subset of tests as expected to be fail.""" source = request.getfixturevalue('source') originalname = request.keywords.node.originalname - if source == 'gateway-over-duckdb' and originalname == 'test_tpch': - path = request.getfixturevalue('path') - if path.stem in ['02', '04', '16', '17', '18', '20', '21', '22']: - request.node.add_marker(pytest.mark.xfail(reason='DuckDB needs Delim join')) - if path.stem in ['15']: - request.node.add_marker(pytest.mark.xfail(reason='Rounding inconsistency')) + if source == 'gateway-over-duckdb': + if originalname == 'test_tpch': + path = request.getfixturevalue('path') + if path.stem in ['02', '04', '16', '17', '18', '20', '21', '22']: + request.node.add_marker(pytest.mark.xfail(reason='DuckDB needs Delim join')) + if path.stem in ['15']: + request.node.add_marker(pytest.mark.xfail(reason='Rounding inconsistency')) + else: + request.node.add_marker(pytest.mark.xfail(reason='Too few names returned')) + else: + request.node.add_marker(pytest.mark.xfail(reason='Too few names returned')) if source == 'gateway-over-datafusion': pytest.importorskip("datafusion.substrait") if originalname == 'test_count': request.node.add_marker(pytest.mark.xfail(reason='COUNT() not implemented')) + if originalname == 'test_limit': + request.node.add_marker(pytest.mark.xfail(reason='Too few names returned')) if originalname in ['test_tpch']: path = request.getfixturevalue('path') if path.stem in ['01']: @@ -62,8 +70,10 @@ class TestSqlAPI: """Tests of the SQL side of SparkConnect.""" def test_count(self, spark_session_with_tpch_dataset): - outcome = spark_session_with_tpch_dataset.sql( - 'SELECT COUNT(*) FROM customer').collect() + with utilizes_valid_plans(spark_session_with_tpch_dataset): + outcome = spark_session_with_tpch_dataset.sql( + 'SELECT COUNT(*) FROM customer').collect() + assert_that(outcome[0][0], equal_to(149999)) def test_limit(self, spark_session_with_tpch_dataset): @@ -74,8 +84,11 @@ def test_limit(self, spark_session_with_tpch_dataset): Row(c_custkey=5, c_phone='13-750-942-6364', c_mktsegment='HOUSEHOLD'), Row(c_custkey=6, c_phone='30-114-968-4951', c_mktsegment='AUTOMOBILE'), ] - outcome = spark_session_with_tpch_dataset.sql( - 'SELECT c_custkey, c_phone, c_mktsegment FROM customer LIMIT 5').collect() + + with utilizes_valid_plans(spark_session_with_tpch_dataset): + outcome = spark_session_with_tpch_dataset.sql( + 'SELECT c_custkey, c_phone, c_mktsegment FROM customer LIMIT 5').collect() + assertDataFrameEqual(outcome, expected) @pytest.mark.timeout(60) @@ -90,4 +103,6 @@ def test_tpch(self, spark_session_with_tpch_dataset, path): with open(path, "rb") as file: sql_bytes = file.read() sql = sql_bytes.decode('utf-8') - spark_session_with_tpch_dataset.sql(sql).collect() + + with utilizes_valid_plans(spark_session_with_tpch_dataset): + spark_session_with_tpch_dataset.sql(sql).collect() diff --git a/src/gateway/tests/test_tpch_with_dataframe_api.py b/src/gateway/tests/test_tpch_with_dataframe_api.py index 3da9fda..3ee0fcb 100644 --- a/src/gateway/tests/test_tpch_with_dataframe_api.py +++ b/src/gateway/tests/test_tpch_with_dataframe_api.py @@ -4,9 +4,10 @@ import pyspark import pytest +from gateway.tests.compare_dataframes import assert_dataframes_equal +from gateway.tests.plan_validator import utilizes_valid_plans from pyspark import Row from pyspark.sql.functions import avg, col, count, countDistinct, desc, try_sum, when -from pyspark.testing import assertDataFrameEqual @pytest.fixture(autouse=True) @@ -14,13 +15,14 @@ def mark_tests_as_xfail(request): """Marks a subset of tests as expected to be fail.""" source = request.getfixturevalue('source') originalname = request.keywords.node.originalname - if source == 'gateway-over-duckdb' and originalname in [ - 'test_query_02', 'test_query_03', 'test_query_04', 'test_query_05', 'test_query_07', - 'test_query_08', 'test_query_09', 'test_query_10', 'test_query_11', 'test_query_12', - 'test_query_13', 'test_query_14', 'test_query_15', 'test_query_16', 'test_query_17', - 'test_query_18', 'test_query_19', 'test_query_20', 'test_query_21', 'test_query_22']: - request.node.add_marker(pytest.mark.xfail(reason='DuckDB binder error')) - if source == 'gateway-over-datafusion': + if source == 'gateway-over-duckdb': + if originalname in ['test_query_15']: + request.node.add_marker(pytest.mark.xfail(reason='No results (float vs decimal)')) + elif originalname in ['test_query_08']: + request.node.add_marker(pytest.mark.xfail(reason='DuckDB binder error')) + elif originalname == 'test_query_16': + request.node.add_marker(pytest.mark.xfail(reason='results differ')) + elif source == 'gateway-over-datafusion': pytest.importorskip("datafusion.substrait") request.node.add_marker(pytest.mark.xfail(reason='gateway internal error')) @@ -37,21 +39,24 @@ def test_query_01(self, spark_session_with_tpch_dataset): avg_price=38273.13, avg_disc=0.05, count_order=1478493), ] - lineitem = spark_session_with_tpch_dataset.table('lineitem') - outcome = lineitem.filter(col('l_shipdate') <= '1998-09-02').groupBy('l_returnflag', - 'l_linestatus').agg( - try_sum('l_quantity').alias('sum_qty'), - try_sum('l_extendedprice').alias('sum_base_price'), - try_sum(col('l_extendedprice') * (1 - col('l_discount'))).alias('sum_disc_price'), - try_sum(col('l_extendedprice') * (1 - col('l_discount')) * (1 + col('l_tax'))).alias( - 'sum_charge'), - avg('l_quantity').alias('avg_qty'), - avg('l_extendedprice').alias('avg_price'), - avg('l_discount').alias('avg_disc'), - count('*').alias('count_order')) - - sorted_outcome = outcome.sort('l_returnflag', 'l_linestatus').limit(1).collect() - assertDataFrameEqual(sorted_outcome, expected, atol=1e-2) + with utilizes_valid_plans(spark_session_with_tpch_dataset): + lineitem = spark_session_with_tpch_dataset.table('lineitem') + outcome = lineitem.filter(col('l_shipdate') <= '1998-09-02').groupBy( + 'l_returnflag', 'l_linestatus').agg( + try_sum('l_quantity').alias('sum_qty'), + try_sum('l_extendedprice').alias('sum_base_price'), + try_sum(col('l_extendedprice') * (1 - col('l_discount'))).alias('sum_disc_price'), + try_sum( + col('l_extendedprice') * (1 - col('l_discount')) * (1 + col('l_tax'))).alias( + 'sum_charge'), + avg('l_quantity').alias('avg_qty'), + avg('l_extendedprice').alias('avg_price'), + avg('l_discount').alias('avg_disc'), + count('*').alias('count_order')) + + sorted_outcome = outcome.sort('l_returnflag', 'l_linestatus').limit(1).collect() + + assert_dataframes_equal(sorted_outcome, expected) def test_query_02(self, spark_session_with_tpch_dataset): expected = [ @@ -65,31 +70,34 @@ def test_query_02(self, spark_session_with_tpch_dataset): s_comment='efully express instructions. regular requests against the slyly fin'), ] - part = spark_session_with_tpch_dataset.table('part') - supplier = spark_session_with_tpch_dataset.table('supplier') - partsupp = spark_session_with_tpch_dataset.table('partsupp') - nation = spark_session_with_tpch_dataset.table('nation') - region = spark_session_with_tpch_dataset.table('region') + with utilizes_valid_plans(spark_session_with_tpch_dataset): + part = spark_session_with_tpch_dataset.table('part') + supplier = spark_session_with_tpch_dataset.table('supplier') + partsupp = spark_session_with_tpch_dataset.table('partsupp') + nation = spark_session_with_tpch_dataset.table('nation') + region = spark_session_with_tpch_dataset.table('region') - europe = region.filter(col('r_name') == 'EUROPE').join( - nation, col('r_regionkey') == col('n_regionkey')).join( - supplier, col('n_nationkey') == col('s_nationkey')).join( - partsupp, col('s_suppkey') == col('ps_suppkey')) + europe = region.filter(col('r_name') == 'EUROPE').join( + nation, col('r_regionkey') == col('n_regionkey')).join( + supplier, col('n_nationkey') == col('s_nationkey')).join( + partsupp, col('s_suppkey') == col('ps_suppkey')) - brass = part.filter((col('p_size') == 15) & (col('p_type').endswith('BRASS'))).join( - europe, col('ps_partkey') == col('p_partkey')) + brass = part.filter((col('p_size') == 15) & (col('p_type').endswith('BRASS'))).join( + europe, col('ps_partkey') == col('p_partkey')) - minCost = brass.groupBy(col('ps_partkey')).agg( - pyspark.sql.functions.min('ps_supplycost').alias('min')) + minCost = brass.groupBy(col('ps_partkey')).agg( + pyspark.sql.functions.min('ps_supplycost').alias('min')) - outcome = brass.join(minCost, brass.ps_partkey == minCost.ps_partkey).filter( - col('ps_supplycost') == col('min')).select('s_acctbal', 's_name', 'n_name', 'p_partkey', - 'p_mfgr', 's_address', 's_phone', - 's_comment') + outcome = brass.join(minCost, brass.ps_partkey == minCost.ps_partkey).filter( + col('ps_supplycost') == col('min')).select('s_acctbal', 's_name', 'n_name', + 'p_partkey', + 'p_mfgr', 's_address', 's_phone', + 's_comment') - sorted_outcome = outcome.sort( - desc('s_acctbal'), 'n_name', 's_name', 'p_partkey').limit(2).collect() - assertDataFrameEqual(sorted_outcome, expected, atol=1e-2) + sorted_outcome = outcome.sort( + desc('s_acctbal'), 'n_name', 's_name', 'p_partkey').limit(2).collect() + + assert_dataframes_equal(sorted_outcome, expected) def test_query_03(self, spark_session_with_tpch_dataset): expected = [ @@ -105,26 +113,28 @@ def test_query_03(self, spark_session_with_tpch_dataset): o_shippriority=0), ] - customer = spark_session_with_tpch_dataset.table('customer') - lineitem = spark_session_with_tpch_dataset.table('lineitem') - orders = spark_session_with_tpch_dataset.table('orders') + with utilizes_valid_plans(spark_session_with_tpch_dataset): + customer = spark_session_with_tpch_dataset.table('customer') + lineitem = spark_session_with_tpch_dataset.table('lineitem') + orders = spark_session_with_tpch_dataset.table('orders') + + fcust = customer.filter(col('c_mktsegment') == 'BUILDING') + forders = orders.filter(col('o_orderdate') < '1995-03-15') + flineitems = lineitem.filter(lineitem.l_shipdate > '1995-03-15') - fcust = customer.filter(col('c_mktsegment') == 'BUILDING') - forders = orders.filter(col('o_orderdate') < '1995-03-15') - flineitems = lineitem.filter(lineitem.l_shipdate > '1995-03-15') + outcome = fcust.join(forders, col('c_custkey') == forders.o_custkey).select( + 'o_orderkey', 'o_orderdate', 'o_shippriority').join( + flineitems, col('o_orderkey') == flineitems.l_orderkey).select( + 'l_orderkey', + (col('l_extendedprice') * (1 - col('l_discount'))).alias('volume'), + 'o_orderdate', + 'o_shippriority').groupBy('l_orderkey', 'o_orderdate', 'o_shippriority').agg( + try_sum('volume').alias('revenue')).select( + 'l_orderkey', 'revenue', 'o_orderdate', 'o_shippriority') - outcome = fcust.join(forders, col('c_custkey') == forders.o_custkey).select( - 'o_orderkey', 'o_orderdate', 'o_shippriority').join( - flineitems, col('o_orderkey') == flineitems.l_orderkey).select( - 'l_orderkey', - (col('l_extendedprice') * (1 - col('l_discount'))).alias('volume'), - 'o_orderdate', - 'o_shippriority').groupBy('l_orderkey', 'o_orderdate', 'o_shippriority').agg( - try_sum('volume').alias('revenue')).select( - 'l_orderkey', 'revenue', 'o_orderdate', 'o_shippriority') + sorted_outcome = outcome.sort(desc('revenue'), 'o_orderdate').limit(5).collect() - sorted_outcome = outcome.sort(desc('revenue'), 'o_orderdate').limit(5).collect() - assertDataFrameEqual(sorted_outcome, expected, atol=1e-2) + assert_dataframes_equal(sorted_outcome, expected) def test_query_04(self, spark_session_with_tpch_dataset): expected = [ @@ -135,21 +145,23 @@ def test_query_04(self, spark_session_with_tpch_dataset): Row(o_orderpriority='5-LOW', order_count=10487), ] - orders = spark_session_with_tpch_dataset.table('orders') - lineitem = spark_session_with_tpch_dataset.table('lineitem') + with utilizes_valid_plans(spark_session_with_tpch_dataset): + orders = spark_session_with_tpch_dataset.table('orders') + lineitem = spark_session_with_tpch_dataset.table('lineitem') + + forders = orders.filter( + (col('o_orderdate') >= '1993-07-01') & (col('o_orderdate') < '1993-10-01')) + flineitems = lineitem.filter(col('l_commitdate') < col('l_receiptdate')).select( + 'l_orderkey').distinct() - forders = orders.filter( - (col('o_orderdate') >= '1993-07-01') & (col('o_orderdate') < '1993-10-01')) - flineitems = lineitem.filter(col('l_commitdate') < col('l_receiptdate')).select( - 'l_orderkey').distinct() + outcome = flineitems.join( + forders, + col('l_orderkey') == col('o_orderkey')).groupBy('o_orderpriority').agg( + count('o_orderpriority').alias('order_count')) - outcome = flineitems.join( - forders, - col('l_orderkey') == col('o_orderkey')).groupBy('o_orderpriority').agg( - count('o_orderpriority').alias('order_count')) + sorted_outcome = outcome.sort('o_orderpriority').collect() - sorted_outcome = outcome.sort('o_orderpriority').collect() - assertDataFrameEqual(sorted_outcome, expected, atol=1e-2) + assert_dataframes_equal(sorted_outcome, expected) def test_query_05(self, spark_session_with_tpch_dataset): expected = [ @@ -160,47 +172,49 @@ def test_query_05(self, spark_session_with_tpch_dataset): Row(n_name='JAPAN', revenue=45410175.70), ] - customer = spark_session_with_tpch_dataset.table('customer') - orders = spark_session_with_tpch_dataset.table('orders') - lineitem = spark_session_with_tpch_dataset.table('lineitem') - nation = spark_session_with_tpch_dataset.table('nation') - region = spark_session_with_tpch_dataset.table('region') - supplier = spark_session_with_tpch_dataset.table('supplier') - - forders = orders.filter(col('o_orderdate') >= '1994-01-01').filter( - col('o_orderdate') < '1995-01-01') - - outcome = region.filter(col('r_name') == 'ASIA').join( # r_name = 'ASIA' - nation, col('r_regionkey') == col('n_regionkey')).join( - supplier, col('n_nationkey') == col('s_nationkey')).join( - lineitem, col('s_suppkey') == col('l_suppkey')).select( - 'n_name', 'l_extendedprice', 'l_discount', 'l_quantity', 'l_orderkey', - 's_nationkey').join(forders, col('l_orderkey') == forders.o_orderkey).join( - customer, (col('o_custkey') == col('c_custkey')) & ( - col('s_nationkey') == col('c_nationkey'))).select( - 'n_name', - (col('l_extendedprice') * (1 - col('l_discount'))).alias('volume')).groupBy( - 'n_name').agg(try_sum('volume').alias('revenue')) - - sorted_outcome = outcome.sort('revenue').collect() - - assertDataFrameEqual(sorted_outcome, expected, atol=1e-2) + with utilizes_valid_plans(spark_session_with_tpch_dataset): + customer = spark_session_with_tpch_dataset.table('customer') + orders = spark_session_with_tpch_dataset.table('orders') + lineitem = spark_session_with_tpch_dataset.table('lineitem') + nation = spark_session_with_tpch_dataset.table('nation') + region = spark_session_with_tpch_dataset.table('region') + supplier = spark_session_with_tpch_dataset.table('supplier') + + forders = orders.filter(col('o_orderdate') >= '1994-01-01').filter( + col('o_orderdate') < '1995-01-01') + + outcome = region.filter(col('r_name') == 'ASIA').join( # r_name = 'ASIA' + nation, col('r_regionkey') == col('n_regionkey')).join( + supplier, col('n_nationkey') == col('s_nationkey')).join( + lineitem, col('s_suppkey') == col('l_suppkey')).select( + 'n_name', 'l_extendedprice', 'l_discount', 'l_quantity', 'l_orderkey', + 's_nationkey').join(forders, col('l_orderkey') == forders.o_orderkey).join( + customer, (col('o_custkey') == col('c_custkey')) & ( + col('s_nationkey') == col('c_nationkey'))).select( + 'n_name', + (col('l_extendedprice') * (1 - col('l_discount'))).alias('volume')).groupBy( + 'n_name').agg(try_sum('volume').alias('revenue')) + + sorted_outcome = outcome.sort('revenue').collect() + + assert_dataframes_equal(sorted_outcome, expected) def test_query_06(self, spark_session_with_tpch_dataset): expected = [ Row(revenue=123141078.23), ] - lineitem = spark_session_with_tpch_dataset.table('lineitem') + with utilizes_valid_plans(spark_session_with_tpch_dataset): + lineitem = spark_session_with_tpch_dataset.table('lineitem') - outcome = lineitem.filter((col('l_shipdate') >= '1994-01-01') & - (col('l_shipdate') < '1995-01-01') & - (col('l_discount') >= 0.05) & - (col('l_discount') <= 0.07) & - (col('l_quantity') < 24)).agg( - try_sum(col('l_extendedprice') * col('l_discount'))).alias('revenue') + outcome = lineitem.filter((col('l_shipdate') >= '1994-01-01') & + (col('l_shipdate') < '1995-01-01') & + (col('l_discount') >= 0.05) & + (col('l_discount') <= 0.07) & + (col('l_quantity') < 24)).agg( + try_sum(col('l_extendedprice') * col('l_discount'))).alias('revenue').collect() - assertDataFrameEqual(outcome, expected, atol=1e-2) + assert_dataframes_equal(outcome, expected) def test_query_07(self, spark_session_with_tpch_dataset): expected = [ @@ -210,34 +224,36 @@ def test_query_07(self, spark_session_with_tpch_dataset): Row(supp_nation='GERMANY', cust_nation='FRANCE', l_year='1996', revenue=52520549.02), ] - customer = spark_session_with_tpch_dataset.table('customer') - orders = spark_session_with_tpch_dataset.table('orders') - lineitem = spark_session_with_tpch_dataset.table('lineitem') - supplier = spark_session_with_tpch_dataset.table('supplier') - nation = spark_session_with_tpch_dataset.table('nation') - - fnation = nation.filter((nation.n_name == 'FRANCE') | (nation.n_name == 'GERMANY')) - fline = lineitem.filter( - (col('l_shipdate') >= '1995-01-01') & (col('l_shipdate') <= '1996-12-31')) - - suppNation = fnation.join(supplier, col('n_nationkey') == col('s_nationkey')).join( - fline, col('s_suppkey') == col('l_suppkey')).select( - col('n_name').alias('supp_nation'), 'l_orderkey', 'l_extendedprice', 'l_discount', - 'l_shipdate') - - outcome = fnation.join(customer, col('n_nationkey') == col('c_nationkey')).join( - orders, col('c_custkey') == col('o_custkey')).select( - col('n_name').alias('cust_nation'), 'o_orderkey').join( - suppNation, col('o_orderkey') == suppNation.l_orderkey).filter( - (col('supp_nation') == 'FRANCE') & (col('cust_nation') == 'GERMANY') | ( - col('supp_nation') == 'GERMANY') & (col('cust_nation') == 'FRANCE')).select( - 'supp_nation', 'cust_nation', col('l_shipdate').substr(0, 4).alias('l_year'), - (col('l_extendedprice') * (1 - col('l_discount'))).alias('volume')).groupBy( - 'supp_nation', 'cust_nation', 'l_year').agg( - try_sum('volume').alias('revenue')) - - sorted_outcome = outcome.sort('supp_nation', 'cust_nation', 'l_year').collect() - assertDataFrameEqual(sorted_outcome, expected, atol=1e-2) + with utilizes_valid_plans(spark_session_with_tpch_dataset): + customer = spark_session_with_tpch_dataset.table('customer') + orders = spark_session_with_tpch_dataset.table('orders') + lineitem = spark_session_with_tpch_dataset.table('lineitem') + supplier = spark_session_with_tpch_dataset.table('supplier') + nation = spark_session_with_tpch_dataset.table('nation') + + fnation = nation.filter((nation.n_name == 'FRANCE') | (nation.n_name == 'GERMANY')) + fline = lineitem.filter( + (col('l_shipdate') >= '1995-01-01') & (col('l_shipdate') <= '1996-12-31')) + + suppNation = fnation.join(supplier, col('n_nationkey') == col('s_nationkey')).join( + fline, col('s_suppkey') == col('l_suppkey')).select( + col('n_name').alias('supp_nation'), 'l_orderkey', 'l_extendedprice', 'l_discount', + 'l_shipdate') + + outcome = fnation.join(customer, col('n_nationkey') == col('c_nationkey')).join( + orders, col('c_custkey') == col('o_custkey')).select( + col('n_name').alias('cust_nation'), 'o_orderkey').join( + suppNation, col('o_orderkey') == suppNation.l_orderkey).filter( + (col('supp_nation') == 'FRANCE') & (col('cust_nation') == 'GERMANY') | ( + col('supp_nation') == 'GERMANY') & (col('cust_nation') == 'FRANCE')).select( + 'supp_nation', 'cust_nation', col('l_shipdate').substr(1, 4).alias('l_year'), + (col('l_extendedprice') * (1 - col('l_discount'))).alias('volume')).groupBy( + 'supp_nation', 'cust_nation', 'l_year').agg( + try_sum('volume').alias('revenue')) + + sorted_outcome = outcome.sort('supp_nation', 'cust_nation', 'l_year').collect() + + assert_dataframes_equal(sorted_outcome, expected) def test_query_08(self, spark_session_with_tpch_dataset): expected = [ @@ -245,41 +261,44 @@ def test_query_08(self, spark_session_with_tpch_dataset): Row(o_year='1996', mkt_share=0.04), ] - customer = spark_session_with_tpch_dataset.table('customer') - orders = spark_session_with_tpch_dataset.table('orders') - lineitem = spark_session_with_tpch_dataset.table('lineitem') - nation = spark_session_with_tpch_dataset.table('nation') - part = spark_session_with_tpch_dataset.table('part') - region = spark_session_with_tpch_dataset.table('region') - supplier = spark_session_with_tpch_dataset.table('supplier') - - fregion = region.filter(col('r_name') == 'AMERICA') - forder = orders.filter((col('o_orderdate') >= '1995-01-01') & ( - col('o_orderdate') <= '1996-12-31')) - fpart = part.filter(col('p_type') == 'ECONOMY ANODIZED STEEL') - - nat = nation.join(supplier, col('n_nationkey') == col('s_nationkey')) - - line = lineitem.select( - 'l_partkey', 'l_suppkey', 'l_orderkey', - (col('l_extendedprice') * (1 - col('l_discount'))).alias( - 'volume')).join( - fpart, col('l_partkey') == fpart.p_partkey).join( - nat, col('l_suppkey') == nat.s_suppkey) - - outcome = nation.join(fregion, col('n_regionkey') == fregion.r_regionkey).select( - 'n_nationkey', 'n_name').join(customer, - col('n_nationkey') == col('c_nationkey')).select( - 'c_custkey').join(forder, col('c_custkey') == col('o_custkey')).select( - 'o_orderkey', 'o_orderdate').join(line, col('o_orderkey') == line.l_orderkey).select( - col('n_name'), col('o_orderdate').substr(0, 4).alias('o_year'), - col('volume')).withColumn('case_volume', - when(col('n_name') == 'BRAZIL', col('volume')).otherwise( - 0)).groupBy('o_year').agg( - (try_sum('case_volume') / try_sum('volume')).alias('mkt_share')) - - sorted_outcome = outcome.sort('o_year').collect() - assertDataFrameEqual(sorted_outcome, expected, atol=1e-2) + with utilizes_valid_plans(spark_session_with_tpch_dataset): + customer = spark_session_with_tpch_dataset.table('customer') + orders = spark_session_with_tpch_dataset.table('orders') + lineitem = spark_session_with_tpch_dataset.table('lineitem') + nation = spark_session_with_tpch_dataset.table('nation') + part = spark_session_with_tpch_dataset.table('part') + region = spark_session_with_tpch_dataset.table('region') + supplier = spark_session_with_tpch_dataset.table('supplier') + + fregion = region.filter(col('r_name') == 'AMERICA') + forder = orders.filter((col('o_orderdate') >= '1995-01-01') & ( + col('o_orderdate') <= '1996-12-31')) + fpart = part.filter(col('p_type') == 'ECONOMY ANODIZED STEEL') + + nat = nation.join(supplier, col('n_nationkey') == col('s_nationkey')) + + line = lineitem.select( + 'l_partkey', 'l_suppkey', 'l_orderkey', + (col('l_extendedprice') * (1 - col('l_discount'))).alias( + 'volume')).join( + fpart, col('l_partkey') == fpart.p_partkey).join( + nat, col('l_suppkey') == nat.s_suppkey) + + outcome = nation.join(fregion, col('n_regionkey') == fregion.r_regionkey).select( + 'n_nationkey', 'n_name').join(customer, + col('n_nationkey') == col('c_nationkey')).select( + 'c_custkey').join(forder, col('c_custkey') == col('o_custkey')).select( + 'o_orderkey', 'o_orderdate').join(line, + col('o_orderkey') == line.l_orderkey).select( + col('n_name'), col('o_orderdate').substr(1, 4).alias('o_year'), + col('volume')).withColumn('case_volume', + when(col('n_name') == 'BRAZIL', col('volume')).otherwise( + 0)).groupBy('o_year').agg( + (try_sum('case_volume') / try_sum('volume')).alias('mkt_share')) + + sorted_outcome = outcome.sort('o_year').collect() + + assert_dataframes_equal(sorted_outcome, expected) def test_query_09(self, spark_session_with_tpch_dataset): # TODO -- Verify the corretness of these results against another version of the dataset. @@ -291,28 +310,30 @@ def test_query_09(self, spark_session_with_tpch_dataset): Row(n_name='ARGENTINA', o_year='1994', sum_profit=48268856.35), ] - orders = spark_session_with_tpch_dataset.table('orders') - lineitem = spark_session_with_tpch_dataset.table('lineitem') - nation = spark_session_with_tpch_dataset.table('nation') - part = spark_session_with_tpch_dataset.table('part') - partsupp = spark_session_with_tpch_dataset.table('partsupp') - supplier = spark_session_with_tpch_dataset.table('supplier') - - linePart = part.filter(col('p_name').contains('green')).join( - lineitem, col('p_partkey') == lineitem.l_partkey) - natSup = nation.join(supplier, col('n_nationkey') == supplier.s_nationkey) - - outcome = linePart.join(natSup, col('l_suppkey') == natSup.s_suppkey).join( - partsupp, (col('l_suppkey') == partsupp.ps_suppkey) & ( - col('l_partkey') == partsupp.ps_partkey)).join( - orders, col('l_orderkey') == orders.o_orderkey).select( - 'n_name', col('o_orderdate').substr(0, 4).alias('o_year'), - (col('l_extendedprice') * (1 - col('l_discount')) - ( - col('ps_supplycost') * col('l_quantity'))).alias('amount')).groupBy( - 'n_name', 'o_year').agg(try_sum('amount').alias('sum_profit')) - - sorted_outcome = outcome.sort('n_name', desc('o_year')).limit(5).collect() - assertDataFrameEqual(sorted_outcome, expected, atol=1e-2) + with utilizes_valid_plans(spark_session_with_tpch_dataset): + orders = spark_session_with_tpch_dataset.table('orders') + lineitem = spark_session_with_tpch_dataset.table('lineitem') + nation = spark_session_with_tpch_dataset.table('nation') + part = spark_session_with_tpch_dataset.table('part') + partsupp = spark_session_with_tpch_dataset.table('partsupp') + supplier = spark_session_with_tpch_dataset.table('supplier') + + linePart = part.filter(col('p_name').contains('green')).join( + lineitem, col('p_partkey') == lineitem.l_partkey) + natSup = nation.join(supplier, col('n_nationkey') == supplier.s_nationkey) + + outcome = linePart.join(natSup, col('l_suppkey') == natSup.s_suppkey).join( + partsupp, (col('l_suppkey') == partsupp.ps_suppkey) & ( + col('l_partkey') == partsupp.ps_partkey)).join( + orders, col('l_orderkey') == orders.o_orderkey).select( + 'n_name', col('o_orderdate').substr(1, 4).alias('o_year'), + (col('l_extendedprice') * (1 - col('l_discount')) - ( + col('ps_supplycost') * col('l_quantity'))).alias('amount')).groupBy( + 'n_name', 'o_year').agg(try_sum('amount').alias('sum_profit')) + + sorted_outcome = outcome.sort('n_name', desc('o_year')).limit(5).collect() + + assert_dataframes_equal(sorted_outcome, expected) def test_query_10(self, spark_session_with_tpch_dataset): expected = [ @@ -327,55 +348,60 @@ def test_query_10(self, spark_session_with_tpch_dataset): 'pinto beans. ironic, idle re'), ] - customer = spark_session_with_tpch_dataset.table('customer') - lineitem = spark_session_with_tpch_dataset.table('lineitem') - nation = spark_session_with_tpch_dataset.table('nation') - orders = spark_session_with_tpch_dataset.table('orders') - - flineitem = lineitem.filter(col('l_returnflag') == 'R') - - outcome = orders.filter( - (col('o_orderdate') >= '1993-10-01') & (col('o_orderdate') < '1994-01-01')).join( - customer, col('o_custkey') == customer.c_custkey).join( - nation, col('c_nationkey') == nation.n_nationkey).join( - flineitem, col('o_orderkey') == flineitem.l_orderkey).select( - 'c_custkey', 'c_name', - (col('l_extendedprice') * (1 - col('l_discount'))).alias('volume'), - 'c_acctbal', 'n_name', 'c_address', 'c_phone', 'c_comment').groupBy( - 'c_custkey', 'c_name', 'c_acctbal', 'c_phone', 'n_name', 'c_address', 'c_comment').agg( - try_sum('volume').alias('revenue')).select( - 'c_custkey', 'c_name', 'revenue', 'c_acctbal', 'n_name', 'c_address', 'c_phone', - 'c_comment') - - sorted_outcome = outcome.sort(desc('revenue')).limit(2).collect() - assertDataFrameEqual(sorted_outcome, expected, atol=1e-2) + with utilizes_valid_plans(spark_session_with_tpch_dataset): + customer = spark_session_with_tpch_dataset.table('customer') + lineitem = spark_session_with_tpch_dataset.table('lineitem') + nation = spark_session_with_tpch_dataset.table('nation') + orders = spark_session_with_tpch_dataset.table('orders') + + flineitem = lineitem.filter(col('l_returnflag') == 'R') + + outcome = orders.filter( + (col('o_orderdate') >= '1993-10-01') & (col('o_orderdate') < '1994-01-01')).join( + customer, col('o_custkey') == customer.c_custkey).join( + nation, col('c_nationkey') == nation.n_nationkey).join( + flineitem, col('o_orderkey') == flineitem.l_orderkey).select( + 'c_custkey', 'c_name', + (col('l_extendedprice') * (1 - col('l_discount'))).alias('volume'), + 'c_acctbal', 'n_name', 'c_address', 'c_phone', 'c_comment').groupBy( + 'c_custkey', 'c_name', 'c_acctbal', 'c_phone', 'n_name', 'c_address', + 'c_comment').agg( + try_sum('volume').alias('revenue')).select( + 'c_custkey', 'c_name', 'revenue', 'c_acctbal', 'n_name', 'c_address', 'c_phone', + 'c_comment') + + sorted_outcome = outcome.sort(desc('revenue')).limit(2).collect() + + assert_dataframes_equal(sorted_outcome, expected) def test_query_11(self, spark_session_with_tpch_dataset): expected = [ - Row(ps_partkey=129760, value=17538456.86), - Row(ps_partkey=166726, value=16503353.92), - Row(ps_partkey=191287, value=16474801.97), - Row(ps_partkey=161758, value=16101755.54), - Row(ps_partkey=34452, value=15983844.72), + Row(ps_partkey=129760, part_value=17538456.86), + Row(ps_partkey=166726, part_value=16503353.92), + Row(ps_partkey=191287, part_value=16474801.97), + Row(ps_partkey=161758, part_value=16101755.54), + Row(ps_partkey=34452, part_value=15983844.72), ] - nation = spark_session_with_tpch_dataset.table('nation') - partsupp = spark_session_with_tpch_dataset.table('partsupp') - supplier = spark_session_with_tpch_dataset.table('supplier') + with utilizes_valid_plans(spark_session_with_tpch_dataset): + nation = spark_session_with_tpch_dataset.table('nation') + partsupp = spark_session_with_tpch_dataset.table('partsupp') + supplier = spark_session_with_tpch_dataset.table('supplier') - tmp = nation.filter(col('n_name') == 'GERMANY').join( - supplier, col('n_nationkey') == supplier.s_nationkey).select( - 's_suppkey').join(partsupp, col('s_suppkey') == partsupp.ps_suppkey).select( - 'ps_partkey', (col('ps_supplycost') * col('ps_availqty')).alias('value')) + tmp = nation.filter(col('n_name') == 'GERMANY').join( + supplier, col('n_nationkey') == supplier.s_nationkey).select( + 's_suppkey').join(partsupp, col('s_suppkey') == partsupp.ps_suppkey).select( + 'ps_partkey', (col('ps_supplycost') * col('ps_availqty')).alias('value')) - sumRes = tmp.agg(try_sum('value').alias('total_value')) + sumRes = tmp.agg(try_sum('value').alias('total_value')) - outcome = tmp.groupBy('ps_partkey').agg( - (try_sum('value')).alias('part_value')).join( - sumRes, col('part_value') > col('total_value') * 0.0001) + outcome = tmp.groupBy('ps_partkey').agg( + (try_sum('value')).alias('part_value')).join( + sumRes, col('part_value') > col('total_value') * 0.0001) - sorted_outcome = outcome.sort(desc('part_value')).limit(5).collect() - assertDataFrameEqual(sorted_outcome, expected, atol=1e-2) + sorted_outcome = outcome.sort(desc('part_value')).limit(5).collect() + + assert_dataframes_equal(sorted_outcome, expected) def test_query_12(self, spark_session_with_tpch_dataset): expected = [ @@ -383,26 +409,31 @@ def test_query_12(self, spark_session_with_tpch_dataset): Row(l_shipmode='SHIP', high_line_count=6200, low_line_count=9262), ] - lineitem = spark_session_with_tpch_dataset.table('lineitem') - orders = spark_session_with_tpch_dataset.table('orders') - - outcome = lineitem.filter( - (col('l_shipmode') == 'MAIL') | (col('l_shipmode') == 'SHIP')).filter( - (col('l_commitdate') < col('l_receiptdate')) & - (col('l_shipdate') < col('l_commitdate')) & - (col('l_receiptdate') >= '1994-01-01') & (col('l_receiptdate') < '1995-01-01')).join( - orders, - col('l_orderkey') == orders.o_orderkey).select( - 'l_shipmode', 'o_orderpriority').groupBy('l_shipmode').agg( - count( - when((col('o_orderpriority') == '1-URGENT') | (col('o_orderpriority') == '2-HIGH'), - True)).alias('high_line_count'), - count( - when((col('o_orderpriority') != '1-URGENT') & (col('o_orderpriority') != '2-HIGH'), - True)).alias('low_line_count')) - - sorted_outcome = outcome.sort('l_shipmode').collect() - assertDataFrameEqual(sorted_outcome, expected, atol=1e-2) + with utilizes_valid_plans(spark_session_with_tpch_dataset): + lineitem = spark_session_with_tpch_dataset.table('lineitem') + orders = spark_session_with_tpch_dataset.table('orders') + + outcome = lineitem.filter( + (col('l_shipmode') == 'MAIL') | (col('l_shipmode') == 'SHIP')).filter( + (col('l_commitdate') < col('l_receiptdate')) & + (col('l_shipdate') < col('l_commitdate')) & + (col('l_receiptdate') >= '1994-01-01') & ( + col('l_receiptdate') < '1995-01-01')).join( + orders, + col('l_orderkey') == orders.o_orderkey).select( + 'l_shipmode', 'o_orderpriority').groupBy('l_shipmode').agg( + count( + when((col('o_orderpriority') == '1-URGENT') | ( + col('o_orderpriority') == '2-HIGH'), + True)).alias('high_line_count'), + count( + when((col('o_orderpriority') != '1-URGENT') & ( + col('o_orderpriority') != '2-HIGH'), + True)).alias('low_line_count')) + + sorted_outcome = outcome.sort('l_shipmode').collect() + + assert_dataframes_equal(sorted_outcome, expected) def test_query_13(self, spark_session_with_tpch_dataset): # TODO -- Verify the corretness of these results against another version of the dataset. @@ -412,56 +443,63 @@ def test_query_13(self, spark_session_with_tpch_dataset): Row(c_count=11, custdist=6014), ] - customer = spark_session_with_tpch_dataset.table('customer') - orders = spark_session_with_tpch_dataset.table('orders') + with utilizes_valid_plans(spark_session_with_tpch_dataset): + customer = spark_session_with_tpch_dataset.table('customer') + orders = spark_session_with_tpch_dataset.table('orders') + + outcome = customer.join( + orders, (col('c_custkey') == orders.o_custkey) & ( + ~col('o_comment').rlike('.*special.*requests.*')), 'left_outer').groupBy( + 'o_custkey').agg(count('o_orderkey').alias('c_count')).groupBy( + 'c_count').agg(count('o_custkey').alias('custdist')) - outcome = customer.join( - orders, (col('c_custkey') == orders.o_custkey) & ( - ~col('o_comment').rlike('.*special.*requests.*')), 'left_outer').groupBy( - 'o_custkey').agg(count('o_orderkey').alias('c_count')).groupBy( - 'c_count').agg(count('o_custkey').alias('custdist')) + sorted_outcome = outcome.sort(desc('custdist'), desc('c_count')).limit(3).collect() - sorted_outcome = outcome.sort(desc('custdist'), desc('c_count')).limit(3).collect() - assertDataFrameEqual(sorted_outcome, expected, atol=1e-2) + assert_dataframes_equal(sorted_outcome, expected) def test_query_14(self, spark_session_with_tpch_dataset): expected = [ Row(promo_revenue=16.38), ] - lineitem = spark_session_with_tpch_dataset.table('lineitem') - part = spark_session_with_tpch_dataset.table('part') + with utilizes_valid_plans(spark_session_with_tpch_dataset): + lineitem = spark_session_with_tpch_dataset.table('lineitem') + part = spark_session_with_tpch_dataset.table('part') - outcome = part.join(lineitem, (col('l_partkey') == col('p_partkey')) & - (col('l_shipdate') >= '1995-09-01') & - (col('l_shipdate') < '1995-10-01')).select( - 'p_type', (col('l_extendedprice') * (1 - col('l_discount'))).alias('value')).agg( - try_sum(when(col('p_type').contains('PROMO'), col('value'))) * 100 / try_sum( - col('value')) - ).alias('promo_revenue') + outcome = part.join(lineitem, (col('l_partkey') == col('p_partkey')) & + (col('l_shipdate') >= '1995-09-01') & + (col('l_shipdate') < '1995-10-01')).select( + 'p_type', (col('l_extendedprice') * (1 - col('l_discount'))).alias('value')).agg( + try_sum(when(col('p_type').contains('PROMO'), col('value'))) * 100 / try_sum( + col('value')) + ).alias('promo_revenue').collect() - assertDataFrameEqual(outcome, expected, atol=1e-2) + assert_dataframes_equal(outcome, expected) def test_query_15(self, spark_session_with_tpch_dataset): expected = [ - Row(s_suppkey=8449, s_name='Supplier#000008449', s_address='Wp34zim9qYFbVctdW'), + Row(s_suppkey=8449, s_name='Supplier#000008449', s_address='Wp34zim9qYFbVctdW', + s_phone='20-469-856-8873', total=1772627.21), ] - lineitem = spark_session_with_tpch_dataset.table('lineitem') - supplier = spark_session_with_tpch_dataset.table('supplier') + with utilizes_valid_plans(spark_session_with_tpch_dataset): + lineitem = spark_session_with_tpch_dataset.table('lineitem') + supplier = spark_session_with_tpch_dataset.table('supplier') + + revenue = lineitem.filter((col('l_shipdate') >= '1996-01-01') & + (col('l_shipdate') < '1996-04-01')).select( + 'l_suppkey', + (col('l_extendedprice') * (1 - col('l_discount'))).alias('value')).groupBy( + 'l_suppkey').agg(try_sum('value').alias('total')) - revenue = lineitem.filter((col('l_shipdate') >= '1996-01-01') & - (col('l_shipdate') < '1996-04-01')).select( - 'l_suppkey', (col('l_extendedprice') * (1 - col('l_discount'))).alias('value')).groupBy( - 'l_suppkey').agg(try_sum('value').alias('total')) + outcome = revenue.agg(pyspark.sql.functions.max(col('total')).alias('max_total')).join( + revenue, col('max_total') == revenue.total).join( + supplier, col('l_suppkey') == supplier.s_suppkey).select( + 's_suppkey', 's_name', 's_address', 's_phone', 'total') - outcome = revenue.agg(pyspark.sql.functions.max(col('total')).alias('max_total')).join( - revenue, col('max_total') == revenue.total).join( - supplier, col('l_suppkey') == supplier.s_suppkey).select( - 's_suppkey', 's_name', 's_address', 's_phone', 'total') + sorted_outcome = outcome.sort('s_suppkey').limit(1).collect() - sorted_outcome = outcome.sort('s_suppkey').limit(1).collect() - assertDataFrameEqual(sorted_outcome, expected, atol=1e-2) + assert_dataframes_equal(sorted_outcome, expected) def test_query_16(self, spark_session_with_tpch_dataset): expected = [ @@ -470,45 +508,49 @@ def test_query_16(self, spark_session_with_tpch_dataset): Row(p_brand='Brand#11', p_type='STANDARD BRUSHED TIN', p_size=23, supplier_cnt=24), ] - part = spark_session_with_tpch_dataset.table('part') - partsupp = spark_session_with_tpch_dataset.table('partsupp') - supplier = spark_session_with_tpch_dataset.table('supplier') + with utilizes_valid_plans(spark_session_with_tpch_dataset): + part = spark_session_with_tpch_dataset.table('part') + partsupp = spark_session_with_tpch_dataset.table('partsupp') + supplier = spark_session_with_tpch_dataset.table('supplier') - fparts = part.filter((col('p_brand') != 'Brand#45') & - (~col('p_type').startswith('MEDIUM POLISHED')) & - (col('p_size').isin([3, 14, 23, 45, 49, 9, 19, 36]))).select( - 'p_partkey', 'p_brand', 'p_type', 'p_size') + fparts = part.filter((col('p_brand') != 'Brand#45') & + (~col('p_type').startswith('MEDIUM POLISHED')) & + (col('p_size').isin([3, 14, 23, 45, 49, 9, 19, 36]))).select( + 'p_partkey', 'p_brand', 'p_type', 'p_size') - outcome = supplier.filter(~col('s_comment').rlike('.*Customer.*Complaints.*')).join( - partsupp, col('s_suppkey') == partsupp.ps_suppkey).select( - 'ps_partkey', 'ps_suppkey').join( - fparts, col('ps_partkey') == fparts.p_partkey).groupBy( - 'p_brand', 'p_type', 'p_size').agg(countDistinct('ps_suppkey').alias('supplier_cnt')) + outcome = supplier.filter(~col('s_comment').rlike('.*Customer.*Complaints.*')).join( + partsupp, col('s_suppkey') == partsupp.ps_suppkey).select( + 'ps_partkey', 'ps_suppkey').join( + fparts, col('ps_partkey') == fparts.p_partkey).groupBy( + 'p_brand', 'p_type', 'p_size').agg( + countDistinct('ps_suppkey').alias('supplier_cnt')) - sorted_outcome = outcome.sort( - desc('supplier_cnt'), 'p_brand', 'p_type', 'p_size').limit(3).collect() - assertDataFrameEqual(sorted_outcome, expected, atol=1e-2) + sorted_outcome = outcome.sort( + desc('supplier_cnt'), 'p_brand', 'p_type', 'p_size').limit(3).collect() + + assert_dataframes_equal(sorted_outcome, expected) def test_query_17(self, spark_session_with_tpch_dataset): expected = [ Row(avg_yearly=348406.02), ] - lineitem = spark_session_with_tpch_dataset.table('lineitem') - part = spark_session_with_tpch_dataset.table('part') + with utilizes_valid_plans(spark_session_with_tpch_dataset): + lineitem = spark_session_with_tpch_dataset.table('lineitem') + part = spark_session_with_tpch_dataset.table('part') - fpart = part.filter( - (col('p_brand') == 'Brand#23') & (col('p_container') == 'MED BOX')).select( - 'p_partkey').join(lineitem, col('p_partkey') == lineitem.l_partkey, 'left_outer') + fpart = part.filter( + (col('p_brand') == 'Brand#23') & (col('p_container') == 'MED BOX')).select( + 'p_partkey').join(lineitem, col('p_partkey') == lineitem.l_partkey, 'left_outer') - outcome = fpart.groupBy('p_partkey').agg( - (avg('l_quantity') * 0.2).alias('avg_quantity')).select( - col('p_partkey').alias('key'), 'avg_quantity').join( - fpart, col('key') == fpart.p_partkey).filter( - col('l_quantity') < col('avg_quantity')).agg( - try_sum('l_extendedprice') / 7).alias('avg_yearly') + outcome = fpart.groupBy('p_partkey').agg( + (avg('l_quantity') * 0.2).alias('avg_quantity')).select( + col('p_partkey').alias('key'), 'avg_quantity').join( + fpart, col('key') == fpart.p_partkey).filter( + col('l_quantity') < col('avg_quantity')).agg( + try_sum('l_extendedprice') / 7).alias('avg_yearly').collect() - assertDataFrameEqual(outcome, expected, atol=1e-2) + assert_dataframes_equal(outcome, expected) def test_query_18(self, spark_session_with_tpch_dataset): expected = [ @@ -520,51 +562,55 @@ def test_query_18(self, spark_session_with_tpch_dataset): o_totalprice=530604.44, sum_l_quantity=317.00), ] - customer = spark_session_with_tpch_dataset.table('customer') - lineitem = spark_session_with_tpch_dataset.table('lineitem') - orders = spark_session_with_tpch_dataset.table('orders') + with utilizes_valid_plans(spark_session_with_tpch_dataset): + customer = spark_session_with_tpch_dataset.table('customer') + lineitem = spark_session_with_tpch_dataset.table('lineitem') + orders = spark_session_with_tpch_dataset.table('orders') + + outcome = lineitem.groupBy('l_orderkey').agg( + try_sum('l_quantity').alias('sum_quantity')).filter( + col('sum_quantity') > 300).select(col('l_orderkey').alias('key'), + 'sum_quantity').join( + orders, orders.o_orderkey == col('key')).join( + lineitem, col('o_orderkey') == lineitem.l_orderkey).join( + customer, col('o_custkey') == customer.c_custkey).select( + 'l_quantity', 'c_name', 'c_custkey', 'o_orderkey', 'o_orderdate', + 'o_totalprice').groupBy( + 'c_name', 'c_custkey', 'o_orderkey', 'o_orderdate', 'o_totalprice').agg( + try_sum('l_quantity')).alias('sum_l_quantity') - outcome = lineitem.groupBy('l_orderkey').agg( - try_sum('l_quantity').alias('sum_quantity')).filter( - col('sum_quantity') > 300).select(col('l_orderkey').alias('key'), 'sum_quantity').join( - orders, orders.o_orderkey == col('key')).join( - lineitem, col('o_orderkey') == lineitem.l_orderkey).join( - customer, col('o_custkey') == customer.c_custkey).select( - 'l_quantity', 'c_name', 'c_custkey', 'o_orderkey', 'o_orderdate', - 'o_totalprice').groupBy( - 'c_name', 'c_custkey', 'o_orderkey', 'o_orderdate', 'o_totalprice').agg( - try_sum('l_quantity')) + sorted_outcome = outcome.sort(desc('o_totalprice'), 'o_orderdate').limit(2).collect() - sorted_outcome = outcome.sort(desc('o_totalprice'), 'o_orderdate').limit(2).collect() - assertDataFrameEqual(sorted_outcome, expected, atol=1e-2) + assert_dataframes_equal(sorted_outcome, expected) def test_query_19(self, spark_session_with_tpch_dataset): expected = [ Row(revenue=3083843.06), ] - lineitem = spark_session_with_tpch_dataset.table('lineitem') - part = spark_session_with_tpch_dataset.table('part') - - outcome = part.join(lineitem, col('l_partkey') == col('p_partkey')).filter( - col('l_shipmode').isin(['AIR', 'AIR REG']) & ( - col('l_shipinstruct') == 'DELIVER IN PERSON')).filter( - ((col('p_brand') == 'Brand#12') & ( - col('p_container').isin(['SM CASE', 'SM BOX', 'SM PACK', 'SM PKG'])) & - (col('l_quantity') >= 1) & (col('l_quantity') <= 11) & - (col('p_size') >= 1) & (col('p_size') <= 5)) | - ((col('p_brand') == 'Brand#23') & ( - col('p_container').isin(['MED BAG', 'MED BOX', 'MED PKG', 'MED PACK'])) & - (col('l_quantity') >= 10) & (col('l_quantity') <= 20) & - (col('p_size') >= 1) & (col('p_size') <= 10)) | - ((col('p_brand') == 'Brand#34') & ( - col('p_container').isin(['LG CASE', 'LG BOX', 'LG PACK', 'LG PKG'])) & - (col('l_quantity') >= 20) & (col('l_quantity') <= 30) & - (col('p_size') >= 1) & (col('p_size') <= 15))).select( - (col('l_extendedprice') * (1 - col('l_discount'))).alias('volume')).agg( - try_sum('volume').alias('revenue')) - - assertDataFrameEqual(outcome, expected, atol=1e-2) + with utilizes_valid_plans(spark_session_with_tpch_dataset): + lineitem = spark_session_with_tpch_dataset.table('lineitem') + part = spark_session_with_tpch_dataset.table('part') + + outcome = part.join(lineitem, col('l_partkey') == col('p_partkey')).filter( + col('l_shipmode').isin(['AIR', 'AIR REG']) & ( + col('l_shipinstruct') == 'DELIVER IN PERSON')).filter( + ((col('p_brand') == 'Brand#12') & ( + col('p_container').isin(['SM CASE', 'SM BOX', 'SM PACK', 'SM PKG'])) & + (col('l_quantity') >= 1) & (col('l_quantity') <= 11) & + (col('p_size') >= 1) & (col('p_size') <= 5)) | + ((col('p_brand') == 'Brand#23') & ( + col('p_container').isin(['MED BAG', 'MED BOX', 'MED PKG', 'MED PACK'])) & + (col('l_quantity') >= 10) & (col('l_quantity') <= 20) & + (col('p_size') >= 1) & (col('p_size') <= 10)) | + ((col('p_brand') == 'Brand#34') & ( + col('p_container').isin(['LG CASE', 'LG BOX', 'LG PACK', 'LG PKG'])) & + (col('l_quantity') >= 20) & (col('l_quantity') <= 30) & + (col('p_size') >= 1) & (col('p_size') <= 15))).select( + (col('l_extendedprice') * (1 - col('l_discount'))).alias('volume')).agg( + try_sum('volume').alias('revenue')).collect() + + assert_dataframes_equal(outcome, expected) def test_query_20(self, spark_session_with_tpch_dataset): expected = [ @@ -573,30 +619,32 @@ def test_query_20(self, spark_session_with_tpch_dataset): Row(s_name='Supplier#000000205', s_address='rF uV8d0JNEk'), ] - lineitem = spark_session_with_tpch_dataset.table('lineitem') - nation = spark_session_with_tpch_dataset.table('nation') - part = spark_session_with_tpch_dataset.table('part') - partsupp = spark_session_with_tpch_dataset.table('partsupp') - supplier = spark_session_with_tpch_dataset.table('supplier') + with utilizes_valid_plans(spark_session_with_tpch_dataset): + lineitem = spark_session_with_tpch_dataset.table('lineitem') + nation = spark_session_with_tpch_dataset.table('nation') + part = spark_session_with_tpch_dataset.table('part') + partsupp = spark_session_with_tpch_dataset.table('partsupp') + supplier = spark_session_with_tpch_dataset.table('supplier') + + flineitem = lineitem.filter( + (col('l_shipdate') >= '1994-01-01') & (col('l_shipdate') < '1995-01-01')).groupBy( + 'l_partkey', 'l_suppkey').agg( + try_sum(col('l_quantity') * 0.5).alias('sum_quantity')) - flineitem = lineitem.filter( - (col('l_shipdate') >= '1994-01-01') & (col('l_shipdate') < '1995-01-01')).groupBy( - 'l_partkey', 'l_suppkey').agg( - try_sum(col('l_quantity') * 0.5).alias('sum_quantity')) + fnation = nation.filter(col('n_name') == 'CANADA') + nat_supp = supplier.select('s_suppkey', 's_name', 's_nationkey', 's_address').join( + fnation, col('s_nationkey') == fnation.n_nationkey) - fnation = nation.filter(col('n_name') == 'CANADA') - nat_supp = supplier.select('s_suppkey', 's_name', 's_nationkey', 's_address').join( - fnation, col('s_nationkey') == fnation.n_nationkey) + outcome = part.filter(col('p_name').startswith('forest')).select('p_partkey').join( + partsupp, col('p_partkey') == partsupp.ps_partkey).join( + flineitem, (col('ps_suppkey') == flineitem.l_suppkey) & ( + col('ps_partkey') == flineitem.l_partkey)).filter( + col('ps_availqty') > col('sum_quantity')).select('ps_suppkey').distinct().join( + nat_supp, col('ps_suppkey') == nat_supp.s_suppkey).select('s_name', 's_address') - outcome = part.filter(col('p_name').startswith('forest')).select('p_partkey').join( - partsupp, col('p_partkey') == partsupp.ps_partkey).join( - flineitem, (col('ps_suppkey') == flineitem.l_suppkey) & ( - col('ps_partkey') == flineitem.l_partkey)).filter( - col('ps_availqty') > col('sum_quantity')).select('ps_suppkey').distinct().join( - nat_supp, col('ps_suppkey') == nat_supp.s_suppkey).select('s_name', 's_address') + sorted_outcome = outcome.sort('s_name').limit(3).collect() - sorted_outcome = outcome.sort('s_name').limit(3).collect() - assertDataFrameEqual(sorted_outcome, expected, atol=1e-2) + assert_dataframes_equal(sorted_outcome, expected) def test_query_21(self, spark_session_with_tpch_dataset): # TODO -- Verify the corretness of these results against another version of the dataset. @@ -608,44 +656,47 @@ def test_query_21(self, spark_session_with_tpch_dataset): Row(s_name='Supplier#000000486', numwait=25), ] - lineitem = spark_session_with_tpch_dataset.table('lineitem') - nation = spark_session_with_tpch_dataset.table('nation') - orders = spark_session_with_tpch_dataset.table('orders') - supplier = spark_session_with_tpch_dataset.table('supplier') + with utilizes_valid_plans(spark_session_with_tpch_dataset): + lineitem = spark_session_with_tpch_dataset.table('lineitem') + nation = spark_session_with_tpch_dataset.table('nation') + orders = spark_session_with_tpch_dataset.table('orders') + supplier = spark_session_with_tpch_dataset.table('supplier') - fsupplier = supplier.select('s_suppkey', 's_nationkey', 's_name') + fsupplier = supplier.select('s_suppkey', 's_nationkey', 's_name') - plineitem = lineitem.select('l_suppkey', 'l_orderkey', 'l_receiptdate', 'l_commitdate') + plineitem = lineitem.select('l_suppkey', 'l_orderkey', 'l_receiptdate', 'l_commitdate') - flineitem = plineitem.filter(col('l_receiptdate') > col('l_commitdate')) + flineitem = plineitem.filter(col('l_receiptdate') > col('l_commitdate')) - line1 = plineitem.groupBy('l_orderkey').agg( - countDistinct('l_suppkey').alias('suppkey_count'), - pyspark.sql.functions.max(col('l_suppkey')).alias('suppkey_max')).select( - col('l_orderkey').alias('key'), 'suppkey_count', 'suppkey_max') + line1 = plineitem.groupBy('l_orderkey').agg( + countDistinct('l_suppkey').alias('suppkey_count'), + pyspark.sql.functions.max(col('l_suppkey')).alias('suppkey_max')).select( + col('l_orderkey').alias('key'), 'suppkey_count', 'suppkey_max') - line2 = flineitem.groupBy('l_orderkey').agg( - countDistinct('l_suppkey').alias('suppkey_count'), - pyspark.sql.functions.max(col('l_suppkey')).alias('suppkey_max')).select( - col('l_orderkey').alias('key'), 'suppkey_count', 'suppkey_max') + line2 = flineitem.groupBy('l_orderkey').agg( + countDistinct('l_suppkey').alias('suppkey_count'), + pyspark.sql.functions.max(col('l_suppkey')).alias('suppkey_max')).select( + col('l_orderkey').alias('key'), 'suppkey_count', 'suppkey_max') - forder = orders.select('o_orderkey', 'o_orderstatus').filter(col('o_orderstatus') == 'F') + forder = orders.select('o_orderkey', 'o_orderstatus').filter( + col('o_orderstatus') == 'F') - outcome = nation.filter(col('n_name') == 'SAUDI ARABIA').join( - fsupplier, col('n_nationkey') == fsupplier.s_nationkey).join( - flineitem, col('s_suppkey') == flineitem.l_suppkey).join( - forder, col('l_orderkey') == forder.o_orderkey).join( - line1, col('l_orderkey') == line1.key).filter( - (col('suppkey_count') > 1) | - ((col('suppkey_count') == 1) & (col('l_suppkey') == col('suppkey_max')))).select( - 's_name', 'l_orderkey', 'l_suppkey').join( - line2, col('l_orderkey') == line2.key, 'left_outer').select( - 's_name', 'l_orderkey', 'l_suppkey', 'suppkey_count', 'suppkey_max').filter( - (col('suppkey_count') == 1) & (col('l_suppkey') == col('suppkey_max'))).groupBy( - 's_name').agg(count(col('l_suppkey')).alias('numwait')) + outcome = nation.filter(col('n_name') == 'SAUDI ARABIA').join( + fsupplier, col('n_nationkey') == fsupplier.s_nationkey).join( + flineitem, col('s_suppkey') == flineitem.l_suppkey).join( + forder, col('l_orderkey') == forder.o_orderkey).join( + line1, col('l_orderkey') == line1.key).filter( + (col('suppkey_count') > 1) | + ((col('suppkey_count') == 1) & (col('l_suppkey') == col('suppkey_max')))).select( + 's_name', 'l_orderkey', 'l_suppkey').join( + line2, col('l_orderkey') == line2.key, 'left_outer').select( + 's_name', 'l_orderkey', 'l_suppkey', 'suppkey_count', 'suppkey_max').filter( + (col('suppkey_count') == 1) & (col('l_suppkey') == col('suppkey_max'))).groupBy( + 's_name').agg(count(col('l_suppkey')).alias('numwait')) - sorted_outcome = outcome.sort(desc('numwait'), 's_name').limit(5).collect() - assertDataFrameEqual(sorted_outcome, expected, atol=1e-2) + sorted_outcome = outcome.sort(desc('numwait'), 's_name').limit(5).collect() + + assert_dataframes_equal(sorted_outcome, expected) def test_query_22(self, spark_session_with_tpch_dataset): expected = [ @@ -658,22 +709,24 @@ def test_query_22(self, spark_session_with_tpch_dataset): Row(cntrycode='31', numcust=922, totacctbal=6806670.18), ] - customer = spark_session_with_tpch_dataset.table('customer') - orders = spark_session_with_tpch_dataset.table('orders') + with utilizes_valid_plans(spark_session_with_tpch_dataset): + customer = spark_session_with_tpch_dataset.table('customer') + orders = spark_session_with_tpch_dataset.table('orders') + + fcustomer = customer.select( + 'c_acctbal', 'c_custkey', (col('c_phone').substr(1, 2)).alias('cntrycode')).filter( + col('cntrycode').isin(['13', '31', '23', '29', '30', '18', '17'])) - fcustomer = customer.select( - 'c_acctbal', 'c_custkey', (col('c_phone').substr(0, 2)).alias('cntrycode')).filter( - col('cntrycode').isin(['13', '31', '23', '29', '30', '18', '17'])) + avg_customer = fcustomer.filter(col('c_acctbal') > 0.00).agg( + avg('c_acctbal').alias('avg_acctbal')) - avg_customer = fcustomer.filter(col('c_acctbal') > 0.00).agg( - avg('c_acctbal').alias('avg_acctbal')) + outcome = orders.groupBy('o_custkey').agg( + count('o_custkey')).select('o_custkey').join( + fcustomer, col('o_custkey') == fcustomer.c_custkey, 'right_outer').filter( + col('o_custkey').isNull()).join(avg_customer).filter( + col('c_acctbal') > col('avg_acctbal')).groupBy('cntrycode').agg( + count('c_custkey').alias('numcust'), try_sum('c_acctbal')) - outcome = orders.groupBy('o_custkey').agg( - count('o_custkey')).select('o_custkey').join( - fcustomer, col('o_custkey') == fcustomer.c_custkey, 'right_outer').filter( - col('o_custkey').isNull()).join(avg_customer).filter( - col('c_acctbal') > col('avg_acctbal')).groupBy('cntrycode').agg( - count('c_custkey').alias('numcust'), try_sum('c_acctbal')) + sorted_outcome = outcome.sort('cntrycode').collect() - sorted_outcome = outcome.sort('cntrycode').collect() - assertDataFrameEqual(sorted_outcome, expected, atol=1e-2) + assert_dataframes_equal(sorted_outcome, expected)