diff --git a/CHANGELOG.md b/CHANGELOG.md index 9f430941..d034aeef 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -9,6 +9,8 @@ This project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.htm ### Changed - `pw.io.airbyte.read` can now be used with Airbyte connectors implemented in Python without requiring Docker. +- **BREAKING**: UDFs now verify the type of returned values at runtime. If it is possible to cast a returned value to a proper type, the values is cast. If the value does not match the expected type and can't be cast, an error is raised. +- **BREAKING**: `pw.reducers.ndarray` reducer requires input column to either have type `float`, `int` or `Array`. ## [0.13.2] - 2024-07-08 diff --git a/docs/2.developers/7.templates/.interval_over_gaussian_filter/article.py b/docs/2.developers/7.templates/.interval_over_gaussian_filter/article.py index dade0c91..c923d179 100644 --- a/docs/2.developers/7.templates/.interval_over_gaussian_filter/article.py +++ b/docs/2.developers/7.templates/.interval_over_gaussian_filter/article.py @@ -218,7 +218,7 @@ def load_to_pathway(x, y): points_within_50 = time_series.windowby( time_series.x, window=pw.temporal.intervals_over( - at=time_series.x, lower_bound=-50.0, upper_bound=50.0 + at=time_series.x, lower_bound=-50.0, upper_bound=50.0, is_outer=False ), ).reduce( pw.this._pw_window_location, @@ -378,7 +378,7 @@ def smooth_table(table): points_within_50 = table.windowby( table.x, window=pw.temporal.intervals_over( - at=table.x, lower_bound=-50.0, upper_bound=50.0 + at=table.x, lower_bound=-50.0, upper_bound=50.0, is_outer=False ), ).reduce( pw.this._pw_window_location, diff --git a/docs/2.developers/7.templates/.interval_over_upsampling/article.py b/docs/2.developers/7.templates/.interval_over_upsampling/article.py index 45b19d35..e872a219 100644 --- a/docs/2.developers/7.templates/.interval_over_upsampling/article.py +++ b/docs/2.developers/7.templates/.interval_over_upsampling/article.py @@ -194,7 +194,7 @@ def load_to_pathway(x, y): upsampled_stream = data_stream_B.windowby( data_stream_B.x, window=pw.temporal.intervals_over( - at=data_stream_A.x, lower_bound=-100.0, upper_bound=100.0 + at=data_stream_A.x, lower_bound=-100.0, upper_bound=100.0, is_outer=False ), ).reduce( x=pw.this._pw_window_location, diff --git a/python/pathway/engine.pyi b/python/pathway/engine.pyi index d35ffcfe..8ddedb96 100644 --- a/python/pathway/engine.pyi +++ b/python/pathway/engine.pyi @@ -30,7 +30,7 @@ class Pointer(Generic[_T]): def ref_scalar(*args, optional=False) -> Pointer: ... def ref_scalar_with_instance(*args, instance: Value, optional=False) -> Pointer: ... -class PathwayType(Enum): +class PathwayType: ANY: PathwayType STRING: PathwayType INT: PathwayType @@ -40,11 +40,17 @@ class PathwayType(Enum): DATE_TIME_NAIVE: PathwayType DATE_TIME_UTC: PathwayType DURATION: PathwayType - ARRAY: PathwayType + @staticmethod + def array(dim: int | None, wrapped: PathwayType) -> PathwayType: ... JSON: PathwayType - TUPLE: PathwayType + @staticmethod + def tuple(*args: PathwayType) -> PathwayType: ... + @staticmethod + def list(arg: PathwayType) -> PathwayType: ... BYTES: PathwayType PY_OBJECT_WRAPPER: PathwayType + @staticmethod + def optional(arg: PathwayType) -> PathwayType: ... class ConnectorMode(Enum): STATIC: ConnectorMode @@ -133,9 +139,11 @@ class DataRow: self, key: Pointer, values: list[Value], + *, time: int = 0, diff: int = 1, shard: int | None = None, + dtypes: list[PathwayType], ) -> None: ... class MissingValueError(BaseException): @@ -202,12 +210,16 @@ class BinaryOperator: class Expression: @staticmethod - def const(value: Value) -> Expression: ... + def const(value: Value, dtype: PathwayType) -> Expression: ... @staticmethod def argument(index: int) -> Expression: ... @staticmethod def apply( - fun: Callable, /, *args: Expression, propagate_none=False + fun: Callable, + /, + *args: Expression, + dtype: PathwayType, + propagate_none: bool = False, ) -> Expression: ... @staticmethod def is_none(expr: Expression) -> Expression: ... @@ -480,6 +492,7 @@ class Scope: propagate_none: bool, deterministic: bool, properties: TableProperties, + dtype: PathwayType, ) -> Table: ... def gradual_broadcast( self, @@ -713,7 +726,7 @@ def run_with_new_graph( def unsafe_make_pointer(arg) -> Pointer: ... class DataFormat: - value_fields: Any + value_fields: list[ValueField] def __init__(self, *args, **kwargs): ... @@ -735,7 +748,6 @@ class DataStorage: object_pattern: str mock_events: dict[tuple[str, int], list[SnapshotEvent]] | None table_name: str | None - column_names: list[str] | None def __init__(self, *args, **kwargs): ... class CsvParserSettings: @@ -746,7 +758,7 @@ class AwsS3Settings: class ValueField: name: str - def __init__(self, name: str, type_: PathwayType, *, is_optional: bool = False): ... + def __init__(self, name: str, type_: PathwayType): ... def set_default(self, *args, **kwargs): ... class PythonSubject: diff --git a/python/pathway/internals/_io_helpers.py b/python/pathway/internals/_io_helpers.py index f6b11459..8beea3d6 100644 --- a/python/pathway/internals/_io_helpers.py +++ b/python/pathway/internals/_io_helpers.py @@ -5,7 +5,7 @@ import boto3 import boto3.session -from pathway.internals import api, dtype as dt, schema +from pathway.internals import api, schema from pathway.internals.table import Table from pathway.internals.trace import trace_user_frame @@ -133,8 +133,7 @@ def _format_output_value_fields(table: Table) -> list[api.ValueField]: value_fields.append( api.ValueField( column_name, - column_data.dtype.map_to_engine(), - is_optional=isinstance(column_data.dtype, dt.Optional), + column_data.dtype.to_engine(), ) ) @@ -146,19 +145,11 @@ def _form_value_fields(schema: type[schema.Schema]) -> list[api.ValueField]: default_values = schema.default_values() result = [] - # XXX fix mapping schema types to PathwayType - types = { - name: (dt.unoptionalize(dtype).to_engine(), isinstance(dtype, dt.Optional)) - for name, dtype in schema._dtypes().items() - } + types = {name: dtype.to_engine() for name, dtype in schema._dtypes().items()} for f in schema.column_names(): - simple_type, is_optional = types.get(f, (None, False)) - if ( - simple_type is None - ): # types can contain None if there is field of type None in the schema - simple_type = api.PathwayType.ANY - value_field = api.ValueField(f, simple_type, is_optional=is_optional) + dtype = types.get(f, api.PathwayType.ANY) + value_field = api.ValueField(f, dtype) if f in default_values: value_field.set_default(default_values[f]) result.append(value_field) diff --git a/python/pathway/internals/api.py b/python/pathway/internals/api.py index c8c3c818..fad7893e 100644 --- a/python/pathway/internals/api.py +++ b/python/pathway/internals/api.py @@ -154,6 +154,10 @@ def static_table_from_pandas( ordinary_columns = [ column for column in df.columns if column not in PANDAS_PSEUDOCOLUMNS ] + if column_types: + dtypes = [column_types[c].to_engine() for c in ordinary_columns] + else: + dtypes = [PathwayType.ANY] * len(ordinary_columns) if connector_properties is None: column_properties = [] @@ -163,9 +167,7 @@ def static_table_from_pandas( if v is not None: dtype = type(v) break - column_properties.append( - ColumnProperties(dtype=dt.wrap(dtype).map_to_engine()) - ) + column_properties.append(ColumnProperties(dtype=dt.wrap(dtype).to_engine())) connector_properties = ConnectorProperties(column_properties=column_properties) assert len(connector_properties.column_properties) == len( @@ -181,7 +183,9 @@ def static_table_from_pandas( if diff not in [-1, 1]: raise ValueError(f"Column {DIFF_PSEUDOCOLUMN} can only contain 1 and -1.") shard = data[SHARD_PSEUDOCOLUMN][i] if SHARD_PSEUDOCOLUMN in data else None - input_row = DataRow(key, values, time=time, diff=diff, shard=shard) + input_row = DataRow( + key, values, time=time, diff=diff, shard=shard, dtypes=dtypes + ) input_data.append(input_row) return scope.static_table(input_data, connector_properties) diff --git a/python/pathway/internals/datasource.py b/python/pathway/internals/datasource.py index cca872cb..333d00dd 100644 --- a/python/pathway/internals/datasource.py +++ b/python/pathway/internals/datasource.py @@ -30,7 +30,7 @@ def connector_properties(self) -> api.ConnectorProperties: for column in self.schema.columns().values(): columns.append( api.ColumnProperties( - dtype=column.dtype.map_to_engine(), + dtype=column.dtype.to_engine(), append_only=column.append_only, ) ) diff --git a/python/pathway/internals/dtype.py b/python/pathway/internals/dtype.py index b3ca2e43..9eda4533 100644 --- a/python/pathway/internals/dtype.py +++ b/python/pathway/internals/dtype.py @@ -25,11 +25,8 @@ class DType(ABC): _cache: dict[typing.Any, DType] = {} - def to_engine(self) -> api.PathwayType | None: - return None - - def map_to_engine(self) -> api.PathwayType: - return self.to_engine() or api.PathwayType.ANY + @abstractmethod + def to_engine(self) -> api.PathwayType: ... @abstractmethod def is_value_compatible(self, arg) -> bool: ... @@ -125,6 +122,9 @@ def __new__(cls) -> _NoneDType: def is_value_compatible(self, arg): return arg is None or isinstance(arg, pd._libs.missing.NAType) + def to_engine(self) -> api.PathwayType: + return api.PathwayType.ANY + @property def typehint(self) -> None: return None @@ -184,6 +184,9 @@ def __new__( def is_value_compatible(self, arg): return callable(arg) + def to_engine(self) -> api.PathwayType: + return api.PathwayType.ANY # also passed to the engine as column properties + @cached_property def typehint(self) -> typing.Any: if isinstance(self.arg_types, EllipsisType): @@ -207,7 +210,7 @@ def _set_args(self, n_dim, wrapped): self.n_dim = n_dim def to_engine(self) -> api.PathwayType: - return api.PathwayType.ARRAY + return api.PathwayType.array(self.n_dim, self.wrapped.to_engine()) def __new__(cls, n_dim, wrapped) -> Array: dtype = wrap(wrapped) @@ -298,6 +301,9 @@ def __repr__(self): def _set_args(self, wrapped): self.wrapped = wrapped + def to_engine(self) -> api.PathwayType: + return api.PathwayType.optional(self.wrapped.to_engine()) + def __new__(cls, arg: DType) -> DType: # type:ignore[misc] arg = wrap(arg) if arg == NONE or isinstance(arg, Optional) or arg == ANY: @@ -327,7 +333,7 @@ def _set_args(self, args): self.args = args def to_engine(self) -> PathwayType: - return api.PathwayType.TUPLE + return api.PathwayType.tuple(*[arg.to_engine() for arg in self.args]) def __new__(cls, *args: DType | EllipsisType) -> Tuple | List: # type: ignore[misc] if any(isinstance(arg, EllipsisType) for arg in args): @@ -391,7 +397,7 @@ def _set_args(self, wrapped): self.wrapped = wrapped def to_engine(self) -> PathwayType: - return api.PathwayType.TUPLE + return api.PathwayType.list(self.wrapped.to_engine()) def is_value_compatible(self, arg): return isinstance(arg, (tuple, list)) and all( diff --git a/python/pathway/internals/expressions/numerical.py b/python/pathway/internals/expressions/numerical.py index 341362e5..f681a092 100644 --- a/python/pathway/internals/expressions/numerical.py +++ b/python/pathway/internals/expressions/numerical.py @@ -124,12 +124,16 @@ def round(self, decimals: expr.ColumnExpression | int = 0) -> expr.ColumnExpress ( (dt.INT, dt.INT), dt.INT, - lambda x, y: api.Expression.apply(round, x, y), + lambda x, y: api.Expression.apply( + round, x, y, dtype=dt.INT.to_engine() + ), ), ( (dt.FLOAT, dt.INT), dt.FLOAT, - lambda x, y: api.Expression.apply(round, x, y), + lambda x, y: api.Expression.apply( + round, x, y, dtype=dt.FLOAT.to_engine() + ), ), ), "num.round", @@ -175,14 +179,18 @@ def fill_na(self, default_value: int | float) -> expr.ColumnExpression: dt.FLOAT, dt.FLOAT, lambda x: api.Expression.apply( - lambda y: float(default_value) if math.isnan(y) else y, x + lambda y: float(default_value) if math.isnan(y) else y, + x, + dtype=dt.FLOAT.to_engine(), ), ), ( dt.Optional(dt.INT), dt.INT, lambda x: api.Expression.apply( - lambda y: int(default_value) if y is None else y, x + lambda y: int(default_value) if y is None else y, + x, + dtype=dt.INT.to_engine(), ), ), ( @@ -195,6 +203,7 @@ def fill_na(self, default_value: int | float) -> expr.ColumnExpression: else y ), x, + dtype=dt.FLOAT.to_engine(), ), ), ), diff --git a/python/pathway/internals/expressions/string.py b/python/pathway/internals/expressions/string.py index 438fc361..73eb85b5 100644 --- a/python/pathway/internals/expressions/string.py +++ b/python/pathway/internals/expressions/string.py @@ -58,7 +58,15 @@ def lower(self) -> expr.ColumnExpression: """ return expr.MethodCallExpression( - ((dt.STR, dt.STR, lambda x: api.Expression.apply(str.lower, x)),), + ( + ( + dt.STR, + dt.STR, + lambda x: api.Expression.apply( + str.lower, x, dtype=dt.STR.to_engine() + ), + ), + ), "str.lower", self._expression, ) @@ -91,7 +99,15 @@ def upper(self) -> expr.ColumnExpression: """ return expr.MethodCallExpression( - ((dt.STR, dt.STR, lambda x: api.Expression.apply(str.upper, x)),), + ( + ( + dt.STR, + dt.STR, + lambda x: api.Expression.apply( + str.upper, x, dtype=dt.STR.to_engine() + ), + ), + ), "str.upper", self._expression, ) @@ -124,7 +140,15 @@ def reversed(self) -> expr.ColumnExpression: """ return expr.MethodCallExpression( - ((dt.STR, dt.STR, lambda x: api.Expression.apply(lambda y: y[::-1], x)),), + ( + ( + dt.STR, + dt.STR, + lambda x: api.Expression.apply( + lambda y: y[::-1], x, dtype=dt.STR.to_engine() + ), + ), + ), "str.reverse", self._expression, ) @@ -157,7 +181,13 @@ def len(self) -> expr.ColumnExpression: """ return expr.MethodCallExpression( - ((dt.STR, dt.INT, lambda x: api.Expression.apply(len, x)),), + ( + ( + dt.STR, + dt.INT, + lambda x: api.Expression.apply(len, x, dtype=dt.INT.to_engine()), + ), + ), "str.len", self._expression, ) @@ -225,7 +255,12 @@ def replace( (dt.STR, dt.STR, dt.STR, dt.INT), dt.STR, lambda x, y, z, c: api.Expression.apply( - lambda s1, s2, s3, cnt: s1.replace(s2, s3, cnt), x, y, z, c + lambda s1, s2, s3, cnt: s1.replace(s2, s3, cnt), + x, + y, + z, + c, + dtype=dt.STR.to_engine(), ), ), ), @@ -268,7 +303,9 @@ def startswith( ( (dt.STR, dt.STR), dt.BOOL, - lambda x, y: api.Expression.apply(str.startswith, x, y), + lambda x, y: api.Expression.apply( + str.startswith, x, y, dtype=dt.BOOL.to_engine() + ), ), ), "str.starts_with", @@ -308,7 +345,9 @@ def endswith( ( (dt.STR, dt.STR), dt.BOOL, - lambda x, y: api.Expression.apply(str.endswith, x, y), + lambda x, y: api.Expression.apply( + str.endswith, x, y, dtype=dt.BOOL.to_engine() + ), ), ), "str.ends_with", @@ -341,7 +380,15 @@ def swapcase(self) -> expr.ColumnExpression: """ return expr.MethodCallExpression( - ((dt.STR, dt.STR, lambda x: api.Expression.apply(str.swapcase, x)),), + ( + ( + dt.STR, + dt.STR, + lambda x: api.Expression.apply( + str.swapcase, x, dtype=dt.STR.to_engine() + ), + ), + ), "str.swap_case", self._expression, ) @@ -379,7 +426,9 @@ def strip( ( (dt.STR, dt.Optional(dt.STR)), dt.STR, - lambda x, y: api.Expression.apply(str.strip, x, y), + lambda x, y: api.Expression.apply( + str.strip, x, y, dtype=dt.STR.to_engine() + ), ), ), "str.strip", @@ -408,7 +457,15 @@ def title(self) -> expr.ColumnExpression: """ return expr.MethodCallExpression( - ((dt.STR, dt.STR, lambda x: api.Expression.apply(str.title, x)),), + ( + ( + dt.STR, + dt.STR, + lambda x: api.Expression.apply( + str.title, x, dtype=dt.STR.to_engine() + ), + ), + ), "str.title", self._expression, ) @@ -454,7 +511,9 @@ def count( dt.Optional(dt.INT), ), dt.INT, - lambda *args: api.Expression.apply(str.count, *args), + lambda *args: api.Expression.apply( + str.count, *args, dtype=dt.INT.to_engine() + ), ), ), "str.count", @@ -506,7 +565,9 @@ def find( dt.Optional(dt.INT), ), dt.INT, - lambda *args: api.Expression.apply(str.find, *args), + lambda *args: api.Expression.apply( + str.find, *args, dtype=dt.INT.to_engine() + ), ), ), "str.find", @@ -558,7 +619,9 @@ def rfind( dt.Optional(dt.INT), ), dt.INT, - lambda *args: api.Expression.apply(str.rfind, *args), + lambda *args: api.Expression.apply( + str.rfind, *args, dtype=dt.INT.to_engine() + ), ), ), "str.rfind", @@ -617,7 +680,9 @@ def removeprefix( ( (dt.STR, dt.STR), dt.STR, - lambda x, y: api.Expression.apply(str.removeprefix, x, y), + lambda x, y: api.Expression.apply( + str.removeprefix, x, y, dtype=dt.STR.to_engine() + ), ), ), "str.remove_prefix", @@ -674,7 +739,9 @@ def removesuffix( ( (dt.STR, dt.STR), dt.STR, - lambda x, y: api.Expression.apply(str.removesuffix, x, y), + lambda x, y: api.Expression.apply( + str.removesuffix, x, y, dtype=dt.STR.to_engine() + ), ), ), "str.remove_suffix", @@ -721,6 +788,7 @@ def slice( x, y, z, + dtype=dt.STR.to_engine(), ), ), ), diff --git a/python/pathway/internals/graph_runner/expression_evaluator.py b/python/pathway/internals/graph_runner/expression_evaluator.py index 1c9fd044..4154db40 100644 --- a/python/pathway/internals/graph_runner/expression_evaluator.py +++ b/python/pathway/internals/graph_runner/expression_evaluator.py @@ -73,7 +73,7 @@ def column_properties(self, column: clmn.Column) -> api.ColumnProperties: props = column.properties return api.ColumnProperties( trace=column.trace.to_engine(), - dtype=props.dtype.map_to_engine(), + dtype=props.dtype.to_engine(), append_only=props.append_only, ) @@ -337,7 +337,9 @@ def eval_unary_op( ) is not None: return result_expression - return api.Expression.apply(operator_fun, arg) + return api.Expression.apply( + operator_fun, arg, dtype=expression._dtype.to_engine() + ) def eval_binary_op( self, @@ -400,7 +402,7 @@ def eval_const( expression: expr.ColumnConstExpression, eval_state: RowwiseEvalState | None = None, ): - return api.Expression.const(expression._val) + return api.Expression.const(expression._val, expression._dtype.to_engine()) def eval_call( self, @@ -425,6 +427,7 @@ def eval_apply( fun, *(self.eval_expression(arg, eval_state=eval_state) for arg in args), propagate_none=expression._propagate_none, + dtype=expression._dtype.to_engine(), ) def eval_async_apply( @@ -451,6 +454,7 @@ def eval_async_apply( expression._propagate_none, expression._deterministic, self._table_properties(output_storage), + expression._dtype.to_engine(), ) assert eval_state is not None @@ -555,10 +559,13 @@ def eval_require( ] res = val + none_expr = api.Expression.const( + None, dt.Optional(expression._dtype).to_engine() + ) for arg in reversed(args): res = api.Expression.if_else( api.Expression.is_none(arg), - api.Expression.const(None), + none_expr, res, ) diff --git a/python/pathway/internals/operator_mapping.py b/python/pathway/internals/operator_mapping.py index 5bf8bf0d..77737bfb 100644 --- a/python/pathway/internals/operator_mapping.py +++ b/python/pathway/internals/operator_mapping.py @@ -66,9 +66,9 @@ def get_unary_operators_mapping(op, operand_dtype, default=None): def get_unary_expression(expr, op, expr_dtype: dt.DType, default=None): op_engine = _unary_operators_to_engine.get(op) - expr_dtype_engine = expr_dtype.to_engine() - if op_engine is None or expr_dtype_engine is None: + if op_engine is None: return default + expr_dtype_engine = expr_dtype.to_engine() expression = api.Expression.unary_expression(expr, op_engine, expr_dtype_engine) return expression if expression is not None else default @@ -218,7 +218,7 @@ def get_binary_expression( op_engine = _binary_operators_to_engine.get(op) left_dtype_engine = left_dtype.to_engine() right_dtype_engine = right_dtype.to_engine() - if op_engine is None or left_dtype_engine is None or right_dtype_engine is None: + if op_engine is None: return default expression = api.Expression.binary_expression( @@ -251,8 +251,6 @@ def get_cast_operators_mapping( source_type_engine = dt.unoptionalize(source_type).to_engine() target_type_engine = dt.unoptionalize(target_type).to_engine() - if source_type_engine is None or target_type_engine is None: - return default if isinstance(source_type, dt.Optional) and isinstance(target_type, dt.Optional): fun = api.Expression.cast_optional else: @@ -271,10 +269,6 @@ def get_convert_operators_mapping( source_type_engine = dt.unoptionalize(source_type).to_engine() target_type_engine = dt.unoptionalize(target_type).to_engine() - assert ( - source_type_engine is not None and target_type_engine is not None - ), "invalid pathway type" - expression = api.Expression.convert_optional( expr, source_type_engine, diff --git a/python/pathway/internals/reducers.py b/python/pathway/internals/reducers.py index 3f20e69b..0cecf3d5 100644 --- a/python/pathway/internals/reducers.py +++ b/python/pathway/internals/reducers.py @@ -184,6 +184,24 @@ def additional_args_from_context( return () +class TupleConvertibleToNDArrayWrappingReducer(TupleWrappingReducer): + def return_type( + self, arg_types: builtins.list[dt.DType], id_type: dt.DType + ) -> dt.DType: + arg_type = arg_types[0] + if self._skip_nones: + arg_type = dt.unoptionalize(arg_type) + if builtins.any( + dt.dtype_issubclass(arg_type, dtype) + for dtype in [dt.FLOAT, dt.ANY_ARRAY, dt.ANY_TUPLE] + ): + return dt.List(arg_type) + raise TypeError( + f"Pathway does not support using reducer {self.name}" + + f" on column of type {arg_type}.\n" + ) + + class StatefulManyReducer(Reducer): name = "stateful_many" combine_many: api.CombineMany @@ -220,6 +238,14 @@ def _tuple(skip_nones: bool): ) +def _ndarray(skip_nones: bool): + return TupleConvertibleToNDArrayWrappingReducer( + name="ndarray", + engine_reducer=api.Reducer.tuple(skip_nones), + skip_nones=skip_nones, + ) + + _argmin = IdTypeUnaryReducer(name="argmin", engine_reducer=api.Reducer.ARG_MIN) _argmax = IdTypeUnaryReducer(name="argmax", engine_reducer=api.Reducer.ARG_MAX) _unique = TypePreservingUnaryReducer(name="unique", engine_reducer=api.Reducer.UNIQUE) @@ -618,9 +644,10 @@ def ndarray(expression: expr.ColumnExpression, *, skip_nones: bool = False): """ from pathway.internals.common import apply_with_type - return apply_with_type( - np.array, np.ndarray, tuple(expression, skip_nones=skip_nones) + tuples = _apply_unary_reducer( + _ndarray(skip_nones), expression, skip_nones=skip_nones ) + return apply_with_type(np.array, np.ndarray, tuples) def earliest(expression: expr.ColumnExpression) -> expr.ColumnExpression: diff --git a/python/pathway/io/_utils.py b/python/pathway/io/_utils.py index 5151e50d..f60fbf25 100644 --- a/python/pathway/io/_utils.py +++ b/python/pathway/io/_utils.py @@ -46,9 +46,7 @@ PathwayType.DATE_TIME_NAIVE: dt.DATE_TIME_NAIVE, PathwayType.DATE_TIME_UTC: dt.DATE_TIME_UTC, PathwayType.DURATION: dt.DURATION, - PathwayType.ARRAY: dt.ANY_ARRAY, PathwayType.JSON: dt.JSON, - PathwayType.TUPLE: dt.ANY_TUPLE, PathwayType.BYTES: dt.BYTES, PathwayType.PY_OBJECT_WRAPPER: dt.ANY_PY_OBJECT_WRAPPER, } diff --git a/python/pathway/io/http/_server.py b/python/pathway/io/http/_server.py index 252e5570..b1591366 100644 --- a/python/pathway/io/http/_server.py +++ b/python/pathway/io/http/_server.py @@ -246,7 +246,7 @@ def _construct_openapi_get_request_schema(self, schema) -> list: } self._add_optional_traits_if_present(field_description, props) openapi_type = _ENGINE_TO_OPENAPI_TYPE.get( - unoptionalize(props.dtype).map_to_engine() + unoptionalize(props.dtype).to_engine() ) if openapi_type: field_description["schema"] = { @@ -266,7 +266,7 @@ def _construct_openapi_json_schema(self, schema) -> dict: for name, props in schema.columns().items(): openapi_type = _ENGINE_TO_OPENAPI_TYPE.get( - unoptionalize(props.dtype).map_to_engine() + unoptionalize(props.dtype).to_engine() ) if openapi_type is None: # not something we can clearly define the type for, so it will be @@ -284,7 +284,7 @@ def _construct_openapi_json_schema(self, schema) -> dict: field_description["default"] = props.default_value self._add_optional_traits_if_present(field_description, props) - openapi_format = _ENGINE_TO_OPENAPI_FORMAT.get(props.dtype.map_to_engine()) + openapi_format = _ENGINE_TO_OPENAPI_FORMAT.get(props.dtype.to_engine()) if openapi_format is not None: field_description["format"] = openapi_format diff --git a/python/pathway/io/sqlite/__init__.py b/python/pathway/io/sqlite/__init__.py index 67a89f70..15fb30a1 100644 --- a/python/pathway/io/sqlite/__init__.py +++ b/python/pathway/io/sqlite/__init__.py @@ -49,7 +49,6 @@ def read( storage_type="sqlite", path=fspath(path), table_name=table_name, - column_names=schema.column_names(), mode=api.ConnectorMode.STREAMING, ) data_format = api.DataFormat( diff --git a/python/pathway/stdlib/ml/index.py b/python/pathway/stdlib/ml/index.py index a7fb5f1d..12d17be0 100644 --- a/python/pathway/stdlib/ml/index.py +++ b/python/pathway/stdlib/ml/index.py @@ -135,13 +135,13 @@ def get_nearest_items( >>> pw.debug.compute_and_print(relevant_docs) | document | embeddings ^YYY4HAB... | () | () - ^X1MXHYY... | ('document 2', 'document 3') | ((1, 1, 0), (0, 0, 1)) + ^X1MXHYY... | ('document 2', 'document 3') | ((1.0, 1.0, 0.0), (0.0, 0.0, 1.0)) >>> index = KNNIndex(documents.embeddings, documents, n_dimensions=3, metadata=documents.metadata) >>> relevant_docs_meta = index.get_nearest_items(queries.embeddings, k=2, metadata_filter="foo >= `3`") >>> pw.debug.compute_and_print(relevant_docs_meta) - | document | embeddings | metadata - ^YYY4HAB... | () | () | () - ^X1MXHYY... | ('document 3',) | ((0, 0, 1),) | (pw.Json({'foo': 3}),) + | document | embeddings | metadata + ^YYY4HAB... | () | () | () + ^X1MXHYY... | ('document 3',) | ((0.0, 0.0, 1.0),) | (pw.Json({'foo': 3}),) >>> data = pw.debug.table_from_markdown( ... ''' ... x | y | __time__ diff --git a/python/pathway/tests/test_api.py b/python/pathway/tests/test_api.py index 80c655e9..b2748c6a 100644 --- a/python/pathway/tests/test_api.py +++ b/python/pathway/tests/test_api.py @@ -47,7 +47,7 @@ def static_table_from_pandas(scope, df, ptr_columns=(), legacy=True): for col in schema.columns().values(): columns.append( api.ColumnProperties( - dtype=col.dtype.map_to_engine(), + dtype=col.dtype.to_engine(), append_only=col.append_only, ) ) @@ -1188,7 +1188,7 @@ def test_value_type_via_python(event_loop, value): def build(s): key = api.ref_scalar() universe = s.static_universe([key]) - dtype = dt.wrap(type(value)).map_to_engine() + dtype = dt.wrap(type(value)).to_engine() column = s.static_column( universe, [(key, value)], properties=api.ColumnProperties(dtype=dtype) ) diff --git a/python/pathway/tests/test_common.py b/python/pathway/tests/test_common.py index bd8f35b7..a70a4198 100644 --- a/python/pathway/tests/test_common.py +++ b/python/pathway/tests/test_common.py @@ -3199,6 +3199,39 @@ def test_ndarray_reducer(): assert_table_equality_wo_index(res, expected) +def test_ndarray_reducer_on_ndarrays(): + t = pw.debug.table_from_markdown( + """ + a | b | val + 0 | 0 | 1 + 0 | 0 | 2 + 0 | 1 | 3 + 0 | 1 | 4 + 1 | 0 | 5 + 1 | 0 | 6 + 1 | 0 | 7 + 1 | 1 | 8 + 1 | 1 | 9 + 1 | 1 | 0 + """ + ) + s = t.groupby(pw.this.a, pw.this.b, sort_by=pw.this.val).reduce( + pw.this.a, val=pw.reducers.ndarray(pw.this.val) + ) + res = s.groupby(pw.this.a, sort_by=pw.this.val).reduce( + pw.this.a, val=pw.reducers.ndarray(pw.this.val) + ) + expected = pw.debug.table_from_pandas( + pd.DataFrame( + { + "a": [0, 1], + "val": [np.array([[1, 2], [3, 4]]), np.array([[0, 8, 9], [5, 6, 7]])], + } + ) + ) + assert_table_equality_wo_index(res, expected) + + def test_earliest_and_latest_reducer(): t = T( """ diff --git a/python/pathway/tests/test_errors.py b/python/pathway/tests/test_errors.py index 71015af7..80511e0b 100644 --- a/python/pathway/tests/test_errors.py +++ b/python/pathway/tests/test_errors.py @@ -394,6 +394,51 @@ async def div(a: int, b: int) -> int: ) +def test_udf_return_type(): + @pw.udf + def f(a: int) -> str: + if a % 2 == 0: + return str(a) + "x" + else: + return a # type: ignore[return-value] + + res = ( + T( + """ + a + 1 + 2 + 3 + 4 + """ + ) + .select(a=f(pw.this.a)) + .select(a=pw.fill_error(pw.this.a, "xx")) + ) + expected = T( + """ + a + xx + 2x + xx + 4x + """ + ) + expected_err = T( + """ + message + TypeError: cannot create an object of type String from value 1 + TypeError: cannot create an object of type String from value 3 + """, + split_on_whitespace=False, + ) + assert_table_equality_wo_index( + (res, pw.global_error_log().select(pw.this.message)), + (expected, expected_err), + terminate_on_error=False, + ) + + def test_concat(): t1 = pw.debug.table_from_markdown( """ @@ -1077,10 +1122,10 @@ class InputSchema(pw.Schema): expected_errors = T( """ message - failed to parse value "t" at field "c" according to the type Int in schema: invalid digit found in string - failed to parse value "x" at field "b" according to the type Int in schema: invalid digit found in string - failed to parse value "y" at field "c" according to the type Int in schema: invalid digit found in string - failed to parse value "z" at field "b" according to the type Int in schema: invalid digit found in string + failed to parse value "t" at field "c" according to the type int in schema: invalid digit found in string + failed to parse value "x" at field "b" according to the type int in schema: invalid digit found in string + failed to parse value "y" at field "c" according to the type int in schema: invalid digit found in string + failed to parse value "z" at field "b" according to the type int in schema: invalid digit found in string """, split_on_whitespace=False, ) @@ -1114,10 +1159,10 @@ class InputSchema(pw.Schema): """ message error in primary key, skipping the row: failed to parse value "x" at field "b" \ -according to the type Int in schema: invalid digit found in string - failed to parse value "y" at field "c" according to the type Int in schema: invalid digit found in string +according to the type int in schema: invalid digit found in string + failed to parse value "y" at field "c" according to the type int in schema: invalid digit found in string error in primary key, skipping the row: failed to parse value "z" at field "b" \ -according to the type Int in schema: invalid digit found in string +according to the type int in schema: invalid digit found in string """, split_on_whitespace=False, ) @@ -1153,10 +1198,10 @@ class InputSchema(pw.Schema): expected_errors = T( """ message - value "x" in field "b" is inconsistent with type Int from schema - value "1" in field "b" is inconsistent with type Int from schema - value "t" in field "c" is inconsistent with type Int / None from schema - value "y" in field "c" is inconsistent with type Int / None from schema + failed to create a field "b" with type int from json payload: "x" + failed to create a field "b" with type int from json payload: "1" + failed to create a field "c" with type int / None from json payload: "t" + failed to create a field "c" with type int / None from json payload: "y" """, split_on_whitespace=False, ).select(message=pw.this.message.str.replace("/", "|")) @@ -1192,9 +1237,9 @@ class InputSchema(pw.Schema): expected_errors = T( """ message - error in primary key, skipping the row: value "x" in field "b" is inconsistent with type Int from schema - error in primary key, skipping the row: value "1" in field "b" is inconsistent with type Int from schema - value "y" in field "c" is inconsistent with type Int / None from schema + error in primary key, skipping the row: failed to create a field "b" with type int from json payload: "x" + error in primary key, skipping the row: failed to create a field "b" with type int from json payload: "1" + failed to create a field "c" with type int / None from json payload: "y" """, split_on_whitespace=False, ).select(message=pw.this.message.str.replace("/", "|")) @@ -1236,8 +1281,8 @@ def run(self): expected_errors = T( """ message - value 2.3 in field "a" is inconsistent with type Int from schema - value 11 in field "b" is inconsistent with type String from schema + cannot create a field "a" with type int from value 2.3 + cannot create a field "b" with type str from value 11 no value for "b" field and no default specified """, split_on_whitespace=False, @@ -1275,8 +1320,8 @@ def run(self): expected_errors = T( """ message - value 2.3 in field "a" is inconsistent with type Int from schema - error in primary key, skipping the row: value 11 in field "b" is inconsistent with type String from schema + cannot create a field "a" with type int from value 2.3 + error in primary key, skipping the row: cannot create a field "b" with type str from value 11 error in primary key, skipping the row: no value for "b" field and no default specified """, split_on_whitespace=False, diff --git a/python/pathway/tests/test_types.py b/python/pathway/tests/test_types.py index 9be85a84..ae43ab75 100644 --- a/python/pathway/tests/test_types.py +++ b/python/pathway/tests/test_types.py @@ -1,7 +1,14 @@ # Copyright © 2024 Pathway +from typing import Any + +import numpy as np +import pandas as pd +import pytest + import pathway as pw import pathway.internals.dtype as dt +from pathway.internals.schema import schema_from_types from pathway.tests.utils import T, assert_table_equality_wo_index @@ -186,3 +193,49 @@ class OutputNumbersAsString(pw.Schema): ) assert_table_equality_wo_index(t, expected) + + +@pytest.mark.parametrize( + "data,dtype", + [ + ([0.5, 1, 2, 3], float), + ([[1, 2, 3], [0.4, 0.4, 2], [0.1, 0.2, 0.3]], list[float]), + ([[1, 2], [3, 4.2], [5, 6.2]], tuple[int, float]), + ( + [1, "a", "xyz", "", [4, 3, 2], {"a": 2, "b": 3}, pw.Json(10), True, 13.5], + pw.Json, + ), + ([np.array([1, 2, 3]), np.array([2.3, 3.4, 1.2])], dt.Array(None, float)), + ], +) +def test_udfs_and_python_connectors_take_type_into_account( + data: list[Any], dtype: type +): + internal_dtype = dt.wrap(dtype) + schema = schema_from_types(a=dtype) + + class Subject(pw.io.python.ConnectorSubject): + def run(self): + for entry in data: + self.next(a=entry) + + @pw.udf(return_type=dtype) + def producer(index: int): + return data[index] + + @pw.udf(return_type=dtype) + def assert_type_is_correct(entry): + print(entry) + internal_dtype.is_value_compatible(entry) + return entry + + t1 = pw.io.python.read(Subject(), schema=schema).select( + a=assert_type_is_correct(pw.this.a) + ) + t2 = ( + pw.debug.table_from_pandas(pd.DataFrame({"a": np.arange(len(data))})) + .select(a=producer(pw.this.a)) + .select(a=assert_type_is_correct(pw.this.a)) + ) + + assert_table_equality_wo_index(t1, t2) diff --git a/python/pathway/tests/test_udf.py b/python/pathway/tests/test_udf.py index 3c564cff..234315c2 100644 --- a/python/pathway/tests/test_udf.py +++ b/python/pathway/tests/test_udf.py @@ -995,3 +995,29 @@ def f(a: int) -> int: with warnings.catch_warnings(): warnings.simplefilter("error") f(pw.this.a) + + +def test_cast_on_return() -> None: + @pw.udf() + def f(a: int) -> float: + return a + + t = pw.debug.table_from_markdown( + """ + a | b + 1 | 1.5 + 2 | 2.5 + 3 | 3.5 + """ + ).with_columns(a=f(pw.this.a)) + + res = t.select(c=pw.this.a + pw.this.b) + expected = pw.debug.table_from_markdown( + """ + c + 2.5 + 4.5 + 6.5 + """ + ) + assert_table_equality(res, expected) diff --git a/python/pathway/tests/utils.py b/python/pathway/tests/utils.py index b2113320..4a6f40d3 100644 --- a/python/pathway/tests/utils.py +++ b/python/pathway/tests/utils.py @@ -232,6 +232,8 @@ def assert_equal_tables(t0: api.CapturedStream, t1: api.CapturedStream) -> None: def make_value_hashable(val: api.Value): if isinstance(val, np.ndarray): return (type(val), val.dtype, val.shape, str(val)) + elif isinstance(val, pw.Json): + return (type(val), repr(val)) else: return val diff --git a/src/connectors/data_format.rs b/src/connectors/data_format.rs index ae2c83ed..fb331093 100644 --- a/src/connectors/data_format.rs +++ b/src/connectors/data_format.rs @@ -13,16 +13,16 @@ use std::str::{from_utf8, Utf8Error}; use crate::connectors::metadata::SourceMetadata; use crate::connectors::ReaderContext::{Diff, KeyValue, RawBytes, TokenizedEntries}; use crate::connectors::{DataEventType, Offset, ReaderContext, SessionType, SnapshotEvent}; -use crate::engine::error::{limit_length, DynError, DynResult}; -use crate::engine::{CompoundType, DataError, Key, Result, Timestamp, Type, Value}; +use crate::engine::error::{limit_length, DynError, DynResult, STANDARD_OBJECT_LENGTH_LIMIT}; +use crate::engine::{Error, Key, Result, Timestamp, Type, Value}; -use itertools::Itertools; +use itertools::{chain, Itertools}; use log::error; use serde::ser::{SerializeMap, Serializer}; use serde_json::json; use serde_json::Value as JsonValue; -use super::data_storage::SpecialEvent; +use super::data_storage::{ConversionError, SpecialEvent}; pub const COMMIT_LITERAL: &str = "*COMMIT*"; const DEBEZIUM_EMPTY_KEY_PAYLOAD: &str = "{\"payload\": {\"before\": {}, \"after\": {}}}"; @@ -140,18 +140,18 @@ pub enum ParseError { SchemaNotSatisfied { value: String, field_name: String, - type_: CompoundType, + type_: Type, error: DynError, }, #[error("too small number of csv tokens in the line: {0}")] UnexpectedNumberOfCsvTokens(usize), - #[error("failed to create a field {field_name:?} with type {type_} from the following json payload: {}", limit_length(format!("{payload}"), 500))] + #[error("failed to create a field {field_name:?} with type {type_} from json payload: {}", limit_length(format!("{payload}"), STANDARD_OBJECT_LENGTH_LIMIT))] FailedToParseFromJson { field_name: String, payload: JsonValue, - type_: CompoundType, + type_: Type, }, #[error("key-value pair has unexpected number of tokens: {0} instead of 2")] @@ -186,20 +186,13 @@ pub enum ParseError { Utf8DecodeFailed(#[from] Utf8Error), #[error("parsing {0} from an external datasource is not supported")] - UnparsableType(CompoundType), + UnparsableType(Type), #[error("error in primary key, skipping the row: {0}")] ErrorInKey(DynError), #[error("no value for {field_name:?} field and no default specified")] NoDefault { field_name: String }, - - #[error("value {} in field {field_name:?} is inconsistent with type {type_} from schema", limit_length(format!("{value}"), 500))] - IncorrectType { - value: Value, - field_name: String, - type_: CompoundType, - }, } #[derive(Debug, thiserror::Error)] @@ -225,43 +218,24 @@ impl From for ParseError { pub type ParseResult = DynResult>; type PrepareStringResult = Result; -fn maybe_add_field_name(err: DynError, field_name: &str) -> DynError { - match err.into() { - DataError::IncorrectType { value, type_ } => ParseError::IncorrectType { - value, - field_name: field_name.to_string(), - type_, - } - .into(), - err => err.into(), - } -} - -#[derive(Clone, Default, Debug)] +#[derive(Clone, Debug)] pub struct InnerSchemaField { - type_: CompoundType, + type_: Type, default: Option, // None means that there is no default for the field } impl InnerSchemaField { - pub fn new(type_: Type, is_optional: bool, default: Option) -> Self { - Self { - type_: CompoundType::new(type_, is_optional), - default, - } + pub fn new(type_: Type, default: Option) -> Self { + Self { type_, default } } - pub fn adjust_value(&self, name: &str, value: Option) -> DynResult { + pub fn maybe_use_default( + &self, + name: &str, + value: Option>>, + ) -> DynResult { match value { - Some(value) => { - if self.type_.get_main_type() == Type::Json { - Ok(Value::from(serialize_value_to_json(&value)?)) - } else { - self.type_ - .convert_value(value.clone()) - .map_err(|err| maybe_add_field_name(err, name)) - } - } + Some(value) => Ok(value?), None => self.default.clone().ok_or( ParseError::NoDefault { field_name: name.to_string(), @@ -389,14 +363,14 @@ impl DsvSettings { Box::new(DsvFormatter::new(self)) } - pub fn parser(self, schema: HashMap) -> Box { - Box::new(DsvParser::new(self, schema)) + pub fn parser(self, schema: HashMap) -> Result> { + Ok(Box::new(DsvParser::new(self, schema)?)) } } #[derive(Clone)] enum DsvColumnIndex { - Index(usize), + IndexWithSchema(usize, InnerSchemaField), Metadata, } @@ -408,7 +382,6 @@ pub struct DsvParser { metadata_column_value: Value, key_column_indices: Option>, value_column_indices: Vec, - indexed_schema: HashMap, dsv_header_read: bool, } @@ -442,19 +415,19 @@ fn parse_with_type( field_name: &str, ) -> DynResult { if let Some(default) = &schema.default { - if raw_value.is_empty() && !matches!(schema.type_.get_main_type(), Type::Any | Type::String) + if raw_value.is_empty() && !matches!(schema.type_.unoptionalize(), Type::Any | Type::String) { return Ok(default.clone()); } } - match schema.type_.get_main_type() { + match schema.type_.unoptionalize() { Type::Any | Type::String => Ok(Value::from(raw_value)), Type::Bool => Ok(Value::Bool(parse_bool_advanced(raw_value).map_err( |e| ParseError::SchemaNotSatisfied { field_name: field_name.to_string(), value: raw_value.to_string(), - type_: schema.type_, + type_: schema.type_.clone(), error: Box::new(e), }, )?)), @@ -462,7 +435,7 @@ fn parse_with_type( ParseError::SchemaNotSatisfied { field_name: field_name.to_string(), value: raw_value.to_string(), - type_: schema.type_, + type_: schema.type_.clone(), error: Box::new(e), } })?)), @@ -470,7 +443,7 @@ fn parse_with_type( ParseError::SchemaNotSatisfied { field_name: field_name.to_string(), value: raw_value.to_string(), - type_: schema.type_, + type_: schema.type_.clone(), error: Box::new(e), } })?)), @@ -479,35 +452,59 @@ fn parse_with_type( serde_json::from_str(raw_value).map_err(|e| ParseError::SchemaNotSatisfied { field_name: field_name.to_string(), value: raw_value.to_string(), - type_: schema.type_, + type_: schema.type_.clone(), error: Box::new(e), })?; Ok(Value::from(json)) } - _ => Err(ParseError::UnparsableType(schema.type_).into()), + _ => Err(ParseError::UnparsableType(schema.type_.clone()).into()), + } +} + +fn ensure_all_fields_in_schema( + key_column_names: &Option>, + value_column_names: &Vec, + schema: &HashMap, +) -> Result<()> { + for name in chain!(key_column_names.iter().flatten(), value_column_names) { + if !schema.contains_key(name) { + return Err(Error::FieldNotInSchema { + name: name.clone(), + schema_keys: schema.keys().cloned().collect(), + }); + } } + Ok(()) } /// "magic field" containing the metadata const METADATA_FIELD_NAME: &str = "_metadata"; impl DsvParser { - pub fn new(settings: DsvSettings, schema: HashMap) -> DsvParser { - DsvParser { + pub fn new( + settings: DsvSettings, + schema: HashMap, + ) -> Result { + ensure_all_fields_in_schema( + &settings.key_column_names, + &settings.value_column_names, + &schema, + )?; + Ok(DsvParser { settings, schema, metadata_column_value: Value::None, header: Vec::new(), key_column_indices: None, value_column_indices: Vec::new(), - indexed_schema: HashMap::new(), dsv_header_read: false, - } + }) } fn column_indices_by_names( tokenized_entries: &[String], sought_names: &[String], + schema: &HashMap, ) -> Result, ParseError> { let mut value_indices_found = 0; @@ -528,8 +525,10 @@ impl DsvParser { for (index, value) in tokenized_entries.iter().enumerate() { if let Some(indices) = requested_indices.get(value) { + let schema_item = &schema[value]; for requested_index in indices { - column_indices[*requested_index] = DsvColumnIndex::Index(index); + column_indices[*requested_index] = + DsvColumnIndex::IndexWithSchema(index, schema_item.clone()); value_indices_found += 1; } } @@ -547,21 +546,18 @@ impl DsvParser { fn parse_dsv_header(&mut self, tokenized_entries: &[String]) -> Result<(), ParseError> { self.key_column_indices = match &self.settings.key_column_names { - Some(names) => Some(Self::column_indices_by_names(tokenized_entries, names)?), + Some(names) => Some(Self::column_indices_by_names( + tokenized_entries, + names, + &self.schema, + )?), None => None, }; - self.value_column_indices = - Self::column_indices_by_names(tokenized_entries, &self.settings.value_column_names)?; - - self.indexed_schema = { - let mut indexed_schema = HashMap::new(); - for (index, item) in tokenized_entries.iter().enumerate() { - if let Some(schema_item) = self.schema.get(item) { - indexed_schema.insert(index, (*schema_item).clone()); - } - } - indexed_schema - }; + self.value_column_indices = Self::column_indices_by_names( + tokenized_entries, + &self.settings.value_column_names, + &self.schema, + )?; self.header = tokenized_entries.to_vec(); self.dsv_header_read = true; @@ -590,20 +586,12 @@ impl DsvParser { &self, tokens: &[String], indices: &[DsvColumnIndex], - indexed_schema: &HashMap, header: &[String], ) -> ValueFieldsWithErrors { let mut parsed_tokens = Vec::with_capacity(indices.len()); for index in indices { let token = match index { - DsvColumnIndex::Index(index) => { - let default_schema; - let schema_item = if let Some(schema_item) = indexed_schema.get(index) { - schema_item - } else { - default_schema = InnerSchemaField::default(); - &default_schema - }; + DsvColumnIndex::IndexWithSchema(index, schema_item) => { parse_with_type(&tokens[*index], schema_item, &header[*index]) } DsvColumnIndex::Metadata => Ok(self.metadata_column_value.clone()), @@ -629,31 +617,27 @@ impl DsvParser { let mut line_has_enough_tokens = true; if let Some(indices) = &self.key_column_indices { for index in indices { - if let DsvColumnIndex::Index(index) = index { + if let DsvColumnIndex::IndexWithSchema(index, _) = index { line_has_enough_tokens &= index < &tokens.len(); } } } for index in &self.value_column_indices { - if let DsvColumnIndex::Index(index) = index { + if let DsvColumnIndex::IndexWithSchema(index, _) = index { line_has_enough_tokens &= index < &tokens.len(); } } if line_has_enough_tokens { let key = match &self.key_column_indices { Some(indices) => Some( - self.values_by_indices(tokens, indices, &self.indexed_schema, &self.header) + self.values_by_indices(tokens, indices, &self.header) .into_iter() .collect(), ), None => None, }; - let parsed_tokens = self.values_by_indices( - tokens, - &self.value_column_indices, - &self.indexed_schema, - &self.header, - ); + let parsed_tokens = + self.values_by_indices(tokens, &self.value_column_indices, &self.header); let parsed_entry = match event { DataEventType::Insert => ParsedEventWithErrors::Insert((key, parsed_tokens)), DataEventType::Delete => ParsedEventWithErrors::Delete((key, parsed_tokens)), @@ -951,28 +935,45 @@ pub struct DebeziumMessageParser { db_type: DebeziumDBType, } -fn parse_value_from_json(value: &JsonValue) -> Option { - match value { - JsonValue::Null => Some(Value::None), - JsonValue::String(s) => Some(Value::from(s.as_str())), - JsonValue::Number(v) => { - if let Some(parsed_u64) = v.as_u64() { - Some(Value::Int(parsed_u64.try_into().unwrap())) - } else if let Some(parsed_i64) = v.as_i64() { - Some(Value::Int(parsed_i64)) +fn parse_list_from_json(values: &[JsonValue], dtype: &Type) -> Option { + let mut list = Vec::with_capacity(values.len()); + for value in values { + list.push(parse_value_from_json(value, dtype)?); + } + Some(Value::from(list)) +} + +fn parse_tuple_from_json(values: &[JsonValue], dtypes: &[Type]) -> Option { + if values.len() != dtypes.len() { + return None; + } + let mut tuple = Vec::with_capacity(values.len()); + for (value, dtype) in values.iter().zip_eq(dtypes.iter()) { + tuple.push(parse_value_from_json(value, dtype)?); + } + Some(Value::from(tuple)) +} + +fn parse_value_from_json(value: &JsonValue, dtype: &Type) -> Option { + match (dtype, value) { + (Type::Json, value) => Some(Value::from(value.clone())), + (Type::Optional(_) | Type::Any, JsonValue::Null) => Some(Value::None), + (Type::Optional(arg), value) => parse_value_from_json(value, arg), + (Type::String | Type::Any, JsonValue::String(s)) => Some(Value::from(s.as_str())), + (Type::Int, JsonValue::Number(v)) => Some(Value::from(v.as_i64()?)), + (Type::Float, JsonValue::Number(v)) => v.as_f64().map(Value::from), + (Type::Any, JsonValue::Number(v)) => { + if let Some(parsed_i64) = v.as_i64() { + Some(Value::from(parsed_i64)) } else { - v.as_f64().map(Value::from) + Some(Value::from(v.as_f64()?)) } } - JsonValue::Bool(v) => Some(Value::Bool(*v)), - JsonValue::Array(v) => { - let mut tuple = Vec::with_capacity(v.len()); - for item in v { - tuple.push(parse_value_from_json(item)?); - } - Some(Value::Tuple(tuple.into())) - } - JsonValue::Object(_) => None, + (Type::Bool | Type::Any, JsonValue::Bool(v)) => Some(Value::Bool(*v)), + (Type::Tuple(dtypes), JsonValue::Array(v)) => parse_tuple_from_json(v, dtypes), + (Type::List(arg), JsonValue::Array(v)) => parse_list_from_json(v, arg), + (Type::Any, JsonValue::Array(v)) => parse_list_from_json(v, &Type::Any), + _ => None, } } @@ -1035,9 +1036,9 @@ fn values_by_names_from_json( for value_field in field_names { let (default_value, dtype) = { if let Some(schema_item) = schema.get(value_field) { - (schema_item.default.as_ref(), schema_item.type_) + (schema_item.default.as_ref(), &schema_item.type_) } else { - (None, CompoundType::default()) + (None, &Type::Any) } }; @@ -1045,20 +1046,14 @@ fn values_by_names_from_json( Ok(metadata_column_value.clone()) } else if let Some(path) = column_paths.get(value_field) { if let Some(value) = payload.pointer(path) { - match dtype.get_main_type() { - Type::Json => Ok(Value::from(value.clone())), - _ => parse_value_from_json(value) - .ok_or_else(|| { - ParseError::FailedToParseFromJson { - field_name: value_field.to_string(), - payload: value.clone(), - type_: dtype, - } - .into() - }) - .and_then(|value| dtype.convert_value(value)) - .map_err(|err| maybe_add_field_name(err, value_field)), - } + parse_value_from_json(value, dtype).ok_or_else(|| { + ParseError::FailedToParseFromJson { + field_name: value_field.to_string(), + payload: value.clone(), + type_: dtype.clone(), + } + .into() + }) } else if let Some(default) = default_value { Ok(default.clone()) } else if field_absence_is_error { @@ -1075,20 +1070,14 @@ fn values_by_names_from_json( let value_specified_in_json = payload.get(value_field).is_some(); if value_specified_in_json { - match dtype.get_main_type() { - Type::Json => Ok(Value::from(payload[&value_field].clone())), - _ => parse_value_from_json(&payload[&value_field]) - .ok_or_else(|| { - ParseError::FailedToParseFromJson { - field_name: value_field.to_string(), - payload: payload[&value_field].clone(), - type_: dtype, - } - .into() - }) - .and_then(|value| dtype.convert_value(value)) - .map_err(|err| maybe_add_field_name(err, value_field)), - } + parse_value_from_json(&payload[&value_field], dtype).ok_or_else(|| { + ParseError::FailedToParseFromJson { + field_name: value_field.to_string(), + payload: payload[&value_field].clone(), + type_: dtype.clone(), + } + .into() + }) } else if let Some(default) = default_value { Ok(default.clone()) } else if field_absence_is_error { @@ -1344,8 +1333,9 @@ impl JsonLinesParser { field_absence_is_error: bool, schema: HashMap, session_type: SessionType, - ) -> JsonLinesParser { - JsonLinesParser { + ) -> Result { + ensure_all_fields_in_schema(&key_field_names, &value_field_names, &schema)?; + Ok(JsonLinesParser { key_field_names, value_field_names, column_paths, @@ -1353,7 +1343,7 @@ impl JsonLinesParser { schema, metadata_column_value: Value::None, session_type, - } + }) } } @@ -1454,13 +1444,14 @@ impl TransparentParser { value_field_names: Vec, schema: HashMap, session_type: SessionType, - ) -> TransparentParser { - TransparentParser { + ) -> Result { + ensure_all_fields_in_schema(&key_field_names, &value_field_names, &schema)?; + Ok(TransparentParser { key_field_names, value_field_names, schema, session_type, - } + }) } } @@ -1472,31 +1463,24 @@ impl Parser for TransparentParser { if values.get_special() == Some(SpecialEvent::Commit) { return Ok(vec![ParsedEventWithErrors::AdvanceTime]); } - let key = key - .clone() - .map(Ok) - .or(self.key_field_names.as_ref().map(|key_field_names| { + let key = key.clone().map(Ok).or_else(|| { + self.key_field_names.as_ref().map(|key_field_names| { key_field_names .iter() .map(|name| { - self.schema - .get(name) - .expect( - "there should be an entry in the schema for name in key_field_names", - ) - .adjust_value(name, values.get(name).cloned()) + self.schema[name] // ensure_all_fields_in_schema in new() makes sure that all keys are in the schema + .maybe_use_default(name, values.get(name).cloned()) }) .collect() - })); + }) + }); let values: Vec<_> = self .value_field_names .iter() .map(|name| { - self.schema - .get(name) - .expect("there should be an entry in the schema for name in value_field_names") - .adjust_value(name, values.get(name).cloned()) + self.schema[name] // ensure_all_fields_in_schema in new() makes sure that all keys are in the schema + .maybe_use_default(name, values.get(name).cloned()) }) .collect(); diff --git a/src/connectors/data_storage.rs b/src/connectors/data_storage.rs index 588a922c..d877d6b1 100644 --- a/src/connectors/data_storage.rs +++ b/src/connectors/data_storage.rs @@ -29,6 +29,7 @@ use std::thread::sleep; use std::time::{Duration, Instant, SystemTime}; use chrono::{DateTime, FixedOffset}; +use itertools::Itertools; use log::{error, info, warn}; use postgres::types::ToSql; use tempfile::{tempdir, tempfile, TempDir}; @@ -40,6 +41,9 @@ use crate::connectors::metadata::SourceMetadata; use crate::connectors::offset::EMPTY_OFFSET; use crate::connectors::{Offset, OffsetKey, OffsetValue}; use crate::deepcopy::DeepCopy; +use crate::engine::error::limit_length; +use crate::engine::error::DynResult; +use crate::engine::error::STANDARD_OBJECT_LENGTH_LIMIT; use crate::engine::time::DateTime as EngineDateTime; use crate::engine::Type; use crate::engine::Value; @@ -47,6 +51,7 @@ use crate::engine::{DateTimeNaive, DateTimeUtc, Duration as EngineDuration}; use crate::fs_helpers::ensure_directory; use crate::persistence::frontier::OffsetAntichain; use crate::persistence::{ExternalPersistentId, PersistentId}; +use crate::python_api::extract_value; use crate::python_api::threads::PythonThreadState; use crate::python_api::PythonSubject; use crate::python_api::ValueField; @@ -97,9 +102,6 @@ use rdkafka::producer::{BaseRecord, DefaultProducerContext, Producer, ThreadedPr use rdkafka::topic_partition_list::Offset as KafkaOffset; use rdkafka::Message; use rusqlite::types::ValueRef as SqliteValue; -use rusqlite::types::{ - FromSql as FromSqlite, FromSqlError as FromSqliteError, FromSqlResult as FromSqliteResult, -}; use rusqlite::Connection as SqliteConnection; use rusqlite::Error as SqliteError; use s3::bucket::Bucket as S3Bucket; @@ -145,9 +147,9 @@ impl TryFrom<&str> for SpecialEvent { } } -#[derive(Debug, Default, Clone, PartialEq, Eq)] +#[derive(Debug, Default, Clone, Eq, PartialEq)] pub struct ValuesMap { - map: HashMap, + map: HashMap>>, // TODO: use a vector if performance improvement is needed // then Reader has to be aware of the columns order } @@ -156,20 +158,27 @@ impl ValuesMap { const SPECIAL_FIELD_NAME: &'static str = "_pw_special"; pub fn get_special(&self) -> Option { if self.map.len() == 1 { - let value = self.map.get(Self::SPECIAL_FIELD_NAME)?; - value.as_string().ok()?.as_str().try_into().ok() + let value = self.map.get(Self::SPECIAL_FIELD_NAME)?.as_ref(); + value.ok()?.as_string().ok()?.as_str().try_into().ok() } else { None } } - pub fn get(&self, key: &str) -> Option<&Value> { + pub fn get(&self, key: &str) -> Option<&Result>> { self.map.get(key) } + + pub fn to_pure_hashmap(self) -> DynResult> { + self.map + .into_iter() + .map(|(key, value)| Ok((key, value?))) + .try_collect() + } } -impl From> for ValuesMap { - fn from(value: HashMap) -> Self { +impl From>>> for ValuesMap { + fn from(value: HashMap>>) -> Self { ValuesMap { map: value } } } @@ -180,7 +189,7 @@ fn create_async_runtime() -> Result { .build() } -#[derive(PartialEq, Eq, Debug)] +#[derive(Debug)] pub enum ReaderContext { RawBytes(DataEventType, Vec), TokenizedEntries(DataEventType, Vec), @@ -213,7 +222,7 @@ impl ReaderContext { } } -#[derive(Debug, Eq, PartialEq)] +#[derive(Debug)] pub enum ReadResult { Finished, NewSource(Option), @@ -270,6 +279,14 @@ pub enum ReadError { DeltaLakeForbiddenRemoval, } +#[derive(Debug, thiserror::Error, Clone, Eq, PartialEq)] +#[error("cannot create a field {field_name:?} with type {type_} from value {value_repr}")] +pub struct ConversionError { + value_repr: String, + field_name: String, + type_: Type, +} + #[derive(Serialize, Deserialize, Clone, Copy, Debug)] pub enum StorageType { FileSystem, @@ -1378,11 +1395,13 @@ impl Reader for CsvFilesystemReader { pub struct PythonReaderBuilder { subject: Py, persistent_id: Option, + schema: HashMap, } pub struct PythonReader { subject: Py, persistent_id: Option, + schema: HashMap, total_entries_read: u64, is_initialized: bool, is_finished: bool, @@ -1392,10 +1411,15 @@ pub struct PythonReader { } impl PythonReaderBuilder { - pub fn new(subject: Py, persistent_id: Option) -> Self { + pub fn new( + subject: Py, + persistent_id: Option, + schema: HashMap, + ) -> Self { Self { subject, persistent_id, + schema, } } } @@ -1406,11 +1430,13 @@ impl ReaderBuilder for PythonReaderBuilder { let Self { subject, persistent_id, + schema, } = *self; Ok(Box::new(PythonReader { subject, persistent_id, + schema, python_thread_state, total_entries_read: 0, is_initialized: false, @@ -1435,6 +1461,17 @@ impl ReaderBuilder for PythonReaderBuilder { } } +impl PythonReader { + fn conversion_error(ob: &Bound, name: String, type_: Type) -> ConversionError { + let value_repr = limit_length(format!("{ob}"), STANDARD_OBJECT_LENGTH_LIMIT); + ConversionError { + value_repr, + field_name: name, + type_, + } + } +} + impl Reader for PythonReader { fn seek(&mut self, frontier: &OffsetAntichain) -> Result<(), ReadError> { let offset_value = frontier.get_offset(&OffsetKey::Empty); @@ -1460,14 +1497,26 @@ impl Reader for PythonReader { } Python::with_gil(|py| { - let (event, key, values): (DataEventType, Option, HashMap) = self - .subject - .borrow(py) - .read - .call0(py)? - .extract(py) - .map_err(ReadError::Py)?; + let (event, key, objects): (DataEventType, Option, HashMap>) = + self.subject + .borrow(py) + .read + .call0(py)? + .extract(py) + .map_err(ReadError::Py)?; let key = key.map(|key| vec![key]); + let mut values = HashMap::with_capacity(objects.len()); + for (name, ob) in objects { + let dtype = self.schema.get(&name).unwrap_or(&Type::Any); // Any for special values + let value = extract_value(ob.bind(py), dtype).map_err(|_err| { + Box::new(Self::conversion_error( + ob.bind(py), + name.clone(), + dtype.clone(), + )) + }); + values.insert(name, value); + } let values: ValuesMap = values.into(); if event != DataEventType::Insert && !self.subject.borrow(py).deletions_enabled { @@ -2429,31 +2478,12 @@ impl Reader for S3GenericReader { } } -impl FromSqlite for Value { - /// Convert raw `SQLite` field into one of internal value types - /// There are only five supported types: null, integer, real, text, blob - /// See also: - fn column_result(value: SqliteValue<'_>) -> FromSqliteResult { - match value { - SqliteValue::Null => Ok(Value::None), - SqliteValue::Integer(val) => Ok(Value::Int(val)), - SqliteValue::Real(val) => Ok(Value::Float(val.into())), - SqliteValue::Text(val) => { - let parsed_string = - from_utf8(val).map_err(|e| FromSqliteError::Other(Box::new(e)))?; - Ok(Value::String(parsed_string.into())) - } - SqliteValue::Blob(val) => Ok(Value::Bytes(val.into())), - } - } -} - const SQLITE_DATA_VERSION_PRAGMA: &str = "data_version"; pub struct SqliteReader { connection: SqliteConnection, table_name: String, - column_names: Vec, + schema: Vec<(String, Type)>, last_saved_data_version: Option, stored_state: HashMap, @@ -2464,12 +2494,12 @@ impl SqliteReader { pub fn new( connection: SqliteConnection, table_name: String, - column_names: Vec, + schema: Vec<(String, Type)>, ) -> Self { Self { connection, table_name, - column_names, + schema, last_saved_data_version: None, queued_updates: VecDeque::new(), @@ -2490,10 +2520,52 @@ impl SqliteReader { version.expect("pragma.data_version request should not fail") } + /// Convert raw `SQLite` field into one of internal value types + /// There are only five supported types: null, integer, real, text, blob + /// See also: + fn convert_to_value( + value: SqliteValue<'_>, + field_name: &str, + dtype: &Type, + ) -> Result> { + let value = match (dtype, value) { + (Type::Optional(_) | Type::Any, SqliteValue::Null) => Some(Value::None), + (Type::Optional(arg), value) => Self::convert_to_value(value, field_name, arg).ok(), + (Type::Int | Type::Any, SqliteValue::Integer(val)) => Some(Value::Int(val)), + (Type::Float | Type::Any, SqliteValue::Real(val)) => Some(Value::Float(val.into())), + (Type::String | Type::Any, SqliteValue::Text(val)) => from_utf8(val) + .ok() + .map(|parsed_string| Value::String(parsed_string.into())), + (Type::Json, SqliteValue::Text(val)) => from_utf8(val) + .ok() + .and_then(|parsed_string| { + serde_json::from_str::(parsed_string).ok() + }) + .map(Value::from), + (Type::Bytes | Type::Any, SqliteValue::Blob(val)) => Some(Value::Bytes(val.into())), + _ => None, + }; + if let Some(value) = value { + Ok(value) + } else { + let value_repr = limit_length(format!("{value:?}"), STANDARD_OBJECT_LENGTH_LIMIT); + Err(Box::new(ConversionError { + value_repr, + field_name: field_name.to_owned(), + type_: dtype.clone(), + })) + } + } + fn load_table(&mut self) -> Result<(), ReadError> { + let column_names: Vec<&str> = self + .schema + .iter() + .map(|(name, _dtype)| name.as_str()) + .collect(); let query = format!( "SELECT {},_rowid_ FROM {}", - self.column_names.join(","), + column_names.join(","), self.table_name ); @@ -2502,10 +2574,12 @@ impl SqliteReader { let mut present_rowids = HashSet::new(); while let Some(row) = rows.next()? { - let rowid: i64 = row.get(self.column_names.len())?; - let mut values = HashMap::with_capacity(self.column_names.len()); - for (column_idx, column_name) in self.column_names.iter().enumerate() { - values.insert(column_name.clone(), row.get(column_idx)?); + let rowid: i64 = row.get(self.schema.len())?; + let mut values = HashMap::with_capacity(self.schema.len()); + for (column_idx, (column_name, column_dtype)) in self.schema.iter().enumerate() { + let value = + Self::convert_to_value(row.get_ref(column_idx)?, column_name, column_dtype); + values.insert(column_name.clone(), value); } let values: ValuesMap = values.into(); self.stored_state @@ -2740,7 +2814,7 @@ impl DeltaTableWriter { Ok(DTRecordBatch::try_new(self.schema.clone(), data_columns)?) } - fn delta_table_primitive_type(type_: Type) -> Result { + fn delta_table_primitive_type(type_: &Type) -> Result { Ok(DeltaTableKernelType::Primitive(match type_ { Type::Bool => DeltaTablePrimitiveType::Boolean, Type::Float => DeltaTablePrimitiveType::Double, @@ -2749,13 +2823,17 @@ impl DeltaTableWriter { Type::DateTimeNaive => DeltaTablePrimitiveType::TimestampNtz, Type::DateTimeUtc => DeltaTablePrimitiveType::Timestamp, Type::Int | Type::Duration => DeltaTablePrimitiveType::Long, - Type::Any | Type::Array | Type::Tuple | Type::PyObjectWrapper | Type::Pointer => { - return Err(WriteError::UnsupportedType(type_)) - } + Type::Optional(wrapped) => return Self::delta_table_primitive_type(wrapped), + Type::Any + | Type::Array(_, _) + | Type::Tuple(_) + | Type::List(_) + | Type::PyObjectWrapper + | Type::Pointer => return Err(WriteError::UnsupportedType(type_.clone())), })) } - fn arrow_data_type(type_: Type) -> Result { + fn arrow_data_type(type_: &Type) -> Result { Ok(match type_ { Type::Bool => ArrowDataType::Boolean, Type::Int | Type::Duration => ArrowDataType::Int64, @@ -2768,9 +2846,12 @@ impl DeltaTableWriter { Type::DateTimeUtc => { ArrowDataType::Timestamp(ArrowTimeUnit::Microsecond, Some("UTC".into())) } - Type::Any | Type::Array | Type::Tuple | Type::PyObjectWrapper => { - return Err(WriteError::UnsupportedType(type_)) - } + Type::Optional(wrapped) => return Self::arrow_data_type(wrapped), + Type::Any + | Type::Array(_, _) + | Type::Tuple(_) + | Type::List(_) + | Type::PyObjectWrapper => return Err(WriteError::UnsupportedType(type_.clone())), }) } @@ -2779,12 +2860,16 @@ impl DeltaTableWriter { for field in value_fields { schema_fields.push(ArrowField::new( field.name.clone(), - Self::arrow_data_type(field.type_)?, - field.is_optional, + Self::arrow_data_type(&field.type_)?, + field.type_.can_be_none(), )); } for (field, type_) in SPECIAL_OUTPUT_FIELDS { - schema_fields.push(ArrowField::new(field, Self::arrow_data_type(type_)?, false)); + schema_fields.push(ArrowField::new( + field, + Self::arrow_data_type(&type_)?, + false, + )); } Ok(ArrowSchema::new(schema_fields)) } @@ -2798,14 +2883,14 @@ impl DeltaTableWriter { for field in schema_fields { struct_fields.push(DeltaTableStructField::new( field.name.clone(), - Self::delta_table_primitive_type(field.type_)?, - field.is_optional, + Self::delta_table_primitive_type(&field.type_)?, + field.type_.can_be_none(), )); } for (field, type_) in SPECIAL_OUTPUT_FIELDS { struct_fields.push(DeltaTableStructField::new( field, - Self::delta_table_primitive_type(type_)?, + Self::delta_table_primitive_type(&type_)?, false, )); } @@ -3048,31 +3133,40 @@ impl Reader for DeltaTableReader { }; let value = match (parquet_value, expected_type) { - (ParquetValue::Null, _) => Value::None, - (ParquetValue::Bool(b), Type::Bool | Type::Any) => Value::from(*b), - (ParquetValue::Long(i), Type::Int | Type::Any) => Value::from(*i), - (ParquetValue::Long(i), Type::Duration) => { - Value::from(EngineDuration::new_with_unit(*i, "us").unwrap()) - } - (ParquetValue::Double(f), Type::Float | Type::Any) => Value::Float((*f).into()), - (ParquetValue::Str(s), Type::String | Type::Any) => Value::String(s.into()), - (ParquetValue::Str(s), Type::Json) => { - let json: serde_json::Value = serde_json::from_str(s).unwrap(); - Value::from(json) - } - (ParquetValue::TimestampMicros(us), Type::DateTimeNaive | Type::Any) => { - Value::from(DateTimeNaive::from_timestamp(*us, "us").unwrap()) + (ParquetValue::Null, _) => Some(Value::None), + (ParquetValue::Bool(b), Type::Bool | Type::Any) => Some(Value::from(*b)), + (ParquetValue::Long(i), Type::Int | Type::Any) => Some(Value::from(*i)), + (ParquetValue::Long(i), Type::Duration) => Some(Value::from( + EngineDuration::new_with_unit(*i, "us").unwrap(), + )), + (ParquetValue::Double(f), Type::Float | Type::Any) => { + Some(Value::Float((*f).into())) } + (ParquetValue::Str(s), Type::String | Type::Any) => Some(Value::String(s.into())), + (ParquetValue::Str(s), Type::Json) => serde_json::from_str::(s) + .ok() + .map(Value::from), + (ParquetValue::TimestampMicros(us), Type::DateTimeNaive | Type::Any) => Some( + Value::from(DateTimeNaive::from_timestamp(*us, "us").unwrap()), + ), (ParquetValue::TimestampMicros(us), Type::DateTimeUtc) => { - Value::from(DateTimeUtc::from_timestamp(*us, "us").unwrap()) + Some(Value::from(DateTimeUtc::from_timestamp(*us, "us").unwrap())) } - (ParquetValue::Bytes(b), Type::Bytes | Type::Any) => Value::Bytes(b.data().into()), - _ => { - return Err(ReadError::WrongParquetType( - parquet_value.clone(), - *expected_type, - )) + (ParquetValue::Bytes(b), Type::Bytes | Type::Any) => { + Some(Value::Bytes(b.data().into())) } + _ => None, + }; + let value = if let Some(value) = value { + Ok(value) + } else { + let value_repr = + limit_length(format!("{parquet_value:?}"), STANDARD_OBJECT_LENGTH_LIMIT); + Err(Box::new(ConversionError { + value_repr, + field_name: name.clone(), + type_: expected_type.clone(), + })) }; row_map.insert(name.clone(), value); } diff --git a/src/engine/error.rs b/src/engine/error.rs index a9bd839a..fb51a581 100644 --- a/src/engine/error.rs +++ b/src/engine/error.rs @@ -6,7 +6,6 @@ use std::fmt; use std::result; use super::ColumnPath; -use super::CompoundType; use super::{Key, Value}; use crate::persistence::metadata_backends::Error as MetadataBackendError; @@ -134,6 +133,12 @@ pub enum Error { #[error(transparent)] DataError(DataError), + + #[error("column {name} is not present in schema. Schema keys are: {schema_keys:?}")] + FieldNotInSchema { + name: String, + schema_keys: Vec, + }, } impl Error { @@ -217,10 +222,11 @@ impl fmt::Display for Trace { } } -pub fn limit_length(mut s: String, max_length: usize) -> String { +pub const STANDARD_OBJECT_LENGTH_LIMIT: usize = 500; + +pub fn limit_length(s: String, max_length: usize) -> String { if s.len() > max_length { - s.truncate(max_length); - s + "..." + s.chars().take(max_length - 3).collect::() + "..." } else { s } @@ -312,9 +318,6 @@ pub enum DataError { #[error("mixing types in npsum is not allowed")] MixingTypesInNpSum, - #[error("value {} is inconsistent with type {type_}", limit_length(format!("{value}"), 500))] - IncorrectType { value: Value, type_: CompoundType }, - #[error(transparent)] Other(DynError), } diff --git a/src/engine/expression.rs b/src/engine/expression.rs index 40bbaa41..e2f0ee7f 100644 --- a/src/engine/expression.rs +++ b/src/engine/expression.rs @@ -17,7 +17,7 @@ use smallvec::SmallVec; use super::error::{DataError, DynError, DynResult}; use super::time::{DateTime, DateTimeNaive, DateTimeUtc, Duration}; -use super::value::SimpleType; +use super::value::Kind; use super::{Key, Type, Value}; use crate::engine::ShardPolicy; use crate::mat_mul::mat_mul; @@ -432,8 +432,8 @@ fn are_tuples_equal(lhs: &Arc<[Value]>, rhs: &Arc<[Value]>) -> DynResult { (Value::Int(val_l), Value::Float(val_r)) => Ok(OrderedFloat(*val_l as f64).eq(val_r)), (val, Value::None) | (Value::None, val) => Ok(val == &Value::None), (val_l, val_r) => { - let type_l = val_l.simple_type(); - let type_r = val_r.simple_type(); + let type_l = val_l.kind(); + let type_r = val_r.kind(); if type_l == type_r { Ok(val_l.eq(val_r)) } else { @@ -459,14 +459,10 @@ fn compare_tuples(lhs: &Arc<[Value]>, rhs: &Arc<[Value]>) -> DynResult #[allow(clippy::cast_precision_loss)] (Value::Int(val_l), Value::Float(val_r)) => Ok(OrderedFloat(*val_l as f64).cmp(val_r)), (val_l, val_r) => { - let type_l = val_l.simple_type(); - let type_r = val_r.simple_type(); - let is_incomparable_type = [ - SimpleType::Json, - SimpleType::IntArray, - SimpleType::FloatArray, - ] - .contains(&type_l); + let type_l = val_l.kind(); + let type_r = val_r.kind(); + let is_incomparable_type = + [Kind::Json, Kind::IntArray, Kind::FloatArray].contains(&type_l); if type_l != type_r || is_incomparable_type { let msg = format!( "comparison not supported between instances of '{type_l:?}' and '{type_r:?}'", @@ -651,8 +647,8 @@ impl AnyExpression { (Value::FloatArray(lhs), Value::FloatArray(rhs)) => mat_mul_wrapper(&lhs, &rhs), (Value::IntArray(lhs), Value::IntArray(rhs)) => mat_mul_wrapper(&lhs, &rhs), (lhs_val, rhs_val) => { - let lhs_type = lhs_val.simple_type(); - let rhs_type = rhs_val.simple_type(); + let lhs_type = lhs_val.kind(); + let rhs_type = rhs_val.kind(); Err(DynError::from(DataError::ValueError(format!( "can't perform matrix multiplication on {lhs_type:?} and {rhs_type:?}", )))) diff --git a/src/engine/graph.rs b/src/engine/graph.rs index f7b1d0e2..663dcee7 100644 --- a/src/engine/graph.rs +++ b/src/engine/graph.rs @@ -9,10 +9,11 @@ use std::time::{Duration, SystemTime}; use futures::future::BoxFuture; use id_arena::ArenaBehavior; -use pyo3::exceptions::PyTypeError; +use itertools::Itertools; +use pyo3::exceptions::{PyTypeError, PyValueError}; use pyo3::prelude::pymethods; use pyo3::pyclass::CompareOp; -use pyo3::{pyclass, PyResult, Python}; +use pyo3::{pyclass, Bound, PyAny, PyResult, Python}; use scopeguard::defer; use crate::connectors::data_format::{Formatter, Parser}; @@ -20,6 +21,7 @@ use crate::connectors::data_storage::{ReaderBuilder, Writer}; use crate::connectors::monitoring::ConnectorStats; use crate::external_integration::ExternalIndex; use crate::persistence::ExternalPersistentId; +use crate::python_api::extract_value; use super::error::{DynResult, Trace}; use super::external_index_wrappers::{ExternalIndexData, ExternalIndexQuery}; @@ -201,24 +203,41 @@ impl DataRow { #[pyo3(signature = ( key, values, + *, time = Timestamp(0), diff = 1, shard = None, + dtypes, ))] + #[allow(clippy::needless_pass_by_value)] // can't use &[Type] with pyo3 pub fn new( key: Key, - values: Vec, + values: Vec>, time: Timestamp, diff: isize, shard: Option, - ) -> Self { - Self { + dtypes: Vec, + ) -> PyResult { + if values.len() != dtypes.len() { + let message = format!( + "Length of values ({}) should be equal to the length of dtypes ({}).", + values.len(), + dtypes.len() + ); + return Err(PyValueError::new_err(message)); + } + let extracted_values: Vec<_> = values + .into_iter() + .zip(dtypes) + .map(|(ob, dtype)| extract_value(&ob, &dtype)) + .try_collect()?; + Ok(Self { key, - values, + values: extracted_values, time, diff, shard, - } + }) } fn __repr__(&self) -> String { diff --git a/src/engine/mod.rs b/src/engine/mod.rs index 4533ee4e..4a8769a8 100644 --- a/src/engine/mod.rs +++ b/src/engine/mod.rs @@ -10,7 +10,7 @@ pub use self::error::{DataError, Error, Result}; pub mod report_error; pub mod value; -pub use self::value::{CompoundType, Key, KeyImpl, ShardPolicy, Type, Value}; +pub use self::value::{Key, KeyImpl, ShardPolicy, Type, Value}; pub mod reduce; pub use reduce::Reducer; diff --git a/src/engine/value.rs b/src/engine/value.rs index cc370fc0..e53c5006 100644 --- a/src/engine/value.rs +++ b/src/engine/value.rs @@ -428,6 +428,12 @@ impl From<&[Value]> for Value { } } +impl From> for Value { + fn from(t: Vec) -> Self { + Self::Tuple(t.into()) + } +} + impl From> for Value { fn from(a: ArrayD) -> Self { Self::IntArray(Handle::new(a)) @@ -486,7 +492,7 @@ impl From for Value { // so changing them will result in changed IDs #[repr(u8)] #[derive(Debug, Copy, Clone, PartialEq, Eq)] -pub enum SimpleType { +pub enum Kind { None, Bool, Int, @@ -505,30 +511,8 @@ pub enum SimpleType { PyObjectWrapper, } -impl SimpleType { - pub fn to_type(&self) -> Option { - match self { - SimpleType::None | SimpleType::Error => None, - SimpleType::Bool => Some(Type::Bool), - SimpleType::Int => Some(Type::Int), - SimpleType::Float => Some(Type::Float), - SimpleType::Pointer => Some(Type::Pointer), - SimpleType::String => Some(Type::String), - SimpleType::Tuple => Some(Type::Tuple), - SimpleType::IntArray | SimpleType::FloatArray => Some(Type::Array), - SimpleType::DateTimeNaive => Some(Type::DateTimeNaive), - SimpleType::DateTimeUtc => Some(Type::DateTimeUtc), - SimpleType::Duration => Some(Type::Duration), - SimpleType::Bytes => Some(Type::Bytes), - SimpleType::Json => Some(Type::Json), - SimpleType::PyObjectWrapper => Some(Type::PyObjectWrapper), - } - } -} - -#[derive(Debug, Default, Clone, Copy, PartialEq, Eq)] +#[derive(Debug, Clone, PartialEq, Eq)] pub enum Type { - #[default] Any, Bool, Int, @@ -539,83 +523,75 @@ pub enum Type { DateTimeNaive, DateTimeUtc, Duration, - Array, + Array(Option, Arc), Json, - Tuple, + Tuple(Arc<[Type]>), + List(Arc), PyObjectWrapper, + Optional(Arc), } -#[derive(Debug, Default, Clone, Copy, PartialEq, Eq)] -pub struct CompoundType { - pub type_: Type, - pub is_optional: bool, -} - -impl CompoundType { - pub fn new(type_: Type, is_optional: bool) -> Self { - Self { type_, is_optional } +impl Type { + pub fn can_be_none(&self) -> bool { + matches!(self, Self::Optional(_) | Self::Any) } - - pub fn matches(&self, value: &Value) -> bool { - if self.type_ == Type::Any { - true - } else if let Some(value_type) = value.simple_type().to_type() { - self.type_ == value_type - } else { - false - } - } - - #[allow(clippy::cast_precision_loss)] - pub fn convert_value(&self, value: Value) -> DynResult { - if self.matches(&value) || self.is_optional && value == Value::None { - return Ok(value); - } - match (value, self.type_) { - (Value::Int(i), Type::Float) => Ok(Value::from(i as f64)), - (value, _) => Err(DataError::IncorrectType { - value, - type_: *self, - } - .into()), + pub fn unoptionalize(&self) -> &Self { + match self { + Self::Optional(arg) => arg, + type_ => type_, } } - - pub fn get_main_type(&self) -> Type { - self.type_ - } } -impl Display for CompoundType { +impl Display for Type { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - if self.is_optional { - write!(f, "{:?} | None", self.type_) - } else { - write!(f, "{:?}", self.type_) + match self { + Type::Any => write!(f, "Any"), + Type::Bool => write!(f, "bool"), + Type::Int => write!(f, "int"), + Type::Float => write!(f, "float"), + Type::Pointer => write!(f, "Pointer"), + Type::String => write!(f, "str"), + Type::Bytes => write!(f, "bytes"), + Type::DateTimeNaive => write!(f, "DateTimeNaive"), + Type::DateTimeUtc => write!(f, "DateTimeUtc"), + Type::Duration => write!(f, "Duration"), + Type::Array(dim, arg) => { + if let Some(dim) = dim { + write!(f, "Array({dim}, {arg})") + } else { + write!(f, "Array({arg})") + } + } + Type::Json => write!(f, "Json"), + Type::Tuple(args) => write!(f, "tuple[{}]", args.iter().format(", ")), + Type::List(arg) => write!(f, "list[{arg}]"), + Type::PyObjectWrapper => write!(f, "PyObjectWrapper"), + Type::Optional(arg) => write!(f, "{arg} | None"), } } } impl Value { #[must_use] - pub fn simple_type(&self) -> SimpleType { + pub fn kind(&self) -> Kind { match self { - Self::None => SimpleType::None, - Self::Bool(_) => SimpleType::Bool, - Self::Int(_) => SimpleType::Int, - Self::Float(_) => SimpleType::Float, - Self::Pointer(_) => SimpleType::Pointer, - Self::String(_) => SimpleType::String, - Self::Bytes(_) => SimpleType::Bytes, - Self::Tuple(_) => SimpleType::Tuple, - Self::IntArray(_) => SimpleType::IntArray, - Self::FloatArray(_) => SimpleType::FloatArray, - Self::DateTimeNaive(_) => SimpleType::DateTimeNaive, - Self::DateTimeUtc(_) => SimpleType::DateTimeUtc, - Self::Duration(_) => SimpleType::Duration, - Self::Json(_) => SimpleType::Json, - Self::Error => SimpleType::Error, - Self::PyObjectWrapper(_) => SimpleType::PyObjectWrapper, + Self::None => Kind::None, + Self::Bool(_) => Kind::Bool, + Self::Int(_) => Kind::Int, + Self::Float(_) => Kind::Float, + Self::Pointer(_) => Kind::Pointer, + Self::String(_) => Kind::String, + Self::Bytes(_) => Kind::Bytes, + Self::Tuple(_) => Kind::Tuple, + Self::IntArray(_) => Kind::IntArray, + Self::FloatArray(_) => Kind::FloatArray, + Self::DateTimeNaive(_) => Kind::DateTimeNaive, + Self::DateTimeUtc(_) => Kind::DateTimeUtc, + Self::Duration(_) => Kind::Duration, + Self::Json(_) => Kind::Json, + Self::Error => Kind::Error, + Self::PyObjectWrapper(_) => Kind::PyObjectWrapper, } } } @@ -742,7 +718,7 @@ impl HashInto for Duration { impl HashInto for Value { fn hash_into(&self, hasher: &mut Hasher) { - (self.simple_type() as u8).hash_into(hasher); + (self.kind() as u8).hash_into(hasher); match self { Self::None => {} Self::Bool(b) => b.hash_into(hasher), diff --git a/src/python_api.rs b/src/python_api.rs index 76c203c2..3c09799d 100644 --- a/src/python_api.rs +++ b/src/python_api.rs @@ -23,6 +23,7 @@ use elasticsearch::{ }; use itertools::Itertools; use log::warn; +use ndarray; use numpy::{PyArray, PyReadonlyArrayDyn}; use once_cell::sync::Lazy; use postgres::{Client, NoTls}; @@ -45,6 +46,7 @@ use s3::bucket::Bucket as S3Bucket; use scopeguard::defer; use send_wrapper::SendWrapper; use serde_json::Value as JsonValue; +use std::borrow::Borrow; use std::cell::RefCell; use std::cmp::Ordering; use std::collections::{HashMap, HashSet}; @@ -255,6 +257,155 @@ fn is_pathway_json(ob: &Bound) -> PyResult { Ok(type_name == "Json") } +fn array_with_proper_dimensions( + array: ndarray::ArrayD, + dim: Option, +) -> Option> { + match dim { + Some(dim) if array.ndim() == dim => Some(array), + Some(_) => None, + None => Some(array), + } +} + +fn extract_datetime(ob: &Bound, type_: &Type) -> Option { + let type_name = ob.get_type().qualname().ok()?; + let value = if type_name == "datetime" { + value_from_python_datetime(ob).ok() + } else if matches!( + type_name.as_ref(), + "Timestamp" | "DateTimeNaive" | "DateTimeUtc" + ) { + value_from_pandas_timestamp(ob).ok() + } else { + None + }?; + match (&value, type_) { + (Value::DateTimeNaive(_), Type::DateTimeNaive) + | (Value::DateTimeUtc(_), Type::DateTimeUtc) => Some(value), + _ => None, + } +} + +fn extract_int_array(ob: &Bound, dim: Option) -> Option> { + let array = if let Ok(array) = ob.extract::>() { + Some(array.as_array().to_owned()) + } else if let Ok(array) = ob.extract::>() { + Some(array.as_array().mapv(i64::from)) + } else if let Ok(array) = ob.extract::>() { + Some(array.as_array().mapv(i64::from)) + } else { + None + }?; + array_with_proper_dimensions(array, dim) +} + +#[allow(clippy::cast_precision_loss)] +fn extract_float_array(ob: &Bound, dim: Option) -> Option> { + let array = if let Ok(array) = ob.extract::>() { + array.as_array().to_owned() + } else if let Ok(array) = ob.extract::>() { + array.as_array().mapv(f64::from) + } else { + extract_int_array(ob, dim).map(|array| array.mapv(|v| v as f64))? + }; + array_with_proper_dimensions(array, dim) +} + +fn py_type_error(ob: &Bound, type_: &Type) -> PyErr { + PyTypeError::new_err(format!( + "cannot create an object of type {type_:?} from value {ob}" + )) +} + +pub fn extract_value(ob: &Bound, type_: &Type) -> PyResult { + let extracted = match type_ { + Type::Any => ob.extract().ok(), + Type::Optional(arg) => { + if ob.is_none() { + Some(Value::None) + } else { + Some(extract_value(ob, arg)?) + } + } + Type::Bool => ob + .extract::<&PyBool>() + .ok() + .map(|b| Value::from(b.is_true())), + Type::Int => ob.extract::().ok().map(Value::from), + Type::Float => ob.extract::().ok().map(Value::from), + Type::Pointer => ob.extract::().ok().map(Value::from), + Type::String => ob + .downcast::() + .ok() + .and_then(|s| s.to_str().ok()) + .map(Value::from), + Type::Bytes => ob + .downcast::() + .ok() + .map(|b| Value::from(b.as_bytes())), + Type::DateTimeNaive | Type::DateTimeUtc => extract_datetime(ob, type_), + Type::Duration => { + // XXX: check types, not names + let type_name = ob.get_type().qualname()?; + if type_name == "timedelta" { + value_from_python_timedelta(ob).ok() + } else if matches!(type_name.as_ref(), "Timedelta" | "Duration") { + value_from_pandas_timedelta(ob).ok() + } else { + None + } + } + Type::Array(dim, wrapped) => match wrapped.borrow() { + Type::Int => Ok(extract_int_array(ob, *dim).map(Value::from)), + Type::Float => Ok(extract_float_array(ob, *dim).map(Value::from)), + Type::Any => Ok(extract_int_array(ob, *dim) + .map(Value::from) + .or_else(|| extract_float_array(ob, *dim).map(Value::from))), + wrapped => Err(PyValueError::new_err(format!( + "{wrapped:?} is invalid type for Array" + ))), + }?, + Type::Json => { + if is_pathway_json(ob)? { + value_json_from_py_any(&ob.getattr("value")?).ok() + } else { + value_json_from_py_any(ob).ok() + } + } + Type::Tuple(args) => { + let obs = ob.extract::>>()?; + if obs.len() == args.len() { + let values: Vec<_> = obs + .into_iter() + .zip(args.iter()) + .map(|(ob, type_)| extract_value(&ob, type_)) + .try_collect()?; + Some(Value::from(values.as_slice())) + } else { + None + } + } + Type::List(arg) => { + let obs = ob.extract::>>()?; + let values: Vec<_> = obs + .into_iter() + .map(|ob| extract_value(&ob, arg)) + .try_collect()?; + Some(Value::from(values.as_slice())) + } + Type::PyObjectWrapper => { + let value = if let Ok(ob) = ob.extract::() { + ob + } else { + PyObjectWrapper::new(ob.clone().unbind()) + }; + Some(Value::from(value.into_internal())) + } + }; + extracted.ok_or_else(|| py_type_error(ob, type_)) +} + impl<'py> FromPyObject<'py> for Value { fn extract_bound(ob: &Bound<'py, PyAny>) -> PyResult { let py = ob.py(); @@ -399,7 +550,7 @@ impl IntoPy for Reducer { impl<'py> FromPyObject<'py> for Type { fn extract_bound(ob: &Bound<'py, PyAny>) -> PyResult { - Ok(ob.extract::>()?.0) + Ok(ob.extract::>()?.0.clone()) } } @@ -929,8 +1080,9 @@ macro_rules! binary_expr { #[pymethods] impl PyExpression { #[staticmethod] - fn r#const(value: Value) -> Self { - Self::new(Arc::new(Expression::new_const(value)), false) + fn r#const(ob: &Bound, type_: Type) -> PyResult { + let value = extract_value(ob, &type_)?; + Ok(Self::new(Arc::new(Expression::new_const(value)), false)) } #[staticmethod] @@ -942,8 +1094,13 @@ impl PyExpression { } #[staticmethod] - #[pyo3(signature = (function, *args, propagate_none=false))] - fn apply(function: Py, args: Vec>, propagate_none: bool) -> Self { + #[pyo3(signature = (function, *args, dtype, propagate_none=false))] + fn apply( + function: Py, + args: Vec>, + dtype: Type, + propagate_none: bool, + ) -> Self { let args = args .into_iter() .map(|expr| expr.inner.clone()) @@ -951,10 +1108,10 @@ impl PyExpression { let func = Box::new(move |input: &[Value]| { Python::with_gil(|py| -> DynResult { let args = PyTuple::new_bound(py, input); - Ok(function.call1(py, args)?.extract::(py)?) + let result = function.call1(py, args)?; + Ok(extract_value(result.bind(py), &dtype)?) }) }); - let expression = if propagate_none { AnyExpression::OptionalApply(func, args.into()) } else { @@ -1139,13 +1296,27 @@ impl PyExpression { Some(binary_op!(FloatExpression::DurationTrueDiv, lhs, rhs)) } (Op::Mod, Tp::Duration, Tp::Duration) => Some(binary_op!(DurationE::Mod, lhs, rhs)), - (Op::MatMul, Tp::Array, Tp::Array) => Some(binary_op!(AnyE::MatMul, lhs, rhs)), - (Op::Eq, Tp::Tuple, Tp::Tuple) => Some(binary_op!(BoolE::TupleEq, lhs, rhs)), - (Op::Ne, Tp::Tuple, Tp::Tuple) => Some(binary_op!(BoolE::TupleNe, lhs, rhs)), - (Op::Lt, Tp::Tuple, Tp::Tuple) => Some(binary_op!(BoolE::TupleLt, lhs, rhs)), - (Op::Le, Tp::Tuple, Tp::Tuple) => Some(binary_op!(BoolE::TupleLe, lhs, rhs)), - (Op::Gt, Tp::Tuple, Tp::Tuple) => Some(binary_op!(BoolE::TupleGt, lhs, rhs)), - (Op::Ge, Tp::Tuple, Tp::Tuple) => Some(binary_op!(BoolE::TupleGe, lhs, rhs)), + (Op::MatMul, Tp::Array(_, _), Tp::Array(_, _)) => { + Some(binary_op!(AnyE::MatMul, lhs, rhs)) + } + (Op::Eq, Tp::Tuple(_) | Tp::List(_), Tp::Tuple(_) | Tp::List(_)) => { + Some(binary_op!(BoolE::TupleEq, lhs, rhs)) + } + (Op::Ne, Tp::Tuple(_) | Tp::List(_), Tp::Tuple(_) | Tp::List(_)) => { + Some(binary_op!(BoolE::TupleNe, lhs, rhs)) + } + (Op::Lt, Tp::Tuple(_) | Tp::List(_), Tp::Tuple(_) | Tp::List(_)) => { + Some(binary_op!(BoolE::TupleLt, lhs, rhs)) + } + (Op::Le, Tp::Tuple(_) | Tp::List(_), Tp::Tuple(_) | Tp::List(_)) => { + Some(binary_op!(BoolE::TupleLe, lhs, rhs)) + } + (Op::Gt, Tp::Tuple(_) | Tp::List(_), Tp::Tuple(_) | Tp::List(_)) => { + Some(binary_op!(BoolE::TupleGt, lhs, rhs)) + } + (Op::Ge, Tp::Tuple(_) | Tp::List(_), Tp::Tuple(_) | Tp::List(_)) => { + Some(binary_op!(BoolE::TupleGe, lhs, rhs)) + } _ => None, } } @@ -1189,7 +1360,7 @@ impl PyExpression { #[staticmethod] fn convert_optional(expr: &PyExpression, source_type: Type, target_type: Type) -> Option { type Tp = Type; - match (source_type, target_type) { + match (&source_type, &target_type) { (Tp::Json, Tp::Int | Tp::Float | Tp::Bool | Tp::String) => { Some(unary_op!(AnyExpression::JsonToOptional, expr, target_type)) } @@ -1430,16 +1601,30 @@ impl PathwayType { pub const DATE_TIME_UTC: Type = Type::DateTimeUtc; #[classattr] pub const DURATION: Type = Type::Duration; - #[classattr] - pub const ARRAY: Type = Type::Array; + #[staticmethod] + #[pyo3(signature = (dim, wrapped))] + pub fn array(dim: Option, wrapped: Type) -> Type { + Type::Array(dim, wrapped.into()) + } #[classattr] pub const JSON: Type = Type::Json; - #[classattr] - pub const TUPLE: Type = Type::Tuple; + #[staticmethod] + #[pyo3(signature = (*args))] + pub fn tuple(args: Vec) -> Type { + Type::Tuple(args.into()) + } + #[staticmethod] + pub fn list(arg: Type) -> Type { + Type::List(arg.into()) + } #[classattr] pub const BYTES: Type = Type::Bytes; #[classattr] pub const PY_OBJECT_WRAPPER: Type = Type::PyObjectWrapper; + #[staticmethod] + pub fn optional(wrapped: Type) -> Type { + Type::Optional(wrapped.into()) + } } #[pyclass(module = "pathway.engine", frozen, name = "ReadMethod")] @@ -2265,6 +2450,7 @@ impl Scope { Column::new(universe, handle) } + #[allow(clippy::too_many_arguments)] pub fn async_apply_table( self_: &Bound, table: PyRef, @@ -2273,7 +2459,9 @@ impl Scope { propagate_none: bool, deterministic: bool, properties: TableProperties, + dtype: Type, ) -> PyResult> { + let dtype = Arc::new(dtype); let event_loop = self_.borrow().event_loop.clone(); let table_handle = self_.borrow().graph.async_apply_table( Arc::new(move |_, values: &[Value]| { @@ -2290,9 +2478,12 @@ impl Scope { pyo3_asyncio::into_future_with_locals(&locals, awaitable) }); - Box::pin(async { - let result = future?.await?; - Python::with_gil(|py| result.extract::(py).map_err(DynError::from)) + Box::pin({ + let dtype = dtype.clone(); + async { + let result = future?.await?; + Python::with_gil(move |py| Ok(extract_value(result.bind(py), &dtype)?)) + } }) }), table.handle, @@ -3459,7 +3650,6 @@ pub struct DataStorage { object_pattern: String, mock_events: Option>>, table_name: Option, - column_names: Option>, header_fields: Vec<(String, usize)>, key_field_index: Option, min_commit_frequency: Option, @@ -3702,32 +3892,30 @@ pub struct ValueField { #[pyo3(get)] pub type_: Type, #[pyo3(get)] - pub is_optional: bool, - #[pyo3(get)] pub default: Option, } impl ValueField { fn as_inner_schema_field(&self) -> InnerSchemaField { - InnerSchemaField::new(self.type_, self.is_optional, self.default.clone()) + InnerSchemaField::new(self.type_.clone(), self.default.clone()) } } #[pymethods] impl ValueField { #[new] - #[pyo3(signature = (name, type_, is_optional=false))] - fn new(name: String, type_: Type, is_optional: bool) -> Self { + #[pyo3(signature = (name, type_))] + fn new(name: String, type_: Type) -> Self { ValueField { name, type_, - is_optional, default: None, } } - fn set_default(&mut self, value: Value) { - self.default = Some(value); + fn set_default(&mut self, ob: &Bound) -> PyResult<()> { + self.default = Some(extract_value(ob, &self.type_)?); + Ok(()) } } @@ -3769,7 +3957,6 @@ impl DataStorage { object_pattern = "*".to_string(), mock_events = None, table_name = None, - column_names = None, header_fields = Vec::new(), key_field_index = None, min_commit_frequency = None, @@ -3794,7 +3981,6 @@ impl DataStorage { object_pattern: String, mock_events: Option>>, table_name: Option, - column_names: Option>, header_fields: Vec<(String, usize)>, key_field_index: Option, min_commit_frequency: Option, @@ -3818,7 +4004,6 @@ impl DataStorage { object_pattern, mock_events, table_name, - column_names, header_fields, key_field_index, min_commit_frequency, @@ -4147,6 +4332,7 @@ impl DataStorage { fn construct_python_reader( &self, py: pyo3::Python, + data_format: &DataFormat, ) -> PyResult<(Box, usize)> { let subject = self.python_subject.clone().ok_or_else(|| { PyValueError::new_err("For Python connector, python_subject should be specified") @@ -4158,11 +4344,19 @@ impl DataStorage { )); } - let reader = PythonReaderBuilder::new(subject, self.internal_persistent_id()); + let reader = PythonReaderBuilder::new( + subject, + self.internal_persistent_id(), + data_format.value_fields_type_map(py), + ); Ok((Box::new(reader), 1)) } - fn construct_sqlite_reader(&self) -> PyResult<(Box, usize)> { + fn construct_sqlite_reader( + &self, + py: pyo3::Python, + data_format: &DataFormat, + ) -> PyResult<(Box, usize)> { let connection = SqliteConnection::open_with_flags( self.path()?, SqliteOpenFlags::SQLITE_OPEN_READ_ONLY | SqliteOpenFlags::SQLITE_OPEN_NO_MUTEX, @@ -4171,10 +4365,12 @@ impl DataStorage { let table_name = self.table_name.clone().ok_or_else(|| { PyValueError::new_err("For Sqlite connector, table_name should be specified") })?; - let column_names = self.column_names.clone().ok_or_else(|| { - PyValueError::new_err("For Sqlite connector, column_names should be specified") - })?; - let reader = SqliteReader::new(connection, table_name, column_names); + + let reader = SqliteReader::new( + connection, + table_name, + data_format.value_fields_type_map(py).into_iter().collect(), + ); Ok((Box::new(reader), 1)) } @@ -4222,8 +4418,8 @@ impl DataStorage { "s3_csv" => self.construct_s3_csv_reader(py), "csv" => self.construct_csv_reader(py), "kafka" => self.construct_kafka_reader(), - "python" => self.construct_python_reader(py), - "sqlite" => self.construct_sqlite_reader(), + "python" => self.construct_python_reader(py, data_format), + "sqlite" => self.construct_sqlite_reader(py, data_format), "deltalake" => self.construct_deltalake_reader(py, data_format), other => Err(PyValueError::new_err(format!( "Unknown data source {other:?}" @@ -4387,10 +4583,10 @@ impl DataStorage { impl DataFormat { pub fn value_fields_type_map(&self, py: pyo3::Python) -> HashMap { - let mut result = HashMap::new(); + let mut result = HashMap::with_capacity(self.value_fields.len()); for field in &self.value_fields { let name = field.borrow(py).name.clone(); - let type_ = field.borrow(py).type_; + let type_ = field.borrow(py).type_.clone(); result.insert(name, type_); } result @@ -4451,7 +4647,7 @@ impl DataFormat { match self.format_type.as_ref() { "dsv" => { let settings = self.construct_dsv_settings(py)?; - Ok(settings.parser(self.schema(py)?)) + Ok(settings.parser(self.schema(py)?)?) } "debezium" => { let parser = DebeziumMessageParser::new( @@ -4470,7 +4666,7 @@ impl DataFormat { self.field_absence_is_error, self.schema(py)?, self.session_type, - ); + )?; Ok(Box::new(parser)) } "identity" => Ok(Box::new(IdentityParser::new( @@ -4484,7 +4680,7 @@ impl DataFormat { self.value_field_names(py), self.schema(py)?, self.session_type, - ))), + )?)), _ => Err(PyValueError::new_err("Unknown data format")), } } diff --git a/tests/integration/main.rs b/tests/integration/main.rs index 8324c4c9..88cc6c5d 100644 --- a/tests/integration/main.rs +++ b/tests/integration/main.rs @@ -27,5 +27,6 @@ mod test_sqlite; mod test_stream_snapshot; mod test_time; mod test_time_column; +mod test_types; mod test_upsert_session; mod test_value_to_sql; diff --git a/tests/integration/test_connector_field_defaults.rs b/tests/integration/test_connector_field_defaults.rs index 06b90efe..21941e09 100644 --- a/tests/integration/test_connector_field_defaults.rs +++ b/tests/integration/test_connector_field_defaults.rs @@ -13,16 +13,31 @@ use pathway_engine::connectors::data_storage::{ use pathway_engine::connectors::SessionType; use pathway_engine::engine::{Type, Value}; +fn get_schema_with_common_parts() -> HashMap { + [ + ( + "seq_id".to_string(), + InnerSchemaField::new(Type::String, None), + ), + ("key".to_string(), InnerSchemaField::new(Type::String, None)), + ( + "value".to_string(), + InnerSchemaField::new(Type::String, None), + ), + ] + .into() +} + #[test] fn test_dsv_with_default_end_of_line() -> eyre::Result<()> { let mut builder = csv::ReaderBuilder::new(); builder.has_headers(false); - let mut schema = HashMap::new(); + let mut schema = get_schema_with_common_parts(); schema.insert( "number".to_string(), - InnerSchemaField::new(Type::Int, false, Some(Value::Int(42))), + InnerSchemaField::new(Type::Int, Some(Value::Int(42))), ); let reader = CsvFilesystemReader::new( @@ -39,7 +54,7 @@ fn test_dsv_with_default_end_of_line() -> eyre::Result<()> { ',', ), schema, - ); + )?; let read_lines = read_data_from_reader(Box::new(reader), Box::new(parser))?; assert_eq!( @@ -72,11 +87,11 @@ fn test_dsv_with_default_middle_of_line() -> eyre::Result<()> { let mut builder = csv::ReaderBuilder::new(); builder.has_headers(false); - let mut schema = HashMap::new(); + let mut schema = get_schema_with_common_parts(); schema.insert( "number".to_string(), - InnerSchemaField::new(Type::Int, false, Some(Value::Int(42))), + InnerSchemaField::new(Type::Int, Some(Value::Int(42))), ); let reader = CsvFilesystemReader::new( @@ -93,7 +108,7 @@ fn test_dsv_with_default_middle_of_line() -> eyre::Result<()> { ',', ), schema, - ); + )?; let read_lines = read_data_from_reader(Box::new(reader), Box::new(parser))?; assert_eq!( @@ -126,11 +141,9 @@ fn test_dsv_fails_without_default() -> eyre::Result<()> { let mut builder = csv::ReaderBuilder::new(); builder.has_headers(false); - let mut schema = HashMap::new(); - schema.insert( - "number".to_string(), - InnerSchemaField::new(Type::Int, false, None), - ); + let mut schema = get_schema_with_common_parts(); + + schema.insert("number".to_string(), InnerSchemaField::new(Type::Int, None)); let reader = CsvFilesystemReader::new( "tests/data/dsv_with_skips.txt", @@ -146,7 +159,7 @@ fn test_dsv_fails_without_default() -> eyre::Result<()> { ',', ), schema, - ); + )?; let read_lines = read_data_from_reader(Box::new(reader), Box::new(parser))?; assert_eq!( @@ -179,11 +192,11 @@ fn test_dsv_with_default_nullable() -> eyre::Result<()> { let mut builder = csv::ReaderBuilder::new(); builder.has_headers(false); - let mut schema = HashMap::new(); + let mut schema = get_schema_with_common_parts(); schema.insert( "number".to_string(), - InnerSchemaField::new(Type::Int, false, Some(Value::None)), + InnerSchemaField::new(Type::Optional(Type::Int.into()), Some(Value::None)), ); let reader = CsvFilesystemReader::new( @@ -200,7 +213,7 @@ fn test_dsv_with_default_nullable() -> eyre::Result<()> { ',', ), schema, - ); + )?; let read_lines = read_data_from_reader(Box::new(reader), Box::new(parser))?; assert_eq!( @@ -228,8 +241,19 @@ fn test_dsv_with_default_nullable() -> eyre::Result<()> { Ok(()) } +fn get_schema_abc() -> HashMap { + [ + ("a".to_string(), InnerSchemaField::new(Type::String, None)), + ("b".to_string(), InnerSchemaField::new(Type::Int, None)), + ("c".to_string(), InnerSchemaField::new(Type::Int, None)), + ] + .into() +} + #[test] fn test_jsonlines_fails_without_default() -> eyre::Result<()> { + let mut schema = get_schema_abc(); + schema.insert("d".to_string(), InnerSchemaField::new(Type::Int, None)); let reader = FilesystemReader::new( "tests/data/jsonlines.txt", ConnectorMode::Static, @@ -242,9 +266,9 @@ fn test_jsonlines_fails_without_default() -> eyre::Result<()> { vec!["b".to_string(), "c".to_string(), "d".to_string()], HashMap::new(), true, - HashMap::new(), + schema, SessionType::Native, - ); + )?; let read_lines = read_data_from_reader(Box::new(reader), Box::new(parser))?; assert_eq!( @@ -271,10 +295,10 @@ fn test_jsonlines_fails_without_default() -> eyre::Result<()> { #[test] fn test_jsonlines_with_default() -> eyre::Result<()> { - let mut schema = HashMap::new(); + let mut schema = get_schema_abc(); schema.insert( "d".to_string(), - InnerSchemaField::new(Type::Int, false, Some(Value::Int(42))), + InnerSchemaField::new(Type::Int, Some(Value::Int(42))), ); let reader = FilesystemReader::new( @@ -291,7 +315,7 @@ fn test_jsonlines_with_default() -> eyre::Result<()> { true, schema, SessionType::Native, - ); + )?; let read_lines = read_data_from_reader(Box::new(reader), Box::new(parser))?; assert_eq!( @@ -318,10 +342,10 @@ fn test_jsonlines_with_default() -> eyre::Result<()> { #[test] fn test_jsonlines_with_default_at_jsonpath() -> eyre::Result<()> { - let mut schema = HashMap::new(); + let mut schema = get_schema_abc(); schema.insert( "d".to_string(), - InnerSchemaField::new(Type::Int, false, Some(Value::Int(42))), + InnerSchemaField::new(Type::Int, Some(Value::Int(42))), ); let mut routes = HashMap::new(); @@ -344,7 +368,7 @@ fn test_jsonlines_with_default_at_jsonpath() -> eyre::Result<()> { true, schema, SessionType::Native, - ); + )?; let read_lines = read_data_from_reader(Box::new(reader), Box::new(parser))?; assert_eq!( @@ -371,10 +395,10 @@ fn test_jsonlines_with_default_at_jsonpath() -> eyre::Result<()> { #[test] fn test_jsonlines_explicit_null_not_overridden() -> eyre::Result<()> { - let mut schema = HashMap::new(); + let mut schema = get_schema_abc(); schema.insert( "d".to_string(), - InnerSchemaField::new(Type::Int, true, Some(Value::Int(42))), + InnerSchemaField::new(Type::Optional(Type::Int.into()), Some(Value::Int(42))), ); let reader = FilesystemReader::new( @@ -391,7 +415,7 @@ fn test_jsonlines_explicit_null_not_overridden() -> eyre::Result<()> { true, schema, SessionType::Native, - ); + )?; let read_lines = read_data_from_reader(Box::new(reader), Box::new(parser))?; assert_eq!( diff --git a/tests/integration/test_deltalake.rs b/tests/integration/test_deltalake.rs index a03700de..c9977dcc 100644 --- a/tests/integration/test_deltalake.rs +++ b/tests/integration/test_deltalake.rs @@ -16,19 +16,20 @@ use pathway_engine::connectors::data_storage::{ ConnectorMode, DeltaTableReader, DeltaTableWriter, ObjectDownloader, WriteError, Writer, }; use pathway_engine::connectors::SessionType; -use pathway_engine::engine::{DateTimeNaive, DateTimeUtc, Duration, Key, Timestamp, Type, Value}; +use pathway_engine::engine::{ + DateTimeNaive, DateTimeUtc, Duration, Key, Result, Timestamp, Type, Value, +}; use pathway_engine::python_api::ValueField; use crate::helpers::read_data_from_reader; -fn run_single_column_save(type_: Type, values: &[Value]) -> Result<(), WriteError> { +fn run_single_column_save(type_: Type, values: &[Value]) -> eyre::Result<()> { let test_storage = tempdir().expect("tempdir creation failed"); let test_storage_path = test_storage.path(); let value_fields = vec![ValueField { name: "field".to_string(), - type_, - is_optional: true, + type_: Type::Optional(type_.clone().into()), default: None, }]; @@ -47,20 +48,20 @@ fn run_single_column_save(type_: Type, values: &[Value]) -> Result<(), WriteErro writer.write(context)?; } writer.flush(true)?; - let rows_present = read_from_deltalake(test_storage_path.to_str().unwrap(), type_); + let rows_present = read_from_deltalake(test_storage_path.to_str().unwrap(), &type_); assert_eq!(rows_present, values); - let rows_roundtrip = read_with_connector(test_storage_path.to_str().unwrap(), type_); + let rows_roundtrip = read_with_connector(test_storage_path.to_str().unwrap(), type_)?; assert_eq!(rows_roundtrip, values); Ok(()) } -fn read_with_connector(path: &str, type_: Type) -> Vec { +fn read_with_connector(path: &str, type_: Type) -> Result> { let mut schema = HashMap::new(); schema.insert( "field".to_string(), - InnerSchemaField::new(type_, true, None), + InnerSchemaField::new(Type::Optional(type_.clone().into()), None), ); let mut type_map = HashMap::new(); type_map.insert("field".to_string(), type_); @@ -74,7 +75,7 @@ fn read_with_connector(path: &str, type_: Type) -> Vec { ) .unwrap(); let parser = - TransparentParser::new(None, vec!["field".to_string()], schema, SessionType::Native); + TransparentParser::new(None, vec!["field".to_string()], schema, SessionType::Native)?; let values_read = read_data_from_reader(Box::new(reader), Box::new(parser)).unwrap(); let mut result = Vec::new(); for event in values_read { @@ -85,10 +86,10 @@ fn read_with_connector(path: &str, type_: Type) -> Vec { let value = values[0].clone(); result.push(value); } - result + Ok(result) } -fn read_from_deltalake(path: &str, type_: Type) -> Vec { +fn read_from_deltalake(path: &str, type_: &Type) -> Vec { let mut reread_values = Vec::new(); tokio::runtime::Builder::new_current_thread() .enable_all() @@ -132,7 +133,7 @@ fn read_from_deltalake(path: &str, type_: Type) -> Vec { (ParquetField::TimestampMicros(us), Type::DateTimeNaive) => Value::from(DateTimeNaive::from_timestamp(*us, "us").unwrap()), (ParquetField::TimestampMicros(us), Type::DateTimeUtc) => Value::from(DateTimeUtc::from_timestamp(*us, "us").unwrap()), (ParquetField::Bytes(b), Type::Bytes) => Value::Bytes(b.data().into()), - _ => panic!("Pathway shouldn't have serialized field of type {type_:?} as {field:?}"), + (field, type_) => panic!("Pathway shouldn't have serialized field of type {type_:?} as {field:?}"), }; reread_values.push(parsed_value); } @@ -145,98 +146,93 @@ fn read_from_deltalake(path: &str, type_: Type) -> Vec { #[test] fn test_save_bool() -> eyre::Result<()> { - Ok(run_single_column_save( - Type::Bool, - &[Value::Bool(true), Value::Bool(false)], - )?) + run_single_column_save(Type::Bool, &[Value::Bool(true), Value::Bool(false)]) } #[test] fn test_save_int() -> eyre::Result<()> { - Ok(run_single_column_save(Type::Int, &[Value::Int(10)])?) + run_single_column_save(Type::Int, &[Value::Int(10)]) } #[test] fn test_save_float() -> eyre::Result<()> { - Ok(run_single_column_save( - Type::Float, - &[Value::Float(0.01.into())], - )?) + run_single_column_save(Type::Float, &[Value::Float(0.01.into())]) } #[test] fn test_save_string() -> eyre::Result<()> { - Ok(run_single_column_save( + run_single_column_save( Type::String, &[Value::String("abc".into()), Value::String("".into())], - )?) + ) } #[test] fn test_save_bytes() -> eyre::Result<()> { let test_bytes: &[u8] = &[1, 10, 5, 19, 55, 67, 9, 87, 28]; - Ok(run_single_column_save( + run_single_column_save( Type::Bytes, &[Value::Bytes([].into()), Value::Bytes(test_bytes.into())], - )?) + ) } #[test] fn test_save_datetimenaive() -> eyre::Result<()> { - Ok(run_single_column_save( + run_single_column_save( Type::DateTimeNaive, &[ Value::DateTimeNaive(DateTimeNaive::from_timestamp(0, "s")?), Value::DateTimeNaive(DateTimeNaive::from_timestamp(10000, "s")?), Value::DateTimeNaive(DateTimeNaive::from_timestamp(-10000, "s")?), ], - )?) + ) } #[test] fn test_save_datetimeutc() -> eyre::Result<()> { - Ok(run_single_column_save( + run_single_column_save( Type::DateTimeUtc, &[ Value::DateTimeUtc(DateTimeUtc::new(0)), Value::DateTimeUtc(DateTimeUtc::new(10_000_000_000_000)), Value::DateTimeUtc(DateTimeUtc::new(-10_000_000_000_000)), ], - )?) + ) } #[test] fn test_save_duration() -> eyre::Result<()> { - Ok(run_single_column_save( + run_single_column_save( Type::Duration, &[ Value::Duration(Duration::new(0)), Value::Duration(Duration::new(10_000_000_000_000)), Value::Duration(Duration::new(-10_000_000_000_000)), ], - )?) + ) } #[test] fn test_save_json() -> eyre::Result<()> { - Ok(run_single_column_save( - Type::Json, - &[Value::from(json!({"A": 100}))], - )?) + run_single_column_save(Type::Json, &[Value::from(json!({"A": 100}))]) } #[test] fn test_unsupported_types_fail_as_expected() -> eyre::Result<()> { let unsupported_types = &[ Type::Any, - Type::Array, + Type::Array(Some(2), Type::Int.into()), Type::PyObjectWrapper, - Type::Tuple, + Type::Tuple([].into()), Type::Pointer, ]; for t in unsupported_types { - let save_result = run_single_column_save(*t, &[]); - assert_matches!(save_result, Err(WriteError::UnsupportedType(_))); + let save_result = run_single_column_save(t.clone(), &[]); + assert!(save_result.is_err()); + assert_matches!( + save_result.err().unwrap().downcast::(), + Ok(WriteError::UnsupportedType(_)) + ); } Ok(()) } diff --git a/tests/integration/test_dsv.rs b/tests/integration/test_dsv.rs index 7395ff5e..7a54737b 100644 --- a/tests/integration/test_dsv.rs +++ b/tests/integration/test_dsv.rs @@ -27,8 +27,12 @@ fn test_dsv_read_ok() -> eyre::Result<()> { )?; let mut parser = DsvParser::new( DsvSettings::new(Some(vec!["a".to_string()]), vec!["b".to_string()], ','), - HashMap::new(), - ); + [ + ("a".to_string(), InnerSchemaField::new(Type::String, None)), + ("b".to_string(), InnerSchemaField::new(Type::String, None)), + ] + .into(), + )?; reader.read()?; let header_read_result = reader.read()?; @@ -75,8 +79,12 @@ fn test_dsv_column_does_not_exist() -> eyre::Result<()> { )?; let parser = DsvParser::new( DsvSettings::new(Some(vec!["a".to_string()]), vec!["c".to_string()], ','), - HashMap::new(), - ); + [ + ("a".to_string(), InnerSchemaField::new(Type::Int, None)), + ("c".to_string(), InnerSchemaField::new(Type::Int, None)), + ] + .into(), + )?; assert_error_shown( Box::new(reader), @@ -99,8 +107,12 @@ fn test_dsv_rows_parsing_ignore_type() -> eyre::Result<()> { )?; let mut parser = DsvParser::new( DsvSettings::new(Some(vec!["a".to_string()]), vec!["b".to_string()], ','), - HashMap::new(), - ); + [ + ("a".to_string(), InnerSchemaField::new(Type::String, None)), + ("b".to_string(), InnerSchemaField::new(Type::Int, None)), + ] + .into(), + )?; reader.read()?; let header_read_result = reader.read()?; @@ -135,8 +147,12 @@ fn test_dsv_not_enough_columns() -> eyre::Result<()> { )?; let mut parser = DsvParser::new( DsvSettings::new(Some(vec!["a".to_string()]), vec!["b".to_string()], ','), - HashMap::new(), - ); + [ + ("a".to_string(), InnerSchemaField::new(Type::Int, None)), + ("b".to_string(), InnerSchemaField::new(Type::Int, None)), + ] + .into(), + )?; let _ = reader .read() @@ -181,8 +197,12 @@ fn test_dsv_autogenerate_pkey() -> eyre::Result<()> { )?; let mut parser = DsvParser::new( DsvSettings::new(None, vec!["a".to_string(), "b".to_string()], ','), - HashMap::new(), - ); + [ + ("a".to_string(), InnerSchemaField::new(Type::Int, None)), + ("b".to_string(), InnerSchemaField::new(Type::Int, None)), + ] + .into(), + )?; let mut keys: HashSet = HashSet::new(); @@ -229,8 +249,13 @@ fn test_dsv_composite_pkey() -> eyre::Result<()> { vec!["c".to_string()], ',', ), - HashMap::new(), - ); + [ + ("a".to_string(), InnerSchemaField::new(Type::Int, None)), + ("b".to_string(), InnerSchemaField::new(Type::Int, None)), + ("c".to_string(), InnerSchemaField::new(Type::Int, None)), + ] + .into(), + )?; let mut keys = Vec::new(); @@ -269,21 +294,16 @@ fn test_dsv_composite_pkey() -> eyre::Result<()> { #[test] fn test_dsv_read_schema_ok() -> eyre::Result<()> { let mut schema = HashMap::new(); - schema.insert( - "bool".to_string(), - InnerSchemaField::new(Type::Bool, false, None), - ); - schema.insert( - "int".to_string(), - InnerSchemaField::new(Type::Int, false, None), - ); + schema.insert("key".to_string(), InnerSchemaField::new(Type::String, None)); + schema.insert("bool".to_string(), InnerSchemaField::new(Type::Bool, None)); + schema.insert("int".to_string(), InnerSchemaField::new(Type::Int, None)); schema.insert( "float".to_string(), - InnerSchemaField::new(Type::Float, false, None), + InnerSchemaField::new(Type::Float, None), ); schema.insert( "string".to_string(), - InnerSchemaField::new(Type::String, false, None), + InnerSchemaField::new(Type::String, None), ); let mut reader = FilesystemReader::new( @@ -305,7 +325,7 @@ fn test_dsv_read_schema_ok() -> eyre::Result<()> { ',', ), schema, - ); + )?; reader.read()?; let header_read_result = reader.read()?; @@ -349,21 +369,15 @@ fn test_dsv_read_schema_ok() -> eyre::Result<()> { #[test] fn test_dsv_read_schema_nonparsable() -> eyre::Result<()> { let mut schema = HashMap::new(); - schema.insert( - "bool".to_string(), - InnerSchemaField::new(Type::Bool, false, None), - ); - schema.insert( - "int".to_string(), - InnerSchemaField::new(Type::Int, false, None), - ); + schema.insert("bool".to_string(), InnerSchemaField::new(Type::Bool, None)); + schema.insert("int".to_string(), InnerSchemaField::new(Type::Int, None)); schema.insert( "float".to_string(), - InnerSchemaField::new(Type::Float, false, None), + InnerSchemaField::new(Type::Float, None), ); schema.insert( "string".to_string(), - InnerSchemaField::new(Type::String, false, None), + InnerSchemaField::new(Type::String, None), ); let mut reader = FilesystemReader::new( @@ -385,7 +399,7 @@ fn test_dsv_read_schema_nonparsable() -> eyre::Result<()> { ',', ), schema, - ); + )?; reader.read()?; let header_read_result = reader.read()?; @@ -404,7 +418,7 @@ fn test_dsv_read_schema_nonparsable() -> eyre::Result<()> { assert_error_shown_for_reader_context( &bytes, Box::new(parser), - r#"failed to parse value "zzz" at field "int" according to the type Int in schema: invalid digit found in string"#, + r#"failed to parse value "zzz" at field "int" according to the type int in schema: invalid digit found in string"#, ErrorPlacement::Value(1), ); } diff --git a/tests/integration/test_dsv_dir.rs b/tests/integration/test_dsv_dir.rs index daf1d5ee..6fbbd564 100644 --- a/tests/integration/test_dsv_dir.rs +++ b/tests/integration/test_dsv_dir.rs @@ -2,18 +2,21 @@ use super::helpers::read_data_from_reader; -use std::collections::HashMap; - -use pathway_engine::connectors::data_format::ParsedEvent; use pathway_engine::connectors::data_format::{DsvParser, DsvSettings}; +use pathway_engine::connectors::data_format::{InnerSchemaField, ParsedEvent}; use pathway_engine::connectors::data_storage::{ConnectorMode, CsvFilesystemReader}; -use pathway_engine::engine::Value; +use pathway_engine::engine::{Type, Value}; #[test] fn test_dsv_dir_ok() -> eyre::Result<()> { let mut builder = csv::ReaderBuilder::new(); builder.has_headers(false); + let schema = [ + ("key".to_string(), InnerSchemaField::new(Type::String, None)), + ("foo".to_string(), InnerSchemaField::new(Type::String, None)), + ]; + let reader = CsvFilesystemReader::new( "tests/data/csvdir", builder, @@ -23,8 +26,8 @@ fn test_dsv_dir_ok() -> eyre::Result<()> { )?; let parser = DsvParser::new( DsvSettings::new(Some(vec!["key".to_string()]), vec!["foo".to_string()], ','), - HashMap::new(), - ); + schema.into(), + )?; let read_lines = read_data_from_reader(Box::new(reader), Box::new(parser))?; @@ -47,6 +50,11 @@ fn test_single_file_ok() -> eyre::Result<()> { let mut builder = csv::ReaderBuilder::new(); builder.has_headers(false); + let schema = [ + ("a".to_string(), InnerSchemaField::new(Type::String, None)), + ("b".to_string(), InnerSchemaField::new(Type::String, None)), + ]; + let reader = CsvFilesystemReader::new( "tests/data/sample.txt", builder, @@ -56,8 +64,8 @@ fn test_single_file_ok() -> eyre::Result<()> { )?; let parser = DsvParser::new( DsvSettings::new(Some(vec!["a".to_string()]), vec!["b".to_string()], ','), - HashMap::new(), - ); + schema.into(), + )?; let read_lines = read_data_from_reader(Box::new(reader), Box::new(parser))?; @@ -72,6 +80,15 @@ fn test_custom_delimiter() -> eyre::Result<()> { builder.delimiter(b'+'); builder.has_headers(false); + let schema = [ + ("key".to_string(), InnerSchemaField::new(Type::String, None)), + ("foo".to_string(), InnerSchemaField::new(Type::String, None)), + ( + "foofoo".to_string(), + InnerSchemaField::new(Type::String, None), + ), + ]; + let reader = CsvFilesystemReader::new( "tests/data/sql_injection.txt", builder, @@ -85,8 +102,8 @@ fn test_custom_delimiter() -> eyre::Result<()> { vec!["foo".to_string(), "foofoo".to_string()], '+', ), - HashMap::new(), - ); + schema.into(), + )?; let read_lines = read_data_from_reader(Box::new(reader), Box::new(parser))?; assert_eq!(read_lines.len(), 2); @@ -99,6 +116,18 @@ fn test_escape_fields() -> eyre::Result<()> { let mut builder = csv::ReaderBuilder::new(); builder.has_headers(false); + let schema = [ + ("key".to_string(), InnerSchemaField::new(Type::String, None)), + ( + "value,with,comma".to_string(), + InnerSchemaField::new(Type::String, None), + ), + ( + "some other value".to_string(), + InnerSchemaField::new(Type::String, None), + ), + ]; + let reader = CsvFilesystemReader::new( "tests/data/csv_fields_escaped.txt", builder, @@ -115,8 +144,8 @@ fn test_escape_fields() -> eyre::Result<()> { ], ',', ), - HashMap::new(), - ); + schema.into(), + )?; let read_lines = read_data_from_reader(Box::new(reader), Box::new(parser))?; @@ -145,6 +174,14 @@ fn test_escape_newlines() -> eyre::Result<()> { let mut builder = csv::ReaderBuilder::new(); builder.has_headers(false); + let schema = [ + ("key".to_string(), InnerSchemaField::new(Type::String, None)), + ( + "value".to_string(), + InnerSchemaField::new(Type::String, None), + ), + ]; + let reader = CsvFilesystemReader::new( "tests/data/csv_escaped_newlines.txt", builder, @@ -158,8 +195,8 @@ fn test_escape_newlines() -> eyre::Result<()> { vec!["value".to_string()], ',', ), - HashMap::new(), - ); + schema.into(), + )?; let read_lines = read_data_from_reader(Box::new(reader), Box::new(parser))?; @@ -202,6 +239,18 @@ fn test_special_fields() -> eyre::Result<()> { let mut builder = csv::ReaderBuilder::new(); builder.has_headers(false); + let schema = [ + ("key".to_string(), InnerSchemaField::new(Type::String, None)), + ( + "value".to_string(), + InnerSchemaField::new(Type::String, None), + ), + ( + "data".to_string(), + InnerSchemaField::new(Type::String, None), + ), + ]; + let reader = CsvFilesystemReader::new( "tests/data/csv_special_fields.txt", builder, @@ -215,8 +264,8 @@ fn test_special_fields() -> eyre::Result<()> { vec!["value".to_string(), "data".to_string()], ',', ), - HashMap::new(), - ); + schema.into(), + )?; let read_lines = read_data_from_reader(Box::new(reader), Box::new(parser))?; diff --git a/tests/integration/test_jsonlines.rs b/tests/integration/test_jsonlines.rs index bd60adc0..f7e8b3c6 100644 --- a/tests/integration/test_jsonlines.rs +++ b/tests/integration/test_jsonlines.rs @@ -8,10 +8,10 @@ use std::collections::HashMap; use std::sync::Arc; -use pathway_engine::connectors::data_format::{JsonLinesParser, ParsedEvent}; +use pathway_engine::connectors::data_format::{InnerSchemaField, JsonLinesParser, ParsedEvent}; use pathway_engine::connectors::data_storage::{ConnectorMode, FilesystemReader, ReadMethod}; use pathway_engine::connectors::SessionType; -use pathway_engine::engine::Value; +use pathway_engine::engine::{Type, Value}; #[test] fn test_jsonlines_ok() -> eyre::Result<()> { @@ -22,14 +22,19 @@ fn test_jsonlines_ok() -> eyre::Result<()> { ReadMethod::ByLine, "*", )?; + let schema = [ + ("a".to_string(), InnerSchemaField::new(Type::String, None)), + ("b".to_string(), InnerSchemaField::new(Type::Int, None)), + ("c".to_string(), InnerSchemaField::new(Type::Int, None)), + ]; let parser = JsonLinesParser::new( Some(vec!["a".to_string()]), vec!["b".to_string(), "c".to_string()], HashMap::new(), true, - HashMap::new(), + schema.into(), SessionType::Native, - ); + )?; let entries = read_data_from_reader(Box::new(reader), Box::new(parser))?; @@ -62,14 +67,20 @@ fn test_jsonlines_incorrect_key() -> eyre::Result<()> { ReadMethod::ByLine, "*", )?; + let schema = [ + ("a".to_string(), InnerSchemaField::new(Type::String, None)), + ("b".to_string(), InnerSchemaField::new(Type::Int, None)), + ("c".to_string(), InnerSchemaField::new(Type::Int, None)), + ("d".to_string(), InnerSchemaField::new(Type::Int, None)), + ]; let parser = JsonLinesParser::new( Some(vec!["a".to_string(), "d".to_string()]), vec!["b".to_string(), "c".to_string()], HashMap::new(), true, - HashMap::new(), + schema.into(), SessionType::Native, - ); + )?; assert_error_shown( Box::new(reader), @@ -90,14 +101,20 @@ fn test_jsonlines_incomplete_key_to_null() -> eyre::Result<()> { ReadMethod::ByLine, "*", )?; + let schema = [ + ("a".to_string(), InnerSchemaField::new(Type::String, None)), + ("b".to_string(), InnerSchemaField::new(Type::Int, None)), + ("c".to_string(), InnerSchemaField::new(Type::Int, None)), + ("d".to_string(), InnerSchemaField::new(Type::Int, None)), + ]; let parser = JsonLinesParser::new( Some(vec!["a".to_string(), "d".to_string()]), vec!["b".to_string(), "c".to_string()], HashMap::new(), false, - HashMap::new(), + schema.into(), SessionType::Native, - ); + )?; let entries = read_data_from_reader(Box::new(reader), Box::new(parser))?; assert_eq!(entries.len(), 4); @@ -114,14 +131,19 @@ fn test_jsonlines_incorrect_values() -> eyre::Result<()> { ReadMethod::ByLine, "*", )?; + let schema = [ + ("a".to_string(), InnerSchemaField::new(Type::String, None)), + ("b".to_string(), InnerSchemaField::new(Type::Int, None)), + ("qqq".to_string(), InnerSchemaField::new(Type::Int, None)), + ]; let parser = JsonLinesParser::new( Some(vec!["a".to_string()]), vec!["b".to_string(), "qqq".to_string()], HashMap::new(), true, - HashMap::new(), + schema.into(), SessionType::Native, - ); + )?; assert_error_shown( Box::new(reader), @@ -142,6 +164,49 @@ fn test_jsonlines_types_parsing() -> eyre::Result<()> { ReadMethod::ByLine, "*", )?; + let schema = [ + ("a".to_string(), InnerSchemaField::new(Type::String, None)), + ( + "float".to_string(), + InnerSchemaField::new(Type::Float, None), + ), + ( + "int_positive".to_string(), + InnerSchemaField::new(Type::Int, None), + ), + ( + "int_negative".to_string(), + InnerSchemaField::new(Type::Int, None), + ), + ( + "string".to_string(), + InnerSchemaField::new(Type::String, None), + ), + ( + "array".to_string(), + InnerSchemaField::new( + Type::Tuple( + [ + Type::String, + Type::Int, + Type::Int, + Type::Float, + Type::Tuple([].into()), + ] + .into(), + ), + None, + ), + ), + ( + "bool_true".to_string(), + InnerSchemaField::new(Type::Bool, None), + ), + ( + "bool_false".to_string(), + InnerSchemaField::new(Type::Bool, None), + ), + ]; let parser = JsonLinesParser::new( Some(vec!["a".to_string()]), vec![ @@ -155,9 +220,9 @@ fn test_jsonlines_types_parsing() -> eyre::Result<()> { ], HashMap::new(), true, - HashMap::new(), + schema.into(), SessionType::Native, - ); + )?; let entries = read_data_from_reader(Box::new(reader), Box::new(parser))?; @@ -200,6 +265,24 @@ fn test_jsonlines_complex_paths() -> eyre::Result<()> { routes.insert("pet_name".to_string(), "/pet/name".to_string()); routes.insert("pet_height".to_string(), "/pet/measurements/1".to_string()); + let schema = [ + ( + "owner".to_string(), + InnerSchemaField::new(Type::String, None), + ), + ( + "pet_kind".to_string(), + InnerSchemaField::new(Type::String, None), + ), + ( + "pet_name".to_string(), + InnerSchemaField::new(Type::String, None), + ), + ( + "pet_height".to_string(), + InnerSchemaField::new(Type::Int, None), + ), + ]; let parser = JsonLinesParser::new( None, vec![ @@ -210,9 +293,9 @@ fn test_jsonlines_complex_paths() -> eyre::Result<()> { ], routes, true, - HashMap::new(), + schema.into(), SessionType::Native, - ); + )?; let entries = read_data_from_reader(Box::new(reader), Box::new(parser))?; @@ -251,6 +334,24 @@ fn test_jsonlines_complex_paths_error() -> eyre::Result<()> { "*", )?; + let schema = [ + ( + "owner".to_string(), + InnerSchemaField::new(Type::String, None), + ), + ( + "pet_kind".to_string(), + InnerSchemaField::new(Type::String, None), + ), + ( + "pet_name".to_string(), + InnerSchemaField::new(Type::String, None), + ), + ( + "pet_height".to_string(), + InnerSchemaField::new(Type::Int, None), + ), + ]; let mut routes = HashMap::new(); routes.insert("owner".to_string(), "/name".to_string()); routes.insert("pet_kind".to_string(), "/pet/animal".to_string()); @@ -270,9 +371,9 @@ fn test_jsonlines_complex_paths_error() -> eyre::Result<()> { ], routes, true, - HashMap::new(), + schema.into(), SessionType::Native, - ); + )?; assert_error_shown( Box::new(reader), @@ -294,6 +395,24 @@ fn test_jsonlines_complex_path_ignore_errors() -> eyre::Result<()> { "*", )?; + let schema = [ + ( + "owner".to_string(), + InnerSchemaField::new(Type::String, None), + ), + ( + "pet_kind".to_string(), + InnerSchemaField::new(Type::String, None), + ), + ( + "pet_name".to_string(), + InnerSchemaField::new(Type::String, None), + ), + ( + "pet_height".to_string(), + InnerSchemaField::new(Type::Int, None), + ), + ]; let mut routes = HashMap::new(); routes.insert("owner".to_string(), "/name".to_string()); routes.insert("pet_kind".to_string(), "/pet/animal".to_string()); @@ -313,9 +432,9 @@ fn test_jsonlines_complex_path_ignore_errors() -> eyre::Result<()> { ], routes, false, - HashMap::new(), + schema.into(), SessionType::Native, - ); + )?; let entries = read_data_from_reader(Box::new(reader), Box::new(parser))?; assert_eq!(entries.len(), 2); @@ -332,14 +451,20 @@ fn test_jsonlines_incorrect_key_verbose_error() -> eyre::Result<()> { ReadMethod::ByLine, "*", )?; + let schema = [ + ("a".to_string(), InnerSchemaField::new(Type::String, None)), + ("b".to_string(), InnerSchemaField::new(Type::Int, None)), + ("c".to_string(), InnerSchemaField::new(Type::Int, None)), + ("d".to_string(), InnerSchemaField::new(Type::Int, None)), + ]; let parser = JsonLinesParser::new( Some(vec!["a".to_string(), "d".to_string()]), vec!["b".to_string(), "c".to_string()], HashMap::new(), true, - HashMap::new(), + schema.into(), SessionType::Native, - ); + )?; assert_error_shown( Box::new(reader), @@ -363,14 +488,20 @@ fn test_jsonlines_incorrect_jsonpointer_verbose_error() -> eyre::Result<()> { ReadMethod::ByLine, "*", )?; + let schema = [ + ("a".to_string(), InnerSchemaField::new(Type::String, None)), + ("b".to_string(), InnerSchemaField::new(Type::Int, None)), + ("c".to_string(), InnerSchemaField::new(Type::Int, None)), + ("d".to_string(), InnerSchemaField::new(Type::Int, None)), + ]; let parser = JsonLinesParser::new( Some(vec!["a".to_string(), "d".to_string()]), vec!["b".to_string(), "c".to_string()], routes, true, - HashMap::new(), + schema.into(), SessionType::Native, - ); + )?; assert_error_shown( Box::new(reader), @@ -391,19 +522,20 @@ fn test_jsonlines_failed_to_parse_field() -> eyre::Result<()> { ReadMethod::ByLine, "*", )?; + let schema = [("pet".to_string(), InnerSchemaField::new(Type::Any, None))]; let parser = JsonLinesParser::new( None, vec!["pet".to_string()], HashMap::new(), true, - HashMap::new(), + schema.into(), SessionType::Native, - ); + )?; assert_error_shown( Box::new(reader), Box::new(parser), - r#"failed to create a field "pet" with type Any from the following json payload: {"animal":"dog","measurements":[200,400,600],"name":"Alice"}"#, + r#"failed to create a field "pet" with type Any from json payload: {"animal":"dog","measurements":[200,400,600],"name":"Alice"}"#, ErrorPlacement::Value(0), ); diff --git a/tests/integration/test_metadata.rs b/tests/integration/test_metadata.rs index 6fd786e1..bf731433 100644 --- a/tests/integration/test_metadata.rs +++ b/tests/integration/test_metadata.rs @@ -5,13 +5,14 @@ use super::helpers::read_data_from_reader; use std::collections::HashMap; use pathway_engine::connectors::data_format::{ - DsvParser, DsvSettings, IdentityParser, JsonLinesParser, KeyGenerationPolicy, ParsedEvent, + DsvParser, DsvSettings, IdentityParser, InnerSchemaField, JsonLinesParser, KeyGenerationPolicy, + ParsedEvent, }; use pathway_engine::connectors::data_storage::{ ConnectorMode, CsvFilesystemReader, FilesystemReader, ReadMethod, }; use pathway_engine::connectors::SessionType; -use pathway_engine::engine::Value; +use pathway_engine::engine::{Type, Value}; /// This function requires that _metadata field is the last in the `value_names_list` fn check_file_name_in_metadata(data_read: &ParsedEvent, name: &str) { @@ -36,6 +37,14 @@ fn test_metadata_fs_dir() -> eyre::Result<()> { ReadMethod::ByLine, "*", )?; + let schema = [ + ("key".to_string(), InnerSchemaField::new(Type::Int, None)), + ("foo".to_string(), InnerSchemaField::new(Type::String, None)), + ( + "_metadata".to_string(), + InnerSchemaField::new(Type::Json, None), + ), + ]; let parser = DsvParser::new( DsvSettings::new( Some(vec!["key".to_string()]), @@ -46,8 +55,8 @@ fn test_metadata_fs_dir() -> eyre::Result<()> { ], ',', ), - HashMap::new(), - ); + schema.into(), + )?; let data_read = read_data_from_reader(Box::new(reader), Box::new(parser))?; check_file_name_in_metadata(&data_read[0], "tests/data/csvdir/a.txt\""); @@ -66,6 +75,14 @@ fn test_metadata_fs_file() -> eyre::Result<()> { ReadMethod::ByLine, "*", )?; + let schema = [ + ("key".to_string(), InnerSchemaField::new(Type::Int, None)), + ("foo".to_string(), InnerSchemaField::new(Type::String, None)), + ( + "_metadata".to_string(), + InnerSchemaField::new(Type::Json, None), + ), + ]; let parser = DsvParser::new( DsvSettings::new( Some(vec!["key".to_string()]), @@ -76,8 +93,8 @@ fn test_metadata_fs_file() -> eyre::Result<()> { ], ',', ), - HashMap::new(), - ); + schema.into(), + )?; let data_read = read_data_from_reader(Box::new(reader), Box::new(parser))?; check_file_name_in_metadata(&data_read[0], "tests/data/minimal.txt\""); @@ -97,6 +114,14 @@ fn test_metadata_csv_dir() -> eyre::Result<()> { None, "*", )?; + let schema = [ + ("key".to_string(), InnerSchemaField::new(Type::Int, None)), + ("foo".to_string(), InnerSchemaField::new(Type::String, None)), + ( + "_metadata".to_string(), + InnerSchemaField::new(Type::Json, None), + ), + ]; let parser = DsvParser::new( DsvSettings::new( Some(vec!["key".to_string()]), @@ -107,8 +132,8 @@ fn test_metadata_csv_dir() -> eyre::Result<()> { ], ',', ), - HashMap::new(), - ); + schema.into(), + )?; let data_read = read_data_from_reader(Box::new(reader), Box::new(parser))?; check_file_name_in_metadata(&data_read[0], "tests/data/csvdir/a.txt\""); @@ -130,6 +155,14 @@ fn test_metadata_csv_file() -> eyre::Result<()> { None, "*", )?; + let schema = [ + ("key".to_string(), InnerSchemaField::new(Type::Int, None)), + ("foo".to_string(), InnerSchemaField::new(Type::String, None)), + ( + "_metadata".to_string(), + InnerSchemaField::new(Type::Json, None), + ), + ]; let parser = DsvParser::new( DsvSettings::new( Some(vec!["key".to_string()]), @@ -140,8 +173,8 @@ fn test_metadata_csv_file() -> eyre::Result<()> { ], ',', ), - HashMap::new(), - ); + schema.into(), + )?; let data_read = read_data_from_reader(Box::new(reader), Box::new(parser))?; check_file_name_in_metadata(&data_read[0], "tests/data/minimal.txt\""); @@ -158,14 +191,21 @@ fn test_metadata_json_file() -> eyre::Result<()> { ReadMethod::ByLine, "*", )?; + let schema = [ + ("a".to_string(), InnerSchemaField::new(Type::String, None)), + ( + "_metadata".to_string(), + InnerSchemaField::new(Type::Json, None), + ), + ]; let parser = JsonLinesParser::new( None, vec!["a".to_string(), "_metadata".to_string()], HashMap::new(), false, - HashMap::new(), + schema.into(), SessionType::Native, - ); + )?; let data_read = read_data_from_reader(Box::new(reader), Box::new(parser))?; check_file_name_in_metadata(&data_read[0], "tests/data/jsonlines.txt\""); @@ -182,14 +222,21 @@ fn test_metadata_json_dir() -> eyre::Result<()> { ReadMethod::ByLine, "*", )?; + let schema = [ + ("a".to_string(), InnerSchemaField::new(Type::String, None)), + ( + "_metadata".to_string(), + InnerSchemaField::new(Type::Json, None), + ), + ]; let parser = JsonLinesParser::new( None, vec!["a".to_string(), "_metadata".to_string()], HashMap::new(), false, - HashMap::new(), + schema.into(), SessionType::Native, - ); + )?; let data_read = read_data_from_reader(Box::new(reader), Box::new(parser))?; check_file_name_in_metadata(&data_read[0], "tests/data/jsonlines/one.jsonlines\""); diff --git a/tests/integration/test_parser.rs b/tests/integration/test_parser.rs index 30360912..328a451e 100644 --- a/tests/integration/test_parser.rs +++ b/tests/integration/test_parser.rs @@ -15,36 +15,34 @@ use pathway_engine::engine::{Type, Value}; #[test] fn test_transparent_parser() -> eyre::Result<()> { let value_field_names = vec!["a".to_owned(), "b".to_owned()]; - let schema = HashMap::from([ - ( - "a".to_owned(), - InnerSchemaField::new(Type::Int, false, None), - ), + let schema = [ + ("a".to_owned(), InnerSchemaField::new(Type::Int, None)), ( "b".to_owned(), - InnerSchemaField::new(Type::String, true, None), + InnerSchemaField::new(Type::Optional(Type::String.into()), None), ), - ]); - let mut parser = TransparentParser::new(None, value_field_names, schema, SessionType::Native); + ]; + let mut parser = + TransparentParser::new(None, value_field_names, schema.into(), SessionType::Native)?; let contexts = vec![ ReaderContext::from_diff( DataEventType::Insert, None, HashMap::from([ - ("a".to_owned(), Value::Int(3)), - ("b".to_owned(), Value::from("abc")), + ("a".to_owned(), Ok(Value::Int(3))), + ("b".to_owned(), Ok(Value::from("abc"))), ]) .into(), ), ReaderContext::from_diff( DataEventType::Insert, None, - HashMap::from([("b".to_owned(), Value::from("abc"))]).into(), + HashMap::from([("b".to_owned(), Ok(Value::from("abc")))]).into(), ), ReaderContext::from_diff( DataEventType::Insert, None, - HashMap::from([("a".to_owned(), Value::Int(2))]).into(), + HashMap::from([("a".to_owned(), Ok(Value::Int(2)))]).into(), ), ]; let expected = vec![ @@ -69,41 +67,45 @@ fn test_transparent_parser() -> eyre::Result<()> { #[test] fn test_transparent_parser_defaults() -> eyre::Result<()> { let value_field_names = vec!["a".to_owned(), "b".to_owned()]; - let schema = HashMap::from([ + let schema = [ ( "a".to_owned(), - InnerSchemaField::new(Type::Int, false, Some(Value::Int(10))), + InnerSchemaField::new(Type::Int, Some(Value::Int(10))), ), ( "b".to_owned(), - InnerSchemaField::new(Type::String, true, Some(Value::from("default"))), + InnerSchemaField::new( + Type::Optional(Type::String.into()), + Some(Value::from("default")), + ), ), - ]); - let mut parser = TransparentParser::new(None, value_field_names, schema, SessionType::Native); + ]; + let mut parser = + TransparentParser::new(None, value_field_names, schema.into(), SessionType::Native)?; let contexts = vec![ ReaderContext::from_diff( DataEventType::Insert, None, HashMap::from([ - ("a".to_owned(), Value::Int(3)), - ("b".to_owned(), Value::from("abc")), + ("a".to_owned(), Ok(Value::Int(3))), + ("b".to_owned(), Ok(Value::from("abc"))), ]) .into(), ), ReaderContext::from_diff( DataEventType::Insert, None, - HashMap::from([("b".to_owned(), Value::from("abc"))]).into(), + HashMap::from([("b".to_owned(), Ok(Value::from("abc")))]).into(), ), ReaderContext::from_diff( DataEventType::Insert, None, - HashMap::from([("a".to_owned(), Value::Int(2))]).into(), + HashMap::from([("a".to_owned(), Ok(Value::Int(2)))]).into(), ), ReaderContext::from_diff( DataEventType::Delete, None, - HashMap::from([("a".to_owned(), Value::Int(2))]).into(), + HashMap::from([("a".to_owned(), Ok(Value::Int(2)))]).into(), ), ]; let expected = vec![ @@ -129,24 +131,19 @@ fn test_transparent_parser_defaults() -> eyre::Result<()> { #[test] fn test_transparent_parser_upsert() -> eyre::Result<()> { let value_field_names = vec!["a".to_owned(), "b".to_owned()]; - let schema = HashMap::from([ - ( - "a".to_owned(), - InnerSchemaField::new(Type::Int, false, None), - ), - ( - "b".to_owned(), - InnerSchemaField::new(Type::String, false, None), - ), - ]); - let mut parser = TransparentParser::new(None, value_field_names, schema, SessionType::Upsert); + let schema = [ + ("a".to_owned(), InnerSchemaField::new(Type::Int, None)), + ("b".to_owned(), InnerSchemaField::new(Type::String, None)), + ]; + let mut parser = + TransparentParser::new(None, value_field_names, schema.into(), SessionType::Upsert)?; let contexts = vec![ ReaderContext::from_diff( DataEventType::Upsert, None, HashMap::from([ - ("a".to_owned(), Value::Int(3)), - ("b".to_owned(), Value::from("abc")), + ("a".to_owned(), Ok(Value::Int(3))), + ("b".to_owned(), Ok(Value::from("abc"))), ]) .into(), ), @@ -154,8 +151,8 @@ fn test_transparent_parser_upsert() -> eyre::Result<()> { DataEventType::Delete, None, HashMap::from([ - ("a".to_owned(), Value::Int(3)), - ("b".to_owned(), Value::from("abc")), + ("a".to_owned(), Ok(Value::Int(3))), + ("b".to_owned(), Ok(Value::from("abc"))), ]) .into(), ), diff --git a/tests/integration/test_seek.rs b/tests/integration/test_seek.rs index 907caa73..cfb49f0c 100644 --- a/tests/integration/test_seek.rs +++ b/tests/integration/test_seek.rs @@ -9,14 +9,14 @@ use std::sync::{Arc, Mutex}; use tempfile::tempdir; use pathway_engine::connectors::data_format::{ - DsvParser, DsvSettings, JsonLinesParser, ParsedEvent, Parser, + DsvParser, DsvSettings, InnerSchemaField, JsonLinesParser, ParsedEvent, Parser, }; use pathway_engine::connectors::data_storage::ReaderBuilder; use pathway_engine::connectors::data_storage::{ ConnectorMode, CsvFilesystemReader, FilesystemReader, ReadMethod, }; use pathway_engine::connectors::SessionType; -use pathway_engine::engine::Value; +use pathway_engine::engine::{Result, Type, Value}; use pathway_engine::persistence::tracker::WorkerPersistentStorage; enum TestedFormat { @@ -24,23 +24,30 @@ enum TestedFormat { Json, } -fn csv_reader_parser_pair(input_path: &str) -> (Box, Box) { +fn csv_reader_parser_pair(input_path: &str) -> Result<(Box, Box)> { let mut builder = csv::ReaderBuilder::new(); builder.has_headers(false); let reader = CsvFilesystemReader::new(input_path, builder, ConnectorMode::Static, Some(1), "*").unwrap(); + let schema = [ + ("key".to_string(), InnerSchemaField::new(Type::String, None)), + ( + "value".to_string(), + InnerSchemaField::new(Type::String, None), + ), + ]; let parser = DsvParser::new( DsvSettings::new( Some(vec!["key".to_string()]), vec!["value".to_string()], ',', ), - HashMap::new(), - ); - (Box::new(reader), Box::new(parser)) + schema.into(), + )?; + Ok((Box::new(reader), Box::new(parser))) } -fn json_reader_parser_pair(input_path: &str) -> (Box, Box) { +fn json_reader_parser_pair(input_path: &str) -> Result<(Box, Box)> { let reader = FilesystemReader::new( input_path, ConnectorMode::Static, @@ -49,27 +56,34 @@ fn json_reader_parser_pair(input_path: &str) -> (Box, Box>>, -) -> FullReadResult { +) -> Result { let (reader, mut parser) = match format { - TestedFormat::Csv => csv_reader_parser_pair(input_path.to_str().unwrap()), - TestedFormat::Json => json_reader_parser_pair(input_path.to_str().unwrap()), + TestedFormat::Csv => csv_reader_parser_pair(input_path.to_str().unwrap())?, + TestedFormat::Json => json_reader_parser_pair(input_path.to_str().unwrap())?, }; - full_cycle_read(reader, parser.as_mut(), persistent_storage) + Ok(full_cycle_read(reader, parser.as_mut(), persistent_storage)) } #[test] @@ -83,7 +97,7 @@ fn test_csv_file_recovery() -> eyre::Result<()> { std::fs::write(&input_path, "key,value\n1,2\na,b").unwrap(); { let tracker = create_persistence_manager(&pstorage_root_path, true); - let data_stream = full_cycle_read_kv(TestedFormat::Csv, &input_path, Some(&tracker)); + let data_stream = full_cycle_read_kv(TestedFormat::Csv, &input_path, Some(&tracker))?; assert_eq!( data_stream.new_parsed_entries, vec![ @@ -102,7 +116,7 @@ fn test_csv_file_recovery() -> eyre::Result<()> { std::fs::write(&input_path, "key,value\n1,2\na,b\nc,d\n55,66").unwrap(); { let tracker = create_persistence_manager(&pstorage_root_path, false); - let data_stream = full_cycle_read_kv(TestedFormat::Csv, &input_path, Some(&tracker)); + let data_stream = full_cycle_read_kv(TestedFormat::Csv, &input_path, Some(&tracker))?; eprintln!("data stream after: {:?}", data_stream.new_parsed_entries); assert_eq!( data_stream.new_parsed_entries, @@ -140,7 +154,7 @@ fn test_csv_dir_recovery() -> eyre::Result<()> { { let tracker = create_persistence_manager(&pstorage_root_path, true); - let data_stream = full_cycle_read_kv(TestedFormat::Csv, &inputs_dir_path, Some(&tracker)); + let data_stream = full_cycle_read_kv(TestedFormat::Csv, &inputs_dir_path, Some(&tracker))?; assert_eq!( data_stream.new_parsed_entries, vec![ @@ -176,7 +190,7 @@ fn test_csv_dir_recovery() -> eyre::Result<()> { .unwrap(); { let tracker = create_persistence_manager(&pstorage_root_path, false); - let data_stream = full_cycle_read_kv(TestedFormat::Csv, &inputs_dir_path, Some(&tracker)); + let data_stream = full_cycle_read_kv(TestedFormat::Csv, &inputs_dir_path, Some(&tracker))?; assert_eq!( data_stream.new_parsed_entries, vec![ParsedEvent::Insert(( @@ -205,7 +219,7 @@ fn test_json_file_recovery() -> eyre::Result<()> { .unwrap(); { let tracker = create_persistence_manager(&pstorage_root_path, true); - let data_stream = full_cycle_read_kv(TestedFormat::Json, &input_path, Some(&tracker)); + let data_stream = full_cycle_read_kv(TestedFormat::Json, &input_path, Some(&tracker))?; assert_eq!( data_stream.new_parsed_entries, vec![ @@ -224,7 +238,7 @@ fn test_json_file_recovery() -> eyre::Result<()> { .unwrap(); { let tracker = create_persistence_manager(&pstorage_root_path, false); - let data_stream = full_cycle_read_kv(TestedFormat::Json, &input_path, Some(&tracker)); + let data_stream = full_cycle_read_kv(TestedFormat::Json, &input_path, Some(&tracker))?; assert_eq!( data_stream.new_parsed_entries, vec![ParsedEvent::Insert(( @@ -260,7 +274,7 @@ fn test_json_folder_recovery() -> eyre::Result<()> { .unwrap(); { let tracker = create_persistence_manager(&pstorage_root_path, true); - let data_stream = full_cycle_read_kv(TestedFormat::Json, &inputs_dir_path, Some(&tracker)); + let data_stream = full_cycle_read_kv(TestedFormat::Json, &inputs_dir_path, Some(&tracker))?; assert_eq!( data_stream.new_parsed_entries, vec![ @@ -286,7 +300,7 @@ fn test_json_folder_recovery() -> eyre::Result<()> { .unwrap(); { let tracker = create_persistence_manager(&pstorage_root_path, false); - let data_stream = full_cycle_read_kv(TestedFormat::Json, &inputs_dir_path, Some(&tracker)); + let data_stream = full_cycle_read_kv(TestedFormat::Json, &inputs_dir_path, Some(&tracker))?; assert_eq!( data_stream.new_parsed_entries, vec![ @@ -323,7 +337,7 @@ fn test_json_recovery_from_empty_folder() -> eyre::Result<()> { .unwrap(); { let tracker = create_persistence_manager(&pstorage_root_path, true); - let data_stream = full_cycle_read_kv(TestedFormat::Json, &inputs_dir_path, Some(&tracker)); + let data_stream = full_cycle_read_kv(TestedFormat::Json, &inputs_dir_path, Some(&tracker))?; assert_eq!( data_stream.new_parsed_entries, vec![ @@ -347,7 +361,7 @@ fn test_json_recovery_from_empty_folder() -> eyre::Result<()> { .unwrap(); { let tracker = create_persistence_manager(&pstorage_root_path, false); - let data_stream = full_cycle_read_kv(TestedFormat::Json, &inputs_dir_path, Some(&tracker)); + let data_stream = full_cycle_read_kv(TestedFormat::Json, &inputs_dir_path, Some(&tracker))?; assert_eq!( data_stream.new_parsed_entries, vec![ diff --git a/tests/integration/test_sqlite.rs b/tests/integration/test_sqlite.rs index 9d9d888f..20d7fc36 100644 --- a/tests/integration/test_sqlite.rs +++ b/tests/integration/test_sqlite.rs @@ -3,6 +3,9 @@ use std::collections::HashMap; use std::sync::Arc; +use assert_matches::assert_matches; +use eyre::eyre; + use pathway_engine::connectors::data_format::InnerSchemaField; use pathway_engine::connectors::data_format::ParseError; use pathway_engine::connectors::data_format::TransparentParser; @@ -29,10 +32,10 @@ fn test_sqlite_read_table() -> eyre::Result<()> { SqliteOpenFlags::SQLITE_OPEN_READ_ONLY, )?; let value_field_names = vec![ - "id".to_string(), - "name".to_string(), - "price".to_string(), - "photo".to_string(), + ("id".to_string(), Type::Int), + ("name".to_string(), Type::String), + ("price".to_string(), Type::Float), + ("photo".to_string(), Type::Optional(Type::Bytes.into())), ]; let mut reader = SqliteReader::new(connection, "goods".to_string(), value_field_names); let mut read_results = Vec::new(); @@ -44,36 +47,16 @@ fn test_sqlite_read_table() -> eyre::Result<()> { break; } } - assert_eq!( - read_results, - vec![ + assert_matches!( + read_results.as_slice(), + [ ReadResult::NewSource(None), ReadResult::Data( - ReaderContext::Diff(( - DataEventType::Insert, - Some(vec![Value::Int(1)]), - HashMap::from([ - ("id".to_owned(), Value::Int(1)), - ("name".to_owned(), Value::String("Milk".into())), - ("price".to_owned(), Value::Float(1.1.into())), - ("photo".to_owned(), Value::None) - ]) - .into(), - )), + ReaderContext::Diff((DataEventType::Insert, Some(_), _)), EMPTY_OFFSET ), ReadResult::Data( - ReaderContext::Diff(( - DataEventType::Insert, - Some(vec![Value::Int(2)]), - HashMap::from([ - ("id".to_owned(), Value::Int(2)), - ("name".to_owned(), Value::String("Bread".into())), - ("price".to_owned(), Value::Float(0.75.into())), - ("photo".to_owned(), Value::Bytes(Arc::new([0, 0]))) - ]) - .into(), - )), + ReaderContext::Diff((DataEventType::Insert, Some(_), _,)), EMPTY_OFFSET ), ReadResult::FinishedSource { @@ -81,6 +64,45 @@ fn test_sqlite_read_table() -> eyre::Result<()> { } ] ); + read_results.pop().unwrap(); + let read_result_2 = read_results.pop().unwrap(); + let read_result_1 = read_results.pop().unwrap(); // pop().unwrap()s are safe because read_results matches the pattern above + if let ReadResult::Data( + ReaderContext::Diff((DataEventType::Insert, Some(key), values_map)), + EMPTY_OFFSET, + ) = read_result_1 + { + assert_eq!(key, vec![Value::Int(1)]); + assert_eq!( + values_map.to_pure_hashmap().map_err(|e| eyre!(e))?, + HashMap::from([ + ("id".to_owned(), Value::Int(1)), + ("name".to_owned(), Value::String("Milk".into())), + ("price".to_owned(), Value::Float(1.1.into())), + ("photo".to_owned(), Value::None) + ]) + ) + } else { + unreachable!(); //data in read_results[1] matches the structure above + } + if let ReadResult::Data( + ReaderContext::Diff((DataEventType::Insert, Some(key), values_map)), + EMPTY_OFFSET, + ) = read_result_2 + { + assert_eq!(key, vec![Value::Int(2)]); + assert_eq!( + values_map.to_pure_hashmap().map_err(|e| eyre!(e))?, + HashMap::from([ + ("id".to_owned(), Value::Int(2)), + ("name".to_owned(), Value::String("Bread".into())), + ("price".to_owned(), Value::Float(0.75.into())), + ("photo".to_owned(), Value::Bytes(Arc::new([0, 0]))) + ]) + ) + } else { + unreachable!(); //data in read_results[1] matches the structure above + } Ok(()) } @@ -90,32 +112,21 @@ fn test_sqlite_read_table_with_parser() -> eyre::Result<()> { "tests/data/sqlite/goods_test.db", SqliteOpenFlags::SQLITE_OPEN_READ_ONLY, )?; - let value_field_names = vec![ - "id".to_string(), - "name".to_string(), - "price".to_string(), - "photo".to_string(), + let schema = vec![ + ("id".to_string(), Type::Int), + ("name".to_string(), Type::String), + ("price".to_string(), Type::Float), + ("photo".to_string(), Type::Optional(Type::Bytes.into())), ]; - let schema = HashMap::from([ - ( - "id".to_owned(), - InnerSchemaField::new(Type::Int, false, None), - ), - ( - "name".to_owned(), - InnerSchemaField::new(Type::String, false, None), - ), - ( - "price".to_owned(), - InnerSchemaField::new(Type::Float, false, None), - ), - ( - "photo".to_owned(), - InnerSchemaField::new(Type::Bytes, true, None), - ), - ]); - let mut reader = SqliteReader::new(connection, "goods".to_string(), value_field_names.clone()); - let mut parser = TransparentParser::new(None, value_field_names, schema, SessionType::Native); + let value_field_names = schema.iter().map(|(name, _dtype)| name.clone()).collect(); + let schema_map = schema + .clone() + .into_iter() + .map(|(name, dtype)| (name, InnerSchemaField::new(dtype, None))) + .collect(); + let mut reader = SqliteReader::new(connection, "goods".to_string(), schema); + let mut parser = + TransparentParser::new(None, value_field_names, schema_map, SessionType::Native)?; let mut parsed_events: Vec = Vec::new(); loop { @@ -163,32 +174,20 @@ fn test_sqlite_read_table_nonparsable() -> eyre::Result<()> { "tests/data/sqlite/goods_test.db", SqliteOpenFlags::SQLITE_OPEN_READ_ONLY, )?; - let value_field_names = vec![ - "id".to_string(), - "name".to_string(), - "price".to_string(), - "photo".to_string(), + let schema = vec![ + ("id".to_string(), Type::Int), + ("name".to_string(), Type::String), + ("price".to_string(), Type::Float), + ("photo".to_string(), Type::Bytes), ]; - let schema = HashMap::from([ - ( - "id".to_owned(), - InnerSchemaField::new(Type::Int, false, None), - ), - ( - "name".to_owned(), - InnerSchemaField::new(Type::String, false, None), - ), - ( - "price".to_owned(), - InnerSchemaField::new(Type::Float, false, None), - ), - ( - "photo".to_owned(), - InnerSchemaField::new(Type::Bytes, false, None), - ), - ]); - let mut reader = SqliteReader::new(connection, "goods".to_string(), value_field_names.clone()); - let parser = TransparentParser::new(None, value_field_names, schema, SessionType::Native); + let value_field_names = schema.iter().map(|(name, _dtype)| name.clone()).collect(); + let schema_map = schema + .clone() + .into_iter() + .map(|(name, dtype)| (name, InnerSchemaField::new(dtype, None))) + .collect(); + let mut reader = SqliteReader::new(connection, "goods".to_string(), schema.clone()); + let parser = TransparentParser::new(None, value_field_names, schema_map, SessionType::Native)?; reader.read()?; let read_result = reader.read()?; @@ -196,17 +195,17 @@ fn test_sqlite_read_table_nonparsable() -> eyre::Result<()> { ReadResult::Data(context, _) => assert_error_shown_for_reader_context( &context, Box::new(parser), - r#"value None in field "photo" is inconsistent with type Bytes from schema"#, + r#"cannot create a field "photo" with type bytes from value None"#, ErrorPlacement::Value(3), ), _ => panic!("row_read_result is not Data"), } reader.read()?; - assert_eq!( + assert_matches!( reader.read()?, ReadResult::FinishedSource { commit_allowed: true, - }, + } ); Ok(()) } diff --git a/tests/integration/test_types.rs b/tests/integration/test_types.rs new file mode 100644 index 00000000..43625c26 --- /dev/null +++ b/tests/integration/test_types.rs @@ -0,0 +1,39 @@ +// Copyright © 2024 Pathway + +use pathway_engine::engine::Type; + +#[test] +fn test_type_display() { + assert_eq!(Type::Any.to_string(), "Any"); + assert_eq!(Type::Bool.to_string(), "bool"); + assert_eq!(Type::Int.to_string(), "int"); + assert_eq!(Type::Float.to_string(), "float"); + assert_eq!(Type::Pointer.to_string(), "Pointer"); + assert_eq!(Type::String.to_string(), "str"); + assert_eq!(Type::Bytes.to_string(), "bytes"); + assert_eq!(Type::DateTimeNaive.to_string(), "DateTimeNaive"); + assert_eq!(Type::DateTimeUtc.to_string(), "DateTimeUtc"); + assert_eq!(Type::Duration.to_string(), "Duration"); + assert_eq!( + Type::Array(Some(2), Type::Int.into()).to_string(), + "Array(2, int)" + ); + assert_eq!( + Type::Array(None, Type::Float.into()).to_string(), + "Array(float)" + ); + assert_eq!(Type::Json.to_string(), "Json"); + assert_eq!(Type::Tuple([].into()).to_string(), "tuple[]"); + assert_eq!( + Type::Tuple([Type::String, Type::Bytes].into()).to_string(), + "tuple[str, bytes]" + ); + assert_eq!(Type::List(Type::Bool.into()).to_string(), "list[bool]"); + assert_eq!(Type::List(Type::Int.into()).to_string(), "list[int]"); + assert_eq!(Type::PyObjectWrapper.to_string(), "PyObjectWrapper"); + assert_eq!( + Type::Optional(Type::Pointer.into()).to_string(), + "Pointer | None" + ); + assert_eq!(Type::Optional(Type::Int.into()).to_string(), "int | None"); +}