diff --git a/aiostream/aiter_utils.py b/aiostream/aiter_utils.py index 03189be..283e1ca 100644 --- a/aiostream/aiter_utils.py +++ b/aiostream/aiter_utils.py @@ -64,11 +64,11 @@ def anext(obj: AsyncIterator[T]) -> Awaitable[T]: @overload -def anext(obj: AsyncIterator[T], default: U) -> Awaitable[T] | U: +def anext(obj: AsyncIterator[T], default: U) -> Awaitable[T | U]: pass -def anext(obj: AsyncIterator[T], default: Any = UNSET) -> Awaitable[T] | Any: +def anext(obj: AsyncIterator[T], default: Any = UNSET) -> Awaitable[T | Any]: """Access anext magic method.""" assert_async_iterator(obj) if default is UNSET: diff --git a/aiostream/stream/combine.py b/aiostream/stream/combine.py index 43e1313..cc266a4 100644 --- a/aiostream/stream/combine.py +++ b/aiostream/stream/combine.py @@ -81,20 +81,23 @@ async def zip( # Loop over items _StopSentinelType = enum.Enum("_StopSentinelType", "STOP_SENTINEL") STOP_SENTINEL = _StopSentinelType.STOP_SENTINEL + items: list[T] while True: - coros = ( - anext(streamer, STOP_SENTINEL) if strict else anext(streamer) - for streamer in streamers - ) - try: - items = await asyncio.gather(*coros) - except StopAsyncIteration: # can only happen in non-strict mode - break if strict: - if all(item == STOP_SENTINEL for item in items): + coros = (anext(streamer, STOP_SENTINEL) for streamer in streamers) + _items = await asyncio.gather(*coros) + if all(item == STOP_SENTINEL for item in _items): break - elif any(item == STOP_SENTINEL for item in items): + elif any(item == STOP_SENTINEL for item in _items): raise ValueError("iterables have different lengths") + # This holds because we've ruled out STOP_SENTINEL above: + items = cast(list[T], _items) + else: + coros = (anext(streamer) for streamer in streamers) + try: + items = await asyncio.gather(*coros) + except StopAsyncIteration: + break yield tuple(items)