Skip to content

Commit

Permalink
Ruff format
Browse files Browse the repository at this point in the history
  • Loading branch information
prmoore77 committed Aug 22, 2024
1 parent 4c761c0 commit df03fa4
Show file tree
Hide file tree
Showing 2 changed files with 41 additions and 48 deletions.
36 changes: 15 additions & 21 deletions src/gateway/demo/client_demo.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,11 +12,12 @@
from pyspark.sql.functions import col

# Setup logging
logging.basicConfig(format='%(asctime)s - %(levelname)-8s %(message)s',
datefmt='%Y-%m-%d %H:%M:%S %Z',
level=getattr(logging, os.getenv("LOG_LEVEL", "INFO")),
stream=sys.stdout
)
logging.basicConfig(
format="%(asctime)s - %(levelname)-8s %(message)s",
datefmt="%Y-%m-%d %H:%M:%S %Z",
level=getattr(logging, os.getenv("LOG_LEVEL", "INFO")),
stream=sys.stdout,
)
_LOGGER = logging.getLogger()

# Constants
Expand All @@ -33,9 +34,7 @@ def find_tpch(raise_error_if_not_exists: bool) -> Path:


# pylint: disable=fixme
def get_customer_database(spark_session: SparkSession,
use_gateway: bool
) -> DataFrame:
def get_customer_database(spark_session: SparkSession, use_gateway: bool) -> DataFrame:
"""Register the TPC-H customer database."""
location_customer = str(find_tpch(raise_error_if_not_exists=(not use_gateway)) / "customer")

Expand All @@ -44,12 +43,9 @@ def get_customer_database(spark_session: SparkSession,

# pylint: disable=fixme
# ruff: noqa: T201
def execute_query(spark_session: SparkSession,
use_gateway: bool) -> None:
def execute_query(spark_session: SparkSession, use_gateway: bool) -> None:
"""Run a single sample query against the gateway."""
df_customer = get_customer_database(spark_session=spark_session,
use_gateway=use_gateway
)
df_customer = get_customer_database(spark_session=spark_session, use_gateway=use_gateway)

df_customer.createOrReplaceTempView("customer")

Expand All @@ -65,11 +61,11 @@ def execute_query(spark_session: SparkSession,


def run_demo(
use_gateway: bool = True,
host: str = "localhost",
port: int = SERVER_PORT,
use_tls: bool = False,
token: str | None = None,
use_gateway: bool = True,
host: str = "localhost",
port: int = SERVER_PORT,
use_tls: bool = False,
token: str | None = None,
):
"""Run a small Spark Substrait Gateway client demo."""
logging.basicConfig(level=logging.INFO, encoding="utf-8")
Expand All @@ -91,9 +87,7 @@ def run_demo(
spark = SparkSession.builder.remote(f"sc://{host}:{port}/{uri_parameters}").getOrCreate()
else:
spark = SparkSession.builder.master("local").getOrCreate()
execute_query(spark_session=spark,
use_gateway=use_gateway
)
execute_query(spark_session=spark, use_gateway=use_gateway)


@click.command()
Expand Down
53 changes: 26 additions & 27 deletions src/gateway/demo/generate_tpch_parquet_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,35 +10,35 @@
from .client_demo import CLIENT_DEMO_DATA_LOCATION

# Setup logging
logging.basicConfig(format='%(asctime)s - %(levelname)-8s %(message)s',
datefmt='%Y-%m-%d %H:%M:%S %Z',
level=getattr(logging, os.getenv("LOG_LEVEL", "INFO")),
stream=sys.stdout
)
logging.basicConfig(
format="%(asctime)s - %(levelname)-8s %(message)s",
datefmt="%Y-%m-%d %H:%M:%S %Z",
level=getattr(logging, os.getenv("LOG_LEVEL", "INFO")),
stream=sys.stdout,
)
_LOGGER = logging.getLogger()


def execute_query(conn: duckdb.DuckDBPyConnection,
query: str
):
def execute_query(conn: duckdb.DuckDBPyConnection, query: str):
_LOGGER.info(msg=f"Executing SQL: '{query}'")
conn.execute(query=query)


def get_printable_number(num: float):
return '{:.9g}'.format(num)
return "{:.9g}".format(num)


def generate_tpch_parquet_data(tpch_scale_factor: int,
data_directory: str,
overwrite: bool
) -> Path:
_LOGGER.info(msg=("Creating a TPC-H parquet dataset - with parameters: "
f"--tpch-scale-factor={tpch_scale_factor} "
f"--data-directory='{data_directory}' "
f"--overwrite={overwrite}"
)
)
def generate_tpch_parquet_data(
tpch_scale_factor: int, data_directory: str, overwrite: bool
) -> Path:
_LOGGER.info(
msg=(
"Creating a TPC-H parquet dataset - with parameters: "
f"--tpch-scale-factor={tpch_scale_factor} "
f"--data-directory='{data_directory}' "
f"--overwrite={overwrite}"
)
)

# Output the database version
_LOGGER.info(msg=f"Using DuckDB Version: {duckdb.__version__}")
Expand All @@ -63,7 +63,9 @@ def generate_tpch_parquet_data(tpch_scale_factor: int,
raise RuntimeError(f"Directory: {target_directory.as_posix()} exists, aborting.")

target_directory.mkdir(parents=True, exist_ok=True)
execute_query(conn=conn, query=f"EXPORT DATABASE '{target_directory.as_posix()}' (FORMAT PARQUET)")
execute_query(
conn=conn, query=f"EXPORT DATABASE '{target_directory.as_posix()}' (FORMAT PARQUET)"
)

_LOGGER.info(msg=f"Wrote out parquet data to path: '{target_directory.as_posix()}'")

Expand Down Expand Up @@ -93,28 +95,25 @@ def generate_tpch_parquet_data(tpch_scale_factor: int,
default=1,
show_default=True,
required=True,
help="The TPC-H scale factor to generate."
help="The TPC-H scale factor to generate.",
)
@click.option(
"--data-directory",
type=str,
default=CLIENT_DEMO_DATA_LOCATION.as_posix(),
show_default=True,
required=True,
help="The target output data directory to put the files into"
help="The target output data directory to put the files into",
)
@click.option(
"--overwrite/--no-overwrite",
type=bool,
default=False,
show_default=True,
required=True,
help="Can we overwrite the target directory if it already exists..."
help="Can we overwrite the target directory if it already exists...",
)
def click_generate_tpch_parquet_data(tpch_scale_factor: int,
data_directory: str,
overwrite: bool
):
def click_generate_tpch_parquet_data(tpch_scale_factor: int, data_directory: str, overwrite: bool):
generate_tpch_parquet_data(**locals())


Expand Down

0 comments on commit df03fa4

Please sign in to comment.