From ae390cdbc67bf780ead56591c79c9a7c7a47a34b Mon Sep 17 00:00:00 2001 From: sundy-li <543950155@qq.com> Date: Sun, 1 Sep 2024 11:36:53 +0800 Subject: [PATCH] feat(udf): support batch mode --- python/databend_udf/udf.py | 23 +++++++++++++++++++---- python/example/server.py | 15 +++++++++++++++ 2 files changed, 34 insertions(+), 4 deletions(-) diff --git a/python/databend_udf/udf.py b/python/databend_udf/udf.py index 9775fd1..247b386 100644 --- a/python/databend_udf/udf.py +++ b/python/databend_udf/udf.py @@ -59,9 +59,10 @@ class ScalarFunction(UserDefinedFunction): _io_threads: Optional[int] _executor: Optional[ThreadPoolExecutor] _skip_null: bool + _batch_mode: bool def __init__( - self, func, input_types, result_type, name=None, io_threads=None, skip_null=None + self, func, input_types, result_type, name=None, io_threads=None, skip_null=None, batch_mode=False ): self._func = func self._input_schema = pa.schema( @@ -78,6 +79,7 @@ def __init__( func.__name__ if hasattr(func, "__name__") else func.__class__.__name__ ) self._io_threads = io_threads + self._batch_mode = batch_mode self._executor = ( ThreadPoolExecutor(max_workers=self._io_threads) if self._io_threads is not None @@ -98,7 +100,11 @@ def eval_batch(self, batch: pa.RecordBatch) -> Iterator[pa.RecordBatch]: _input_process_func(_list_field(field))(array) for array, field in zip(inputs, self._input_schema) ] - if self._executor is not None: + + # evaluate the function for each row + if self._batch_mode: + column = self._func(*inputs) + elif self._executor is not None: # concurrently evaluate the function for each row if self._skip_null: tasks = [] @@ -113,7 +119,6 @@ def eval_batch(self, batch: pa.RecordBatch) -> Iterator[pa.RecordBatch]: ] column = [future.result() for future in tasks] else: - # evaluate the function for each row if self._skip_null: column = [] for row in range(batch.num_rows): @@ -140,6 +145,7 @@ def udf( name: Optional[str] = None, io_threads: Optional[int] = None, skip_null: Optional[bool] = False, + batch_mode: Optional[bool] = False, ) -> Callable: """ Annotation for creating a user-defined scalar function. @@ -153,6 +159,7 @@ def udf( - skip_null: A boolean value specifying whether to skip NULL value. If it is set to True, NULL values will not be passed to the function, and the corresponding return value is set to NULL. Default to False. + - batch_mode: A boolean value specifying whether to use batch mode. Default to False. Example: ``` @@ -170,6 +177,13 @@ def external_api(x): response = requests.get(my_endpoint + '?param=' + x) return response["data"] ``` + + Batch mode example: + ``` + @udf(input_types=['INT', 'INT'], result_type='INT', batch_mode=True) + def gcd(x, y): + return [x_i if y_i == 0 else gcd(y_i, x_i % y_i) for x_i, y_i in zip(x, y)] + ``` """ if io_threads is not None and io_threads > 1: @@ -180,10 +194,11 @@ def external_api(x): name, io_threads=io_threads, skip_null=skip_null, + batch_mode=batch_mode ) else: return lambda f: ScalarFunction( - f, input_types, result_type, name, skip_null=skip_null + f, input_types, result_type, name, skip_null=skip_null, batch_mode=batch_mode ) diff --git a/python/example/server.py b/python/example/server.py index 42ad118..8d68898 100644 --- a/python/example/server.py +++ b/python/example/server.py @@ -54,6 +54,20 @@ def gcd(x: int, y: int) -> int: (x, y) = (y, x % y) return x +@udf( + name="gcd_batch", + input_types=["INT", "INT"], + result_type="INT", + batch_mode=True, +) +def gcd_batch(x: list[int], y: list[int]) -> list[int]: + def gcd_single(x_i, y_i): + if x_i == None or y_i == None: + return None + while y_i != 0: + (x_i, y_i) = (y_i, x_i % y_i) + return x_i + return [gcd_single(x_i, y_i) for x_i, y_i in zip(x, y)] @udf(input_types=["VARCHAR", "VARCHAR", "VARCHAR"], result_type="VARCHAR") def split_and_join(s: str, split_s: str, join_s: str) -> str: @@ -303,6 +317,7 @@ def wait_concurrent(x): udf_server.add_function(binary_reverse) udf_server.add_function(bool_select) udf_server.add_function(gcd) + udf_server.add_function(gcd_batch) udf_server.add_function(split_and_join) udf_server.add_function(decimal_div) udf_server.add_function(hex_to_dec)