From 44984f196f7911156b1b32603276c81990087a3c Mon Sep 17 00:00:00 2001 From: Sebastian Eckweiler Date: Tue, 15 Oct 2024 22:28:48 +0200 Subject: [PATCH] [SEDONA-663] Support spark connect python api (#1639) * initial successful test * try add docker-compose based tests * 3.5 only * comment classic tests * try fix yaml * skip other workflows * skip other workflows * try fix if check * fix path * cd to python folder * skip sparkContext with SPARK_REMOTE * fix type check * refactor somewhat * Revert "skip other workflows" This reverts commit 7eb9b6ea * back to full matrix * add license header, fix missing whitespace * Add a simple docstring to SedonaFunction * uncomment build step * need sql extensions * run pre-commit * fix lint/pre-commit * Update .github/workflows/python.yml Co-authored-by: John Bampton * adjust spelling * use UnresolvedFunction instead of CallFunction * revert Pipfile to master rev --------- Co-authored-by: John Bampton --- .github/workflows/python.yml | 17 ++++++++++++ python/sedona/spark/SedonaContext.py | 14 ++++++++-- python/sedona/sql/connect.py | 40 ++++++++++++++++++++++++++++ python/sedona/sql/dataframe_api.py | 35 +++++++++++++++++++----- python/tests/test_base.py | 31 +++++++++++++++++---- 5 files changed, 123 insertions(+), 14 deletions(-) create mode 100644 python/sedona/sql/connect.py diff --git a/.github/workflows/python.yml b/.github/workflows/python.yml index e7d1002d94..04fa4f7fc9 100644 --- a/.github/workflows/python.yml +++ b/.github/workflows/python.yml @@ -153,3 +153,20 @@ jobs: SPARK_VERSION: ${{ matrix.spark }} HADOOP_VERSION: ${{ matrix.hadoop }} run: (export SPARK_HOME=$PWD/spark-${SPARK_VERSION}-bin-hadoop${HADOOP_VERSION};export PYTHONPATH=$SPARK_HOME/python;cd python;pipenv run pytest tests) + - env: + SPARK_VERSION: ${{ matrix.spark }} + HADOOP_VERSION: ${{ matrix.hadoop }} + run: | + if [ ! -f "spark-${SPARK_VERSION}-bin-hadoop${HADOOP_VERSION}/sbin/start-connect-server.sh" ] + then + echo "Skipping connect tests for Spark $SPARK_VERSION" + exit + fi + + export SPARK_HOME=$PWD/spark-${SPARK_VERSION}-bin-hadoop${HADOOP_VERSION} + export PYTHONPATH=$SPARK_HOME/python + export SPARK_REMOTE=local + + cd python + pipenv install "pyspark[connect]==${SPARK_VERSION}" + pipenv run pytest tests/sql/test_dataframe_api.py diff --git a/python/sedona/spark/SedonaContext.py b/python/sedona/spark/SedonaContext.py index 5cba5df624..49db2e47aa 100644 --- a/python/sedona/spark/SedonaContext.py +++ b/python/sedona/spark/SedonaContext.py @@ -21,6 +21,13 @@ from sedona.register.geo_registrator import PackageImporter from sedona.utils import KryoSerializer, SedonaKryoRegistrator +try: + from pyspark.sql.utils import is_remote +except ImportError: + + def is_remote(): + return False + @attr.s class SedonaContext: @@ -34,8 +41,11 @@ def create(cls, spark: SparkSession) -> SparkSession: :return: SedonaContext which is an instance of SparkSession """ spark.sql("SELECT 1 as geom").count() - PackageImporter.import_jvm_lib(spark._jvm) - spark._jvm.SedonaContext.create(spark._jsparkSession, "python") + + # with Spark Connect there is no local JVM + if not is_remote(): + PackageImporter.import_jvm_lib(spark._jvm) + spark._jvm.SedonaContext.create(spark._jsparkSession, "python") return spark @classmethod diff --git a/python/sedona/sql/connect.py b/python/sedona/sql/connect.py new file mode 100644 index 0000000000..3470996308 --- /dev/null +++ b/python/sedona/sql/connect.py @@ -0,0 +1,40 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +from typing import Any, Iterable, List + +import pyspark.sql.connect.functions as f +from pyspark.sql.connect.column import Column +from pyspark.sql.connect.expressions import UnresolvedFunction + + +# mimic semantics of _convert_argument_to_java_column +def _convert_argument_to_connect_column(arg: Any) -> Column: + if isinstance(arg, Column): + return arg + elif isinstance(arg, str): + return f.col(arg) + elif isinstance(arg, Iterable): + return f.array(*[_convert_argument_to_connect_column(x) for x in arg]) + else: + return f.lit(arg) + + +def call_sedona_function_connect(function_name: str, args: List[Any]) -> Column: + + expressions = [_convert_argument_to_connect_column(arg)._expr for arg in args] + return Column(UnresolvedFunction(function_name, expressions)) diff --git a/python/sedona/sql/dataframe_api.py b/python/sedona/sql/dataframe_api.py index 4f79878ba1..2f56dfffa5 100644 --- a/python/sedona/sql/dataframe_api.py +++ b/python/sedona/sql/dataframe_api.py @@ -24,8 +24,23 @@ from pyspark.sql import Column, SparkSession from pyspark.sql import functions as f -ColumnOrName = Union[Column, str] -ColumnOrNameOrNumber = Union[Column, str, float, int] +try: + from pyspark.sql.connect.column import Column as ConnectColumn + from pyspark.sql.utils import is_remote +except ImportError: + # be backwards compatible with Spark < 3.4 + def is_remote(): + return False + + class ConnectColumn: + pass + +else: + from sedona.sql.connect import call_sedona_function_connect + + +ColumnOrName = Union[Column, ConnectColumn, str] +ColumnOrNameOrNumber = Union[Column, ConnectColumn, str, float, int] def _convert_argument_to_java_column(arg: Any) -> Column: @@ -49,13 +64,15 @@ def call_sedona_function( ) # apparently a Column is an Iterable so we need to check for it explicitly - if ( - (not isinstance(args, Iterable)) - or isinstance(args, str) - or isinstance(args, Column) + if (not isinstance(args, Iterable)) or isinstance( + args, (str, Column, ConnectColumn) ): args = [args] + # in spark-connect environments use connect API + if is_remote(): + return call_sedona_function_connect(function_name, args) + args = map(_convert_argument_to_java_column, args) jobject = getattr(spark._jvm, object_name) @@ -86,6 +103,10 @@ def _get_type_list(annotated_type: Type) -> Tuple[Type, ...]: else: valid_types = (annotated_type,) + # functions accepting a Column should also accept the Spark Connect sort of Column + if Column in valid_types: + valid_types = valid_types + (ConnectColumn,) + return valid_types @@ -159,7 +180,7 @@ def validated_function(*args, **kwargs) -> Column: # all arguments are Columns or strings are always legal, so only check types when one of the arguments is not a column if not all( [ - isinstance(x, Column) or isinstance(x, str) + isinstance(x, (Column, ConnectColumn)) or isinstance(x, str) for x in itertools.chain(args, kwargs.values()) ] ): diff --git a/python/tests/test_base.py b/python/tests/test_base.py index e45a6e9f6d..4bfbb86b00 100644 --- a/python/tests/test_base.py +++ b/python/tests/test_base.py @@ -14,22 +14,43 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. - +import os from tempfile import mkdtemp +import pyspark + from sedona.spark import * from sedona.utils.decorators import classproperty +SPARK_REMOTE = os.getenv("SPARK_REMOTE") + class TestBase: @classproperty def spark(self): if not hasattr(self, "__spark"): - spark = SedonaContext.create( - SedonaContext.builder().master("local[*]").getOrCreate() - ) - spark.sparkContext.setCheckpointDir(mkdtemp()) + + builder = SedonaContext.builder() + if SPARK_REMOTE: + builder = ( + builder.remote(SPARK_REMOTE) + .config( + "spark.jars.packages", + f"org.apache.spark:spark-connect_2.12:{pyspark.__version__}", + ) + .config( + "spark.sql.extensions", + "org.apache.sedona.sql.SedonaSqlExtensions", + ) + ) + else: + builder = builder.master("local[*]") + + spark = SedonaContext.create(builder.getOrCreate()) + + if not SPARK_REMOTE: + spark.sparkContext.setCheckpointDir(mkdtemp()) setattr(self, "__spark", spark) return getattr(self, "__spark")