Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: add the latest tag to the docker image #79

Merged
merged 6 commits into from
Aug 22, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions .dockerignore
Original file line number Diff line number Diff line change
Expand Up @@ -2,3 +2,5 @@
helm-chart
tls
.github
Dockerfile
.gitignore
3 changes: 3 additions & 0 deletions .github/workflows/build_docker.yml
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,9 @@ jobs:
uses: docker/metadata-action@v5
with:
images: ${{ env.IMAGE_NAME }}
tags: |
# set latest tag for default branch
type=raw,value=latest,enable={{is_default_branch}}

- name: Set up QEMU
uses: docker/setup-qemu-action@v3
Expand Down
4 changes: 3 additions & 1 deletion Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -47,4 +47,6 @@ RUN pip install .
# Expose the gRPC port
EXPOSE 50051

ENTRYPOINT ["conda", "run", "--no-capture-output", "-n", "spark-substrait-gateway-env", "python", "src/gateway/server.py"]
ENV GENERATE_CLIENT_DEMO_DATA="true"

ENTRYPOINT ["conda", "run", "--no-capture-output", "-n", "spark-substrait-gateway-env", "scripts/start_demo_server.sh"]
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -83,3 +83,4 @@ spark-substrait-gateway-server = "gateway.server:click_serve"
spark-substrait-client-demo = "gateway.demo.client_demo:click_run_demo"
spark-substrait-create-tls-keypair = "gateway.setup.tls_utilities:click_create_tls_keypair"
spark-substrait-create-jwt = "gateway.utilities.create_jwt:main"
spark-substrait-create-client-demo-data = "gateway.demo.generate_tpch_parquet_data:click_generate_tpch_parquet_data"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we also want to include the TPC-DS data or would we rather keep the footprint of the image small?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We could potentially add it - I just targeted to TPC-H b/c that is what the client-demo file needed. Would doing it in a later PR be ok, if needed?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yep, no need to overbuild this.

14 changes: 14 additions & 0 deletions scripts/start_demo_server.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
#!/bin/bash
# SPDX-License-Identifier: Apache-2.0

# This script starts the Spark Substrait Gateway demo server.
# It will create demo TPC-H (Scale Factor 1GB) data, and start the server.

set -e

if [ $(echo "${GENERATE_CLIENT_DEMO_DATA}" | tr '[:upper:]' '[:lower:]') == "true" ]; then
echo "Generating client demo TPC-H data..."
spark-substrait-create-client-demo-data
fi

spark-substrait-gateway-server
33 changes: 22 additions & 11 deletions src/gateway/demo/client_demo.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

import logging
import os
import sys
from pathlib import Path

import click
Expand All @@ -11,30 +12,40 @@

from gateway.config import SERVER_PORT

_LOGGER = logging.getLogger(__name__)
# Setup logging
logging.basicConfig(
format="%(asctime)s - %(levelname)-8s %(message)s",
datefmt="%Y-%m-%d %H:%M:%S %Z",
level=getattr(logging, os.getenv("LOG_LEVEL", "INFO")),
stream=sys.stdout,
)
_LOGGER = logging.getLogger()

# Constants
CLIENT_DEMO_DATA_LOCATION = Path("data") / "tpch" / "parquet"


def find_tpch() -> Path:
def find_tpch(raise_error_if_not_exists: bool) -> Path:
"""Find the location of the TPCH dataset."""
location = Path("third_party") / "tpch" / "parquet"
if location.exists():
return location
raise ValueError("TPCH dataset not found")
location = CLIENT_DEMO_DATA_LOCATION
if raise_error_if_not_exists and not location.exists():
raise ValueError("TPCH dataset not found")
return location


# pylint: disable=fixme
def get_customer_database(spark_session: SparkSession) -> DataFrame:
def get_customer_database(spark_session: SparkSession, use_gateway: bool) -> DataFrame:
"""Register the TPC-H customer database."""
location_customer = str(find_tpch() / "customer")
location_customer = str(find_tpch(raise_error_if_not_exists=(not use_gateway)) / "customer")

return spark_session.read.parquet(location_customer, mergeSchema=False)


# pylint: disable=fixme
# ruff: noqa: T201
def execute_query(spark_session: SparkSession) -> None:
def execute_query(spark_session: SparkSession, use_gateway: bool) -> None:
"""Run a single sample query against the gateway."""
df_customer = get_customer_database(spark_session)
df_customer = get_customer_database(spark_session=spark_session, use_gateway=use_gateway)

