Skip to content

Commit

Permalink
Make changes for sync dependencies
Browse files Browse the repository at this point in the history
  • Loading branch information
yakimka committed Apr 21, 2024
1 parent b83d4bd commit c75a2bc
Show file tree
Hide file tree
Showing 7 changed files with 302 additions and 99 deletions.
4 changes: 2 additions & 2 deletions nanodi/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
from nanodi.nanodi import Depends, inject, shutdown_resources
from nanodi.nanodi import Depends, inject, resource, shutdown_resources

__all__ = ["Depends", "inject", "shutdown_resources"]
__all__ = ["Depends", "inject", "shutdown_resources", "resource"]
149 changes: 101 additions & 48 deletions nanodi/nanodi.py
Original file line number Diff line number Diff line change
@@ -1,22 +1,32 @@
from __future__ import annotations

import functools
import inspect
from collections.abc import Callable, Coroutine
from dataclasses import dataclass
from typing import Any, ParamSpec, TypeAlias, TypeVar
from collections.abc import Callable, Coroutine, Generator
from contextlib import (
AbstractAsyncContextManager,
AbstractContextManager,
ExitStack,
asynccontextmanager,
contextmanager,
)
from dataclasses import dataclass, field
from typing import Any, AsyncContextManager, ContextManager, ParamSpec, TypeVar

Dependency = Callable[..., Any]

Dependency: TypeAlias = Callable[[], Any]

_not_injected = object()
_registry: dict[Dependency, Any] = {}
_shutdown_callbacks: list[Callable[[], None]] = []
_unset = object()
_resources_exit_stack = ExitStack()
_resources: dict[Dependency, AsyncContextManager | ContextManager] = {}
_resources_result_cache: dict[Dependency, Any] = {}


def Depends(dependency: Dependency, /, use_cache: bool = True) -> Any: # noqa: N802
_registry[dependency] = _not_injected
return _Depends(dependency, use_cache)


@dataclass
@dataclass(frozen=True)
class _Depends:
dependency: Dependency
use_cache: bool
Expand All @@ -26,6 +36,42 @@ class _Depends:
P = ParamSpec("P")


@dataclass(frozen=True)
class ResolvedDependency:
original: Dependency
context_manager: ContextManager | AsyncContextManager | None = field(compare=False)
is_async: bool = field(default=False, compare=False)
use_cache: bool = True

@classmethod
def resolve(cls, depends: _Depends) -> ResolvedDependency:
context_manager: ContextManager | AsyncContextManager | None = None
is_async = False
if inspect.isasyncgenfunction(depends.dependency):
context_manager = asynccontextmanager(depends.dependency)()
is_async = True
elif inspect.isgeneratorfunction(depends.dependency):
context_manager = contextmanager(depends.dependency)()
return cls(depends.dependency, context_manager, is_async, depends.use_cache)


TC = TypeVar("TC", bound=Callable)


def resource(fn: TC) -> TC:
manager: ContextManager | AsyncContextManager
if inspect.isasyncgenfunction(fn):
manager = asynccontextmanager(fn)()
elif inspect.isgeneratorfunction(fn):
manager = contextmanager(fn)()
else:
raise ValueError("Resource must be a generator or async generator function")
_resources[fn] = manager
_resources_result_cache[fn] = _unset

return fn


def inject(fn: Callable[P, T]) -> Callable[P, T | Coroutine[Any, Any, T]]:
signature = inspect.signature(fn)

Expand All @@ -34,48 +80,55 @@ def wrapper(*args: P.args, **kwargs: P.kwargs) -> T:
bound = signature.bind(*args, **kwargs)
bound.apply_defaults()

to_close = []
dependencies: dict[ResolvedDependency, list[str]] = {}
for name, value in bound.arguments.items():
if isinstance(value, _Depends):
if not value.use_cache:
result, close_callback = _resolve_dependency(value.dependency)
bound.arguments[name] = result
if close_callback is not None:
to_close.append(close_callback)
else:
if _registry.get(value.dependency) is _not_injected:
result, close_callback = _resolve_dependency(value.dependency)
_registry[value.dependency] = result
if close_callback is not None:
_shutdown_callbacks.append(close_callback)
bound.arguments[name] = _registry[value.dependency]

result = fn(*bound.args, **bound.kwargs)
for close_callback in to_close:
close_callback()
return result

return wrapper
dependencies.setdefault(ResolvedDependency.resolve(value), []).append(
name
)

with _call_dependencies(dependencies) as arguments:
bound.arguments.update(arguments)
return fn(*bound.args, **bound.kwargs)

def shutdown_resources():
for close_callback in _shutdown_callbacks:
close_callback()


def _resolve_dependency(
dependency: Dependency,
) -> tuple[Any, Callable[[], None] | None]:
result = dependency()
close_callback = None
if inspect.isgeneratorfunction(dependency):
generator = result
result = next(generator)
return wrapper

def close_callback():
try:
next(generator)
except StopIteration:
pass

return result, close_callback
def shutdown_resources() -> None:
_resources_exit_stack.close()


