Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: add Datafusion support for rlike #19

Merged
merged 4 commits into from
May 31, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions src/gateway/converter/conversion_options.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ def __init__(self, backend: BackendOptions = None):
self.use_emits_instead_of_direct = False
self.use_switch_expressions_where_possible = True
self.use_duckdb_regexp_matches_function = False
self.use_regexp_like_function = False
self.duckdb_project_emit_workaround = False
self.safety_project_read_relations = False

Expand All @@ -41,6 +42,7 @@ def datafusion():
"""Return standard options to connect to a Datafusion backend."""
options = ConversionOptions(backend=BackendOptions(BackendEngine.DATAFUSION))
options.use_switch_expressions_where_possible = False
options.use_regexp_like_function = True
return options


Expand Down
10 changes: 10 additions & 0 deletions src/gateway/converter/spark_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,6 +160,14 @@ def __lt__(self, obj) -> bool:
'/functions_string.yaml', 'rpad:str_i64_str', type_pb2.Type(
string=type_pb2.Type.String(
nullability=type_pb2.Type.Nullability.NULLABILITY_REQUIRED))),
'regexp_strpos': ExtensionFunction(
'/functions_string.yaml', 'regexp_strpos:str_str_i64_i64', type_pb2.Type(
i64=type_pb2.Type.I64(
nullability=type_pb2.Type.Nullability.NULLABILITY_REQUIRED))),
'regexp_like': ExtensionFunction(
'/functions_string.yaml', 'regexp_like:str_str', type_pb2.Type(
bool=type_pb2.Type.Boolean(
nullability=type_pb2.Type.Nullability.NULLABILITY_REQUIRED))),
'count': ExtensionFunction(
'/functions_aggregate_generic.yaml', 'count:any', type_pb2.Type(
i64=type_pb2.Type.I64(
Expand Down Expand Up @@ -190,6 +198,8 @@ def __lt__(self, obj) -> bool:
def lookup_spark_function(name: str, options: ConversionOptions) -> ExtensionFunction:
"""Return a Substrait function given a spark function name."""
definition = SPARK_SUBSTRAIT_MAPPING.get(name)
if definition is None:
raise ValueError(f'Function {name} not found in the Spark to Substrait mapping table.')
if not options.return_names_with_types:
definition.name = definition.name.split(':', 1)[0]
return definition
7 changes: 7 additions & 0 deletions src/gateway/converter/spark_to_substrait.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
max_agg_function,
minus_function,
project_relation,
regexp_like_function,
regexp_strpos_function,
repeat_function,
string_concat_agg_function,
Expand Down Expand Up @@ -305,6 +306,12 @@ def convert_rlike_function(
],
output_type=regexp_matches_func.output_type))

if self._conversion_options.use_regexp_like_function:
regexp_like_func = self.lookup_function_by_name('regexp_like')
return regexp_like_function(regexp_like_func,
self.convert_expression(in_.arguments[0]),
self.convert_expression(in_.arguments[1]))

regexp_strpos_func = self.lookup_function_by_name('regexp_strpos')
greater_func = self.lookup_function_by_name('>')

Expand Down
17 changes: 17 additions & 0 deletions src/gateway/converter/substrait_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -254,6 +254,23 @@ def regexp_strpos_function(function_info: ExtensionFunction,
algebra_pb2.FunctionArgument(value=position)]))


def regexp_like_function(function_info: ExtensionFunction,
input: algebra_pb2.Expression,
pattern: algebra_pb2.Expression,
flags: str | None = None) -> algebra_pb2.Expression:
"""Construct a Substrait regex like expression."""
result = algebra_pb2.Expression(scalar_function=algebra_pb2.Expression.ScalarFunction(
function_reference=function_info.anchor,
output_type=function_info.output_type,
arguments=[
algebra_pb2.FunctionArgument(value=input),
algebra_pb2.FunctionArgument(value=pattern)]))
if flags is not None:
result.scalar_function.arguments.append(
algebra_pb2.FunctionArgument(value=string_literal(flags)))
return result


def bool_literal(val: bool) -> algebra_pb2.Expression:
"""Construct a Substrait boolean literal expression."""
return algebra_pb2.Expression(literal=algebra_pb2.Expression.Literal(boolean=val))
Expand Down
5 changes: 4 additions & 1 deletion src/gateway/tests/compare_dataframes.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,10 @@ def align_schema(source_df: list[Row], schema_df: list[Row]):
for field_name, field_value in schema.asDict().items():
if (type(row[field_name] is not type(field_value)) and
isinstance(field_value, datetime.date)):
new_row[field_name] = row[field_name].date()
if row[field_name] is None:
new_row[field_name] = row[field_name]
else:
new_row[field_name] = row[field_name].date()
else:
new_row[field_name] = row[field_name]

Expand Down
15 changes: 6 additions & 9 deletions src/gateway/tests/test_tpch_with_dataframe_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ def mark_tests_as_xfail(request):
request.node.add_marker(pytest.mark.xfail(reason='distinct not supported'))
elif source == 'gateway-over-datafusion':
pytest.importorskip("datafusion.substrait")
if originalname in ['test_query_01']:
if originalname in ['test_query_01', 'test_query_16', 'test_query_18']:
request.node.add_marker(pytest.mark.xfail(reason='Results mismatch'))
elif originalname in ['test_query_03', 'test_query_10', 'test_query_20']:
request.node.add_marker(pytest.mark.xfail(reason='Schema mismatch'))
Expand All @@ -36,15 +36,12 @@ def mark_tests_as_xfail(request):
reason='Cannot create filter with non-boolean predicate - substr function'))
elif originalname in ['test_query_11']:
request.node.add_marker(pytest.mark.xfail(reason='Duplicate field in schema'))
elif originalname in ['test_query_14']:
elif originalname in ['test_query_08', 'test_query_14']:
request.node.add_marker(pytest.mark.xfail(reason='Sum not implemented'))
elif originalname in ['test_query_15', 'test_query_19', 'test_query_22']:
request.node.add_marker(pytest.mark.xfail(reason='no oneof "rex_type" field'))
elif originalname in ['test_query_18']:
request.node.add_marker(pytest.mark.xfail(
reason='Error with assigning date attribute during schema alignment'))
else:
request.node.add_marker(pytest.mark.xfail(reason='gateway internal error'))
elif originalname in ['test_query_15']:
request.node.add_marker(pytest.mark.xfail(reason='No results (float vs decimal)'))
elif originalname in ['test_query_17']:
request.node.add_marker(pytest.mark.xfail(reason='Avg not implemented'))


class TestTpchWithDataFrameAPI:
Expand Down
Loading