Skip to content

Commit

Permalink
feat: add support for colRegex (#35)
Browse files Browse the repository at this point in the history
  • Loading branch information
EpsilonPrime authored Jun 19, 2024
1 parent bf2c6cc commit 97a3303
Show file tree
Hide file tree
Showing 4 changed files with 49 additions and 3 deletions.
1 change: 0 additions & 1 deletion environment.yml
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@ dependencies:
- setuptools_scm >= 6.2.0
- python-substrait >= 0.14.1
- mypy-protobuf
- numpy < 2.0.0
- types-protobuf
- numpy < 2.0.0
- Faker
Expand Down
20 changes: 18 additions & 2 deletions src/gateway/converter/spark_to_substrait.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import json
import operator
import pathlib
import re

import pyarrow as pa
import pyspark.sql.connect.proto.base_pb2 as spark_pb2
Expand Down Expand Up @@ -472,7 +473,7 @@ def convert_expression(self, expr: spark_exprs_pb2.Expression) -> algebra_pb2.Ex
result = self.convert_cast_expression(expr.cast)
case 'unresolved_regex':
raise NotImplementedError(
'unresolved_regex expression type not supported')
'colRegex is only supported at the top level of an expression')
case 'sort_order':
raise NotImplementedError(
'sort_order expression type not supported')
Expand Down Expand Up @@ -1353,6 +1354,21 @@ def convert_project_relation(
self.update_field_references(rel.input.common.plan_id)
symbol = self._symbol_table.get_symbol(self._current_plan_id)
for field_number, expr in enumerate(rel.expressions):
if expr.WhichOneof('expr_type') == 'unresolved_regex':
regex = expr.unresolved_regex.col_name.replace('`', '')
matcher = re.compile(regex)
found = False
for column in symbol.input_fields:
if matcher.match(column):
project.expressions.append(
field_reference(symbol.input_fields.index(column)))
symbol.generated_fields.append(column)
symbol.output_fields.append(column)
found = True
if not found:
raise ValueError(
f'No columns match the regex {regex} in plan id {self._current_plan_id}')
continue
project.expressions.append(self.convert_expression(expr))
if expr.WhichOneof('expr_type') == 'alias':
name = expr.alias.name[0]
Expand All @@ -1364,7 +1380,7 @@ def convert_project_relation(
symbol.output_fields.append(name)
project.common.CopyFrom(self.create_common_relation())
symbol.output_fields = symbol.generated_fields
for field_number in range(len(rel.expressions)):
for field_number in range(len(symbol.output_fields)):
project.common.emit.output_mapping.append(field_number + len(symbol.input_fields))
return algebra_pb2.Rel(project=project)

Expand Down
12 changes: 12 additions & 0 deletions src/gateway/converter/substrait_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -203,6 +203,18 @@ def add_function(function_info: ExtensionFunction,
algebra_pb2.FunctionArgument(value=expr2)]))


def and_function(function_info: ExtensionFunction,
expr1: algebra_pb2.Expression,
expr2: algebra_pb2.Expression) -> algebra_pb2.Expression:
"""Construct a Substrait and expression (binary)."""
return algebra_pb2.Expression(scalar_function=
algebra_pb2.Expression.ScalarFunction(
function_reference=function_info.anchor,
output_type=function_info.output_type,
arguments=[algebra_pb2.FunctionArgument(value=expr1),
algebra_pb2.FunctionArgument(value=expr2)]))


def minus_function(function_info: ExtensionFunction,
expr1: algebra_pb2.Expression,
expr2: algebra_pb2.Expression) -> algebra_pb2.Expression:
Expand Down
19 changes: 19 additions & 0 deletions src/gateway/tests/test_dataframe_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -566,6 +566,25 @@ def test_dropduplicates(self, spark_session):

assertDataFrameEqual(outcome, expected)

def test_colregex(self, spark_session, caplog):
expected = [
Row(a1=1, col2='a'),
Row(a1=2, col2='b'),
Row(a1=3, col2='c'),
]

int_array = pa.array([1, 2, 3], type=pa.int32())
string_array = pa.array(['a', 'b', 'c'], type=pa.string())
table = pa.Table.from_arrays([int_array, int_array, string_array, string_array],
names=['a1', 'c', 'col', 'col2'])

df = create_parquet_table(spark_session, 'mytesttable1', table)

with utilizes_valid_plans(df, caplog):
outcome = df.select(df.colRegex("`(c.l|a)?[0-9]`")).collect()

assertDataFrameEqual(outcome, expected)

def test_subtract(self, spark_session_with_tpch_dataset):
expected = [
Row(n_nationkey=21, n_name='VIETNAM', n_regionkey=2,
Expand Down

0 comments on commit 97a3303

Please sign in to comment.