From c75a2bc47348541cfe959e69121fab87dc5ffba8 Mon Sep 17 00:00:00 2001 From: yakimka Date: Sun, 21 Apr 2024 17:47:23 +0300 Subject: [PATCH] Make changes for sync dependencies --- nanodi/__init__.py | 4 +- nanodi/nanodi.py | 149 +++++++++++++++++++---------- nanodi/providers.py | 44 +++++++++ tests/test_complex_logic.py | 41 ++++++++ tests/test_sync_di.py | 44 +++++++-- tests/test_sync_di_with_closing.py | 60 ++++-------- tests/test_sync_resource.py | 59 ++++++++++++ 7 files changed, 302 insertions(+), 99 deletions(-) create mode 100644 nanodi/providers.py create mode 100644 tests/test_complex_logic.py create mode 100644 tests/test_sync_resource.py diff --git a/nanodi/__init__.py b/nanodi/__init__.py index 071ce5b..d633b2d 100644 --- a/nanodi/__init__.py +++ b/nanodi/__init__.py @@ -1,3 +1,3 @@ -from nanodi.nanodi import Depends, inject, shutdown_resources +from nanodi.nanodi import Depends, inject, resource, shutdown_resources -__all__ = ["Depends", "inject", "shutdown_resources"] +__all__ = ["Depends", "inject", "shutdown_resources", "resource"] diff --git a/nanodi/nanodi.py b/nanodi/nanodi.py index e718190..56ea61a 100644 --- a/nanodi/nanodi.py +++ b/nanodi/nanodi.py @@ -1,22 +1,32 @@ +from __future__ import annotations + import functools import inspect -from collections.abc import Callable, Coroutine -from dataclasses import dataclass -from typing import Any, ParamSpec, TypeAlias, TypeVar +from collections.abc import Callable, Coroutine, Generator +from contextlib import ( + AbstractAsyncContextManager, + AbstractContextManager, + ExitStack, + asynccontextmanager, + contextmanager, +) +from dataclasses import dataclass, field +from typing import Any, AsyncContextManager, ContextManager, ParamSpec, TypeVar + +Dependency = Callable[..., Any] -Dependency: TypeAlias = Callable[[], Any] -_not_injected = object() -_registry: dict[Dependency, Any] = {} -_shutdown_callbacks: list[Callable[[], None]] = [] +_unset = object() +_resources_exit_stack = ExitStack() +_resources: dict[Dependency, AsyncContextManager | ContextManager] = {} +_resources_result_cache: dict[Dependency, Any] = {} def Depends(dependency: Dependency, /, use_cache: bool = True) -> Any: # noqa: N802 - _registry[dependency] = _not_injected return _Depends(dependency, use_cache) -@dataclass +@dataclass(frozen=True) class _Depends: dependency: Dependency use_cache: bool @@ -26,6 +36,42 @@ class _Depends: P = ParamSpec("P") +@dataclass(frozen=True) +class ResolvedDependency: + original: Dependency + context_manager: ContextManager | AsyncContextManager | None = field(compare=False) + is_async: bool = field(default=False, compare=False) + use_cache: bool = True + + @classmethod + def resolve(cls, depends: _Depends) -> ResolvedDependency: + context_manager: ContextManager | AsyncContextManager | None = None + is_async = False + if inspect.isasyncgenfunction(depends.dependency): + context_manager = asynccontextmanager(depends.dependency)() + is_async = True + elif inspect.isgeneratorfunction(depends.dependency): + context_manager = contextmanager(depends.dependency)() + return cls(depends.dependency, context_manager, is_async, depends.use_cache) + + +TC = TypeVar("TC", bound=Callable) + + +def resource(fn: TC) -> TC: + manager: ContextManager | AsyncContextManager + if inspect.isasyncgenfunction(fn): + manager = asynccontextmanager(fn)() + elif inspect.isgeneratorfunction(fn): + manager = contextmanager(fn)() + else: + raise ValueError("Resource must be a generator or async generator function") + _resources[fn] = manager + _resources_result_cache[fn] = _unset + + return fn + + def inject(fn: Callable[P, T]) -> Callable[P, T | Coroutine[Any, Any, T]]: signature = inspect.signature(fn) @@ -34,48 +80,55 @@ def wrapper(*args: P.args, **kwargs: P.kwargs) -> T: bound = signature.bind(*args, **kwargs) bound.apply_defaults() - to_close = [] + dependencies: dict[ResolvedDependency, list[str]] = {} for name, value in bound.arguments.items(): if isinstance(value, _Depends): - if not value.use_cache: - result, close_callback = _resolve_dependency(value.dependency) - bound.arguments[name] = result - if close_callback is not None: - to_close.append(close_callback) - else: - if _registry.get(value.dependency) is _not_injected: - result, close_callback = _resolve_dependency(value.dependency) - _registry[value.dependency] = result - if close_callback is not None: - _shutdown_callbacks.append(close_callback) - bound.arguments[name] = _registry[value.dependency] - - result = fn(*bound.args, **bound.kwargs) - for close_callback in to_close: - close_callback() - return result - - return wrapper + dependencies.setdefault(ResolvedDependency.resolve(value), []).append( + name + ) + with _call_dependencies(dependencies) as arguments: + bound.arguments.update(arguments) + return fn(*bound.args, **bound.kwargs) -def shutdown_resources(): - for close_callback in _shutdown_callbacks: - close_callback() - - -def _resolve_dependency( - dependency: Dependency, -) -> tuple[Any, Callable[[], None] | None]: - result = dependency() - close_callback = None - if inspect.isgeneratorfunction(dependency): - generator = result - result = next(generator) + return wrapper - def close_callback(): - try: - next(generator) - except StopIteration: - pass - return result, close_callback +def shutdown_resources() -> None: + _resources_exit_stack.close() + + +@contextmanager +def _call_dependencies( + dependencies: dict[ResolvedDependency, list[str]], +) -> Generator[dict[str, Any], None, None]: + managers: list[tuple[AbstractContextManager, list[str]]] = [] + async_managers: list[tuple[AbstractAsyncContextManager, list[str]]] = [] + results = {} + for dependency, names in dependencies.items(): + if context_manager := _resources.get(dependency.original): + if isinstance(context_manager, AbstractContextManager): + if _resources_result_cache.get(dependency.original) is _unset: + result = _resources_exit_stack.enter_context(context_manager) + _resources_result_cache[dependency.original] = result + + result = _resources_result_cache[dependency.original] + results.update({name: result for name in names}) + elif dependency.context_manager: + if isinstance(dependency.context_manager, AbstractAsyncContextManager): + async_managers.append((dependency.context_manager, names)) + else: + managers.append((dependency.context_manager, names)) + else: + if dependency.use_cache: + result = dependency.original() + results.update({name: result for name in names}) + else: + results.update({name: dependency.original() for name in names}) + + with ExitStack() as stack: + values = {manager: stack.enter_context(manager) for manager, _ in managers} + for manager, names in managers: + for name in names: + results[name] = values[manager] + yield results diff --git a/nanodi/providers.py b/nanodi/providers.py new file mode 100644 index 0000000..8225eaf --- /dev/null +++ b/nanodi/providers.py @@ -0,0 +1,44 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING, Any + +if TYPE_CHECKING: + from types import TracebackType + + +class Resource: + def init(self) -> Any: + raise NotImplementedError + + def close(self) -> None: + raise NotImplementedError + + def __enter__(self) -> Any: + return self.init() + + def __exit__( + self, + exc_type: type[BaseException] | None, + exc: BaseException | None, + traceback: TracebackType | None, + ) -> None: + self.close() + + +class AsyncResource: + async def init(self) -> Any: + raise NotImplementedError + + async def close(self) -> None: + raise NotImplementedError + + async def __aenter__(self) -> Any: + return await self.init() + + async def __aexit__( + self, + exc_type: type[BaseException] | None, + exc: BaseException | None, + traceback: TracebackType | None, + ) -> None: + await self.close() diff --git a/tests/test_complex_logic.py b/tests/test_complex_logic.py new file mode 100644 index 0000000..d67947f --- /dev/null +++ b/tests/test_complex_logic.py @@ -0,0 +1,41 @@ +from __future__ import annotations + +from nanodi import Depends, inject + + +def get_redis() -> str: + yield "Redis" + + +@inject +def get_sessions_storage(redis: str = Depends(get_redis)) -> str: + return f"SessionsStorage({redis})" + + +def get_postgres_connection() -> str: + return "Postgres" + + +@inject +def get_db(postgres: str = Depends(get_postgres_connection)) -> str: + return f"{postgres} DB" + + +@inject +def get_users_repository(db: str = Depends(get_db)) -> str: + return f"UsersRepository({db})" + + +@inject +def get_users_service( + users_repository: str = Depends(get_users_repository), + sessions_storage: str = Depends(get_sessions_storage), +) -> str: + return f"UsersService({users_repository}, {sessions_storage})" + + +def test_resolve_complex_service(): + users_service = get_users_service() + + expected = "UsersService(UsersRepository(Postgres DB), SessionsStorage(Redis))" + assert users_service == expected diff --git a/tests/test_sync_di.py b/tests/test_sync_di.py index fd19c74..6e6b44d 100644 --- a/tests/test_sync_di.py +++ b/tests/test_sync_di.py @@ -22,21 +22,22 @@ def my_service(redis: Redis = Depends(get_redis)): assert isinstance(redis, Redis) -def test_dependencies_must_use_cache(): +def test_dependencies_in_single_call_must_use_cache(): @inject - def my_service(redis: Redis = Depends(get_redis)): - return redis + def my_service( + redis1: Redis = Depends(get_redis), redis2: Redis = Depends(get_redis) + ): + return redis1, redis2 - redis1 = my_service() - redis2 = my_service() + redis1, redis2 = my_service() assert isinstance(redis1, Redis) assert redis1 is redis2 -def test_dependencies_must_without_cache(): +def test_dependencies_dont_share_cache_between_calls(): @inject - def my_service(redis: Redis = Depends(get_redis, use_cache=False)): + def my_service(redis: Redis = Depends(get_redis)): return redis redis1 = my_service() @@ -45,3 +46,32 @@ def my_service(redis: Redis = Depends(get_redis, use_cache=False)): assert isinstance(redis1, Redis) assert isinstance(redis2, Redis) assert redis1 is not redis2 + + +def test_dependencies_in_single_call_dont_use_cache_if_specified(): + @inject + def my_service( + redis1: Redis = Depends(get_redis, use_cache=False), + redis2: Redis = Depends(get_redis, use_cache=False), + ): + return redis1, redis2 + + redis1, redis2 = my_service() + + assert isinstance(redis1, Redis) + assert isinstance(redis2, Redis) + assert redis1 is not redis2 + + +def test_nested_dependencies(): + @inject + def my_service_inner(redis: Redis = Depends(get_redis)): + return redis + + @inject + def my_service_outer(inner_service: Redis = Depends(my_service_inner)): + return inner_service + + inner_service = my_service_outer() + + assert isinstance(inner_service, Redis) diff --git a/tests/test_sync_di_with_closing.py b/tests/test_sync_di_with_closing.py index 72a39a5..2650a54 100644 --- a/tests/test_sync_di_with_closing.py +++ b/tests/test_sync_di_with_closing.py @@ -1,7 +1,14 @@ -from collections.abc import Generator +from __future__ import annotations + from dataclasses import dataclass +from typing import TYPE_CHECKING + +import pytest -from nanodi import Depends, inject, shutdown_resources +from nanodi import Depends, inject + +if TYPE_CHECKING: + from collections.abc import Generator @dataclass @@ -9,6 +16,11 @@ class Redis: host: str closed: bool = False + def make_request(self) -> None: + if self.closed: + raise ValueError("Connection is closed") + return None + def close(self) -> None: self.closed = True @@ -29,49 +41,13 @@ def my_service(redis: Redis = Depends(get_redis)): assert isinstance(redis, Redis) -def test_close_dependency_after_call_if_not_cached(): +@pytest.mark.parametrize("use_cache", [True, False]) +def test_close_dependency_after_call(use_cache): @inject - def my_service(redis: Redis = Depends(get_redis, use_cache=False)): + def my_service(redis: Redis = Depends(get_redis, use_cache=use_cache)): + redis.make_request() return redis redis = my_service() assert redis.closed is True - - -def test_close_dependency_globally_if_cached(): - @inject - def my_service(redis: Redis = Depends(get_redis, use_cache=True)): - return redis - - redis = my_service() - assert redis.closed is False - - shutdown_resources() - - assert redis.closed is True - - -def test_dependencies_must_use_cache(): - @inject - def my_service(redis: Redis = Depends(get_redis)): - return redis - - redis1 = my_service() - redis2 = my_service() - - assert isinstance(redis1, Redis) - assert redis1 is redis2 - - -def test_dependencies_must_without_cache(): - @inject - def my_service(redis: Redis = Depends(get_redis, use_cache=False)): - return redis - - redis1 = my_service() - redis2 = my_service() - - assert isinstance(redis1, Redis) - assert isinstance(redis2, Redis) - assert redis1 is not redis2 diff --git a/tests/test_sync_resource.py b/tests/test_sync_resource.py new file mode 100644 index 0000000..88e31ab --- /dev/null +++ b/tests/test_sync_resource.py @@ -0,0 +1,59 @@ +from __future__ import annotations + +from dataclasses import dataclass +from typing import TYPE_CHECKING + +import pytest + +from nanodi import Depends, inject, resource, shutdown_resources + +if TYPE_CHECKING: + from collections.abc import Generator + + +@dataclass +class Redis: + host: str + closed: bool = False + + def make_request(self) -> None: + if self.closed: + raise ValueError("Connection is closed") + return None + + def close(self) -> None: + self.closed = True + + +@pytest.fixture() +def redis_dependency(): + @resource + def get_redis() -> Generator[Redis, None, None]: + redis = Redis(host="localhost") + yield redis + redis.close() + + return get_redis + + +def test_resources_dont_close_automatically(redis_dependency): + @inject + def my_service(redis: Redis = Depends(redis_dependency)): + redis.make_request() + return redis + + redis = my_service() + + assert redis.closed is False + + +def test_resources_can_be_closed_manually(redis_dependency): + @inject + def my_service(redis: Redis = Depends(redis_dependency)): + redis.make_request() + return redis + + redis = my_service() + + shutdown_resources() + assert redis.closed is True