@contextmanager
def _call_dependencies(
dependencies: dict[ResolvedDependency, list[str]],
) -> Generator[dict[str, Any], None, None]:
managers: list[tuple[AbstractContextManager, list[str]]] = []
async_managers: list[tuple[AbstractAsyncContextManager, list[str]]] = []
results = {}
for dependency, names in dependencies.items():
if context_manager := _resources.get(dependency.original):
if isinstance(context_manager, AbstractContextManager):
if _resources_result_cache.get(dependency.original) is _unset:
result = _resources_exit_stack.enter_context(context_manager)
_resources_result_cache[dependency.original] = result

result = _resources_result_cache[dependency.original]
results.update({name: result for name in names})
elif dependency.context_manager:
if isinstance(dependency.context_manager, AbstractAsyncContextManager):
async_managers.append((dependency.context_manager, names))
else:
managers.append((dependency.context_manager, names))
else:
if dependency.use_cache:
result = dependency.original()
results.update({name: result for name in names})
else:
results.update({name: dependency.original() for name in names})

with ExitStack() as stack:
values = {manager: stack.enter_context(manager) for manager, _ in managers}
for manager, names in managers:
for name in names:
results[name] = values[manager]
yield results
44 changes: 44 additions & 0 deletions nanodi/providers.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
from __future__ import annotations

from typing import TYPE_CHECKING, Any

if TYPE_CHECKING:
from types import TracebackType


class Resource:
def init(self) -> Any:
raise NotImplementedError

def close(self) -> None:
raise NotImplementedError

def __enter__(self) -> Any:
return self.init()

def __exit__(
self,
exc_type: type[BaseException] | None,
exc: BaseException | None,
traceback: TracebackType | None,
) -> None:
self.close()


class AsyncResource:
async def init(self) -> Any:
raise NotImplementedError

async def close(self) -> None:
raise NotImplementedError

async def __aenter__(self) -> Any:
return await self.init()

async def __aexit__(
self,
exc_type: type[BaseException] | None,
exc: BaseException | None,
traceback: TracebackType | None,
) -> None:
await self.close()
41 changes: 41 additions & 0 deletions tests/test_complex_logic.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
from __future__ import annotations

from nanodi import Depends, inject


def get_redis() -> str:
yield "Redis"


@inject
def get_sessions_storage(redis: str = Depends(get_redis)) -> str:
return f"SessionsStorage({redis})"


def get_postgres_connection() -> str:
return "Postgres"


@inject
def get_db(postgres: str = Depends(get_postgres_connection)) -> str:
return f"{postgres} DB"


@inject
def get_users_repository(db: str = Depends(get_db)) -> str:
return f"UsersRepository({db})"


@inject
def get_users_service(
users_repository: str = Depends(get_users_repository),
sessions_storage: str = Depends(get_sessions_storage),
) -> str:
return f"UsersService({users_repository}, {sessions_storage})"


def test_resolve_complex_service():
users_service = get_users_service()

expected = "UsersService(UsersRepository(Postgres DB), SessionsStorage(Redis))"
assert users_service == expected
44 changes: 37 additions & 7 deletions tests/test_sync_di.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,21 +22,22 @@ def my_service(redis: Redis = Depends(get_redis)):
assert isinstance(redis, Redis)


def test_dependencies_must_use_cache():
def test_dependencies_in_single_call_must_use_cache():
@inject
def my_service(redis: Redis = Depends(get_redis)):
return redis
def my_service(
redis1: Redis = Depends(get_redis), redis2: Redis = Depends(get_redis)
):
return redis1, redis2

redis1 = my_service()
redis2 = my_service()
redis1, redis2 = my_service()

assert isinstance(redis1, Redis)
assert redis1 is redis2


def test_dependencies_must_without_cache():
def test_dependencies_dont_share_cache_between_calls():
@inject
def my_service(redis: Redis = Depends(get_redis, use_cache=False)):
def my_service(redis: Redis = Depends(get_redis)):
return redis

redis1 = my_service()
Expand All @@ -45,3 +46,32 @@ def my_service(redis: Redis = Depends(get_redis, use_cache=False)):
assert isinstance(redis1, Redis)
assert isinstance(redis2, Redis)
assert redis1 is not redis2


def test_dependencies_in_single_call_dont_use_cache_if_specified():
@inject
def my_service(
redis1: Redis = Depends(get_redis, use_cache=False),
redis2: Redis = Depends(get_redis, use_cache=False),
):
return redis1, redis2

redis1, redis2 = my_service()

assert isinstance(redis1, Redis)
assert isinstance(redis2, Redis)
assert redis1 is not redis2


def test_nested_dependencies():
@inject
def my_service_inner(redis: Redis = Depends(get_redis)):
return redis

@inject
def my_service_outer(inner_service: Redis = Depends(my_service_inner)):
return inner_service

inner_service = my_service_outer()

assert isinstance(inner_service, Redis)
Loading

0 comments on commit c75a2bc

Please sign in to comment.