Skip to content

Commit

Permalink
feat: fix the remaining datafusion test failures (#5)
Browse files Browse the repository at this point in the history
While this fixes the datafusion tests locally, there is a packaging issue that causes CI to fail.
The tests have been left in an xfail state so the change can be available for development while
the packaging issue is researched.
  • Loading branch information
EpsilonPrime authored Apr 5, 2024
1 parent 961ce89 commit a490133
Show file tree
Hide file tree
Showing 10 changed files with 81 additions and 52 deletions.
21 changes: 11 additions & 10 deletions .github/workflows/test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -15,24 +15,25 @@ jobs:
strategy:
matrix:
os: [macos-latest, ubuntu-latest]
python: ["3.12"]
python: ["3.10"]
runs-on: ${{ matrix.os }}
steps:
- name: Checkout code
uses: actions/checkout@v4
with:
submodules: recursive
- name: Set up Python
uses: actions/setup-python@v5
- name: Install packages and test dependencies
uses: conda-incubator/setup-miniconda@v3
with:
activate-environment: spark-substrait-gateway-env
environment-file: environment.yml
python-version: ${{ matrix.python }}
- name: Install package and test dependencies
auto-activate-base: false
- name: Build
shell: bash -el {0}
run: |
export PIP_INTEROP_ENABLED=compatible
export PIP_SCRIPT=scripts/pip_verbose.sh
$CONDA/bin/conda env update --file environment.yml --name base
python -m pip install --upgrade pip
python -m pip install ".[test]"
pip install -e .
- name: Run tests
shell: bash -el {0}
run: |
$CONDA/bin/pytest
pytest
9 changes: 4 additions & 5 deletions environment.yml
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,11 @@ channels:
dependencies:
- pip
- pre-commit
- protobuf >= 4.21.6
- protobuf >= 4.21.6, < 5.0.0
- grpcio >= 1.48.1
- grpcio-status
- grpcio-tools
- pytest >= 7.0.0
- python >= 3.8.1
- pytest >= 8.0.0
- setuptools >= 61.0.0
- setuptools_scm >= 6.2.0
- python-substrait >= 0.14.1
Expand All @@ -19,9 +18,9 @@ dependencies:
- pip:
- adbc_driver_manager
- cargo
- pyarrow
- pyarrow >= 13.0.0
- duckdb == 0.10.1
- datafusion
- datafusion >= 36.0.0
- pyspark
- pandas >= 1.0.5
- pyhamcrest
Expand Down
4 changes: 2 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ authors = [{name = "Substrait contributors", email = "[email protected]
license = {text = "Apache-2.0"}
readme = "README.md"
requires-python = ">=3.8.1"
dependencies = ["protobuf >= 3.20", "datafusion"]
dependencies = ["protobuf >= 3.20", "datafusion >= 36.0.0", "pyarrow >= 15.0.2"]
dynamic = ["version"]

[tool.setuptools_scm]
Expand All @@ -29,12 +29,12 @@ respect-gitignore = true
target-version = "py310"
# never autoformat upstream or generated code
exclude = ["third_party/", "src/spark/connect"]
# do not autofix the following (will still get flagged in lint)

[lint]
unfixable = [
"F401", # unused imports
"T201", # print statements
"E712", # truth comparison checks
]

[tool.pylint.MASTER]
Expand Down
3 changes: 0 additions & 3 deletions scripts/pip_verbose.sh

This file was deleted.

11 changes: 9 additions & 2 deletions src/gateway/adbc/backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from substrait.gen.proto import plan_pb2

from gateway.adbc.backend_options import BackendOptions, Backend
from gateway.converter.rename_functions import RenameFunctions
from gateway.converter.replace_local_files import ReplaceLocalFilesWithNamedTable


Expand Down Expand Up @@ -38,7 +39,6 @@ def execute_with_duckdb_over_adbc(self, plan: 'plan_pb2.Plan') -> pyarrow.lib.Ta
# pylint: disable=import-outside-toplevel
def execute_with_datafusion(self, plan: 'plan_pb2.Plan') -> pyarrow.lib.Table:
"""Executes the given Substrait plan against Datafusion."""
import datafusion
import datafusion.substrait

ctx = datafusion.SessionContext()
Expand All @@ -52,15 +52,22 @@ def execute_with_datafusion(self, plan: 'plan_pb2.Plan') -> pyarrow.lib.Table:
ctx.register_parquet(table_name, file)
registered_tables.add(files[0])

RenameFunctions().visit_plan(plan)

try:
plan_data = plan.SerializeToString()
substrait_plan = datafusion.substrait.substrait.serde.deserialize_bytes(plan_data)
logical_plan = datafusion.substrait.substrait.consumer.from_substrait_plan(
ctx, substrait_plan
)

# Create a DataFrame from a deserialized logical plan
# Create a DataFrame from a deserialized logical plan.
df_result = ctx.create_dataframe_from_logical_plan(logical_plan)
for column_number, column_name in enumerate(df_result.schema().names):
df_result = df_result.with_column_renamed(
column_name,
plan.relations[0].root.names[column_number]
)
return df_result.to_arrow_table()
finally:
for table_name in registered_tables:
Expand Down
22 changes: 22 additions & 0 deletions src/gateway/converter/rename_functions.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
# SPDX-License-Identifier: Apache-2.0
"""A library to search Substrait plan for local files."""
from substrait.gen.proto import plan_pb2

from gateway.converter.substrait_plan_visitor import SubstraitPlanVisitor


# pylint: disable=no-member,fixme
class RenameFunctions(SubstraitPlanVisitor):
"""Renames Substrait functions to match what Datafusion expects."""

def visit_plan(self, plan: plan_pb2.Plan) -> None:
"""Modifies the provided plan so that functions are Datafusion compatible."""
super().visit_plan(plan)

for extension in plan.extensions:
if extension.WhichOneof('mapping_type') != 'extension_function':
continue

# TODO -- Take the URI references into account.
if extension.extension_function.name == 'substring':
extension.extension_function.name = 'substr'
6 changes: 2 additions & 4 deletions src/gateway/converter/spark_to_substrait.py
Original file line number Diff line number Diff line change
Expand Up @@ -507,8 +507,6 @@ def convert_show_string_relation(self, rel: spark_relations_pb2.ShowString) -> a
concat_func = self.lookup_function_by_name('concat')
repeat_func = self.lookup_function_by_name('repeat')
lpad_func = self.lookup_function_by_name('lpad')
least_func = self.lookup_function_by_name('least')
greatest_func = self.lookup_function_by_name('greatest')
greater_func = self.lookup_function_by_name('>')
minus_func = self.lookup_function_by_name('-')

Expand Down Expand Up @@ -537,8 +535,8 @@ def convert_show_string_relation(self, rel: spark_relations_pb2.ShowString) -> a
# Find the maximum we will use based on the truncate, max size, and column name length.
project2 = project_relation(
aggregate1,
[greatest_function(greatest_func,
least_function(least_func, field_reference(column_number),
[greatest_function(greater_func,
least_function(greater_func, field_reference(column_number),
bigint_literal(rel.truncate)),
strlen(strlen_func,
string_literal(symbol.input_fields[column_number]))) for
Expand Down
46 changes: 23 additions & 23 deletions src/gateway/converter/substrait_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,28 +124,24 @@ def string_concat_agg_function(function_info: ExtensionFunction,
algebra_pb2.FunctionArgument(value=string_literal(separator))])


def least_function(function_info: ExtensionFunction,
expr1: algebra_pb2.Expression,
def least_function(greater_function_info: ExtensionFunction, expr1: algebra_pb2.Expression,
expr2: algebra_pb2.Expression) -> algebra_pb2.Expression:
"""Constructs a Substrait min expression."""
return algebra_pb2.Expression(scalar_function=
algebra_pb2.Expression.ScalarFunction(
function_reference=function_info.anchor,
output_type=function_info.output_type,
arguments=[algebra_pb2.FunctionArgument(value=expr1),
algebra_pb2.FunctionArgument(value=expr2)]))
return if_then_else_operation(
greater_function(greater_function_info, expr1, expr2),
expr2,
expr1
)


def greatest_function(function_info: ExtensionFunction,
expr1: algebra_pb2.Expression,
def greatest_function(greater_function_info: ExtensionFunction, expr1: algebra_pb2.Expression,
expr2: algebra_pb2.Expression) -> algebra_pb2.Expression:
"""Constructs a Substrait min expression."""
return algebra_pb2.Expression(scalar_function=
algebra_pb2.Expression.ScalarFunction(
function_reference=function_info.anchor,
output_type=function_info.output_type,
arguments=[algebra_pb2.FunctionArgument(value=expr1),
algebra_pb2.FunctionArgument(value=expr2)]))
"""Constructs a Substrait max expression."""
return if_then_else_operation(
greater_function(greater_function_info, expr1, expr2),
expr1,
expr2
)


def greater_or_equal_function(function_info: ExtensionFunction,
Expand Down Expand Up @@ -200,30 +196,34 @@ def lpad_function(function_info: ExtensionFunction,
expression: algebra_pb2.Expression, count: algebra_pb2.Expression,
pad_string: str = ' ') -> algebra_pb2.AggregateFunction:
"""Constructs a Substrait concat expression."""
# TODO -- Avoid a cast if we don't need it.
cast_type = string_type()
return algebra_pb2.Expression(scalar_function=
algebra_pb2.Expression.ScalarFunction(
function_reference=function_info.anchor,
output_type=function_info.output_type,
arguments=[
algebra_pb2.FunctionArgument(value=cast_operation(expression, varchar_type())),
algebra_pb2.FunctionArgument(value=cast_operation(expression, cast_type)),
algebra_pb2.FunctionArgument(value=cast_operation(count, integer_type())),
algebra_pb2.FunctionArgument(
value=cast_operation(string_literal(pad_string), varchar_type()))]))
value=cast_operation(string_literal(pad_string), cast_type))]))


def rpad_function(function_info: ExtensionFunction,
expression: algebra_pb2.Expression, count: algebra_pb2.Expression,
pad_string: str = ' ') -> algebra_pb2.AggregateFunction:
"""Constructs a Substrait concat expression."""
# TODO -- Avoid a cast if we don't need it.
cast_type = string_type()
return algebra_pb2.Expression(scalar_function=
algebra_pb2.Expression.ScalarFunction(
function_reference=function_info.anchor,
output_type=function_info.output_type,
arguments=[
algebra_pb2.FunctionArgument(value=cast_operation(expression, varchar_type())),
algebra_pb2.FunctionArgument(value=cast_operation(expression, cast_type)),
algebra_pb2.FunctionArgument(value=cast_operation(count, integer_type())),
algebra_pb2.FunctionArgument(
value=cast_operation(string_literal(pad_string), varchar_type()))]))
value=cast_operation(string_literal(pad_string), cast_type))]))


def string_literal(val: str) -> algebra_pb2.Expression:
Expand All @@ -245,13 +245,13 @@ def string_type(required: bool = True) -> type_pb2.Type:
return type_pb2.Type(string=type_pb2.Type.String(nullability=nullability))


def varchar_type(required: bool = True) -> type_pb2.Type:
def varchar_type(length: int = 1000, required: bool = True) -> type_pb2.Type:
"""Constructs a Substrait varchar type."""
if required:
nullability = type_pb2.Type.Nullability.NULLABILITY_REQUIRED
else:
nullability = type_pb2.Type.Nullability.NULLABILITY_NULLABLE
return type_pb2.Type(varchar=type_pb2.Type.VarChar(nullability=nullability))
return type_pb2.Type(varchar=type_pb2.Type.VarChar(length=length, nullability=nullability))


def integer_type(required: bool = True) -> type_pb2.Type:
Expand Down
7 changes: 5 additions & 2 deletions src/gateway/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,13 +50,16 @@ def convert_pyarrow_schema_to_spark(schema: pyarrow.Schema) -> types_pb2.DataTyp
for field in schema:
if field.type == pyarrow.bool_():
data_type = types_pb2.DataType(boolean=types_pb2.DataType.Boolean())
elif field.type == pyarrow.int32():
data_type = types_pb2.DataType(integer=types_pb2.DataType.Integer())
elif field.type == pyarrow.int64():
data_type = types_pb2.DataType(long=types_pb2.DataType.Long())
elif field.type == pyarrow.string():
data_type = types_pb2.DataType(string=types_pb2.DataType.String())
else:
raise ValueError(
f'Unsupported arrow schema to Spark schema conversion type: {field.type}')
raise NotImplementedError(
'Conversion from Arrow schema to Spark schema not yet implemented '
f'for type: {field.type}')

struct_field = types_pb2.DataType.StructField(name=field.name, data_type=data_type)
fields.append(struct_field)
Expand Down
4 changes: 3 additions & 1 deletion src/gateway/tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,9 @@ def schema_users():
@pytest.fixture(scope='module',
params=['spark',
pytest.param('gateway-over-duckdb', marks=pytest.mark.xfail),
pytest.param('gateway-over-datafusion', marks=pytest.mark.xfail)])
pytest.param('gateway-over-datafusion',
marks=pytest.mark.xfail(
reason='Datafusion Substrait missing in CI'))])
def spark_session(request):
"""Provides spark sessions connecting to various backends."""
match request.param:
Expand Down

0 comments on commit a490133

Please sign in to comment.