Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: finish column add, rename, and delete relations #28

Merged
merged 10 commits into from
Jun 14, 2024
Merged
1 change: 1 addition & 0 deletions src/gateway/converter/conversion_options.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ def __init__(self, backend: BackendOptions = None):
self.use_first_value_as_any_value = False
self.use_regexp_like_function = False
self.duckdb_project_emit_workaround = False
self.drop_emit_workaround = True
self.safety_project_read_relations = False
self.use_duckdb_struct_name_behavior = False
self.fetch_return_all_workaround = True
Expand Down
138 changes: 69 additions & 69 deletions src/gateway/converter/data/00001.splan
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,7 @@ relations {
project {
common {
emit {
output_mapping: 10
output_mapping: 11
output_mapping: 12
output_mapping: 13
Expand All @@ -87,21 +88,20 @@ relations {
output_mapping: 17
output_mapping: 18
output_mapping: 19
output_mapping: 10
}
}
input {
project {
common {
emit {
output_mapping: 10
output_mapping: 11
output_mapping: 12
output_mapping: 13
output_mapping: 14
output_mapping: 15
output_mapping: 16
output_mapping: 17
output_mapping: 10
output_mapping: 18
output_mapping: 19
}
Expand All @@ -110,13 +110,13 @@ relations {
project {
common {
emit {
output_mapping: 10
output_mapping: 11
output_mapping: 12
output_mapping: 13
output_mapping: 14
output_mapping: 15
output_mapping: 16
output_mapping: 10
output_mapping: 17
output_mapping: 18
output_mapping: 19
Expand Down Expand Up @@ -202,36 +202,6 @@ relations {
}
}
}
expressions {
scalar_function {
function_reference: 1
output_type {
string {
nullability: NULLABILITY_REQUIRED
}
}
arguments {
value {
selection {
direct_reference {
struct_field {
field: 6
}
}
root_reference {
}
}
}
}
arguments {
value {
literal {
string: "; "
}
}
}
}
}
expressions {
selection {
direct_reference {
Expand Down Expand Up @@ -298,21 +268,40 @@ relations {
}
}
expressions {
selection {
direct_reference {
struct_field {
field: 7
scalar_function {
function_reference: 1
output_type {
string {
nullability: NULLABILITY_REQUIRED
}
}
root_reference {
arguments {
value {
selection {
direct_reference {
struct_field {
field: 6
}
}
root_reference {
}
}
}
}
arguments {
value {
literal {
string: "; "
}
}
}
}
}
expressions {
selection {
direct_reference {
struct_field {
field: 8
field: 7
}
}
root_reference {
Expand All @@ -323,34 +312,24 @@ relations {
selection {
direct_reference {
struct_field {
field: 9
field: 8
}
}
root_reference {
}
}
}
}
}
expressions {
cast {
type {
i32 {
nullability: NULLABILITY_REQUIRED
}
}
input {
expressions {
selection {
direct_reference {
struct_field {
field: 7
field: 9
}
}
root_reference {
}
}
}
failure_behavior: FAILURE_BEHAVIOR_THROW_EXCEPTION
}
}
expressions {
Expand Down Expand Up @@ -430,37 +409,38 @@ relations {
}
}
expressions {
selection {
direct_reference {
struct_field {
field: 8
cast {
type {
i32 {
nullability: NULLABILITY_REQUIRED
}
}
root_reference {
input {
selection {
direct_reference {
struct_field {
field: 7
}
}
root_reference {
}
}
}
failure_behavior: FAILURE_BEHAVIOR_THROW_EXCEPTION
}
}
expressions {
selection {
direct_reference {
struct_field {
field: 9
field: 8
}
}
root_reference {
}
}
}
}
}
expressions {
cast {
type {
bool {
nullability: NULLABILITY_REQUIRED
}
}
input {
expressions {
selection {
direct_reference {
struct_field {
Expand All @@ -471,7 +451,6 @@ relations {
}
}
}
failure_behavior: FAILURE_BEHAVIOR_THROW_EXCEPTION
}
}
expressions {
Expand Down Expand Up @@ -572,6 +551,27 @@ relations {
}
}
}
expressions {
cast {
type {
bool {
nullability: NULLABILITY_REQUIRED
}
}
input {
selection {
direct_reference {
struct_field {
field: 9
}
}
root_reference {
}
}
}
failure_behavior: FAILURE_BEHAVIOR_THROW_EXCEPTION
}
}
}
}
condition {
Expand Down
81 changes: 52 additions & 29 deletions src/gateway/converter/spark_to_substrait.py
Original file line number Diff line number Diff line change
Expand Up @@ -1111,44 +1111,63 @@ def convert_with_columns_relation(
project = algebra_pb2.ProjectRel(input=input_rel)
self.update_field_references(rel.input.common.plan_id)
symbol = self._symbol_table.get_symbol(self._current_plan_id)
remapped = False
mapping = list(range(len(symbol.input_fields)))
field_number = len(symbol.input_fields)
proposed_expressions = [field_reference(i) for i in range(len(symbol.input_fields))]
for alias in rel.aliases:
if len(alias.name) != 1:
raise ValueError('every column alias must have exactly one name')
raise ValueError('Only one name part is supported in an alias.')
name = alias.name[0]
project.expressions.append(self.convert_expression(alias.expr))
if name in symbol.input_fields:
remapped = True
mapping[symbol.input_fields.index(name)] = len(symbol.input_fields) + (
len(project.expressions)) - 1
proposed_expressions[symbol.input_fields.index(name)] = self.convert_expression(
alias.expr)
else:
mapping.append(field_number)
field_number += 1
proposed_expressions.append(self.convert_expression(alias.expr))
symbol.generated_fields.append(name)
symbol.output_fields.append(name)
project.common.CopyFrom(self.create_common_relation())
if remapped:
if self._conversion_options.duckdb_project_emit_workaround:
for field_number in range(len(symbol.input_fields)):
if field_number == mapping[field_number]:
project.expressions.append(field_reference(field_number))
mapping[field_number] = len(symbol.input_fields) + (
len(project.expressions)) - 1
for item in mapping:
project.common.emit.output_mapping.append(item)
project.expressions.extend(proposed_expressions)
for i in range(len(proposed_expressions)):
project.common.emit.output_mapping.append(len(symbol.input_fields) + i)
return algebra_pb2.Rel(project=project)

def convert_with_columns_renamed_relation(
self, rel: spark_relations_pb2.WithColumnsRenamed) -> algebra_pb2.Rel:
"""Update the columns names based on the Spark with columns renamed relation."""
input_rel = self.convert_relation(rel.input)
symbol = self._symbol_table.get_symbol(self._current_plan_id)
self.update_field_references(rel.input.common.plan_id)
symbol.output_fields.clear()
if hasattr(rel, 'renames'):
aliases = {r.col_name: r.new_col_name for r in rel.renames}
else:
aliases = rel.rename_columns_map
for field_name in symbol.input_fields:
if field_name in aliases:
symbol.output_fields.append(aliases[field_name])
else:
symbol.output_fields.append(field_name)
return input_rel

def convert_drop_relation(self, rel: spark_relations_pb2.Drop) -> algebra_pb2.Rel:
"""Convert a drop relation into a Substrait project relation."""
input_rel = self.convert_relation(rel.input)
project = algebra_pb2.ProjectRel(input=input_rel)
self.update_field_references(rel.input.common.plan_id)
symbol = self._symbol_table.get_symbol(self._current_plan_id)
if rel.columns:
column_names = [c.unresolved_attribute.unparsed_identifier for c in rel.columns]
else:
new_expressions = []
for field_number in range(len(symbol.input_fields)):
new_expressions.append(field_reference(field_number))
project.common.emit.output_mapping.append(
field_number + len(symbol.input_fields))
new_expressions.extend(list(project.expressions))
del project.expressions[:]
project.expressions.extend(new_expressions)
for field_number in range(len(symbol.generated_fields)):
project.common.emit.output_mapping.append(field_number + len(symbol.input_fields))
column_names = rel.column_names
symbol.output_fields.clear()
for field_number, field_name in enumerate(symbol.input_fields):
if field_name not in column_names:
symbol.output_fields.append(field_name)
if self._conversion_options.drop_emit_workaround:
project.common.emit.output_mapping.append(len(project.expressions))
project.expressions.append(field_reference(field_number))
else:
project.expressions.append(field_reference(field_number))
if not project.expressions:
raise ValueError(f"No columns remaining after drop in plan id {self._current_plan_id}")
return algebra_pb2.Rel(project=project)

def convert_to_df_relation(self, rel: spark_relations_pb2.ToDF) -> algebra_pb2.Rel:
Expand Down Expand Up @@ -1410,6 +1429,10 @@ def convert_relation(self, rel: spark_relations_pb2.Relation) -> algebra_pb2.Rel
result = self.convert_show_string_relation(rel.show_string)
case 'with_columns':
result = self.convert_with_columns_relation(rel.with_columns)
case 'with_columns_renamed':
result = self.convert_with_columns_renamed_relation(rel.with_columns_renamed)
case 'drop':
result = self.convert_drop_relation(rel.drop)
case 'to_df':
result = self.convert_to_df_relation(rel.to_df)
case 'local_relation':
Expand Down
Loading
Loading