Skip to content

Commit

Permalink
feat: add support for 29 common string functions (#38)
Browse files Browse the repository at this point in the history
  • Loading branch information
EpsilonPrime authored Jun 21, 2024
1 parent b904e75 commit 22d7c4b
Show file tree
Hide file tree
Showing 3 changed files with 526 additions and 1 deletion.
80 changes: 80 additions & 0 deletions src/gateway/converter/spark_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,10 +132,18 @@ def __lt__(self, obj) -> bool:
'/functions_string.yaml', 'regexp_matches:str_str', type_pb2.Type(
bool=type_pb2.Type.Boolean(
nullability=type_pb2.Type.Nullability.NULLABILITY_REQUIRED))),
'substr': ExtensionFunction(
'/functions_string.yaml', 'substring:str_int_int', type_pb2.Type(
string=type_pb2.Type.String(
nullability=type_pb2.Type.Nullability.NULLABILITY_REQUIRED))),
'substring': ExtensionFunction(
'/functions_string.yaml', 'substring:str_int_int', type_pb2.Type(
string=type_pb2.Type.String(
nullability=type_pb2.Type.Nullability.NULLABILITY_REQUIRED))),
'instr': ExtensionFunction(
'/functions_string.yaml', 'strpos:str_str', type_pb2.Type(
i64=type_pb2.Type.I64(
nullability=type_pb2.Type.Nullability.NULLABILITY_REQUIRED))),
'startswith': ExtensionFunction(
'/functions_string.yaml', 'starts_with:str_str', type_pb2.Type(
bool=type_pb2.Type.Boolean(
Expand Down Expand Up @@ -164,6 +172,46 @@ def __lt__(self, obj) -> bool:
'/functions_string.yaml', 'string_agg:str', type_pb2.Type(
string=type_pb2.Type.String(
nullability=type_pb2.Type.Nullability.NULLABILITY_REQUIRED))),
'btrim': ExtensionFunction(
'/functions_string.yaml', 'trim:str', type_pb2.Type(
string=type_pb2.Type.String(
nullability=type_pb2.Type.Nullability.NULLABILITY_REQUIRED))),
'ltrim': ExtensionFunction(
'/functions_string.yaml', 'ltrim:str', type_pb2.Type(
string=type_pb2.Type.String(
nullability=type_pb2.Type.Nullability.NULLABILITY_REQUIRED))),
'rtrim': ExtensionFunction(
'/functions_string.yaml', 'rtrim:str', type_pb2.Type(
string=type_pb2.Type.String(
nullability=type_pb2.Type.Nullability.NULLABILITY_REQUIRED))),
'trim': ExtensionFunction(
'/functions_string.yaml', 'trim:str', type_pb2.Type(
string=type_pb2.Type.String(
nullability=type_pb2.Type.Nullability.NULLABILITY_REQUIRED))),
'lcase': ExtensionFunction(
'/functions_string.yaml', 'lower:str', type_pb2.Type(
string=type_pb2.Type.String(
nullability=type_pb2.Type.Nullability.NULLABILITY_REQUIRED))),
'lower': ExtensionFunction(
'/functions_string.yaml', 'lower:str', type_pb2.Type(
string=type_pb2.Type.String(
nullability=type_pb2.Type.Nullability.NULLABILITY_REQUIRED))),
'ucase': ExtensionFunction(
'/functions_string.yaml', 'upper:str', type_pb2.Type(
string=type_pb2.Type.String(
nullability=type_pb2.Type.Nullability.NULLABILITY_REQUIRED))),
'upper': ExtensionFunction(
'/functions_string.yaml', 'upper:str', type_pb2.Type(
string=type_pb2.Type.String(
nullability=type_pb2.Type.Nullability.NULLABILITY_REQUIRED))),
'left': ExtensionFunction(
'/functions_string.yaml', 'left:int_str', type_pb2.Type(
string=type_pb2.Type.String(
nullability=type_pb2.Type.Nullability.NULLABILITY_REQUIRED))),
'right': ExtensionFunction(
'/functions_string.yaml', 'right:int_str', type_pb2.Type(
string=type_pb2.Type.String(
nullability=type_pb2.Type.Nullability.NULLABILITY_REQUIRED))),
'least': ExtensionFunction(
'/functions_comparison.yaml', 'least:i64', type_pb2.Type(
i64=type_pb2.Type.I64(
Expand All @@ -184,10 +232,18 @@ def __lt__(self, obj) -> bool:
'/functions_string.yaml', 'concat:str_str', type_pb2.Type(
string=type_pb2.Type.String(
nullability=type_pb2.Type.Nullability.NULLABILITY_REQUIRED))),
'concat_ws': ExtensionFunction(
'/functions_string.yaml', 'concat_ws:str_str', type_pb2.Type(
string=type_pb2.Type.String(
nullability=type_pb2.Type.Nullability.NULLABILITY_REQUIRED))),
'repeat': ExtensionFunction(
'/functions_string.yaml', 'repeat:str_i64', type_pb2.Type(
string=type_pb2.Type.String(
nullability=type_pb2.Type.Nullability.NULLABILITY_REQUIRED))),
'replace': ExtensionFunction(
'/functions_string.yaml', 'replace:str_str_str', type_pb2.Type(
string=type_pb2.Type.String(
nullability=type_pb2.Type.Nullability.NULLABILITY_REQUIRED))),
'lpad': ExtensionFunction(
'/functions_string.yaml', 'lpad:str_i64_str', type_pb2.Type(
string=type_pb2.Type.String(
Expand All @@ -200,10 +256,34 @@ def __lt__(self, obj) -> bool:
'/functions_string.yaml', 'regexp_strpos:str_str_i64_i64', type_pb2.Type(
i64=type_pb2.Type.I64(
nullability=type_pb2.Type.Nullability.NULLABILITY_REQUIRED))),
'rlike': ExtensionFunction(
'/functions_string.yaml', 'regexp_like:str_str', type_pb2.Type(
bool=type_pb2.Type.Boolean(
nullability=type_pb2.Type.Nullability.NULLABILITY_REQUIRED))),
'regexp': ExtensionFunction(
'/functions_string.yaml', 'regexp_like:str_str', type_pb2.Type(
bool=type_pb2.Type.Boolean(
nullability=type_pb2.Type.Nullability.NULLABILITY_REQUIRED))),
'regexp_like': ExtensionFunction(
'/functions_string.yaml', 'regexp_like:str_str', type_pb2.Type(
bool=type_pb2.Type.Boolean(
nullability=type_pb2.Type.Nullability.NULLABILITY_REQUIRED))),
'bit_length': ExtensionFunction(
'/functions_string.yaml', 'bit_length:str', type_pb2.Type(
i64=type_pb2.Type.I64(
nullability=type_pb2.Type.Nullability.NULLABILITY_REQUIRED))),
'character_length': ExtensionFunction(
'/functions_string.yaml', 'char_length:str', type_pb2.Type(
i64=type_pb2.Type.I64(
nullability=type_pb2.Type.Nullability.NULLABILITY_REQUIRED))),
'char_length': ExtensionFunction(
'/functions_string.yaml', 'char_length:str', type_pb2.Type(
i64=type_pb2.Type.I64(
nullability=type_pb2.Type.Nullability.NULLABILITY_REQUIRED))),
'octet_length': ExtensionFunction(
'/functions_string.yaml', 'octet_length:str', type_pb2.Type(
i64=type_pb2.Type.I64(
nullability=type_pb2.Type.Nullability.NULLABILITY_REQUIRED))),
'count': ExtensionFunction(
'/functions_aggregate_generic.yaml', 'count:any', type_pb2.Type(
i64=type_pb2.Type.I64(
Expand Down
2 changes: 1 addition & 1 deletion src/gateway/converter/spark_to_substrait.py
Original file line number Diff line number Diff line change
Expand Up @@ -416,7 +416,7 @@ def convert_unresolved_function(
return self.convert_when_function(unresolved_function)
if unresolved_function.function_name == 'in':
return self.convert_in_function(unresolved_function)
if unresolved_function.function_name == 'rlike':
if unresolved_function.function_name in ['rlike', 'regexp', 'regexp_like']:
return self.convert_rlike_function(unresolved_function)
if unresolved_function.function_name == 'nanvl':
return self.convert_nanvl_function(unresolved_function)
Expand Down
Loading

0 comments on commit 22d7c4b

Please sign in to comment.