Skip to content

Commit

Permalink
Improved typing for builtins (#127)
Browse files Browse the repository at this point in the history
* move typehints to stub

* improved typing
  • Loading branch information
maxfischer2781 authored Feb 15, 2024
1 parent 1e7d76a commit 91c1f80
Show file tree
Hide file tree
Showing 3 changed files with 244 additions and 277 deletions.
280 changes: 4 additions & 276 deletions asyncstdlib/builtins.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,11 +11,10 @@
Optional,
Dict,
Any,
overload,
)
import builtins as _sync_builtins

from ._typing import T, T1, T2, T3, T4, T5, R, HK, LT, ADD, AnyIterable
from ._typing import T, R, HK, LT, AnyIterable
from ._core import (
aiter,
ScopedIter,
Expand All @@ -27,14 +26,6 @@
__ANEXT_DEFAULT = Sentinel("<no default>")


@overload
async def anext(iterator: AsyncIterator[T]) -> T: ...


@overload
async def anext(iterator: AsyncIterator[T], default: T) -> T: ...


async def anext(
iterator: AsyncIterator[T], default: Union[Sentinel, T] = __ANEXT_DEFAULT
) -> T:
Expand Down Expand Up @@ -63,16 +54,6 @@ async def anext(
__ITER_DEFAULT = Sentinel("<no default>")


@overload
def iter(subject: AnyIterable[T]) -> AsyncIterator[T]:
pass


@overload
def iter(subject: Callable[[], Awaitable[T]], sentinel: T) -> AsyncIterator[T]:
pass


def iter(
subject: Union[AnyIterable[T], Callable[[], Awaitable[T]]],
sentinel: Union[Sentinel, T] = __ITER_DEFAULT,
Expand Down Expand Up @@ -116,7 +97,7 @@ async def acallable_iterator(
value = await subject()


async def all(iterable: AnyIterable[T]) -> bool:
async def all(iterable: AnyIterable[Any]) -> bool:
"""
Return :py:data:`True` if none of the elements of the (async) ``iterable`` are false
"""
Expand All @@ -127,7 +108,7 @@ async def all(iterable: AnyIterable[T]) -> bool:
return True


async def any(iterable: AnyIterable[T]) -> bool:
async def any(iterable: AnyIterable[Any]) -> bool:
"""
Return :py:data:`False` if none of the elements of the (async) ``iterable`` are true
"""
Expand All @@ -138,68 +119,6 @@ async def any(iterable: AnyIterable[T]) -> bool:
return False


@overload
def zip(
__it1: AnyIterable[T1],
*,
strict: bool = ...,
) -> AsyncIterator[Tuple[T1]]: ...


@overload
def zip(
__it1: AnyIterable[T1],
__it2: AnyIterable[T2],
*,
strict: bool = ...,
) -> AsyncIterator[Tuple[T1, T2]]: ...


@overload
def zip(
__it1: AnyIterable[T1],
__it2: AnyIterable[T2],
__it3: AnyIterable[T3],
*,
strict: bool = ...,
) -> AsyncIterator[Tuple[T1, T2, T3]]: ...


@overload
def zip(
__it1: AnyIterable[T1],
__it2: AnyIterable[T2],
__it3: AnyIterable[T3],
__it4: AnyIterable[T4],
*,
strict: bool = ...,
) -> AsyncIterator[Tuple[T1, T2, T3, T4]]: ...


@overload
def zip(
__it1: AnyIterable[T1],
__it2: AnyIterable[T2],
__it3: AnyIterable[T3],
__it4: AnyIterable[T4],
__it5: AnyIterable[T5],
*,
strict: bool = ...,
) -> AsyncIterator[Tuple[T1, T2, T3, T4, T5]]: ...


@overload
def zip(
__it1: AnyIterable[Any],
__it2: AnyIterable[Any],
__it3: AnyIterable[Any],
__it4: AnyIterable[Any],
__it5: AnyIterable[Any],
*iterables: AnyIterable[Any],
strict: bool = ...,
) -> AsyncIterator[Tuple[Any, ...]]: ...


async def zip(
*iterables: AnyIterable[Any], strict: bool = False
) -> AsyncIterator[Tuple[Any, ...]]:
Expand Down Expand Up @@ -285,118 +204,6 @@ async def _zip_inner_strict(
return


@overload
def map(
function: Callable[[T1], Awaitable[R]],
__it1: AnyIterable[T1],
) -> AsyncIterator[R]: ...


@overload
def map(
function: Callable[[T1], R],
__it1: AnyIterable[T1],
) -> AsyncIterator[R]: ...


@overload
def map(
function: Callable[[T1, T2], Awaitable[R]],
__it1: AnyIterable[T1],
__it2: AnyIterable[T2],
) -> AsyncIterator[R]: ...


@overload
def map(
function: Callable[[T1, T2], R], __it1: AnyIterable[T1], __it2: AnyIterable[T2]
) -> AsyncIterator[R]: ...


@overload
def map(
function: Callable[[T1, T2, T3], Awaitable[R]],
__it1: AnyIterable[T1],
__it2: AnyIterable[T2],
__it3: AnyIterable[T3],
) -> AsyncIterator[R]: ...


@overload
def map(
function: Callable[[T1, T2, T3], R],
__it1: AnyIterable[T1],
__it2: AnyIterable[T2],
__it3: AnyIterable[T3],
) -> AsyncIterator[R]: ...


@overload
def map(
function: Callable[[T1, T2, T3, T4], Awaitable[R]],
__it1: AnyIterable[T1],
__it2: AnyIterable[T2],
__it3: AnyIterable[T3],
__it4: AnyIterable[T4],
) -> AsyncIterator[R]: ...


@overload
def map(
function: Callable[[T1, T2, T3, T4], R],
__it1: AnyIterable[T1],
__it2: AnyIterable[T2],
__it3: AnyIterable[T3],
__it4: AnyIterable[T4],
) -> AsyncIterator[R]: ...


@overload
def map(
function: Callable[[T1, T2, T3, T4, T5], Awaitable[R]],
__it1: AnyIterable[T1],
__it2: AnyIterable[T2],
__it3: AnyIterable[T3],
__it4: AnyIterable[T4],
__it5: AnyIterable[T5],
) -> AsyncIterator[R]: ...


@overload
def map(
function: Callable[[T1, T2, T3, T4, T5], R],
__it1: AnyIterable[T1],
__it2: AnyIterable[T2],
__it3: AnyIterable[T3],
__it4: AnyIterable[T4],
__it5: AnyIterable[T5],
) -> AsyncIterator[R]: ...


@overload
def map(
function: Callable[..., Awaitable[R]],
__it1: AnyIterable[Any],
__it2: AnyIterable[Any],
__it3: AnyIterable[Any],
__it4: AnyIterable[Any],
__it5: AnyIterable[Any],
*iterable: AnyIterable[Any],
) -> AsyncIterator[R]: ...


@overload
def map(
function: Callable[..., R],
__it1: AnyIterable[Any],
__it2: AnyIterable[Any],
__it3: AnyIterable[Any],
__it4: AnyIterable[Any],
__it5: AnyIterable[Any],
*iterable: AnyIterable[Any],
) -> AsyncIterator[R]: ...


async def map(
function: Union[Callable[..., R], Callable[..., Awaitable[R]]],
*iterable: AnyIterable[Any],
Expand Down Expand Up @@ -428,26 +235,6 @@ async def map(
__MIN_MAX_DEFAULT = Sentinel("<no default>")


@overload
async def max(iterable: AnyIterable[LT], *, key: None = ...) -> LT: ...


@overload
async def max(
iterable: AnyIterable[LT], *, key: None = ..., default: T
) -> Union[LT, T]: ...


@overload
async def max(iterable: AnyIterable[T1], *, key: Callable[[T1], LT] = ...) -> T1: ...


@overload
async def max(
iterable: AnyIterable[T1], *, key: Callable[[T1], LT] = ..., default: T2
) -> Union[T1, T2]: ...


async def max(
iterable: AnyIterable[Any],
*,
Expand All @@ -474,26 +261,6 @@ async def max(
return await _min_max(iterable, key, True, default)


@overload
async def min(iterable: AnyIterable[LT], *, key: None = ...) -> LT: ...


@overload
async def min(
iterable: AnyIterable[LT], *, key: None = ..., default: T
) -> Union[LT, T]: ...


@overload
async def min(iterable: AnyIterable[T1], *, key: Callable[[T1], LT] = ...) -> T1: ...


@overload
async def min(
iterable: AnyIterable[T1], *, key: Callable[[T1], LT] = ..., default: T2
) -> Union[T1, T2]: ...


async def min(
iterable: AnyIterable[Any],
*,
Expand Down Expand Up @@ -594,18 +361,6 @@ async def enumerate(
count += 1


@overload
async def sum(iterable: AnyIterable[int]) -> int: ...


@overload
async def sum(iterable: AnyIterable[float]) -> float: ...


@overload
async def sum(iterable: AnyIterable[ADD], start: ADD) -> ADD: ...


async def sum(iterable: AnyIterable[Any], start: Any = 0) -> Any:
"""
Sum of ``start`` and all elements in the (async) iterable
Expand All @@ -632,21 +387,6 @@ async def tuple(iterable: Union[Iterable[T], AsyncIterable[T]] = ()) -> Tuple[T,
return (*[element async for element in aiter(iterable)],)


@overload
async def dict( # noqa: F811
iterable: Union[Iterable[Tuple[HK, T]], AsyncIterable[Tuple[HK, T]]] = (),
) -> Dict[HK, T]:
pass


@overload # noqa: F811
async def dict( # noqa: F811
iterable: Union[Iterable[Tuple[HK, T]], AsyncIterable[Tuple[HK, T]]] = (),
**kwargs: T,
) -> Dict[Union[HK, str], T]:
pass


async def dict( # noqa: F811
iterable: Union[Iterable[Tuple[HK, T]], AsyncIterable[Tuple[HK, T]]] = (),
**kwargs: T,
Expand Down Expand Up @@ -674,18 +414,6 @@ async def set(iterable: Union[Iterable[T], AsyncIterable[T]] = ()) -> Set[T]:
return {element async for element in aiter(iterable)}


@overload
async def sorted(
iterable: AnyIterable[LT], *, key: None = ..., reverse: bool = ...
) -> List[LT]: ...


@overload
async def sorted(
iterable: AnyIterable[T], *, key: Callable[[T], LT], reverse: bool = ...
) -> List[T]: ...


async def sorted(
iterable: AnyIterable[T],
*,
Expand Down Expand Up @@ -716,7 +444,7 @@ async def sorted(
try:
return _sync_builtins.sorted(iterable, reverse=reverse) # type: ignore
except TypeError:
items = [item async for item in aiter(iterable)]
items: "_sync_builtins.list[Any]" = [item async for item in aiter(iterable)]
items.sort(reverse=reverse)
return items
else:
Expand Down
Loading

0 comments on commit 91c1f80

Please sign in to comment.