Skip to content

Commit

Permalink
feat: implement basic struct handling (#91)
Browse files Browse the repository at this point in the history
  • Loading branch information
EpsilonPrime authored Oct 9, 2024
1 parent 06a505f commit 23e2fd7
Show file tree
Hide file tree
Showing 10 changed files with 368 additions and 104 deletions.
4 changes: 3 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -22,11 +22,13 @@ or Velox.
### Locally
To run the gateway locally - you need to setup a Python (Conda) environment.

To run the Spark tests you will need Java installed.

Ensure you have [Miniconda](https://docs.anaconda.com/miniconda/miniconda-install/) and [Rust/Cargo](https://doc.rust-lang.org/cargo/getting-started/installation.html) installed.

Once that is done - run these steps from a bash terminal:
```bash
git clone --recursive https://github.com/<your-fork>/spark-substrait-gateway.git
git clone https://github.com/<your-fork>/spark-substrait-gateway.git
cd spark-substrait-gateway
conda init bash
. ~/.bashrc
Expand Down
4 changes: 2 additions & 2 deletions environment.yml
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ dependencies:
- setuptools >= 61.0.0
- setuptools_scm >= 6.2.0
- mypy-protobuf
- types-protobuf >= 4.25.0, < 5.0.0
- types-protobuf >= 5.0.0
- numpy < 2.0.0
- Faker
- pip:
Expand All @@ -27,7 +27,7 @@ dependencies:
- substrait == 0.21.0
- substrait-validator
- pytest-timeout
- protobuf >= 4.25.3, < 5.0.0
- protobuf >= 5.0.0
- cryptography == 43.0.*
- click == 8.1.*
- pyjwt == 2.8.*
Expand Down
57 changes: 57 additions & 0 deletions src/backends/arrow_tools.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
# SPDX-License-Identifier: Apache-2.0
"""Routines to manipulate arrow tables."""
import pyarrow as pa


def _reapply_names_to_type(array: pa.ChunkedArray, names: list[str]) -> (pa.Array, list[str]):
new_arrays = []
new_schema = []

if array.type.num_fields > len(names):
raise ValueError('Insufficient number of names provided to reapply names.')

remaining_names = names
if pa.types.is_list(array.type):
raise NotImplementedError('Reapplying names to lists not yet supported')
if pa.types.is_map(array.type):
raise NotImplementedError('Reapplying names to maps not yet supported')
if pa.types.is_struct(array.type):
field_num = 0
while field_num < array.type.num_fields:
field = array.chunks[0].field(field_num)
this_name = remaining_names.pop(0)

new_array, remaining_names = _reapply_names_to_type(field, remaining_names)
new_arrays.append(new_array)

new_schema.append(pa.field(this_name, new_array.type))

field_num += 1

return pa.StructArray.from_arrays(new_arrays, fields=new_schema), remaining_names
if array.type.num_fields != 0:
raise ValueError(f'Unsupported complex type: {array.type}')
return array, remaining_names


def reapply_names(table: pa.Table, names: list[str]) -> pa.Table:
"""Apply the provided names to the given table recursively."""
new_arrays = []
new_schema = []

remaining_names = names
for column in iter(table.columns):
if not remaining_names:
raise ValueError('Insufficient number of names provided to reapply names.')

this_name = remaining_names.pop(0)

new_array, remaining_names = _reapply_names_to_type(column, remaining_names)
new_arrays.append(new_array)

new_schema.append(pa.field(this_name, new_array.type))

if remaining_names:
raise ValueError('Too many names provided to reapply names.')

return pa.Table.from_arrays(new_arrays, schema=pa.schema(new_schema))
4 changes: 3 additions & 1 deletion src/backends/duckdb_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from substrait.gen.proto import plan_pb2

from backends.backend import Backend
from src.backends.arrow_tools import reapply_names
from transforms.rename_functions import RenameFunctionsForDuckDB


Expand Down Expand Up @@ -73,7 +74,8 @@ def _execute_plan(self, plan: plan_pb2.Plan) -> pa.lib.Table:
query_result = self._connection.from_substrait(proto=plan_data)
except Exception as err:
raise ValueError(f"DuckDB Execution Error: {err}") from err
return query_result.arrow()
arrow = query_result.arrow()
return reapply_names(arrow, plan.relations[0].root.names)

def register_table(
self,
Expand Down
80 changes: 80 additions & 0 deletions src/backends/tests/arrow_tools_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
# SPDX-License-Identifier: Apache-2.0
from dataclasses import dataclass

import pyarrow as pa
import pytest

from src.backends.arrow_tools import reapply_names


@dataclass
class TestCase:
name: str
input: pa.Table
names: list[str]
expected: pa.table
fail: bool = False


cases: list[TestCase] = [
TestCase('empty table', pa.Table.from_arrays([]), [], pa.Table.from_arrays([])),
TestCase('too many names', pa.Table.from_arrays([]), ['a', 'b'], pa.Table.from_arrays([]),
fail=True),
TestCase('normal columns',
pa.Table.from_pydict(
{"name": [None, "Joe", "Sarah", None], "age": [99, None, 42, None]},
schema=pa.schema({"name": pa.string(), "age": pa.int32()})
),
['renamed_name', 'renamed_age'],
pa.Table.from_pydict(
{"renamed_name": [None, "Joe", "Sarah", None],
"renamed_age": [99, None, 42, None]},
schema=pa.schema({"renamed_name": pa.string(), "renamed_age": pa.int32()})
)),
TestCase('too few names',
pa.Table.from_pydict(
{"name": [None, "Joe", "Sarah", None], "age": [99, None, 42, None]},
schema=pa.schema({"name": pa.string(), "age": pa.int32()})
),
['renamed_name'],
pa.Table.from_pydict(
{"renamed_name": [None, "Joe", "Sarah", None],
"renamed_age": [99, None, 42, None]},
schema=pa.schema({"renamed_name": pa.string(), "renamed_age": pa.int32()})
),
fail=True),
TestCase('struct column',
pa.Table.from_arrays(
[pa.array([{"": 1, "b": "b"}],
type=pa.struct([("", pa.int64()), ("b", pa.string())]))],
names=["r"]),
['r', 'a', 'b'],
pa.Table.from_arrays(
[pa.array([{"a": 1, "b": "b"}],
type=pa.struct([("a", pa.int64()), ("b", pa.string())]))], names=["r"])
),
# TODO -- Test nested structs.
# TODO -- Test a list.
# TODO -- Test a map.
# TODO -- Test a mixture of complex and simple types.
]


class TestArrowTools:
"""Tests the functionality of the arrow tools package."""

@pytest.mark.parametrize(
"case", cases, ids=lambda case: case.name
)
def test_reapply_names(self, case):
failed = False
try:
result = reapply_names(case.input, case.names)
except ValueError as _:
result = None
failed = True
if case.fail:
assert failed
else:
assert result == case.expected

2 changes: 1 addition & 1 deletion src/gateway/converter/conversion_options.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
class ConversionOptions:
"""Holds all the possible conversion options."""

def __init__(self, backend: BackendOptions = None):
def __init__(self, backend: BackendOptions):
"""Initialize the conversion options."""
self.use_named_table_workaround = False
self.needs_scheme_in_path_uris = False
Expand Down
Loading

0 comments on commit 23e2fd7

Please sign in to comment.