diff --git a/src/middlewared/middlewared/api/base/handler/accept.py b/src/middlewared/middlewared/api/base/handler/accept.py index 660c7f5be88ff..499393cbb06a9 100644 --- a/src/middlewared/middlewared/api/base/handler/accept.py +++ b/src/middlewared/middlewared/api/base/handler/accept.py @@ -1,19 +1,70 @@ from pydantic_core import ValidationError +from middlewared.api.base.model import BaseModel from middlewared.service_exception import CallError, ValidationErrors -def accept_params(model, args, *, exclude_unset=False, expose_secrets=True): +def accept_params(model: type[BaseModel], args: list, *, exclude_unset=False, expose_secrets=True) -> list: + """ + Accepts a list of `args` for a method call and validates it using `model`. + + Parameters are accepted in the order they are defined in the `model`. + + Returns the list of valid parameters (or raises `ValidationErrors`). + + :param model: `BaseModel` that defines method args. + :param args: a list of method args. + :param exclude_unset: if true, will not append default parameters to the list. + :param expose_secrets: if false, will replace `Private` parameters with a placeholder. + :return: a validated list of method args. + """ + args_as_dict = model_dict_from_list(model, args) + + dump = validate_model(model, args_as_dict, exclude_unset=exclude_unset, expose_secrets=expose_secrets) + + fields = list(model.model_fields) + if exclude_unset: + fields = fields[:len(args)] + + return [dump[field] for field in fields] + + +def model_dict_from_list(model: type[BaseModel], args: list) -> dict: + """ + Converts a list of `args` for a method call to a dictionary using `model`. + + Parameters are accepted in the order they are defined in the `model`. + + For example, given the model that has fields `b` and `a`, and `args` equal to `[1, 2]`, it will return + `{"b": 1, "a": 2"}`. + + :param model: `BaseModel` that defines method args. + :param args: a list of method args. + :return: a dictionary of method args. + """ if len(args) > len(model.model_fields): raise CallError(f"Too many arguments (expected {len(model.model_fields)}, found {len(args)})") - args_as_dict = { + return { field: value for field, value in zip(model.model_fields.keys(), args) } + +def validate_model(model: type[BaseModel], data: dict, *, exclude_unset=False, expose_secrets=True) -> dict: + """ + Validates `data` against the `model`, sanitizes values, sets defaults. + + Raises `ValidationErrors` if any validation errors occur. + + :param model: `BaseModel` subclass. + :param data: provided data. + :param exclude_unset: if true, will not add default values. + :param expose_secrets: if false, will replace `Private` fields with a placeholder. + :return: validated data. + """ try: - instance = model(**args_as_dict) + instance = model(**data) except ValidationError as e: verrors = ValidationErrors() for error in e.errors(): @@ -26,10 +77,4 @@ def accept_params(model, args, *, exclude_unset=False, expose_secrets=True): else: mode = "json" - dump = instance.model_dump(mode=mode, exclude_unset=exclude_unset, warnings=False) - - fields = list(model.model_fields) - if exclude_unset: - fields = fields[:len(args)] - - return [dump[field] for field in fields] + return instance.model_dump(mode=mode, exclude_unset=exclude_unset, warnings=False) diff --git a/src/middlewared/middlewared/api/base/handler/dump_params.py b/src/middlewared/middlewared/api/base/handler/dump_params.py index 1fb214d37c078..f66f99aedc1d3 100644 --- a/src/middlewared/middlewared/api/base/handler/dump_params.py +++ b/src/middlewared/middlewared/api/base/handler/dump_params.py @@ -4,11 +4,20 @@ from middlewared.api.base import BaseModel, Private, PRIVATE_VALUE from middlewared.service_exception import ValidationErrors from .accept import accept_params +from .inspect import model_field_is_model, model_field_is_list_of_models __all__ = ["dump_params"] -def dump_params(model, args, expose_secrets): +def dump_params(model: type[BaseModel], args: list, expose_secrets: bool) -> list: + """ + Dumps a list of `args` for a method call that accepts `model` parameters. + + :param model: `BaseModel` that defines method args. + :param args: a list of method args. + :param expose_secrets: if false, will replace `Private` parameters with a placeholder. + :return: A list of method call arguments ready to be printed. + """ try: return accept_params(model, args, exclude_unset=True, expose_secrets=expose_secrets) except ValidationErrors: @@ -19,15 +28,21 @@ def dump_params(model, args, expose_secrets): ] -def remove_secrets(model, value): - if isinstance(model, type) and issubclass(model, BaseModel) and isinstance(value, dict): +def remove_secrets(model: type[BaseModel], value): + """ + Removes `Private` values from a model value. + :param model: `BaseModel` that corresponds to `value`. + :param value: value that potentially contains `Private` data. + :return: `value` with `Private` parameters replaced with a placeholder. + """ + if isinstance(value, dict) and (nested_model := model_field_is_model(model)): return { k: remove_secrets(v.annotation, value[k]) - for k, v in model.model_fields.items() + for k, v in nested_model.model_fields.items() if k in value } - elif typing.get_origin(model) is list and len(args := typing.get_args(model)) == 1 and isinstance(value, list): - return [remove_secrets(args[0], v) for v in value] + elif isinstance(value, list) and (nested_model := model_field_is_list_of_models(model)): + return [remove_secrets(nested_model, v) for v in value] elif typing.get_origin(model) is Private: return PRIVATE_VALUE else: diff --git a/src/middlewared/middlewared/api/base/handler/inspect.py b/src/middlewared/middlewared/api/base/handler/inspect.py new file mode 100644 index 0000000000000..ced3df3199456 --- /dev/null +++ b/src/middlewared/middlewared/api/base/handler/inspect.py @@ -0,0 +1,23 @@ +import typing + +from middlewared.api.base import BaseModel + + +def model_field_is_model(model) -> type[BaseModel] | None: + """ + Return` model` if it is an API model. Otherwise, returns `None`. + :param model: potentially, API model. + :return: `model` or `None` + """ + if isinstance(model, type) and issubclass(model, BaseModel): + return model + + +def model_field_is_list_of_models(model) -> type[BaseModel] | None: + """ + If` model` represents a list of API models X, then it will return that model X. Otherwise, returns `None`. + :param model: potentially, a model that represents a list of API models. + :return: nested API model or `None` + """ + if typing.get_origin(model) is list and len(args := typing.get_args(model)) == 1: + return args[0] diff --git a/src/middlewared/middlewared/api/base/handler/version.py b/src/middlewared/middlewared/api/base/handler/version.py new file mode 100644 index 0000000000000..7bbbdddfd22ef --- /dev/null +++ b/src/middlewared/middlewared/api/base/handler/version.py @@ -0,0 +1,172 @@ +import enum +from types import ModuleType + +from middlewared.api.base import BaseModel, ForUpdateMetaclass +from .accept import validate_model +from .inspect import model_field_is_model, model_field_is_list_of_models + + +class Direction(enum.StrEnum): + DOWNGRADE = "DOWNGRADE" + UPGRADE = "UPGRADE" + + +class APIVersionDoesNotExistException(Exception): + def __init__(self, version: str): + self.version = version + super().__init__(f"API Version {self.version!r} does not exist") + + +class APIVersionDoesNotContainModelException(Exception): + def __init__(self, version: str, model_name: str): + self.version = version + self.model_name = model_name + super().__init__(f"API version {version!r} does not contain model {model_name!r}") + + +class APIVersion: + def __init__(self, version: str, models: dict[str, type[BaseModel]]): + """ + :param version: API version name + :param models: a dictionary which keys are model names and values are models used in the API version + """ + self.version: str = version + self.models: dict[str, type[BaseModel]] = models + + @classmethod + def from_module(cls, version: str, module: ModuleType) -> "APIVersion": + """ + Create `APIVersion` from a module (e.g. `middlewared.api.v25_04_0`). + :param version: API version name + :param module: module object + :return: `APIVersion` instance + """ + return cls( + version, + { + model_name: model + for model_name, model in [ + (model_name, getattr(module, model_name)) + for model_name in dir(module) + ] + if isinstance(model, type) and issubclass(model, BaseModel) + }, + ) + + def __repr__(self): + return f"" + + +class APIVersionsAdapter: + """ + Converts method parameters and return results between different API versions. + """ + + def __init__(self, versions: list[APIVersion]): + """ + :param versions: A chronologically sorted list of API versions. + """ + self.versions: dict[str, APIVersion] = {version.version: version for version in versions} + self.versions_history: list[str] = list(self.versions.keys()) + self.current_version: str = self.versions_history[-1] + + def adapt(self, value: dict, model_name: str, version1: str, version2: str) -> dict: + """ + Adapts `value` (that matches a model identified by `model_name`) from API `version1` to API `version2`). + + :param value: a value to convert + :param model_name: a name of the model. Must exist in all API versions, including intermediate ones, or + `APIVersionDoesNotContainModelException` will be raised. + :param version1: original API version from which the `value` comes from + :param version2: target API version that needs `value` + :return: converted value + """ + try: + version1_index = self.versions_history.index(version1) + except ValueError: + raise APIVersionDoesNotExistException(version1) from None + + try: + version2_index = self.versions_history.index(version2) + except ValueError: + raise APIVersionDoesNotExistException(version2) from None + + current_version = self.versions[version1] + try: + current_version_model = current_version.models[model_name] + except KeyError: + raise APIVersionDoesNotContainModelException(current_version.version, model_name) + value = validate_model(current_version_model, value) + + if version1_index < version2_index: + step = 1 + direction = Direction.UPGRADE + else: + step = -1 + direction = Direction.DOWNGRADE + + for version_index in range(version1_index + step, version2_index + step, step): + new_version = self.versions[self.versions_history[version_index]] + + value = self._adapt_model(value, model_name, current_version, new_version, direction) + + current_version = new_version + + return value + + def _adapt_model(self, value: dict, model_name: str, current_version: APIVersion, new_version: APIVersion, + direction: Direction): + try: + current_model = current_version.models[model_name] + except KeyError: + raise APIVersionDoesNotContainModelException(current_version.version, model_name) from None + + try: + new_model = new_version.models[model_name] + except KeyError: + raise APIVersionDoesNotContainModelException(new_version.version, model_name) from None + + return self._adapt_value(value, current_model, new_model, direction) + + def _adapt_value(self, value: dict, current_model: type[BaseModel], new_model: type[BaseModel], + direction: Direction): + for k in value: + if k in current_model.model_fields and k in new_model.model_fields: + current_model_field = current_model.model_fields[k].annotation + new_model_field = new_model.model_fields[k].annotation + if ( + isinstance(value[k], dict) and + (current_nested_model := model_field_is_model(current_model_field)) and + (new_nested_model := model_field_is_model(new_model_field)) and + current_nested_model.__class__.__name__ == new_nested_model.__class__.__name__ + ): + value[k] = self._adapt_value(value[k], current_nested_model, new_nested_model, direction) + elif ( + isinstance(value[k], list) and + (current_nested_model := model_field_is_list_of_models(current_model_field)) and + (current_nested_model := model_field_is_model(current_nested_model)) and + (new_nested_model := model_field_is_list_of_models(new_model_field)) and + (new_nested_model := model_field_is_model(new_nested_model)) and + current_nested_model.__class__.__name__ == new_nested_model.__class__.__name__ + ): + value[k] = [ + self._adapt_value(v, current_nested_model, new_nested_model, direction) + for v in value[k] + ] + + if new_model.__class__ is not ForUpdateMetaclass: + for k, field in new_model.model_fields.items(): + if k not in value and not field.is_required(): + value[k] = field.get_default() + + match direction: + case Direction.DOWNGRADE: + value = current_model.to_previous(value) + case Direction.UPGRADE: + value = new_model.from_previous(value) + + for k in list(value): + if k in current_model.model_fields and k not in new_model.model_fields: + value.pop(k) + + return value diff --git a/src/middlewared/middlewared/api/base/model.py b/src/middlewared/middlewared/api/base/model.py index 0b51b443c6f14..2645195924df8 100644 --- a/src/middlewared/middlewared/api/base/model.py +++ b/src/middlewared/middlewared/api/base/model.py @@ -34,6 +34,23 @@ def dump(t): "cannot be a member of an Optional or a Union, please make the whole field Private." ) + @classmethod + def from_previous(cls, value): + """ + Converts model value from a preceding API version to this API version. `value` can be modified in-place. + :param value: value of the same model in the preceding API version. + :return: value in this API version. + """ + return value + + @classmethod + def to_previous(cls, value): + """ + Converts model value from this API version to a preceding API version. `value` can be modified in-place. + :param value: value in this API version. + :return: value of the same model in the preceding API version. + """ + return value class ForUpdateMetaclass(ModelMetaclass): @@ -76,19 +93,41 @@ def _field_for_update(field): return new.annotation, new -def single_argument_args(name): +def single_argument_args(name: str): + """ + Model class decorator used to define an arguments model for a method that accepts a single dictionary argument. + + :param name: name for that single argument. + :return: a model class that consists of unique `name` field that is represented by a class being decorated. + Class name will be preserved. + """ def wrapper(klass): - return create_model( + model = create_model( klass.__name__, __base__=(BaseModel,), __module__=klass.__module__, **{name: Annotated[klass, Field()]}, ) + model.from_previous = classmethod(klass.from_previous) + model.to_previous = classmethod(klass.to_previous) + return model return wrapper def single_argument_result(klass, klass_name=None): + """ + Can be used as: + * Decorator for a class. In that case, it will create a class that represents a return value for a function that + returns a single dictionary, represented by the decorated class. + * Standalone model generator. Will return a model class named `klass_name` that consists of a single field + represented by `klass` (in that case, `klass` can be a primitive type). + + :param klass: class or a primitive type to create model from. + :param klass_name: required, when being called as a standalone model generator. Returned class will have that name. + (otherwise, the decorated class name will be preserved). + :return: a model class that consists of unique `result` field that corresponds to `klass`. + """ if klass is None: klass = NoneType @@ -98,9 +137,13 @@ def single_argument_result(klass, klass_name=None): else: klass_name = klass_name or klass.__name__ - return create_model( + model = create_model( klass_name, __base__=(BaseModel,), __module__=inspect.getmodule(inspect.stack()[1][0]), **{"result": Annotated[klass, Field()]}, ) + if issubclass(klass, BaseModel): + model.from_previous = classmethod(klass.from_previous) + model.to_previous = classmethod(klass.to_previous) + return model diff --git a/src/middlewared/middlewared/api/base/server/legacy_api_method.py b/src/middlewared/middlewared/api/base/server/legacy_api_method.py new file mode 100644 index 0000000000000..e4a3d691ff1a3 --- /dev/null +++ b/src/middlewared/middlewared/api/base/server/legacy_api_method.py @@ -0,0 +1,79 @@ +from typing import TYPE_CHECKING + +from middlewared.api.base.handler.accept import model_dict_from_list +from middlewared.api.base.handler.dump_params import dump_params +from middlewared.api.base.handler.version import APIVersionsAdapter +from middlewared.api.base.server.method import Method +from middlewared.utils.service.call import MethodNotFoundError +from middlewared.utils.service.crud import real_crud_method + +if TYPE_CHECKING: + from middlewared.api.base.server.ws_handler.rpc import RpcWebSocketApp + from middlewared.main import Middleware + + +class LegacyAPIMethod(Method): + """ + Represents a middleware legacy API method used in JSON-RPC server. Converts method parameters and return value + between most recent API version (used in the code) and predetermined legacy API version. + """ + + def __init__(self, middleware: "Middleware", name: str, api_version: str, adapter: APIVersionsAdapter): + """ + :param middleware: `Middleware` instance + :param name: method name + :param api_version: API version name used to convert parameters and return value + :param adapter: `APIVersionsAdapter` instance + """ + super().__init__(middleware, name) + self.api_version = api_version + self.adapter = adapter + + methodobj = self.methodobj + if crud_methodobj := real_crud_method(methodobj): + methodobj = crud_methodobj + if hasattr(methodobj, "new_style_accepts"): + self.accepts_model = methodobj.new_style_accepts + self.returns_model = methodobj.new_style_returns + else: + self.accepts_model = None + self.returns_model = None + + async def call(self, app: "RpcWebSocketApp", params): + if self.accepts_model: + return self._adapt_result(await super().call(app, self._adapt_params(params))) + + return await super().call(app, params) + + def _adapt_params(self, params): + try: + legacy_accepts_model = self.adapter.versions[self.api_version].models[self.accepts_model.__name__] + except KeyError: + # The legacy API does not contain signature definition for this method, which means it didn't exist + # when that API was released. + raise MethodNotFoundError(*self.name.rsplit(".", 1)) + + params_dict = model_dict_from_list(legacy_accepts_model, params) + + adapted_params_dict = self.adapter.adapt( + params_dict, + legacy_accepts_model.__name__, + self.api_version, + self.adapter.current_version, + ) + + return [adapted_params_dict[field] for field in self.accepts_model.model_fields] + + def _adapt_result(self, result): + return self.adapter.adapt( + {"result": result}, + self.returns_model.__name__, + self.adapter.current_version, + self.api_version, + )["result"] + + def dump_args(self, params): + if self.accepts_model: + return dump_params(self.accepts_model, params, False) + + return super().dump_args(params) diff --git a/src/middlewared/middlewared/api/base/server/method.py b/src/middlewared/middlewared/api/base/server/method.py index fa02b1a9ba66c..ebdd9b92b0ecc 100644 --- a/src/middlewared/middlewared/api/base/server/method.py +++ b/src/middlewared/middlewared/api/base/server/method.py @@ -1,22 +1,43 @@ import types +from typing import TYPE_CHECKING from middlewared.job import Job +if TYPE_CHECKING: + from middlewared.api.base.server.ws_handler.rpc import RpcWebSocketApp + from middlewared.main import Middleware + class Method: + """ + Represents a middleware API method used in JSON-RPC server. + """ + def __init__(self, middleware: "Middleware", name: str): + """ + :param middleware: `Middleware` instance + :param name: method name + """ self.middleware = middleware self.name = name + self.serviceobj, self.methodobj = self.middleware.get_method(self.name) - async def call(self, app: "RpcWebSocketApp", params): - serviceobj, methodobj = self.middleware.get_method(self.name) + async def call(self, app: "RpcWebSocketApp", params: list): + """ + Calls the method in the context of a given `app`. + + :param app: `RpcWebSocketApp` instance. + :param params: method arguments. + :return: method return value. + """ + methodobj = self.methodobj await self.middleware.authorize_method_call(app, self.name, methodobj, params) if mock := self.middleware._mock_method(self.name, params): methodobj = mock - result = await self.middleware.call_with_audit(self.name, serviceobj, methodobj, params, app) + result = await self.middleware.call_with_audit(self.name, self.serviceobj, methodobj, params, app) if isinstance(result, Job): result = result.id elif isinstance(result, types.GeneratorType): @@ -26,5 +47,11 @@ async def call(self, app: "RpcWebSocketApp", params): return result - def dump_args(self, params): + def dump_args(self, params: list) -> list: + """ + Dumps the method call params (i.e., removes secrets). + + :param params: method call arguments. + :return: dumped method call arguments. + """ return self.middleware.dump_args(params, method_name=self.name) diff --git a/src/middlewared/middlewared/api/base/server/ws_handler/rpc_factory.py b/src/middlewared/middlewared/api/base/server/ws_handler/rpc_factory.py index c0390794b20fb..71f1e325dc033 100644 --- a/src/middlewared/middlewared/api/base/server/ws_handler/rpc_factory.py +++ b/src/middlewared/middlewared/api/base/server/ws_handler/rpc_factory.py @@ -1,8 +1,20 @@ +from typing import Callable, TYPE_CHECKING + from ..method import Method from .rpc import RpcWebSocketHandler +if TYPE_CHECKING: + from middlewared.main import Middleware + -def create_rpc_ws_handler(middleware: "Middleware"): +def create_rpc_ws_handler(middleware: "Middleware", method_factory: Callable[["Middleware", str], Method]): + """ + Creates a `RpcWebSocketHandler` instance. + :param middleware: `Middleware` instance. + :param method_factory: a callable that creates `Method` instance. Will be called for each discovered middleware + method. + :return: `RpcWebSocketHandler` instance. + """ methods = {} for service_name, service in middleware.get_services().items(): for attribute in dir(service): @@ -14,6 +26,6 @@ def create_rpc_ws_handler(middleware: "Middleware"): method_name = f"{service_name}.{attribute}" - methods[method_name] = Method(middleware, method_name) + methods[method_name] = method_factory(middleware, method_name) return RpcWebSocketHandler(middleware, methods) diff --git a/src/middlewared/middlewared/main.py b/src/middlewared/middlewared/main.py index 5f1a8c9cd5fef..8743fe8f262df 100644 --- a/src/middlewared/middlewared/main.py +++ b/src/middlewared/middlewared/main.py @@ -1,5 +1,8 @@ from .api.base.handler.dump_params import dump_params from .api.base.handler.result import serialize_result +from .api.base.handler.version import APIVersion, APIVersionsAdapter +from .api.base.server.legacy_api_method import LegacyAPIMethod +from .api.base.server.method import Method from .api.base.server.ws_handler.base import BaseWebSocketHandler from .api.base.server.ws_handler.rpc import RpcWebSocketApp, RpcWebSocketAppEvent from .api.base.server.ws_handler.rpc_factory import create_rpc_ws_handler @@ -29,6 +32,7 @@ from .utils.profile import profile_wrap from .utils.rate_limit.cache import RateLimitCache from .utils.service.call import ServiceCallMixin +from .utils.service.crud import real_crud_method from .utils.syslog import syslog_message from .utils.threading import set_thread_name, IoThreadPoolExecutor, io_thread_pool_executor from .utils.time_utils import utc_now @@ -53,10 +57,12 @@ import errno import fcntl import functools +import importlib import inspect import itertools import multiprocessing import os +import pathlib import pickle import re import queue @@ -99,13 +105,6 @@ class LoopMonitorIgnoreFrame: cut_below: bool = False -def real_crud_method(method): - if method.__name__ in ['create', 'update', 'delete'] and hasattr(method, '__self__'): - child_method = getattr(method.__self__, f'do_{method.__name__}', None) - if child_method is not None: - return child_method - - class Application(RpcWebSocketApp): def __init__( self, @@ -892,13 +891,27 @@ def create_task(self, coro, *, name=None): task.add_done_callback(self.tasks.discard) return task + def _load_api_versions(self): + versions = [] + api_dir = os.path.join(os.path.dirname(__file__), 'api') + for version_dir in sorted(pathlib.Path(api_dir).iterdir()): + if version_dir.name.startswith('v') and version_dir.is_dir(): + version = version_dir.name.replace('_', '.') + self._console_write(f'loading API version {version}') + versions.append( + APIVersion.from_module( + version, + importlib.import_module(f'middlewared.api.{version_dir.name}'), + ), + ) + return versions + def __init_services(self): from middlewared.service import CoreService self.add_service(CoreService(self)) self.event_register('core.environ', 'Send on middleware process environment changes.', private=True) - async def __plugins_load(self): - + def __plugins_load(self): setup_funcs = [] def on_module_begin(mod): @@ -1999,7 +2012,6 @@ def _loop_monitor_thread(self): last = current def run(self): - self._console_write('starting') set_thread_name('asyncio_loop') @@ -2031,9 +2043,12 @@ async def __initialize(self): ], loop=self.loop) self.app['middleware'] = self + api_versions = self._load_api_versions() + api_versions_adapter = APIVersionsAdapter(api_versions) + # Needs to happen after setting debug or may cause race condition # http://bugs.python.org/issue30805 - setup_funcs = await self.__plugins_load() + setup_funcs = self.__plugins_load() self._console_write('registering services') @@ -2049,9 +2064,23 @@ async def __initialize(self): self.loop.add_signal_handler(signal.SIGUSR1, self.pdb) self.loop.add_signal_handler(signal.SIGUSR2, self.log_threads_stacks) - rpc_ws_handler = create_rpc_ws_handler(self) - app.router.add_route('GET', '/api/current', rpc_ws_handler) - app.router.add_route('GET', '/api/v25.04.0', rpc_ws_handler) + current_rpc_ws_handler = create_rpc_ws_handler(self, Method) + app.router.add_route('GET', '/api/current', current_rpc_ws_handler) + app.router.add_route('GET', f'/api/{api_versions[-1].version}', current_rpc_ws_handler) + for version in api_versions[:-1]: + app.router.add_route( + 'GET', + f'/api/{version.version}', + create_rpc_ws_handler( + self, + lambda middleware, method_name: LegacyAPIMethod( + middleware, + method_name, + version.version, + api_versions_adapter, + ) + ), + ) app.router.add_route('GET', '/websocket', self.ws_handler) diff --git a/src/middlewared/middlewared/pytest/unit/api/base/server/__init__.py b/src/middlewared/middlewared/pytest/unit/api/base/server/__init__.py new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/src/middlewared/middlewared/pytest/unit/api/base/server/test_legacy_api_method.py b/src/middlewared/middlewared/pytest/unit/api/base/server/test_legacy_api_method.py new file mode 100644 index 0000000000000..1873d15eafcad --- /dev/null +++ b/src/middlewared/middlewared/pytest/unit/api/base/server/test_legacy_api_method.py @@ -0,0 +1,65 @@ +from unittest.mock import Mock + +from middlewared.api.base import BaseModel +from middlewared.api.base.decorator import api_method +from middlewared.api.base.handler.version import APIVersion, APIVersionsAdapter +from middlewared.api.base.server.legacy_api_method import LegacyAPIMethod + + +class MethodArgs(BaseModel): + number: int + multiplier: int = 2 + + +class MethodResult(BaseModel): + result: str + + +MethodArgsV1 = MethodArgs +MethodResultV1 = MethodResult + + +class MethodArgs(BaseModel): + number: int + text: str = "Default" + multiplier: int = 2 + + +class MethodResult(BaseModel): + result: int + + @classmethod + def to_previous(cls, value): + value["result"] = str(value["result"]) + + return value + + +@api_method(MethodArgs, MethodResult) +def method(number, text, multiplier): + return { + "number": number * multiplier, + "text": text * multiplier, + } + + +adapter = APIVersionsAdapter([ + APIVersion("v1", {"MethodArgs": MethodArgsV1, "MethodResult": MethodResultV1}), + APIVersion("v2", {"MethodArgs": MethodArgs, "MethodResult": MethodResult}), +]) +legacy_api_method = LegacyAPIMethod( + Mock( + get_method=Mock(return_value=(Mock(), method)) + ), + "core.test", + "v1", + adapter, +) + + +def test_adapt_params(): + assert legacy_api_method._adapt_params([1]) == [1, "Default", 2] + + +def test_adapt_result(): + assert legacy_api_method._adapt_result(1) == "1" diff --git a/src/middlewared/middlewared/pytest/unit/api/handler/version/__init__.py b/src/middlewared/middlewared/pytest/unit/api/handler/version/__init__.py new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/src/middlewared/middlewared/pytest/unit/api/handler/version/test_adapt.py b/src/middlewared/middlewared/pytest/unit/api/handler/version/test_adapt.py new file mode 100644 index 0000000000000..cd7c2089a13b5 --- /dev/null +++ b/src/middlewared/middlewared/pytest/unit/api/handler/version/test_adapt.py @@ -0,0 +1,69 @@ +from pydantic import EmailStr +import pytest + +from middlewared.api.base import BaseModel +from middlewared.api.base.handler.version import APIVersion, APIVersionsAdapter + + +class SettingsV1(BaseModel): + email: EmailStr | None = None + + +class SettingsV2(BaseModel): + emails: list[EmailStr] + + @classmethod + def from_previous(cls, value): + email = value.pop("email") + + if email is None: + value["emails"] = [] + else: + value["emails"] = [email] + + return value + + @classmethod + def to_previous(cls, value): + emails = value.pop("emails") + + if emails: + value["email"] = emails[0] + else: + value["email"] = None + + return value + + +class SettingsV3(BaseModel): + contacts: list[dict] + + @classmethod + def from_previous(cls, value): + emails = value.pop("emails") + + value["contacts"] = [{"name": email.split("@")[0].title(), "email": email} + for email in emails] + + return value + + @classmethod + def to_previous(cls, value): + contacts = value.pop("contacts") + + value["emails"] = [contact["email"] for contact in contacts] + + return value + + +@pytest.mark.parametrize("version1,value,version2,result", [ + ("v1", {"email": "alice@ixsystems.com"}, "v3", {"contacts": [{"name": "Alice", "email": "alice@ixsystems.com"}]}), + ("v3", {"contacts": [{"name": "Alice", "email": "alice@ixsystems.com"}]}, "v1", {"email": "alice@ixsystems.com"}), +]) +def test_adapt(version1, value, version2, result): + adapter = APIVersionsAdapter([ + APIVersion("v1", {"Settings": SettingsV1}), + APIVersion("v2", {"Settings": SettingsV2}), + APIVersion("v3", {"Settings": SettingsV3}), + ]) + assert adapter.adapt(value, "Settings", version1, version2) == result diff --git a/src/middlewared/middlewared/pytest/unit/api/handler/version/test_adapt_default.py b/src/middlewared/middlewared/pytest/unit/api/handler/version/test_adapt_default.py new file mode 100644 index 0000000000000..a42a272e09998 --- /dev/null +++ b/src/middlewared/middlewared/pytest/unit/api/handler/version/test_adapt_default.py @@ -0,0 +1,25 @@ +import pytest + +from middlewared.api.base import BaseModel +from middlewared.api.base.handler.version import APIVersion, APIVersionsAdapter + + +class SettingsV1(BaseModel): + text1: str + + +class SettingsV2(BaseModel): + text1: str + text2: str = "text2" + + +@pytest.mark.parametrize("version1,value,version2,result", [ + ("v1", {"text1": "text1"}, "v2", {"text1": "text1", "text2": "text2"}), + ("v2", {"text1": "text1", "text2": "text2"}, "v1", {"text1": "text1"}), +]) +def test_adapt(version1, value, version2, result): + adapter = APIVersionsAdapter([ + APIVersion("v1", {"Settings": SettingsV1}), + APIVersion("v2", {"Settings": SettingsV2}), + ]) + assert adapter.adapt(value, "Settings", version1, version2) == result diff --git a/src/middlewared/middlewared/pytest/unit/api/handler/version/test_adapt_nested_model.py b/src/middlewared/middlewared/pytest/unit/api/handler/version/test_adapt_nested_model.py new file mode 100644 index 0000000000000..4957d3a5e90ec --- /dev/null +++ b/src/middlewared/middlewared/pytest/unit/api/handler/version/test_adapt_nested_model.py @@ -0,0 +1,84 @@ +from pydantic import EmailStr +import pytest + +from middlewared.api.base import BaseModel +from middlewared.api.base.handler.version import APIVersion, APIVersionsAdapter + + +class Settings(BaseModel): + email: EmailStr | None = None + + +class UpdateSettingsArgsV1(BaseModel): + settings: Settings + + +class Settings(BaseModel): + emails: list[EmailStr] + + @classmethod + def from_previous(cls, value): + email = value.pop("email") + + if email is None: + value["emails"] = [] + else: + value["emails"] = [email] + + return value + + @classmethod + def to_previous(cls, value): + emails = value.pop("emails") + + if emails: + value["email"] = emails[0] + else: + value["email"] = None + + return value + + +class UpdateSettingsArgsV2(BaseModel): + settings: Settings + + +class Settings(BaseModel): + contacts: list[dict] + + @classmethod + def from_previous(cls, value): + emails = value.pop("emails") + + value["contacts"] = [{"name": email.split("@")[0].title(), "email": email} + for email in emails] + + return value + + @classmethod + def to_previous(cls, value): + contacts = value.pop("contacts") + + value["emails"] = [contact["email"] for contact in contacts] + + return value + + +class UpdateSettingsArgsV3(BaseModel): + settings: Settings + + +@pytest.mark.parametrize("version1,value,version2,result", [ + ("v1", {"settings": {"email": "alice@ixsystems.com"}}, + "v3", {"settings": {"contacts": [{"name": "Alice", "email": "alice@ixsystems.com"}]}}), + ("v3", {"settings": {"contacts": [{"name": "Alice", "email": "alice@ixsystems.com"}]}}, + "v1", {"settings": {"email": "alice@ixsystems.com"}}), +]) +def test_adapt(version1, value, version2, result): + adapter = APIVersionsAdapter([ + APIVersion("v1", {"UpdateSettingsArgs": UpdateSettingsArgsV1}), + APIVersion("v2", {"UpdateSettingsArgs": UpdateSettingsArgsV2}), + APIVersion("v3", {"UpdateSettingsArgs": UpdateSettingsArgsV3}), + ]) + + assert adapter.adapt(value, "UpdateSettingsArgs", version1, version2) == result diff --git a/src/middlewared/middlewared/pytest/unit/api/handler/version/test_adapt_nested_model_list.py b/src/middlewared/middlewared/pytest/unit/api/handler/version/test_adapt_nested_model_list.py new file mode 100644 index 0000000000000..c38b8e7ca05df --- /dev/null +++ b/src/middlewared/middlewared/pytest/unit/api/handler/version/test_adapt_nested_model_list.py @@ -0,0 +1,54 @@ +import pytest + +from middlewared.api.base import BaseModel +from middlewared.api.base.handler.version import APIVersion, APIVersionsAdapter + + +class Contact(BaseModel): + name: str + email: str + + +class SettingsV1(BaseModel): + contacts: list[Contact] + + +class Contact(BaseModel): + first_name: str + last_name: str + email: str + + @classmethod + def from_previous(cls, value): + if " " in value["name"]: + value["first_name"], value["last_name"] = value.pop("name").split(" ", 1) + else: + value["first_name"] = value.pop("name") + value["last_name"] = "" + + return value + + @classmethod + def to_previous(cls, value): + value["name"] = f"{value.pop('first_name')} {value.pop('last_name')}" + + return value + + +class SettingsV2(BaseModel): + contacts: list[Contact] + + +@pytest.mark.parametrize("version1,value,version2,result", [ + ("v1", {"contacts": [{"name": "Jane Doe", "email": "jane@ixsystems.com"}]}, + "v2", {"contacts": [{"first_name": "Jane", "last_name": "Doe", "email": "jane@ixsystems.com"}]}), + ("v2", {"contacts": [{"first_name": "Jane", "last_name": "Doe", "email": "jane@ixsystems.com"}]}, + "v1", {"contacts": [{"name": "Jane Doe", "email": "jane@ixsystems.com"}]}), +]) +def test_adapt(version1, value, version2, result): + adapter = APIVersionsAdapter([ + APIVersion("v1", {"Settings": SettingsV1}), + APIVersion("v2", {"Settings": SettingsV2}), + ]) + + assert adapter.adapt(value, "Settings", version1, version2) == result diff --git a/src/middlewared/middlewared/pytest/unit/api/handler/version/test_adapt_shortcuts.py b/src/middlewared/middlewared/pytest/unit/api/handler/version/test_adapt_shortcuts.py new file mode 100644 index 0000000000000..32bc0b25d8867 --- /dev/null +++ b/src/middlewared/middlewared/pytest/unit/api/handler/version/test_adapt_shortcuts.py @@ -0,0 +1,78 @@ +import pytest + +from middlewared.api.base import BaseModel, ForUpdateMetaclass, single_argument_args, single_argument_result +from middlewared.api.base.handler.version import APIVersion, APIVersionsAdapter + + +class ModelV1(BaseModel, metaclass=ForUpdateMetaclass): + number: int = 1 + + +class ModelV2(BaseModel, metaclass=ForUpdateMetaclass): + number: int = 1 + text: str = "1" + + +def test_adapt_for_update_metaclass(): + adapter = APIVersionsAdapter([ + APIVersion("v1", {"Model": ModelV1}), + APIVersion("v2", {"Model": ModelV2}), + ]) + assert adapter.adapt({}, "Model", "v1", "v2") == {} + + +class ArgsV1(BaseModel): + count: int + force: bool = False + + +@single_argument_args("options") +class ArgsV2(BaseModel): + count: int + exclude: list[str] = [] + force: bool = False + + def from_previous(cls, value): + return { + "options": value, + } + + +@pytest.mark.parametrize("version1,value,version2,result", [ + ("v1", {"count": 1}, "v2", {"options": {"count": 1, "force": False}}), +]) +def test_adapt_single_argument_args(version1, value, version2, result): + adapter = APIVersionsAdapter([ + APIVersion("v1", {"Args": ArgsV1}), + APIVersion("v2", {"Args": ArgsV2}), + ]) + assert adapter.adapt(value, "Args", version1, version2) == result + + +class ResultV1(BaseModel): + result: int + + +@single_argument_result +class ResultV2(BaseModel): + value: int + status: str + + def from_previous(cls, value): + return { + "result": { + "value": value["result"], + "status": "OK", + }, + } + + +@pytest.mark.parametrize("version1,value,version2,result", [ + ("v1", {"result": 1}, "v2", {"result": {"value": 1, "status": "OK"}}), +]) +def test_adapt_single_argument_result(version1, value, version2, result): + adapter = APIVersionsAdapter([ + APIVersion("v1", {"Result": ResultV1}), + APIVersion("v2", {"Result": ResultV2}), + ]) + assert adapter.adapt(value, "Result", version1, version2) == result diff --git a/src/middlewared/middlewared/pytest/unit/api/handler/version/test_adapt_validation.py b/src/middlewared/middlewared/pytest/unit/api/handler/version/test_adapt_validation.py new file mode 100644 index 0000000000000..69bcca0e05644 --- /dev/null +++ b/src/middlewared/middlewared/pytest/unit/api/handler/version/test_adapt_validation.py @@ -0,0 +1,57 @@ +from unittest.mock import ANY + +from pydantic import EmailStr +import pytest + +from middlewared.api.base import BaseModel +from middlewared.api.base.handler.version import APIVersion, APIVersionsAdapter +from middlewared.service_exception import ValidationErrors, ValidationError + + +class SettingsV1(BaseModel): + email: EmailStr | None = None + + +class SettingsV2(BaseModel): + emails: list[EmailStr] + + @classmethod + def from_previous(cls, value): + email = value.pop("email") + + if email is None: + value["emails"] = [] + else: + value["emails"] = [email] + + return value + + @classmethod + def to_previous(cls, value): + emails = value.pop("emails") + + if emails: + value["email"] = emails[0] + else: + value["email"] = None + + return value + + +def test_adapt_validation(): + adapter = APIVersionsAdapter([ + APIVersion("v1", {"Settings": SettingsV1}), + APIVersion("v2", {"Settings": SettingsV2}), + ]) + with pytest.raises(ValidationErrors) as ve: + assert adapter.adapt({"email": ""}, "Settings", "v1", "v2") + + assert ve.value.errors == [ValidationError("email", ANY)] + + +def test_adapt_default(): + adapter = APIVersionsAdapter([ + APIVersion("v1", {"Settings": SettingsV1}), + APIVersion("v2", {"Settings": SettingsV2}), + ]) + assert adapter.adapt({}, "Settings", "v1", "v2") == {"emails": []} diff --git a/src/middlewared/middlewared/utils/service/crud.py b/src/middlewared/middlewared/utils/service/crud.py new file mode 100644 index 0000000000000..8ea8396485495 --- /dev/null +++ b/src/middlewared/middlewared/utils/service/crud.py @@ -0,0 +1,8 @@ +__all__ = ['real_crud_method'] + + +def real_crud_method(method): + if method.__name__ in ['create', 'update', 'delete'] and hasattr(method, '__self__'): + child_method = getattr(method.__self__, f'do_{method.__name__}', None) + if child_method is not None: + return child_method