df_customer.createOrReplaceTempView("customer")

Expand Down Expand Up @@ -76,7 +87,7 @@ def run_demo(
spark = SparkSession.builder.remote(f"sc://{host}:{port}/{uri_parameters}").getOrCreate()
else:
spark = SparkSession.builder.master("local").getOrCreate()
execute_query(spark)
execute_query(spark_session=spark, use_gateway=use_gateway)


@click.command()
Expand Down
128 changes: 128 additions & 0 deletions src/gateway/demo/generate_tpch_parquet_data.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,128 @@
# SPDX-License-Identifier: Apache-2.0
"""A utility module for generating TPC-H parquet data for the client demo."""

import logging
import os
import shutil
import sys
from pathlib import Path

import click
import duckdb

from .client_demo import CLIENT_DEMO_DATA_LOCATION

# Setup logging
logging.basicConfig(
format="%(asctime)s - %(levelname)-8s %(message)s",
datefmt="%Y-%m-%d %H:%M:%S %Z",
level=getattr(logging, os.getenv("LOG_LEVEL", "INFO")),
stream=sys.stdout,
)
_LOGGER = logging.getLogger()


def execute_query(conn: duckdb.DuckDBPyConnection, query: str):
"""Execute and log a query with a DuckDB connection."""
_LOGGER.info(msg=f"Executing SQL: '{query}'")
conn.execute(query=query)


def get_printable_number(num: float):
"""Return a number in a printable format."""
return f"{num:.9g}"


def generate_tpch_parquet_data(
tpch_scale_factor: int, data_directory: str, overwrite: bool
) -> Path:
"""Generate a TPC-H parquet dataset."""
_LOGGER.info(
msg=(
"Creating a TPC-H parquet dataset - with parameters: "
f"--tpch-scale-factor={tpch_scale_factor} "
f"--data-directory='{data_directory}' "
f"--overwrite={overwrite}"
)
)

# Output the database version
_LOGGER.info(msg=f"Using DuckDB Version: {duckdb.__version__}")

# Get an in-memory DuckDB database connection
conn = duckdb.connect()

# Load the TPCH extension needed to generate the data...
conn.load_extension(extension="tpch")

# Generate the data
execute_query(conn=conn, query=f"CALL dbgen(sf={tpch_scale_factor})")

# Export the data
target_directory = Path(data_directory)

if target_directory.exists():
if overwrite:
_LOGGER.warning(msg=f"Directory: {target_directory.as_posix()} exists, removing...")
shutil.rmtree(path=target_directory.as_posix())
else:
raise RuntimeError(f"Directory: {target_directory.as_posix()} exists, aborting.")

target_directory.mkdir(parents=True, exist_ok=True)
execute_query(
conn=conn, query=f"EXPORT DATABASE '{target_directory.as_posix()}' (FORMAT PARQUET)"
)

_LOGGER.info(msg=f"Wrote out parquet data to path: '{target_directory.as_posix()}'")

# Restructure the contents of the directory so that each file is in its own directory
for filename in target_directory.glob(pattern="*.parquet"):
file = Path(filename)
table_name = file.stem
table_directory = target_directory / table_name
table_directory.mkdir(parents=True, exist_ok=True)

if file.name not in ("nation.parquet", "region.parquet"):
new_file_name = f"{table_name}.1.parquet"
else:
new_file_name = file.name

file.rename(target=table_directory / new_file_name)

_LOGGER.info(msg="All done.")

return target_directory


@click.command()
@click.option(
"--tpch-scale-factor",
type=float,
default=1,
show_default=True,
required=True,
help="The TPC-H scale factor to generate.",
)
@click.option(
"--data-directory",
type=str,
default=CLIENT_DEMO_DATA_LOCATION.as_posix(),
show_default=True,
required=True,
help="The target output data directory to put the files into",
)
@click.option(
"--overwrite/--no-overwrite",
type=bool,
default=False,
show_default=True,
required=True,
help="Can we overwrite the target directory if it already exists...",
)
def click_generate_tpch_parquet_data(tpch_scale_factor: int, data_directory: str, overwrite: bool):
"""Provide a click interface for generating TPC-H parquet data."""
generate_tpch_parquet_data(**locals())


if __name__ == "__main__":
click_generate_tpch_parquet_data()
Loading