diff --git a/src/gateway/converter/conversion_options.py b/src/gateway/converter/conversion_options.py index a44bc13..72983ba 100644 --- a/src/gateway/converter/conversion_options.py +++ b/src/gateway/converter/conversion_options.py @@ -17,6 +17,7 @@ def __init__(self, backend: BackendOptions = None): self.use_emits_instead_of_direct = False self.use_switch_expressions_where_possible = True self.use_duckdb_regexp_matches_function = False + self.use_regexp_like_function = False self.duckdb_project_emit_workaround = False self.safety_project_read_relations = False @@ -41,6 +42,7 @@ def datafusion(): """Return standard options to connect to a Datafusion backend.""" options = ConversionOptions(backend=BackendOptions(BackendEngine.DATAFUSION)) options.use_switch_expressions_where_possible = False + options.use_regexp_like_function = True return options diff --git a/src/gateway/converter/spark_functions.py b/src/gateway/converter/spark_functions.py index 721cf50..9ea4c35 100644 --- a/src/gateway/converter/spark_functions.py +++ b/src/gateway/converter/spark_functions.py @@ -160,6 +160,14 @@ def __lt__(self, obj) -> bool: '/functions_string.yaml', 'rpad:str_i64_str', type_pb2.Type( string=type_pb2.Type.String( nullability=type_pb2.Type.Nullability.NULLABILITY_REQUIRED))), + 'regexp_strpos': ExtensionFunction( + '/functions_string.yaml', 'regexp_strpos:str_str_i64_i64', type_pb2.Type( + i64=type_pb2.Type.I64( + nullability=type_pb2.Type.Nullability.NULLABILITY_REQUIRED))), + 'regexp_like': ExtensionFunction( + '/functions_string.yaml', 'regexp_like:str_str', type_pb2.Type( + bool=type_pb2.Type.Boolean( + nullability=type_pb2.Type.Nullability.NULLABILITY_REQUIRED))), 'count': ExtensionFunction( '/functions_aggregate_generic.yaml', 'count:any', type_pb2.Type( i64=type_pb2.Type.I64( @@ -190,6 +198,8 @@ def __lt__(self, obj) -> bool: def lookup_spark_function(name: str, options: ConversionOptions) -> ExtensionFunction: """Return a Substrait function given a spark function name.""" definition = SPARK_SUBSTRAIT_MAPPING.get(name) + if definition is None: + raise ValueError(f'Function {name} not found in the Spark to Substrait mapping table.') if not options.return_names_with_types: definition.name = definition.name.split(':', 1)[0] return definition diff --git a/src/gateway/converter/spark_to_substrait.py b/src/gateway/converter/spark_to_substrait.py index 82d3223..3256787 100644 --- a/src/gateway/converter/spark_to_substrait.py +++ b/src/gateway/converter/spark_to_substrait.py @@ -32,6 +32,7 @@ max_agg_function, minus_function, project_relation, + regexp_like_function, regexp_strpos_function, repeat_function, string_concat_agg_function, @@ -305,6 +306,12 @@ def convert_rlike_function( ], output_type=regexp_matches_func.output_type)) + if self._conversion_options.use_regexp_like_function: + regexp_like_func = self.lookup_function_by_name('regexp_like') + return regexp_like_function(regexp_like_func, + self.convert_expression(in_.arguments[0]), + self.convert_expression(in_.arguments[1])) + regexp_strpos_func = self.lookup_function_by_name('regexp_strpos') greater_func = self.lookup_function_by_name('>') diff --git a/src/gateway/converter/substrait_builder.py b/src/gateway/converter/substrait_builder.py index 9d3469b..edf15de 100644 --- a/src/gateway/converter/substrait_builder.py +++ b/src/gateway/converter/substrait_builder.py @@ -254,6 +254,23 @@ def regexp_strpos_function(function_info: ExtensionFunction, algebra_pb2.FunctionArgument(value=position)])) +def regexp_like_function(function_info: ExtensionFunction, + input: algebra_pb2.Expression, + pattern: algebra_pb2.Expression, + flags: str | None = None) -> algebra_pb2.Expression: + """Construct a Substrait regex like expression.""" + result = algebra_pb2.Expression(scalar_function=algebra_pb2.Expression.ScalarFunction( + function_reference=function_info.anchor, + output_type=function_info.output_type, + arguments=[ + algebra_pb2.FunctionArgument(value=input), + algebra_pb2.FunctionArgument(value=pattern)])) + if flags is not None: + result.scalar_function.arguments.append( + algebra_pb2.FunctionArgument(value=string_literal(flags))) + return result + + def bool_literal(val: bool) -> algebra_pb2.Expression: """Construct a Substrait boolean literal expression.""" return algebra_pb2.Expression(literal=algebra_pb2.Expression.Literal(boolean=val)) diff --git a/src/gateway/tests/compare_dataframes.py b/src/gateway/tests/compare_dataframes.py index ef8a33d..7ba79d5 100644 --- a/src/gateway/tests/compare_dataframes.py +++ b/src/gateway/tests/compare_dataframes.py @@ -24,7 +24,10 @@ def align_schema(source_df: list[Row], schema_df: list[Row]): for field_name, field_value in schema.asDict().items(): if (type(row[field_name] is not type(field_value)) and isinstance(field_value, datetime.date)): - new_row[field_name] = row[field_name].date() + if row[field_name] is None: + new_row[field_name] = row[field_name] + else: + new_row[field_name] = row[field_name].date() else: new_row[field_name] = row[field_name] diff --git a/src/gateway/tests/test_tpch_with_dataframe_api.py b/src/gateway/tests/test_tpch_with_dataframe_api.py index a2bbc2a..56c294b 100644 --- a/src/gateway/tests/test_tpch_with_dataframe_api.py +++ b/src/gateway/tests/test_tpch_with_dataframe_api.py @@ -22,7 +22,7 @@ def mark_tests_as_xfail(request): request.node.add_marker(pytest.mark.xfail(reason='distinct not supported')) elif source == 'gateway-over-datafusion': pytest.importorskip("datafusion.substrait") - if originalname in ['test_query_01']: + if originalname in ['test_query_01', 'test_query_16', 'test_query_18']: request.node.add_marker(pytest.mark.xfail(reason='Results mismatch')) elif originalname in ['test_query_03', 'test_query_10', 'test_query_20']: request.node.add_marker(pytest.mark.xfail(reason='Schema mismatch')) @@ -36,15 +36,12 @@ def mark_tests_as_xfail(request): reason='Cannot create filter with non-boolean predicate - substr function')) elif originalname in ['test_query_11']: request.node.add_marker(pytest.mark.xfail(reason='Duplicate field in schema')) - elif originalname in ['test_query_14']: + elif originalname in ['test_query_08', 'test_query_14']: request.node.add_marker(pytest.mark.xfail(reason='Sum not implemented')) - elif originalname in ['test_query_15', 'test_query_19', 'test_query_22']: - request.node.add_marker(pytest.mark.xfail(reason='no oneof "rex_type" field')) - elif originalname in ['test_query_18']: - request.node.add_marker(pytest.mark.xfail( - reason='Error with assigning date attribute during schema alignment')) - else: - request.node.add_marker(pytest.mark.xfail(reason='gateway internal error')) + elif originalname in ['test_query_15']: + request.node.add_marker(pytest.mark.xfail(reason='No results (float vs decimal)')) + elif originalname in ['test_query_17']: + request.node.add_marker(pytest.mark.xfail(reason='Avg not implemented')) class TestTpchWithDataFrameAPI: