From b99a261fee12850ef42e972e4b708f2456fa2064 Mon Sep 17 00:00:00 2001 From: David Sisson Date: Fri, 31 May 2024 12:51:19 -0700 Subject: [PATCH 1/4] Handle rlike for Datafusion --- src/gateway/converter/conversion_options.py | 2 ++ src/gateway/converter/rename_functions.py | 2 ++ src/gateway/converter/spark_functions.py | 14 +++++++++ src/gateway/converter/spark_to_substrait.py | 7 +++++ src/gateway/converter/substrait_builder.py | 34 +++++++++++++++++++++ 5 files changed, 59 insertions(+) diff --git a/src/gateway/converter/conversion_options.py b/src/gateway/converter/conversion_options.py index a44bc13..72983ba 100644 --- a/src/gateway/converter/conversion_options.py +++ b/src/gateway/converter/conversion_options.py @@ -17,6 +17,7 @@ def __init__(self, backend: BackendOptions = None): self.use_emits_instead_of_direct = False self.use_switch_expressions_where_possible = True self.use_duckdb_regexp_matches_function = False + self.use_regexp_like_function = False self.duckdb_project_emit_workaround = False self.safety_project_read_relations = False @@ -41,6 +42,7 @@ def datafusion(): """Return standard options to connect to a Datafusion backend.""" options = ConversionOptions(backend=BackendOptions(BackendEngine.DATAFUSION)) options.use_switch_expressions_where_possible = False + options.use_regexp_like_function = True return options diff --git a/src/gateway/converter/rename_functions.py b/src/gateway/converter/rename_functions.py index 71ac19b..fed6e9f 100644 --- a/src/gateway/converter/rename_functions.py +++ b/src/gateway/converter/rename_functions.py @@ -35,6 +35,8 @@ def visit_plan(self, plan: plan_pb2.Plan) -> None: extension.extension_function.name = 'instr' elif extension.extension_function.name == 'extract': extension.extension_function.name = 'date_part' + elif extension.extension_function.name == 'like': + extension.extension_function.name = 'regexp_like' # pylint: disable=no-member,fixme diff --git a/src/gateway/converter/spark_functions.py b/src/gateway/converter/spark_functions.py index 721cf50..fa19067 100644 --- a/src/gateway/converter/spark_functions.py +++ b/src/gateway/converter/spark_functions.py @@ -160,6 +160,18 @@ def __lt__(self, obj) -> bool: '/functions_string.yaml', 'rpad:str_i64_str', type_pb2.Type( string=type_pb2.Type.String( nullability=type_pb2.Type.Nullability.NULLABILITY_REQUIRED))), + 'regexp_strpos': ExtensionFunction( + '/functions_string.yaml', 'regexp_strpos:str_str_i64_i64', type_pb2.Type( + i64=type_pb2.Type.I64( + nullability=type_pb2.Type.Nullability.NULLABILITY_REQUIRED))), + 'like': ExtensionFunction( + '/functions_string.yaml', 'like:str_str', type_pb2.Type( + bool=type_pb2.Type.Boolean( + nullability=type_pb2.Type.Nullability.NULLABILITY_REQUIRED))), + 'regexp_like': ExtensionFunction( + '/functions_string.yaml', 'regexp_like:str_str', type_pb2.Type( + bool=type_pb2.Type.Boolean( + nullability=type_pb2.Type.Nullability.NULLABILITY_REQUIRED))), 'count': ExtensionFunction( '/functions_aggregate_generic.yaml', 'count:any', type_pb2.Type( i64=type_pb2.Type.I64( @@ -190,6 +202,8 @@ def __lt__(self, obj) -> bool: def lookup_spark_function(name: str, options: ConversionOptions) -> ExtensionFunction: """Return a Substrait function given a spark function name.""" definition = SPARK_SUBSTRAIT_MAPPING.get(name) + if definition is None: + raise ValueError(f'Function {name} not found in the Spark to Substrait mapping table.') if not options.return_names_with_types: definition.name = definition.name.split(':', 1)[0] return definition diff --git a/src/gateway/converter/spark_to_substrait.py b/src/gateway/converter/spark_to_substrait.py index 82d3223..3256787 100644 --- a/src/gateway/converter/spark_to_substrait.py +++ b/src/gateway/converter/spark_to_substrait.py @@ -32,6 +32,7 @@ max_agg_function, minus_function, project_relation, + regexp_like_function, regexp_strpos_function, repeat_function, string_concat_agg_function, @@ -305,6 +306,12 @@ def convert_rlike_function( ], output_type=regexp_matches_func.output_type)) + if self._conversion_options.use_regexp_like_function: + regexp_like_func = self.lookup_function_by_name('regexp_like') + return regexp_like_function(regexp_like_func, + self.convert_expression(in_.arguments[0]), + self.convert_expression(in_.arguments[1])) + regexp_strpos_func = self.lookup_function_by_name('regexp_strpos') greater_func = self.lookup_function_by_name('>') diff --git a/src/gateway/converter/substrait_builder.py b/src/gateway/converter/substrait_builder.py index 9d3469b..765c578 100644 --- a/src/gateway/converter/substrait_builder.py +++ b/src/gateway/converter/substrait_builder.py @@ -254,6 +254,40 @@ def regexp_strpos_function(function_info: ExtensionFunction, algebra_pb2.FunctionArgument(value=position)])) +def like_function(function_info: ExtensionFunction, + input: algebra_pb2.Expression, + pattern: algebra_pb2.Expression, + flags: str | None = None) -> algebra_pb2.Expression: + """Construct a Substrait like expression.""" + result = algebra_pb2.Expression(scalar_function=algebra_pb2.Expression.ScalarFunction( + function_reference=function_info.anchor, + output_type=function_info.output_type, + arguments=[ + algebra_pb2.FunctionArgument(value=input), + algebra_pb2.FunctionArgument(value=pattern)])) + if flags is not None: + result.scalar_function.arguments.append( + algebra_pb2.FunctionArgument(value=string_literal(flags))) + return result + + +def regexp_like_function(function_info: ExtensionFunction, + input: algebra_pb2.Expression, + pattern: algebra_pb2.Expression, + flags: str | None = None) -> algebra_pb2.Expression: + """Construct a Substrait regex like expression.""" + result = algebra_pb2.Expression(scalar_function=algebra_pb2.Expression.ScalarFunction( + function_reference=function_info.anchor, + output_type=function_info.output_type, + arguments=[ + algebra_pb2.FunctionArgument(value=input), + algebra_pb2.FunctionArgument(value=pattern)])) + if flags is not None: + result.scalar_function.arguments.append( + algebra_pb2.FunctionArgument(value=string_literal(flags))) + return result + + def bool_literal(val: bool) -> algebra_pb2.Expression: """Construct a Substrait boolean literal expression.""" return algebra_pb2.Expression(literal=algebra_pb2.Expression.Literal(boolean=val)) From 6760340815703bfca69816881390a475c62c1862 Mon Sep 17 00:00:00 2001 From: David Sisson Date: Fri, 31 May 2024 13:23:06 -0700 Subject: [PATCH 2/4] Update current list of passing/failing Datafusion tests. --- src/gateway/tests/compare_dataframes.py | 5 ++++- src/gateway/tests/test_tpch_with_dataframe_api.py | 13 ++++--------- 2 files changed, 8 insertions(+), 10 deletions(-) diff --git a/src/gateway/tests/compare_dataframes.py b/src/gateway/tests/compare_dataframes.py index ef8a33d..7ba79d5 100644 --- a/src/gateway/tests/compare_dataframes.py +++ b/src/gateway/tests/compare_dataframes.py @@ -24,7 +24,10 @@ def align_schema(source_df: list[Row], schema_df: list[Row]): for field_name, field_value in schema.asDict().items(): if (type(row[field_name] is not type(field_value)) and isinstance(field_value, datetime.date)): - new_row[field_name] = row[field_name].date() + if row[field_name] is None: + new_row[field_name] = row[field_name] + else: + new_row[field_name] = row[field_name].date() else: new_row[field_name] = row[field_name] diff --git a/src/gateway/tests/test_tpch_with_dataframe_api.py b/src/gateway/tests/test_tpch_with_dataframe_api.py index a2bbc2a..c83b0ca 100644 --- a/src/gateway/tests/test_tpch_with_dataframe_api.py +++ b/src/gateway/tests/test_tpch_with_dataframe_api.py @@ -22,7 +22,7 @@ def mark_tests_as_xfail(request): request.node.add_marker(pytest.mark.xfail(reason='distinct not supported')) elif source == 'gateway-over-datafusion': pytest.importorskip("datafusion.substrait") - if originalname in ['test_query_01']: + if originalname in ['test_query_01', 'test_query_16', 'test_query_18']: request.node.add_marker(pytest.mark.xfail(reason='Results mismatch')) elif originalname in ['test_query_03', 'test_query_10', 'test_query_20']: request.node.add_marker(pytest.mark.xfail(reason='Schema mismatch')) @@ -36,15 +36,10 @@ def mark_tests_as_xfail(request): reason='Cannot create filter with non-boolean predicate - substr function')) elif originalname in ['test_query_11']: request.node.add_marker(pytest.mark.xfail(reason='Duplicate field in schema')) - elif originalname in ['test_query_14']: + elif originalname in ['test_query_08', 'test_query_14']: request.node.add_marker(pytest.mark.xfail(reason='Sum not implemented')) - elif originalname in ['test_query_15', 'test_query_19', 'test_query_22']: - request.node.add_marker(pytest.mark.xfail(reason='no oneof "rex_type" field')) - elif originalname in ['test_query_18']: - request.node.add_marker(pytest.mark.xfail( - reason='Error with assigning date attribute during schema alignment')) - else: - request.node.add_marker(pytest.mark.xfail(reason='gateway internal error')) + elif originalname in ['test_query_17']: + request.node.add_marker(pytest.mark.xfail(reason='Avg not implemented')) class TestTpchWithDataFrameAPI: From 9222763c9cbdc8bcf931c86c014b5e6be148b9ce Mon Sep 17 00:00:00 2001 From: David Sisson Date: Fri, 31 May 2024 13:25:04 -0700 Subject: [PATCH 3/4] Removed currently unused implementation of like --- src/gateway/converter/rename_functions.py | 2 -- src/gateway/converter/spark_functions.py | 4 ---- src/gateway/converter/substrait_builder.py | 17 ----------------- 3 files changed, 23 deletions(-) diff --git a/src/gateway/converter/rename_functions.py b/src/gateway/converter/rename_functions.py index fed6e9f..71ac19b 100644 --- a/src/gateway/converter/rename_functions.py +++ b/src/gateway/converter/rename_functions.py @@ -35,8 +35,6 @@ def visit_plan(self, plan: plan_pb2.Plan) -> None: extension.extension_function.name = 'instr' elif extension.extension_function.name == 'extract': extension.extension_function.name = 'date_part' - elif extension.extension_function.name == 'like': - extension.extension_function.name = 'regexp_like' # pylint: disable=no-member,fixme diff --git a/src/gateway/converter/spark_functions.py b/src/gateway/converter/spark_functions.py index fa19067..9ea4c35 100644 --- a/src/gateway/converter/spark_functions.py +++ b/src/gateway/converter/spark_functions.py @@ -164,10 +164,6 @@ def __lt__(self, obj) -> bool: '/functions_string.yaml', 'regexp_strpos:str_str_i64_i64', type_pb2.Type( i64=type_pb2.Type.I64( nullability=type_pb2.Type.Nullability.NULLABILITY_REQUIRED))), - 'like': ExtensionFunction( - '/functions_string.yaml', 'like:str_str', type_pb2.Type( - bool=type_pb2.Type.Boolean( - nullability=type_pb2.Type.Nullability.NULLABILITY_REQUIRED))), 'regexp_like': ExtensionFunction( '/functions_string.yaml', 'regexp_like:str_str', type_pb2.Type( bool=type_pb2.Type.Boolean( diff --git a/src/gateway/converter/substrait_builder.py b/src/gateway/converter/substrait_builder.py index 765c578..edf15de 100644 --- a/src/gateway/converter/substrait_builder.py +++ b/src/gateway/converter/substrait_builder.py @@ -254,23 +254,6 @@ def regexp_strpos_function(function_info: ExtensionFunction, algebra_pb2.FunctionArgument(value=position)])) -def like_function(function_info: ExtensionFunction, - input: algebra_pb2.Expression, - pattern: algebra_pb2.Expression, - flags: str | None = None) -> algebra_pb2.Expression: - """Construct a Substrait like expression.""" - result = algebra_pb2.Expression(scalar_function=algebra_pb2.Expression.ScalarFunction( - function_reference=function_info.anchor, - output_type=function_info.output_type, - arguments=[ - algebra_pb2.FunctionArgument(value=input), - algebra_pb2.FunctionArgument(value=pattern)])) - if flags is not None: - result.scalar_function.arguments.append( - algebra_pb2.FunctionArgument(value=string_literal(flags))) - return result - - def regexp_like_function(function_info: ExtensionFunction, input: algebra_pb2.Expression, pattern: algebra_pb2.Expression, From b5a0bc4e60de43d8f773757dbc6959cc1ab4076f Mon Sep 17 00:00:00 2001 From: David Sisson Date: Fri, 31 May 2024 13:26:27 -0700 Subject: [PATCH 4/4] added flaky float/decimal marker --- src/gateway/tests/test_tpch_with_dataframe_api.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/src/gateway/tests/test_tpch_with_dataframe_api.py b/src/gateway/tests/test_tpch_with_dataframe_api.py index c83b0ca..56c294b 100644 --- a/src/gateway/tests/test_tpch_with_dataframe_api.py +++ b/src/gateway/tests/test_tpch_with_dataframe_api.py @@ -38,6 +38,8 @@ def mark_tests_as_xfail(request): request.node.add_marker(pytest.mark.xfail(reason='Duplicate field in schema')) elif originalname in ['test_query_08', 'test_query_14']: request.node.add_marker(pytest.mark.xfail(reason='Sum not implemented')) + elif originalname in ['test_query_15']: + request.node.add_marker(pytest.mark.xfail(reason='No results (float vs decimal)')) elif originalname in ['test_query_17']: request.node.add_marker(pytest.mark.xfail(reason='Avg not implemented'))