Skip to content

Commit

Permalink
trying to refactor more [meh]
Browse files Browse the repository at this point in the history
  • Loading branch information
yakimka committed Apr 23, 2024
1 parent b574376 commit f657a20
Showing 1 changed file with 51 additions and 38 deletions.
89 changes: 51 additions & 38 deletions nanodi/nanodi.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import asyncio
import functools
import inspect
from collections.abc import Awaitable, Callable, Coroutine
from collections.abc import Awaitable, Callable, Coroutine, Generator
from contextlib import (
AsyncExitStack,
ExitStack,
Expand All @@ -12,10 +12,20 @@
nullcontext,
)
from dataclasses import dataclass, field
from typing import Any, AsyncContextManager, ContextManager, ParamSpec, TypeVar
from typing import (
TYPE_CHECKING,
Any,
AsyncContextManager,
ContextManager,
ParamSpec,
TypeVar,
)

from nanodi.scopes import NullScope, Scope, SingletonScope

if TYPE_CHECKING:
from inspect import BoundArguments

Dependency = Callable[..., Any]
T = TypeVar("T")
P = ParamSpec("P")
Expand Down Expand Up @@ -66,7 +76,7 @@ def my_service(db=Provide(get_db), settings=Provide(get_settings)):
return Depends.from_dependency(dependency, use_cache)


def inject(fn: Callable[P, T]) -> Callable[P, T | Coroutine[Any, Any, T]]: # noqa: C901
def inject(fn: Callable[P, T]) -> Callable[P, T | Coroutine[Any, Any, T]]:
"""
Decorator to inject dependencies into a function.
Use it in combination with `Provide` to declare dependencies.
Expand All @@ -87,21 +97,15 @@ def my_service(db=Provide(some_dependency_func)):
async def wrapper(*args: P.args, **kwargs: P.kwargs) -> T:
bound = signature.bind(*args, **kwargs)
bound.apply_defaults()

dependencies: dict[Depends, list[str]] = {}
for name, value in bound.arguments.items():
if isinstance(value, Depends):
dependencies.setdefault(value, []).append(name)

exit_stack = AsyncExitStack()
for depends, names in dependencies.items():
get_value = functools.partial(
_get_value_from_depends_async, depends, exit_stack
)
for depends, names, get_value in _resolve_depends(
bound, exit_stack, is_async=True
):
if depends.use_cache:
value = await get_value()
get_value = functools.partial(lambda v: v, value)
bound.arguments.update({name: get_value() for name in names})
bound.arguments.update({name: value for name in names})
else:
bound.arguments.update({name: await get_value() for name in names})

async with exit_stack:
result = await fn(*bound.args, **bound.kwargs)
Expand All @@ -113,21 +117,15 @@ async def wrapper(*args: P.args, **kwargs: P.kwargs) -> T:
def wrapper(*args: P.args, **kwargs: P.kwargs) -> T:
bound = signature.bind(*args, **kwargs)
bound.apply_defaults()

dependencies: dict[Depends, list[str]] = {}
for name, value in bound.arguments.items():
if isinstance(value, Depends):
dependencies.setdefault(value, []).append(name)

exit_stack = ExitStack()
for depends, names in dependencies.items():
get_value = functools.partial(
_get_value_from_depends_sync, depends, exit_stack
)
for depends, names, get_value in _resolve_depends(
bound, exit_stack, is_async=False
):
if depends.use_cache:
value = get_value()
get_value = functools.partial(lambda v: v, value)
bound.arguments.update({name: get_value() for name in names})
bound.arguments.update({name: value for name in names})
else:
bound.arguments.update({name: get_value() for name in names})

with exit_stack:
result = fn(*bound.args, **bound.kwargs)
Expand Down Expand Up @@ -174,7 +172,7 @@ def init_resources() -> Awaitable:
_get_value_from_depends_async(depends, _async_exit_stack)
)
else:
_get_value_from_depends_sync(depends, _exit_stack)
_get_value_from_depends(depends, _exit_stack)

return asyncio.gather(*async_resources)

Expand Down Expand Up @@ -216,9 +214,24 @@ def from_dependency(cls, dependency: Dependency, use_cache: bool) -> Depends:
return cls(dependency, use_cache, context_manager, is_async)


async def _get_value_from_depends_async(
def _resolve_depends(
bound: BoundArguments, exit_stack: AsyncExitStack | ExitStack, is_async: bool
) -> Generator[tuple[Depends, list[str], Callable[[], Any]], None, None]:
dependencies: dict[Depends, list[str]] = {}
for name, value in bound.arguments.items():
if isinstance(value, Depends):
dependencies.setdefault(value, []).append(name)

get_val = _get_value_from_depends_async if is_async else _get_value_from_depends

for depends, names in dependencies.items():
get_value = functools.partial(get_val, depends, exit_stack) # type: ignore
yield depends, names, get_value


def _get_value_from_depends(
depends: Depends,
local_exit_stack: AsyncExitStack,
local_exit_stack: ExitStack,
) -> Any:
scope_name = depends.get_scope_name()
scope = _scopes[scope_name]
Expand All @@ -228,18 +241,18 @@ async def _get_value_from_depends_async(
context_manager = depends.value_as_context_manager()
exit_stack = local_exit_stack
if scope_name == "singleton":
exit_stack = _async_exit_stack
exit_stack = _exit_stack
if depends.is_async:
value = await exit_stack.enter_async_context(context_manager)
value = depends.dependency
else:
value = exit_stack.enter_context(context_manager)
scope.set(depends.dependency, value)
scope.set(depends.dependency, value)
return value


def _get_value_from_depends_sync(
async def _get_value_from_depends_async(
depends: Depends,
local_exit_stack: ExitStack,
local_exit_stack: AsyncExitStack,
) -> Any:
scope_name = depends.get_scope_name()
scope = _scopes[scope_name]
Expand All @@ -249,10 +262,10 @@ def _get_value_from_depends_sync(
context_manager = depends.value_as_context_manager()
exit_stack = local_exit_stack
if scope_name == "singleton":
exit_stack = _exit_stack
exit_stack = _async_exit_stack
if depends.is_async:
value = depends.dependency
value = await exit_stack.enter_async_context(context_manager)
else:
value = exit_stack.enter_context(context_manager)
scope.set(depends.dependency, value)
scope.set(depends.dependency, value)
return value

0 comments on commit f657a20

Please sign in to comment.