diff --git a/src/middlewared/middlewared/api/__init__.py b/src/middlewared/middlewared/api/__init__.py index 632c94b1e329d..43f1482a98d6e 100644 --- a/src/middlewared/middlewared/api/__init__.py +++ b/src/middlewared/middlewared/api/__init__.py @@ -1 +1,3 @@ from .base.decorator import * + +API_LOADING_FORBIDDEN = False diff --git a/src/middlewared/middlewared/api/base/types/urls.py b/src/middlewared/middlewared/api/base/types/urls.py index 474be0551c2fa..e15da67da7e03 100644 --- a/src/middlewared/middlewared/api/base/types/urls.py +++ b/src/middlewared/middlewared/api/base/types/urls.py @@ -1,10 +1,12 @@ -from typing import Annotated +from typing import Annotated, Literal, TypeAlias from pydantic import AfterValidator, HttpUrl from middlewared.api.base.validators import https_only_check -__all__ = ["HttpsOnlyURL"] +__all__ = ["HttpsOnlyURL", "HttpVerb"] HttpsOnlyURL = Annotated[HttpUrl, AfterValidator(https_only_check)] + +HttpVerb: TypeAlias = Literal["GET", "POST", "PUT", "DELETE", "CALL", "SUBSCRIBE", "*"] diff --git a/src/middlewared/middlewared/api/current.py b/src/middlewared/middlewared/api/current.py index 98fd54d020da1..c8c4584b80f8f 100644 --- a/src/middlewared/middlewared/api/current.py +++ b/src/middlewared/middlewared/api/current.py @@ -1 +1,8 @@ +from . import API_LOADING_FORBIDDEN +if API_LOADING_FORBIDDEN: + raise RuntimeError( + "Middleware API loading forbidden in this code path as it is too resource-consuming. Please, inspect the " + "provided traceback and ensure that nothing is imported from `middlewared.api.current`." + ) + from .v25_04_0 import * # noqa diff --git a/src/middlewared/middlewared/api/v25_04_0/api_key.py b/src/middlewared/middlewared/api/v25_04_0/api_key.py index 638ab7ad7cfcc..d91743cd88834 100644 --- a/src/middlewared/middlewared/api/v25_04_0/api_key.py +++ b/src/middlewared/middlewared/api/v25_04_0/api_key.py @@ -1,17 +1,14 @@ from datetime import datetime -from typing import Annotated, Literal, TypeAlias +from typing import Annotated, Literal from pydantic import Secret, StringConstraints from middlewared.api.base import ( BaseModel, Excluded, excluded_field, ForUpdateMetaclass, NonEmptyString, - LocalUsername, RemoteUsername + LocalUsername, RemoteUsername, HttpVerb, ) -HttpVerb: TypeAlias = Literal["GET", "POST", "PUT", "DELETE", "CALL", "SUBSCRIBE", "*"] - - class AllowListItem(BaseModel): method: HttpVerb resource: NonEmptyString diff --git a/src/middlewared/middlewared/main.py b/src/middlewared/middlewared/main.py index 594a23a97c4e7..004079d81fe39 100644 --- a/src/middlewared/middlewared/main.py +++ b/src/middlewared/middlewared/main.py @@ -17,7 +17,7 @@ from .schema import OROperator import middlewared.service from .service_exception import CallError, ErrnoMixin -from .utils import MIDDLEWARE_RUN_DIR, sw_version +from .utils import MIDDLEWARE_RUN_DIR, MIDDLEWARE_STARTED_SENTINEL_PATH, sw_version from .utils.audit import audit_username_from_session from .utils.debug import get_threads_stacks from .utils.limits import MsgSizeError, MsgSizeLimit, parse_message @@ -27,7 +27,6 @@ from .utils.rate_limit.cache import RateLimitCache from .utils.service.call import ServiceCallMixin from .utils.service.crud import real_crud_method -from .utils.syslog import syslog_message from .utils.threading import set_thread_name, IoThreadPoolExecutor, io_thread_pool_executor from .utils.time_utils import utc_now from .utils.type import copy_function_metadata @@ -200,7 +199,7 @@ def _add_api_route(self, version: str, api: API): self.app.router.add_route('GET', f'/api/{version}', RpcWebSocketHandler(self, api.methods)) def __init_services(self): - from middlewared.service import CoreService + from middlewared.service.core_service import CoreService self.add_service(CoreService(self)) self.event_register('core.environ', 'Send on middleware process environment changes.', private=True) @@ -448,7 +447,7 @@ def __notify_startup_progress(self): systemd_notify(f'EXTEND_TIMEOUT_USEC={SYSTEMD_EXTEND_USECS}') def __notify_startup_complete(self): - with open(middlewared.service.MIDDLEWARE_STARTED_SENTINEL_PATH, 'w'): + with open(MIDDLEWARE_STARTED_SENTINEL_PATH, 'w'): pass systemd_notify('READY=1') diff --git a/src/middlewared/middlewared/plugins/datastore/connection.py b/src/middlewared/middlewared/plugins/datastore/connection.py index 36f32e235813c..45d965294651d 100644 --- a/src/middlewared/middlewared/plugins/datastore/connection.py +++ b/src/middlewared/middlewared/plugins/datastore/connection.py @@ -7,7 +7,7 @@ from middlewared.service import private, Service -from middlewared.plugins.config import FREENAS_DATABASE +from middlewared.utils.db import FREENAS_DATABASE thread_pool = ThreadPoolExecutor(1) diff --git a/src/middlewared/middlewared/plugins/datastore/read.py b/src/middlewared/middlewared/plugins/datastore/read.py index d4dd86ad368c0..476ac0b447ca9 100644 --- a/src/middlewared/middlewared/plugins/datastore/read.py +++ b/src/middlewared/middlewared/plugins/datastore/read.py @@ -3,9 +3,7 @@ from sqlalchemy import and_, func, select from sqlalchemy.sql import Alias -from sqlalchemy.sql.elements import UnaryExpression from sqlalchemy.sql.expression import nullsfirst, nullslast -from sqlalchemy.sql.operators import desc_op, nullsfirst_op, nullslast_op from middlewared.schema import accepts, Bool, Dict, Int, List, Ref, Str from middlewared.service import Service diff --git a/src/middlewared/middlewared/plugins/zettarepl.py b/src/middlewared/middlewared/plugins/zettarepl.py index 9b75c5621e1e4..3f2355cd3e042 100644 --- a/src/middlewared/middlewared/plugins/zettarepl.py +++ b/src/middlewared/middlewared/plugins/zettarepl.py @@ -44,7 +44,8 @@ from zettarepl.zettarepl import create_zettarepl from middlewared.logger import setup_logging -from middlewared.service import CallError, Service +from middlewared.service.service import Service +from middlewared.service_exception import CallError from middlewared.utils.cgroups import move_to_root_cgroups from middlewared.utils.prctl import die_with_parent from middlewared.utils.size import format_size diff --git a/src/middlewared/middlewared/pytest/unit/test_service_cli_descriptions.py b/src/middlewared/middlewared/pytest/unit/test_service_cli_descriptions.py index d7af9d5d5356e..941cc602b5a90 100644 --- a/src/middlewared/middlewared/pytest/unit/test_service_cli_descriptions.py +++ b/src/middlewared/middlewared/pytest/unit/test_service_cli_descriptions.py @@ -2,7 +2,7 @@ import pytest -from middlewared.service import CoreService +from middlewared.service.core_service import CoreService @pytest.mark.parametrize("doc,names,descriptions", [ diff --git a/src/middlewared/middlewared/service/__init__.py b/src/middlewared/middlewared/service/__init__.py index 6779f1bf5fe40..2432d8d2097bf 100644 --- a/src/middlewared/middlewared/service/__init__.py +++ b/src/middlewared/middlewared/service/__init__.py @@ -6,7 +6,6 @@ from .compound_service import CompoundService # noqa from .config_service import ConfigService # noqa -from .core_service import CoreService, MIDDLEWARE_RUN_DIR, MIDDLEWARE_STARTED_SENTINEL_PATH # noqa from .crud_service import CRUDService # noqa from .decorators import ( # noqa cli_private, filterable, filterable_returns, item_method, job, lock, no_auth_required, diff --git a/src/middlewared/middlewared/service/core_service.py b/src/middlewared/middlewared/service/core_service.py index 0635b1eb5c559..67d991113077b 100644 --- a/src/middlewared/middlewared/service/core_service.py +++ b/src/middlewared/middlewared/service/core_service.py @@ -33,7 +33,7 @@ from middlewared.pipe import Pipes from middlewared.schema import accepts, Any, Bool, Datetime, Dict, Int, List, Str from middlewared.service_exception import CallError, ValidationErrors -from middlewared.utils import BOOTREADY, filter_list, MIDDLEWARE_RUN_DIR +from middlewared.utils import BOOTREADY, filter_list, MIDDLEWARE_STARTED_SENTINEL_PATH from middlewared.utils.debug import get_frame_details, get_threads_stacks from middlewared.validators import IpAddress, Range @@ -44,9 +44,6 @@ from .service import Service -MIDDLEWARE_STARTED_SENTINEL_PATH = os.path.join(MIDDLEWARE_RUN_DIR, 'middlewared-started') - - def is_service_class(service, klass): return ( isinstance(service, klass) or diff --git a/src/middlewared/middlewared/service/crud_service.py b/src/middlewared/middlewared/service/crud_service.py index 8e1ad4847e6e6..8f61a6f2c4637 100644 --- a/src/middlewared/middlewared/service/crud_service.py +++ b/src/middlewared/middlewared/service/crud_service.py @@ -5,9 +5,10 @@ from pydantic import create_model, Field -from middlewared.api import api_method +from middlewared.api import API_LOADING_FORBIDDEN, api_method from middlewared.api.base.model import BaseModel, query_result, query_result_item -from middlewared.api.current import QueryArgs, QueryOptions +if not API_LOADING_FORBIDDEN: + from middlewared.api.current import QueryArgs, QueryOptions from middlewared.service_exception import CallError, InstanceNotFound from middlewared.schema import accepts, Any, Bool, convert_schema, Dict, Int, List, OROperator, Patch, Ref, returns from middlewared.utils import filter_list diff --git a/src/middlewared/middlewared/service/decorators.py b/src/middlewared/middlewared/service/decorators.py index 2ca5f2b012ee0..d03a3a5da29af 100644 --- a/src/middlewared/middlewared/service/decorators.py +++ b/src/middlewared/middlewared/service/decorators.py @@ -4,9 +4,10 @@ from collections import defaultdict, namedtuple from functools import wraps -from middlewared.api import api_method +from middlewared.api import API_LOADING_FORBIDDEN, api_method from middlewared.api.base import query_result -from middlewared.api.current import QueryArgs, GenericQueryResult +if not API_LOADING_FORBIDDEN: + from middlewared.api.current import QueryArgs, GenericQueryResult from middlewared.schema import accepts, Int, List, OROperator, Ref, returns diff --git a/src/middlewared/middlewared/utils/__init__.py b/src/middlewared/middlewared/utils/__init__.py index c6ae743a7fd2b..dc55572ab97cf 100644 --- a/src/middlewared/middlewared/utils/__init__.py +++ b/src/middlewared/middlewared/utils/__init__.py @@ -34,6 +34,7 @@ class ProductNames: MID_PID = None MIDDLEWARE_RUN_DIR = '/var/run/middleware' +MIDDLEWARE_STARTED_SENTINEL_PATH = f'{MIDDLEWARE_RUN_DIR}/middlewared-started' BOOTREADY = f'{MIDDLEWARE_RUN_DIR}/.bootready' MANIFEST_FILE = '/data/manifest.json' BRAND = ProductName.PRODUCT_NAME diff --git a/src/middlewared/middlewared/utils/allowlist.py b/src/middlewared/middlewared/utils/allowlist.py index 193abe8459324..bbb5d1bb4a858 100644 --- a/src/middlewared/middlewared/utils/allowlist.py +++ b/src/middlewared/middlewared/utils/allowlist.py @@ -1,7 +1,7 @@ import fnmatch import re -from middlewared.api.current import HttpVerb +from middlewared.api.base.types import HttpVerb from middlewared.utils.privilege_constants import ALLOW_LIST_FULL_ADMIN diff --git a/src/middlewared/middlewared/utils/plugins.py b/src/middlewared/middlewared/utils/plugins.py index 0eaee7578b80d..5e224b427220e 100644 --- a/src/middlewared/middlewared/utils/plugins.py +++ b/src/middlewared/middlewared/utils/plugins.py @@ -11,7 +11,7 @@ logger = logging.getLogger(__name__) -def load_modules(directory, base=None, depth=0): +def load_modules(directory, base=None, depth=0, whitelist=None): directory = os.path.normpath(directory) if base is None: middlewared_root = os.path.dirname(os.path.dirname(__file__)) @@ -32,13 +32,15 @@ def load_modules(directory, base=None, depth=0): base = '.'.join(os.path.relpath(directory, new_module_path).split('/')) _, dirs, files = next(os.walk(directory)) - for f in filter(lambda x: x[-3:] == '.py' and x.find('_freebsd') == -1, files): - yield importlib.import_module(base if f == '__init__.py' else f'{base}.{f[:-3]}') + for f in filter(lambda x: x[-3:] == '.py', files): + module_name = base if f == '__init__.py' else f'{base}.{f[:-3]}' + if whitelist is None or any(module_name.startswith(w) for w in whitelist): + yield importlib.import_module(module_name) - for f in filter(lambda x: x.find('_freebsd') == -1, dirs): + for f in dirs: if depth > 0: path = os.path.join(directory, f) - yield from load_modules(path, f'{base}.{f}', depth - 1) + yield from load_modules(path, f'{base}.{f}', depth - 1, whitelist) def load_classes(module, base, blacklist): @@ -92,7 +94,7 @@ def __init__(self): self._services_aliases = {} super().__init__() - def _load_plugins(self, on_module_begin=None, on_module_end=None, on_modules_loaded=None): + def _load_plugins(self, on_module_begin=None, on_module_end=None, on_modules_loaded=None, whitelist=None): from middlewared.service import Service, CompoundService, ABSTRACT_SERVICES services = [] @@ -100,7 +102,7 @@ def _load_plugins(self, on_module_begin=None, on_module_end=None, on_modules_loa if not os.path.exists(plugins_dir): raise ValueError(f'plugins dir not found: {plugins_dir}') - for mod in load_modules(plugins_dir, depth=1): + for mod in load_modules(plugins_dir, depth=1, whitelist=whitelist): if on_module_begin: on_module_begin(mod) diff --git a/src/middlewared/middlewared/worker.py b/src/middlewared/middlewared/worker.py index 43d14411d6ae2..bb3d55816400d 100755 --- a/src/middlewared/middlewared/worker.py +++ b/src/middlewared/middlewared/worker.py @@ -5,6 +5,7 @@ from truenas_api_client import Client +import middlewared.api from . import logger from .common.environ import environ_update from .utils import MIDDLEWARE_RUN_DIR @@ -49,29 +50,32 @@ def call_sync(self, method, *params, timeout=None, **kwargs): """ Calls a method using middleware client """ - serviceobj, methodobj = self.get_method(method) - - if serviceobj._config.process_pool and not hasattr(method, '_job'): - if asyncio.iscoroutinefunction(methodobj): - try: - # Search for a synchronous implementation of the asynchronous method (i.e. `get_instance`). - # Why is this needed? Imagine we have a `ZFSSnapshot` service that uses a process pool. Let's say - # its `create` method calls `zfs.snapshot.get_instance` to return the result. That call will have - # to be forwarded to the main middleware process, which will call `zfs.snapshot.query` in the - # process pool. If the process pool is already exhausted, it will lead to a deadlock. - # By executing a synchronous implementation of the same method in the same process pool we - # eliminate `Hold and wait` condition and prevent deadlock situation from arising. - _, sync_methodobj = self.get_method(f'{method}__sync') - except MethodNotFoundError: - # FIXME: Make this an exception in 22.MM - self.logger.warning('Service uses a process pool but has an asynchronous method: %r', method) - sync_methodobj = None - else: - sync_methodobj = methodobj - - if sync_methodobj is not None: - self.logger.trace('Calling %r in current process', method) - return sync_methodobj(*params) + try: + serviceobj, methodobj = self.get_method(method) + except Exception: + pass + else: + if serviceobj._config.process_pool and not hasattr(method, '_job'): + if asyncio.iscoroutinefunction(methodobj): + try: + # Search for a synchronous implementation of the asynchronous method (i.e. `get_instance`). + # Why is this needed? Imagine we have a `ZFSSnapshot` service that uses a process pool. Let's say + # its `create` method calls `zfs.snapshot.get_instance` to return the result. That call will have + # to be forwarded to the main middleware process, which will call `zfs.snapshot.query` in the + # process pool. If the process pool is already exhausted, it will lead to a deadlock. + # By executing a synchronous implementation of the same method in the same process pool we + # eliminate `Hold and wait` condition and prevent deadlock situation from arising. + _, sync_methodobj = self.get_method(f'{method}__sync') + except MethodNotFoundError: + # FIXME: Make this an exception in 22.MM + self.logger.warning('Service uses a process pool but has an asynchronous method: %r', method) + sync_methodobj = None + else: + sync_methodobj = methodobj + + if sync_methodobj is not None: + self.logger.trace('Calling %r in current process', method) + return sync_methodobj(*params) return self.client.call(method, *params, timeout=timeout, **kwargs) @@ -128,9 +132,13 @@ def receive_events(): def worker_init(debug_level, log_handler): global MIDDLEWARE + middlewared.api.API_LOADING_FORBIDDEN = True MIDDLEWARE = FakeMiddleware() os.environ['MIDDLEWARED_LOADING'] = 'True' - MIDDLEWARE._load_plugins() + MIDDLEWARE._load_plugins(whitelist=[ + 'middlewared.plugins.datastore', + 'middlewared.plugins.zfs_', + ]) os.environ['MIDDLEWARED_LOADING'] = 'False' setproctitle.setproctitle('middlewared (worker)') die_with_parent()