Skip to content

Commit

Permalink
chore: merge recent changes (#12)
Browse files Browse the repository at this point in the history
  • Loading branch information
EpsilonPrime authored May 20, 2024
2 parents dc88946 + d8206e7 commit f12a099
Show file tree
Hide file tree
Showing 23 changed files with 1,509 additions and 625 deletions.
2 changes: 1 addition & 1 deletion environment.yml
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ dependencies:
- pip:
- adbc_driver_manager
- cargo
- pyarrow >= 13.0.0
- pyarrow >= 16.0.0
- duckdb == 0.10.1
- datafusion >= 36.0.0
- pyspark
Expand Down
7 changes: 3 additions & 4 deletions src/gateway/backends/adbc_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,7 @@
from substrait.gen.proto import plan_pb2

from gateway.backends.backend import Backend
from gateway.backends.backend_options import Backend as backend_engine
from gateway.backends.backend_options import BackendOptions
from gateway.backends.backend_options import BackendEngine, BackendOptions


def _import(handle):
Expand All @@ -20,7 +19,7 @@ def _import(handle):
def _get_backend_driver(options: BackendOptions) -> tuple[str, str]:
"""Get the driver and entry point for the specified backend."""
match options.backend:
case backend_engine.DUCKDB:
case BackendEngine.DUCKDB:
driver = duckdb.duckdb.__file__
entry_point = "duckdb_adbc_init"
case _:
Expand Down Expand Up @@ -62,7 +61,7 @@ def register_table(self, name: str, path: Path, file_format: str = 'parquet') ->
file_paths = sorted([str(fp) for fp in file_paths])
# TODO: Support multiple paths.
reader = pq.ParquetFile(file_paths[0])
self._connection.cursor().adbc_ingest(name, reader.iter_batches(), mode="create")
self._connection.cursor().adbc_ingest(name, reader.iter_batches(), mode='create')

def describe_table(self, table_name: str):
"""Asks the backend to describe the given table."""
Expand Down
8 changes: 8 additions & 0 deletions src/gateway/backends/backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,10 @@ def register_table(self, name: str, path: Path, file_format: str = 'parquet') ->
"""Register the given table with the backend."""
raise NotImplementedError()

def describe_files(self, paths: list[str]):
"""Asks the backend to describe the given files."""
raise NotImplementedError()

def describe_table(self, name: str):
"""Asks the backend to describe the given table."""
raise NotImplementedError()
Expand All @@ -43,6 +47,10 @@ def drop_table(self, name: str) -> None:
"""Asks the backend to drop the given table."""
raise NotImplementedError()

def convert_sql(self, sql: str) -> plan_pb2.Plan:
"""Convert SQL into a Substrait plan."""
raise NotImplementedError()

@staticmethod
def expand_location(location: Path | str) -> list[str]:
"""Expand the location of a file or directory into a list of files."""
Expand Down
11 changes: 8 additions & 3 deletions src/gateway/backends/backend_options.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,24 +4,29 @@
from enum import Enum


class Backend(Enum):
class BackendEngine(Enum):
"""Represents the different backends we have support for."""

ARROW = 1
DATAFUSION = 2
DUCKDB = 3

def __str__(self):
"""Return the string representation of the backend."""
return self.name.lower()


@dataclasses.dataclass
class BackendOptions:
"""Holds all the possible backend options."""

backend: Backend
backend: BackendEngine
use_adbc: bool

def __init__(self, backend: Backend, use_adbc: bool = False):
def __init__(self, backend: BackendEngine, use_adbc: bool = False):
"""Create a BackendOptions structure."""
self.backend = backend
self.use_adbc = use_adbc

self.use_arrow_uri_workaround = False
self.use_duckdb_python_api = False
8 changes: 4 additions & 4 deletions src/gateway/backends/backend_selector.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,19 +3,19 @@
from gateway.backends import backend
from gateway.backends.adbc_backend import AdbcBackend
from gateway.backends.arrow_backend import ArrowBackend
from gateway.backends.backend_options import Backend, BackendOptions
from gateway.backends.backend_options import BackendEngine, BackendOptions
from gateway.backends.datafusion_backend import DatafusionBackend
from gateway.backends.duckdb_backend import DuckDBBackend


def find_backend(options: BackendOptions) -> backend.Backend:
"""Given a backend enum, returns an instance of the correct Backend descendant."""
match options.backend:
case Backend.ARROW:
case BackendEngine.ARROW:
return ArrowBackend(options)
case Backend.DATAFUSION:
case BackendEngine.DATAFUSION:
return DatafusionBackend(options)
case Backend.DUCKDB:
case BackendEngine.DUCKDB:
if options.use_adbc:
return AdbcBackend(options)
return DuckDBBackend(options)
Expand Down
51 changes: 44 additions & 7 deletions src/gateway/backends/datafusion_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,19 @@
from gateway.converter.rename_functions import RenameFunctionsForDatafusion
from gateway.converter.replace_local_files import ReplaceLocalFilesWithNamedTable

_DATAFUSION_TO_ARROW = {
'Boolean': pa.bool_(),
'Int8': pa.int8(),
'Int16': pa.int16(),
'Int32': pa.int32(),
'Int64': pa.int64(),
'Float32': pa.float32(),
'Float64': pa.float64(),
'Date32': pa.date32(),
'Timestamp(Nanosecond, None)': pa.timestamp('ns'),
'Utf8': pa.string(),
}


# pylint: disable=import-outside-toplevel
class DatafusionBackend(Backend):
Expand All @@ -23,6 +36,7 @@ def __init__(self, options):
def create_connection(self) -> None:
"""Create a connection to the backend."""
import datafusion

self._connection = datafusion.SessionContext()

def execute(self, plan: plan_pb2.Plan) -> pa.lib.Table:
Expand All @@ -33,10 +47,9 @@ def execute(self, plan: plan_pb2.Plan) -> pa.lib.Table:
registered_tables = set()
for files in file_groups:
table_name = files[0]
for file in files[1]:
if table_name not in registered_tables:
self.register_table(table_name, file)
registered_tables.add(files[0])
location = Path(files[1][0]).parent
self.register_table(table_name, location)
registered_tables.add(table_name)

RenameFunctionsForDatafusion().visit_plan(plan)

Expand All @@ -59,7 +72,31 @@ def execute(self, plan: plan_pb2.Plan) -> pa.lib.Table:
for table_name in registered_tables:
self._connection.deregister_table(table_name)

def register_table(self, name: str, path: Path, file_format: str = 'parquet') -> None:
def register_table(
self, name: str, location: Path, file_format: str = 'parquet'
) -> None:
"""Register the given table with the backend."""
files = Backend.expand_location(path)
self._connection.register_parquet(name, files[0])
files = Backend.expand_location(location)
if not files:
raise ValueError(f"No parquet files found at {location}")
# TODO: Add options to skip table registration if it already exists instead
# of deregistering it.
if self._connection.table_exist(name):
self._connection.deregister_table(name)
self._connection.register_parquet(name, str(location))

def describe_files(self, paths: list[str]):
"""Asks the backend to describe the given files."""
# TODO -- Use the ListingTable API to resolve the combined schema.
df = self._connection.read_parquet(paths[0])
return df.schema()

def describe_table(self, table_name: str):
"""Asks the backend to describe the given table."""
result = self._connection.sql(f"describe {table_name}").to_arrow_table().to_pylist()

fields = []
for index in range(len(result)):
fields.append(pa.field(result[index]['column_name'],
_DATAFUSION_TO_ARROW[result[index]['data_type']]))
return pa.schema(fields)
61 changes: 57 additions & 4 deletions src/gateway/backends/duckdb_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,19 @@

from gateway.backends.backend import Backend

_DUCKDB_TO_ARROW = {
'BOOLEAN': pa.bool_(),
'TINYINT': pa.int8(),
'SMALLINT': pa.int16(),
'INTEGER': pa.int32(),
'BIGINT': pa.int64(),
'FLOAT': pa.float32(),
'DOUBLE': pa.float64(),
'DATE': pa.date32(),
'TIMESTAMP': pa.timestamp('ns'),
'VARCHAR': pa.string(),
}


# pylint: disable=fixme
class DuckDBBackend(Backend):
Expand All @@ -18,6 +31,7 @@ def __init__(self, options):
self._connection = None
super().__init__(options)
self.create_connection()
self._use_duckdb_python_api = options.use_duckdb_python_api

def create_connection(self):
"""Create a connection to the backend."""
Expand All @@ -44,12 +58,51 @@ def execute(self, plan: plan_pb2.Plan) -> pa.lib.Table:
df = query_result.df()
return pa.Table.from_pandas(df=df)

def register_table(self, table_name: str, location: Path, file_format: str = 'parquet') -> None:
def register_table(
self,
table_name: str,
location: Path,
file_format: str = "parquet"
) -> None:
"""Register the given table with the backend."""
files = Backend.expand_location(location)
if not files:
raise ValueError(f"No parquet files found at {location}")
files_str = ', '.join([f"'{f}'" for f in files])
files_sql = f"CREATE OR REPLACE TABLE {table_name} AS FROM read_parquet([{files_str}])"

self._connection.execute(files_sql)
if self._use_duckdb_python_api:
self._connection.register(table_name, self._connection.read_parquet(files))
else:
files_str = ', '.join([f"'{f}'" for f in files])
files_sql = f"CREATE OR REPLACE TABLE {table_name} AS FROM read_parquet([{files_str}])"
self._connection.execute(files_sql)

def describe_files(self, paths: list[str]):
"""Asks the backend to describe the given files."""
files = paths
if len(paths) == 1:
files = self.expand_location(paths[0])
df = self._connection.read_parquet(files)

fields = []
for name, field_type in zip(df.columns, df.types, strict=False):
if name == 'aggr':
# This isn't a real column.
continue
fields.append(pa.field(name, _DUCKDB_TO_ARROW[str(field_type)]))
return pa.schema(fields)

def describe_table(self, name: str):
"""Asks the backend to describe the given table."""
df = self._connection.execute(f'DESCRIBE {name}').fetchdf()

fields = []
for name, field_type in zip(df.column_name, df.column_type, strict=False):
fields.append(pa.field(name, _DUCKDB_TO_ARROW[str(field_type)]))
return pa.schema(fields)

def convert_sql(self, sql: str) -> plan_pb2.Plan:
"""Convert SQL into a Substrait plan."""
plan = plan_pb2.Plan()
proto_bytes = self._connection.get_substrait(query=sql).fetchone()[0]
plan.ParseFromString(proto_bytes)
return plan
30 changes: 30 additions & 0 deletions src/gateway/converter/add_extension_uris.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
# SPDX-License-Identifier: Apache-2.0
"""A library to search Substrait plan for local files."""
from gateway.converter.substrait_plan_visitor import SubstraitPlanVisitor
from substrait.gen.proto import plan_pb2
from substrait.gen.proto.extensions import extensions_pb2


# pylint: disable=E1101,no-member
class AddExtensionUris(SubstraitPlanVisitor):
"""Ensures that the plan has extension URI definitions for all references."""

def visit_plan(self, plan: plan_pb2.Plan) -> None:
"""Modify the provided plan so that all functions have URI references."""
super().visit_plan(plan)

known_uris: list[int] = []
for uri in plan.extension_uris:
known_uris.append(uri.extension_uri_anchor)

for extension in plan.extensions:
if extension.WhichOneof('mapping_type') != 'extension_function':
continue

if extension.extension_function.extension_uri_reference not in known_uris:
# TODO -- Make sure this hack occurs at most once.
uri = extensions_pb2.SimpleExtensionURI(
uri='urn:arrow:substrait_simple_extension_function',
extension_uri_anchor=extension.extension_function.extension_uri_reference)
plan.extension_uris.append(uri)
known_uris.append(extension.extension_function.extension_uri_reference)
14 changes: 9 additions & 5 deletions src/gateway/converter/conversion_options.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
"""Tracks conversion related options."""
import dataclasses

from gateway.backends.backend_options import Backend, BackendOptions
from gateway.backends.backend_options import BackendEngine, BackendOptions


# pylint: disable=too-many-instance-attributes
Expand All @@ -17,6 +17,8 @@ def __init__(self, backend: BackendOptions = None):
self.use_emits_instead_of_direct = False
self.use_switch_expressions_where_possible = True
self.use_duckdb_regexp_matches_function = False
self.duckdb_project_emit_workaround = False
self.safety_project_read_relations = False

self.return_names_with_types = False

Expand All @@ -27,23 +29,25 @@ def __init__(self, backend: BackendOptions = None):

def arrow():
"""Return standard options to connect to the Acero backend."""
options = ConversionOptions(backend=BackendOptions(Backend.ARROW))
options = ConversionOptions(backend=BackendOptions(BackendEngine.ARROW))
options.needs_scheme_in_path_uris = True
options.return_names_with_types = True
options.implement_show_string = False
options.backend.use_arrow_uri_workaround = True
options.safety_project_read_relations = True
return options


def datafusion():
"""Return standard options to connect to a Datafusion backend."""
return ConversionOptions(backend=BackendOptions(Backend.DATAFUSION))
return ConversionOptions(backend=BackendOptions(BackendEngine.DATAFUSION))


def duck_db():
"""Return standard options to connect to a DuckDB backend."""
options = ConversionOptions(backend=BackendOptions(Backend.DUCKDB))
options = ConversionOptions(backend=BackendOptions(BackendEngine.DUCKDB))
options.return_names_with_types = True
options.use_switch_expressions_where_possible = False
options.use_duckdb_regexp_matches_function = True
options.duckdb_project_emit_workaround = True
options.backend.use_duckdb_python_api = False
return options
Loading

0 comments on commit f12a099

Please sign in to comment.