Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feature parity with Python 3.13 #160

Merged
merged 14 commits into from
Oct 23, 2024
Merged
2 changes: 1 addition & 1 deletion LICENSE
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
MIT License

Copyright (c) 2019 - 2024 Max Fischer
Copyright (c) 2019 - 2024 Max Kühn

Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
Expand Down
2 changes: 1 addition & 1 deletion asyncstdlib/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@
from .asynctools import borrow, scoped_iter, await_each, any_iter, apply, sync
from .heapq import merge, nlargest, nsmallest

__version__ = "3.12.5"
__version__ = "3.13.0"

__all__ = [
"anext",
Expand Down
35 changes: 16 additions & 19 deletions asyncstdlib/functools.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
Generic,
Generator,
Optional,
Coroutine,
AsyncContextManager,
Type,
cast,
Expand Down Expand Up @@ -66,25 +65,25 @@ def __repr__(self) -> str:
return f"{self.__class__.__name__}({self.value!r})"


class _FutureCachedValue(Generic[R, T]):
"""A placeholder object to control concurrent access to a cached awaitable value.
class _FutureCachedPropertyValue(Generic[R, T]):
"""
A placeholder object to control concurrent access to a cached awaitable value

When given a lock to coordinate access, only the first task to await on a
cached property triggers the underlying coroutine. Once a value has been
produced, all tasks are unblocked and given the same, single value.

"""

__slots__ = ("_get_attribute", "_instance", "_name", "_lock")
__slots__ = ("_func", "_instance", "_name", "_lock")

def __init__(
self,
get_attribute: Callable[[T], Coroutine[Any, Any, R]],
func: Callable[[T], Awaitable[R]],
instance: T,
name: str,
lock: AsyncContextManager[Any],
):
self._get_attribute = get_attribute
self._func = func
self._instance = instance
self._name = name
self._lock = lock
Expand All @@ -98,7 +97,6 @@ def _instance_value(self) -> Awaitable[R]:

If the instance (no longer) has this attribute, it was deleted and the
process is restarted by delegating to the descriptor.

"""
try:
return self._instance.__dict__[self._name]
Expand All @@ -116,12 +114,17 @@ async def _await_impl(self) -> R:
# the instance attribute is still this placeholder, and we
# hold the lock. Start the getter to store the value on the
# instance and return the value.
return await self._get_attribute(self._instance)
return await self._get_attribute()

# another task produced a value, or the instance.__dict__ object was
# deleted in the interim.
return await stored

async def _get_attribute(self) -> R:
value = await self._func(self._instance)
self._instance.__dict__[self._name] = AwaitableValue(value)
return value

def __repr__(self) -> str:
return (
f"<{type(self).__name__} for '{type(self._instance).__name__}."
Expand All @@ -135,9 +138,10 @@ def __init__(
getter: Callable[[T], Awaitable[R]],
asynccontextmanager_type: Type[AsyncContextManager[Any]] = nullcontext,
):
self.func = getter
self.func = self.__wrapped__ = getter
self.attrname = None
self.__doc__ = getter.__doc__
self.__module__ = getter.__module__
self._asynccontextmanager_type = asynccontextmanager_type

def __set_name__(self, owner: Any, name: str) -> None:
Expand Down Expand Up @@ -175,19 +179,12 @@ def __get__(
# on this instance. It takes care of coordinating between different
# tasks awaiting on the placeholder until the cached value has been
# produced.
wrapper = _FutureCachedValue(
self._get_attribute, instance, name, self._asynccontextmanager_type()
wrapper = _FutureCachedPropertyValue(
self.func, instance, name, self._asynccontextmanager_type()
)
cache[name] = wrapper
return wrapper

async def _get_attribute(self, instance: T) -> R:
value = await self.func(instance)
name = self.attrname
assert name is not None # enforced in __get__
instance.__dict__[name] = AwaitableValue(value)
return value


def cached_property(
type_or_getter: Union[Type[AsyncContextManager[Any]], Callable[[T], Awaitable[R]]],
Expand Down
23 changes: 18 additions & 5 deletions asyncstdlib/itertools.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,6 @@
zip,
enumerate as aenumerate,
iter as aiter,
tuple as atuple,
)

S = TypeVar("S")
Expand Down Expand Up @@ -122,17 +121,31 @@ async def accumulate(iterable, function, *, initial):
yield value


async def batched(iterable: AnyIterable[T], n: int) -> AsyncIterator[Tuple[T, ...]]:
async def batched(
iterable: AnyIterable[T], n: int, strict: bool = False
) -> AsyncIterator[Tuple[T, ...]]:
"""
Batch the ``iterable`` to tuples of the length ``n``.

This lazily exhausts ``iterable`` and returns each batch as soon as it's ready.
This lazily exhausts ``iterable`` and returns each batch as soon as it is ready.
If ``strict`` is :py:data:`True` and the last batch is smaller than ``n``,
:py:exc:`ValueError` is raised.
"""
if n < 1:
raise ValueError("n must be at least one")
async with ScopedIter(iterable) as item_iter:
while batch := await atuple(islice(_borrow(item_iter), n)):
yield batch
batch: list[T] = []
try:
while True:
batch.clear()
for _ in range(n):
batch.append(await anext(item_iter))
yield tuple(batch)
except StopAsyncIteration:
if batch:
if strict and len(batch) < n:
raise ValueError("batched(): incomplete batch") from None
yield tuple(batch)


class chain(AsyncIterator[T]):
Expand Down
20 changes: 13 additions & 7 deletions asyncstdlib/itertools.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -32,27 +32,33 @@ def accumulate(
initial: T1,
) -> AsyncIterator[T1]: ...
@overload
def batched(iterable: AnyIterable[T], n: Literal[1]) -> AsyncIterator[tuple[T]]: ...
def batched(
iterable: AnyIterable[T], n: Literal[1], strict: bool = ...
) -> AsyncIterator[tuple[T]]: ...
@overload
def batched(iterable: AnyIterable[T], n: Literal[2]) -> AsyncIterator[tuple[T, T]]: ...
def batched(
iterable: AnyIterable[T], n: Literal[2], strict: bool = ...
) -> AsyncIterator[tuple[T, T]]: ...
@overload
def batched(
iterable: AnyIterable[T], n: Literal[3]
iterable: AnyIterable[T], n: Literal[3], strict: bool = ...
) -> AsyncIterator[tuple[T, T, T]]: ...
@overload
def batched(
iterable: AnyIterable[T], n: Literal[4]
iterable: AnyIterable[T], n: Literal[4], strict: bool = ...
) -> AsyncIterator[tuple[T, T, T, T]]: ...
@overload
def batched(
iterable: AnyIterable[T], n: Literal[5]
iterable: AnyIterable[T], n: Literal[5], strict: bool = ...
) -> AsyncIterator[tuple[T, T, T, T, T]]: ...
@overload
def batched(
iterable: AnyIterable[T], n: Literal[6]
iterable: AnyIterable[T], n: Literal[6], strict: bool = ...
) -> AsyncIterator[tuple[T, T, T, T, T, T]]: ...
@overload
def batched(iterable: AnyIterable[T], n: int) -> AsyncIterator[tuple[T, ...]]: ...
def batched(
iterable: AnyIterable[T], n: int, strict: bool = ...
) -> AsyncIterator[tuple[T, ...]]: ...

class chain(AsyncIterator[T]):
__slots__: tuple[str, ...]
Expand Down
2 changes: 1 addition & 1 deletion docs/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
# -- Project information -----------------------------------------------------

project = "asyncstdlib"
author = "Max Fischer"
author = "Max Kühn"
copyright = f"2019-2024 {author}"

# The short X.Y version
Expand Down
6 changes: 5 additions & 1 deletion docs/source/api/itertools.rst
Original file line number Diff line number Diff line change
Expand Up @@ -86,11 +86,15 @@ Iterator splitting

.. versionadded:: 3.10.0

.. autofunction:: batched(iterable: (async) iter T, n: int)
.. autofunction:: batched(iterable: (async) iter T, n: int, strict: bool = False)
:async-for: :T

.. versionadded:: 3.11.0

.. versionadded:: 3.13.0

The ``strict`` parameter.

.. py:function:: groupby(iterable: (async) iter T)
:async-for: :(T, async iter T)
:noindex:
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ build-backend = "flit_core.buildapi"
dynamic = ["version", "description"]
name = "asyncstdlib"
authors = [
{name = "Max Fischer", email = "[email protected]"},
{name = "Max Kühn", email = "[email protected]"},
]
readme = "README.rst"
classifiers = [
Expand Down
13 changes: 13 additions & 0 deletions unittests/test_itertools.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,19 @@ async def test_batched_invalid(length):
await a.list(a.batched(range(10), length))


@sync
@pytest.mark.parametrize("values", ([1, 2, 3, 4], [1, 2, 3, 4, 5], [1]))
async def test_batched_strict(values: "list[int]"):
for n in range(1, len(values) + 1):
batches = a.batched(values, n, strict=True)
if len(values) % n == 0:
assert values == list(await a.reduce(lambda a, b: a + b, batches))
else:
assert await a.anext(batches)
with pytest.raises(ValueError):
await a.list(batches)


@sync
async def test_cycle():
async for _ in a.cycle([]):
Expand Down
Loading