diff --git a/.github/workflows/python.yml b/.github/workflows/python.yml index 379d440..a37724c 100644 --- a/.github/workflows/python.yml +++ b/.github/workflows/python.yml @@ -34,6 +34,7 @@ jobs: python -m pip install build python -m build - uses: pypa/gh-action-pypi-publish@release/v1 + if: github.event_name == 'push' with: packages-dir: python/dist/ skip-existing: true diff --git a/python/.gitignore b/python/.gitignore index 8be2526..0ea8ccd 100644 --- a/python/.gitignore +++ b/python/.gitignore @@ -2,3 +2,4 @@ venv/ *.egg-info/ dist/ __pycache__/ +Pipfile.lock diff --git a/python/Pipfile b/python/Pipfile new file mode 100644 index 0000000..031e0f2 --- /dev/null +++ b/python/Pipfile @@ -0,0 +1,14 @@ +[[source]] +url = "https://pypi.org/simple" +verify_ssl = true +name = "pypi" + +[packages] +databend-udf = {file = "."} + +[dev-packages] +flake8 = "*" +black = "*" + +[requires] +python_version = "3.12" diff --git a/python/databend_udf/udf.py b/python/databend_udf/udf.py index 98f88e2..4c1baa3 100644 --- a/python/databend_udf/udf.py +++ b/python/databend_udf/udf.py @@ -13,8 +13,8 @@ # limitations under the License. import json +import logging import inspect -import traceback from concurrent.futures import ThreadPoolExecutor from typing import Iterator, Callable, Optional, Union, List, Dict @@ -24,11 +24,13 @@ # comes from Databend MAX_DECIMAL128_PRECISION = 38 MAX_DECIMAL256_PRECISION = 76 -EXTENSION_KEY = "Extension" -ARROW_EXT_TYPE_VARIANT = "Variant" +EXTENSION_KEY = b"Extension" +ARROW_EXT_TYPE_VARIANT = b"Variant" TIMESTAMP_UINT = "us" +logger = logging.getLogger(__name__) + class UserDefinedFunction: """ @@ -92,8 +94,8 @@ def __init__( def eval_batch(self, batch: pa.RecordBatch) -> Iterator[pa.RecordBatch]: inputs = [[v.as_py() for v in array] for array in batch] inputs = [ - _process_func(pa.list_(type), False)(array) - for array, type in zip(inputs, self._input_schema.types) + _input_process_func(_list_field(field))(array) + for array, field in zip(inputs, self._input_schema) ] if self._executor is not None: # concurrently evaluate the function for each row @@ -122,7 +124,7 @@ def eval_batch(self, batch: pa.RecordBatch) -> Iterator[pa.RecordBatch]: for row in range(batch.num_rows) ] - column = _process_func(pa.list_(self._result_schema.types[0]), True)(column) + column = _output_process_func(_list_field(self._result_schema.field(0)))(column) array = pa.array(column, type=self._result_schema.types[0]) yield pa.RecordBatch.from_arrays([array], schema=self._result_schema) @@ -231,7 +233,7 @@ def do_exchange(self, context, descriptor, reader, writer): for output_batch in udf.eval_batch(batch.data): writer.write_batch(output_batch) except Exception as e: - print(traceback.print_exc()) + logger.exception(e) raise e def add_function(self, udf: UserDefinedFunction): @@ -249,97 +251,112 @@ def add_function(self, udf: UserDefinedFunction): f"RETURNS {output_type} LANGUAGE python " f"HANDLER = '{name}' ADDRESS = 'http://{self._location}';" ) - print(f"added function: {name}, corresponding SQL:\n{sql}\n") + logger.info(f"added function: {name}, SQL:\n{sql}\n") def serve(self): """Start the server.""" - print(f"listening on {self._location}") + logger.info(f"listening on {self._location}") super(UDFServer, self).serve() -def _null_func(*args): - return None - - -def _process_func(type: pa.DataType, output: bool) -> Callable: +def _input_process_func(field: pa.Field) -> Callable: """ - Return a function to process input or output value. - - For input type: - - String=pa.string(): bytes -> str - - Tuple=pa.struct(): dict -> tuple - - Json=pa.large_binary(): bytes -> Any - - Map=pa.map_(): list[tuple(k,v)] -> dict + Return a function to process input value. - For output type: - - Json=pa.large_binary(): Any -> str - - Map=pa.map_(): dict -> list[tuple(k,v)] + - Tuple=pa.struct(): dict -> tuple + - Json=pa.large_binary(): bytes -> Any + - Map=pa.map_(): list[tuple(k,v)] -> dict """ - if pa.types.is_list(type): - func = _process_func(type.value_type, output) + if pa.types.is_list(field.type): + func = _input_process_func(field.type.value_field) return ( - lambda array: [(func(v) if v is not None else None) for v in array] + lambda array: [func(v) if v is not None else None for v in array] if array is not None else None ) - if pa.types.is_struct(type): - funcs = [_process_func(field.type, output) for field in type] - if output: - return ( - lambda tup: tuple( - (func(v) if v is not None else None) for v, func in zip(tup, funcs) - ) - if tup is not None - else None - ) - else: - # the input value of struct type is a dict - # we convert it into tuple here - return ( - lambda map: tuple( - (func(v) if v is not None else None) - for v, func in zip(map.values(), funcs) - ) - if map is not None - else None + if pa.types.is_struct(field.type): + funcs = [_input_process_func(f) for f in field.type] + # the input value of struct type is a dict + # we convert it into tuple here + return ( + lambda map: tuple( + func(v) if v is not None else None + for v, func in zip(map.values(), funcs) ) - if pa.types.is_map(type): + if map is not None + else None + ) + if pa.types.is_map(field.type): funcs = [ - _process_func(type.key_type, output), - _process_func(type.item_type, output), + _input_process_func(field.type.key_field), + _input_process_func(field.type.item_field), ] - if output: - # dict -> list[tuple[k,v]] - return ( - lambda map: [ - tuple(func(v) for v, func in zip(item, funcs)) - for item in map.items() - ] - if map is not None - else None + # list[tuple[k,v]] -> dict + return ( + lambda array: dict( + tuple(func(v) for v, func in zip(item, funcs)) for item in array ) - else: - # list[tuple[k,v]] -> dict - return ( - lambda array: dict( - tuple(func(v) for v, func in zip(item, funcs)) for item in array - ) - if array is not None - else None + if array is not None + else None + ) + if pa.types.is_large_binary(field.type): + if _field_is_variant(field): + return lambda v: json.loads(v) if v is not None else None + + return lambda v: v + + +def _output_process_func(field: pa.Field) -> Callable: + """ + Return a function to process output value. + + - Json=pa.large_binary(): Any -> str + - Map=pa.map_(): dict -> list[tuple(k,v)] + """ + if pa.types.is_list(field.type): + func = _output_process_func(field.type.value_field) + return ( + lambda array: [func(v) if v is not None else None for v in array] + if array is not None + else None + ) + if pa.types.is_struct(field.type): + funcs = [_output_process_func(f) for f in field.type] + return ( + lambda tup: tuple( + func(v) if v is not None else None for v, func in zip(tup, funcs) ) + if tup is not None + else None + ) + if pa.types.is_map(field.type): + funcs = [ + _output_process_func(field.type.key_field), + _output_process_func(field.type.item_field), + ] + # dict -> list[tuple[k,v]] + return ( + lambda map: [ + tuple(func(v) for v, func in zip(item, funcs)) for item in map.items() + ] + if map is not None + else None + ) + if pa.types.is_large_binary(field.type): + if _field_is_variant(field): + return lambda v: json.dumps(_ensure_str(v)) if v is not None else None - if pa.types.is_string(type) and not output: - # string type is converted to LargeBinary in Databend, - # we cast it back to string here - return lambda v: v.decode("utf-8") if v is not None else None - if pa.types.is_large_binary(type): - if output: - return lambda v: json.dumps(v) if v is not None else None - else: - return lambda v: json.loads(v) if v is not None else None return lambda v: v +def _null_func(*args): + return None + + +def _list_field(field: pa.Field) -> pa.Field: + return pa.field("", pa.list_(field)) + + def _to_list(x): if isinstance(x, list): return x @@ -347,6 +364,25 @@ def _to_list(x): return [x] +def _ensure_str(x): + if isinstance(x, bytes): + return x.decode("utf-8") + elif isinstance(x, list): + return [_ensure_str(v) for v in x] + elif isinstance(x, dict): + return {_ensure_str(k): _ensure_str(v) for k, v in x.items()} + else: + return x + + +def _field_is_variant(field: pa.Field) -> bool: + if field.metadata is None: + return False + if field.metadata.get(EXTENSION_KEY) == ARROW_EXT_TYPE_VARIANT: + return True + return False + + def _to_arrow_field(t: Union[str, pa.DataType]) -> pa.Field: """ Convert a string or pyarrow.DataType to pyarrow.Field. @@ -401,7 +437,9 @@ def _type_str_to_arrow_field_inner(type_str: str) -> pa.Field: elif type_str in ("DATETIME", "TIMESTAMP"): return pa.field("", pa.timestamp(TIMESTAMP_UINT), False) elif type_str in ("STRING", "VARCHAR", "CHAR", "CHARACTER", "TEXT"): - return pa.field("", pa.string(), False) + return pa.field("", pa.large_utf8(), False) + elif type_str in ("BINARY"): + return pa.field("", pa.large_binary(), False) elif type_str in ("VARIANT", "JSON"): # In Databend, JSON type is identified by the "EXTENSION" key in the metadata. return pa.field( @@ -460,20 +498,21 @@ def _arrow_field_to_string(field: pa.Field) -> str: """ Convert a `pyarrow.Field` to a SQL data type string. """ - type_str = _data_type_to_string(field.type) + type_str = _field_type_to_string(field) return f"{type_str} NOT NULL" if not field.nullable else type_str def _inner_field_to_string(field: pa.Field) -> str: # inner field default is NOT NULL in databend - type_str = _data_type_to_string(field.type) + type_str = _field_type_to_string(field) return f"{type_str} NULL" if field.nullable else type_str -def _data_type_to_string(t: pa.DataType) -> str: +def _field_type_to_string(field: pa.Field) -> str: """ Convert a `pyarrow.DataType` to a SQL data type string. """ + t = field.type if pa.types.is_boolean(t): return "BOOLEAN" elif pa.types.is_int8(t): @@ -502,10 +541,13 @@ def _data_type_to_string(t: pa.DataType) -> str: return "DATE" elif pa.types.is_timestamp(t): return "TIMESTAMP" - elif pa.types.is_string(t): + elif pa.types.is_large_unicode(t): return "VARCHAR" elif pa.types.is_large_binary(t): - return "VARIANT" + if _field_is_variant(field): + return "VARIANT" + else: + return "BINARY" elif pa.types.is_list(t): return f"ARRAY({_inner_field_to_string(t.value_field)})" elif pa.types.is_map(t): diff --git a/python/example/server.py b/python/example/server.py new file mode 100644 index 0000000..42ad118 --- /dev/null +++ b/python/example/server.py @@ -0,0 +1,322 @@ +# Copyright 2023 RisingWave Labs +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging +import datetime +from decimal import Decimal +import time +from typing import List, Dict, Any, Tuple, Optional + +from databend_udf import udf, UDFServer + +logging.basicConfig(level=logging.INFO) + + +@udf(input_types=["TINYINT", "SMALLINT", "INT", "BIGINT"], result_type="BIGINT") +def add_signed(a, b, c, d): + return a + b + c + d + + +@udf(input_types=["UINT8", "UINT16", "UINT32", "UINT64"], result_type="UINT64") +def add_unsigned(a, b, c, d): + return a + b + c + d + + +@udf(input_types=["FLOAT", "DOUBLE"], result_type="DOUBLE") +def add_float(a, b): + return a + b + + +@udf(input_types=["BOOLEAN", "BIGINT", "BIGINT"], result_type="BIGINT") +def bool_select(condition, a, b): + return a if condition else b + + +@udf( + name="gcd", + input_types=["INT", "INT"], + result_type="INT", + skip_null=True, +) +def gcd(x: int, y: int) -> int: + while y != 0: + (x, y) = (y, x % y) + return x + + +@udf(input_types=["VARCHAR", "VARCHAR", "VARCHAR"], result_type="VARCHAR") +def split_and_join(s: str, split_s: str, join_s: str) -> str: + return join_s.join(s.split(split_s)) + + +@udf(input_types=["BINARY"], result_type="BINARY") +def binary_reverse(s: bytes) -> bytes: + return s[::-1] + + +@udf(input_types="VARCHAR", result_type="DECIMAL(36, 18)") +def hex_to_dec(hex: str) -> Decimal: + hex = hex.strip() + + dec = Decimal(0) + while hex: + chunk = hex[:16] + chunk_value = int(hex[:16], 16) + dec = dec * (1 << (4 * len(chunk))) + chunk_value + chunk_len = len(chunk) + hex = hex[chunk_len:] + return dec + + +@udf(input_types=["DECIMAL(36, 18)", "DECIMAL(36, 18)"], result_type="DECIMAL(72, 28)") +def decimal_div(v1: Decimal, v2: Decimal) -> Decimal: + result = v1 / v2 + return result.quantize(Decimal("0." + "0" * 28)) + + +@udf(input_types=["DATE", "INT"], result_type="DATE") +def add_days_py(dt: datetime.date, days: int): + return dt + datetime.timedelta(days=days) + + +@udf(input_types=["TIMESTAMP", "INT"], result_type="TIMESTAMP") +def add_hours_py(dt: datetime.datetime, hours: int): + return dt + datetime.timedelta(hours=hours) + + +@udf(input_types=["ARRAY(VARCHAR)", "INT"], result_type="VARCHAR") +def array_access(array: List[str], idx: int) -> Optional[str]: + if idx == 0 or idx > len(array): + return None + return array[idx - 1] + + +@udf( + input_types=["ARRAY(INT64 NULL)", "INT64"], + result_type="INT NOT NULL", + skip_null=False, +) +def array_index_of(array: List[int], item: int): + if array is None: + return 0 + + try: + return array.index(item) + 1 + except ValueError: + return 0 + + +@udf(input_types=["MAP(VARCHAR,VARCHAR)", "VARCHAR"], result_type="VARCHAR") +def map_access(map: Dict[str, str], key: str) -> str: + return map[key] if key in map else None + + +@udf(input_types=["VARIANT", "VARCHAR"], result_type="VARIANT") +def json_access(data: Any, key: str) -> Any: + return data[key] + + +@udf(input_types=["ARRAY(VARIANT)"], result_type="VARIANT") +def json_concat(list: List[Any]) -> Any: + return list + + +@udf( + input_types=["TUPLE(ARRAY(VARIANT NULL), INT, VARCHAR)", "INT", "INT"], + result_type="TUPLE(VARIANT NULL, VARIANT NULL)", +) +def tuple_access( + tup: Tuple[List[Any], int, str], idx1: int, idx2: int +) -> Tuple[Any, Any]: + v1 = None if idx1 == 0 or idx1 > len(tup) else tup[idx1 - 1] + v2 = None if idx2 == 0 or idx2 > len(tup) else tup[idx2 - 1] + return v1, v2 + + +ALL_SCALAR_TYPES = [ + "BOOLEAN", + "TINYINT", + "SMALLINT", + "INT", + "BIGINT", + "UINT8", + "UINT16", + "UINT32", + "UINT64", + "FLOAT", + "DOUBLE", + "DATE", + "TIMESTAMP", + "VARCHAR", + "VARIANT", +] + + +@udf( + input_types=ALL_SCALAR_TYPES, + result_type=f"TUPLE({','.join(f'{t} NULL' for t in ALL_SCALAR_TYPES)})", +) +def return_all( + bool, + i8, + i16, + i32, + i64, + u8, + u16, + u32, + u64, + f32, + f64, + date, + timestamp, + varchar, + json, +): + return ( + bool, + i8, + i16, + i32, + i64, + u8, + u16, + u32, + u64, + f32, + f64, + date, + timestamp, + varchar, + json, + ) + + +@udf( + input_types=[f"ARRAY({t})" for t in ALL_SCALAR_TYPES], + result_type=f"TUPLE({','.join(f'ARRAY({t})' for t in ALL_SCALAR_TYPES)})", +) +def return_all_arrays( + bool, + i8, + i16, + i32, + i64, + u8, + u16, + u32, + u64, + f32, + f64, + date, + timestamp, + varchar, + json, +): + return ( + bool, + i8, + i16, + i32, + i64, + u8, + u16, + u32, + u64, + f32, + f64, + date, + timestamp, + varchar, + json, + ) + + +@udf( + input_types=[f"{t} NOT NULL" for t in ALL_SCALAR_TYPES], + result_type=f"TUPLE({','.join(f'{t}' for t in ALL_SCALAR_TYPES)})", +) +def return_all_non_nullable( + bool, + i8, + i16, + i32, + i64, + u8, + u16, + u32, + u64, + f32, + f64, + date, + timestamp, + varchar, + json, +): + return ( + bool, + i8, + i16, + i32, + i64, + u8, + u16, + u32, + u64, + f32, + f64, + date, + timestamp, + varchar, + json, + ) + + +@udf(input_types=["INT"], result_type="INT") +def wait(x): + time.sleep(0.1) + return x + + +@udf(input_types=["INT"], result_type="INT", io_threads=32) +def wait_concurrent(x): + time.sleep(0.1) + return x + + +if __name__ == "__main__": + udf_server = UDFServer("0.0.0.0:8815") + udf_server.add_function(add_signed) + udf_server.add_function(add_unsigned) + udf_server.add_function(add_float) + udf_server.add_function(binary_reverse) + udf_server.add_function(bool_select) + udf_server.add_function(gcd) + udf_server.add_function(split_and_join) + udf_server.add_function(decimal_div) + udf_server.add_function(hex_to_dec) + udf_server.add_function(add_days_py) + udf_server.add_function(add_hours_py) + udf_server.add_function(array_access) + udf_server.add_function(array_index_of) + udf_server.add_function(map_access) + udf_server.add_function(json_access) + udf_server.add_function(json_concat) + udf_server.add_function(tuple_access) + udf_server.add_function(return_all) + udf_server.add_function(return_all_arrays) + udf_server.add_function(return_all_non_nullable) + udf_server.add_function(wait) + udf_server.add_function(wait_concurrent) + udf_server.serve() diff --git a/python/pyproject.toml b/python/pyproject.toml index 4699c9c..bd430c8 100644 --- a/python/pyproject.toml +++ b/python/pyproject.toml @@ -7,7 +7,7 @@ classifiers = [ description = "Databend UDF Server" license = { text = "Apache-2.0" } name = "databend-udf" -version = "0.1.4" +version = "0.2.0" readme = "README.md" requires-python = ">=3.7" dependencies = ["pyarrow"]