Skip to content

Commit

Permalink
[SEDONA-663] Support spark connect python api (#1639)
Browse files Browse the repository at this point in the history
* initial successful test

* try add docker-compose based tests

* 3.5 only

* comment classic tests

* try fix yaml

* skip other workflows

* skip other workflows

* try fix if check

* fix path

* cd to python folder

* skip sparkContext with SPARK_REMOTE

* fix type check

* refactor somewhat

* Revert "skip other workflows"

This reverts commit 7eb9b6e

* back to full matrix

* add license header, fix missing whitespace

* Add a simple docstring to SedonaFunction

* uncomment build step

* need sql extensions

* run pre-commit

* fix lint/pre-commit

* Update .github/workflows/python.yml

Co-authored-by: John Bampton <[email protected]>

* adjust spelling

* use UnresolvedFunction instead of CallFunction

* revert Pipfile to master rev

---------

Co-authored-by: John Bampton <[email protected]>
  • Loading branch information
sebbegg and jbampton authored Oct 15, 2024
1 parent 4c7da62 commit 44984f1
Show file tree
Hide file tree
Showing 5 changed files with 123 additions and 14 deletions.
17 changes: 17 additions & 0 deletions .github/workflows/python.yml
Original file line number Diff line number Diff line change
Expand Up @@ -153,3 +153,20 @@ jobs:
SPARK_VERSION: ${{ matrix.spark }}
HADOOP_VERSION: ${{ matrix.hadoop }}
run: (export SPARK_HOME=$PWD/spark-${SPARK_VERSION}-bin-hadoop${HADOOP_VERSION};export PYTHONPATH=$SPARK_HOME/python;cd python;pipenv run pytest tests)
- env:
SPARK_VERSION: ${{ matrix.spark }}
HADOOP_VERSION: ${{ matrix.hadoop }}
run: |
if [ ! -f "spark-${SPARK_VERSION}-bin-hadoop${HADOOP_VERSION}/sbin/start-connect-server.sh" ]
then
echo "Skipping connect tests for Spark $SPARK_VERSION"
exit
fi
export SPARK_HOME=$PWD/spark-${SPARK_VERSION}-bin-hadoop${HADOOP_VERSION}
export PYTHONPATH=$SPARK_HOME/python
export SPARK_REMOTE=local
cd python
pipenv install "pyspark[connect]==${SPARK_VERSION}"
pipenv run pytest tests/sql/test_dataframe_api.py
14 changes: 12 additions & 2 deletions python/sedona/spark/SedonaContext.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,13 @@
from sedona.register.geo_registrator import PackageImporter
from sedona.utils import KryoSerializer, SedonaKryoRegistrator

try:
from pyspark.sql.utils import is_remote
except ImportError:

def is_remote():
return False


@attr.s
class SedonaContext:
Expand All @@ -34,8 +41,11 @@ def create(cls, spark: SparkSession) -> SparkSession:
:return: SedonaContext which is an instance of SparkSession
"""
spark.sql("SELECT 1 as geom").count()
PackageImporter.import_jvm_lib(spark._jvm)
spark._jvm.SedonaContext.create(spark._jsparkSession, "python")

# with Spark Connect there is no local JVM
if not is_remote():
PackageImporter.import_jvm_lib(spark._jvm)
spark._jvm.SedonaContext.create(spark._jsparkSession, "python")
return spark

@classmethod
Expand Down
40 changes: 40 additions & 0 deletions python/sedona/sql/connect.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.

from typing import Any, Iterable, List

import pyspark.sql.connect.functions as f
from pyspark.sql.connect.column import Column
from pyspark.sql.connect.expressions import UnresolvedFunction


# mimic semantics of _convert_argument_to_java_column
def _convert_argument_to_connect_column(arg: Any) -> Column:
if isinstance(arg, Column):
return arg
elif isinstance(arg, str):
return f.col(arg)
elif isinstance(arg, Iterable):
return f.array(*[_convert_argument_to_connect_column(x) for x in arg])
else:
return f.lit(arg)


def call_sedona_function_connect(function_name: str, args: List[Any]) -> Column:

expressions = [_convert_argument_to_connect_column(arg)._expr for arg in args]
return Column(UnresolvedFunction(function_name, expressions))
35 changes: 28 additions & 7 deletions python/sedona/sql/dataframe_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,23 @@
from pyspark.sql import Column, SparkSession
from pyspark.sql import functions as f

ColumnOrName = Union[Column, str]
ColumnOrNameOrNumber = Union[Column, str, float, int]
try:
from pyspark.sql.connect.column import Column as ConnectColumn
from pyspark.sql.utils import is_remote
except ImportError:
# be backwards compatible with Spark < 3.4
def is_remote():
return False

class ConnectColumn:
pass

else:
from sedona.sql.connect import call_sedona_function_connect


ColumnOrName = Union[Column, ConnectColumn, str]
ColumnOrNameOrNumber = Union[Column, ConnectColumn, str, float, int]


def _convert_argument_to_java_column(arg: Any) -> Column:
Expand All @@ -49,13 +64,15 @@ def call_sedona_function(
)

# apparently a Column is an Iterable so we need to check for it explicitly
if (
(not isinstance(args, Iterable))
or isinstance(args, str)
or isinstance(args, Column)
if (not isinstance(args, Iterable)) or isinstance(
args, (str, Column, ConnectColumn)
):
args = [args]

# in spark-connect environments use connect API
if is_remote():
return call_sedona_function_connect(function_name, args)

args = map(_convert_argument_to_java_column, args)

jobject = getattr(spark._jvm, object_name)
Expand Down Expand Up @@ -86,6 +103,10 @@ def _get_type_list(annotated_type: Type) -> Tuple[Type, ...]:
else:
valid_types = (annotated_type,)

# functions accepting a Column should also accept the Spark Connect sort of Column
if Column in valid_types:
valid_types = valid_types + (ConnectColumn,)

return valid_types


Expand Down Expand Up @@ -159,7 +180,7 @@ def validated_function(*args, **kwargs) -> Column:
# all arguments are Columns or strings are always legal, so only check types when one of the arguments is not a column
if not all(
[
isinstance(x, Column) or isinstance(x, str)
isinstance(x, (Column, ConnectColumn)) or isinstance(x, str)
for x in itertools.chain(args, kwargs.values())
]
):
Expand Down
31 changes: 26 additions & 5 deletions python/tests/test_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,22 +14,43 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.

import os
from tempfile import mkdtemp

import pyspark

from sedona.spark import *
from sedona.utils.decorators import classproperty

SPARK_REMOTE = os.getenv("SPARK_REMOTE")


class TestBase:

@classproperty
def spark(self):
if not hasattr(self, "__spark"):
spark = SedonaContext.create(
SedonaContext.builder().master("local[*]").getOrCreate()
)
spark.sparkContext.setCheckpointDir(mkdtemp())

builder = SedonaContext.builder()
if SPARK_REMOTE:
builder = (
builder.remote(SPARK_REMOTE)
.config(
"spark.jars.packages",
f"org.apache.spark:spark-connect_2.12:{pyspark.__version__}",
)
.config(
"spark.sql.extensions",
"org.apache.sedona.sql.SedonaSqlExtensions",
)
)
else:
builder = builder.master("local[*]")

spark = SedonaContext.create(builder.getOrCreate())

if not SPARK_REMOTE:
spark.sparkContext.setCheckpointDir(mkdtemp())

setattr(self, "__spark", spark)
return getattr(self, "__spark")
Expand Down

0 comments on commit 44984f1

Please sign in to comment.