Skip to content

Commit

Permalink
Add OperatorType and PipableOperatorType protocols
Browse files Browse the repository at this point in the history
  • Loading branch information
vxgmichel committed Jul 23, 2023
1 parent a505abe commit b2e70bf
Show file tree
Hide file tree
Showing 12 changed files with 424 additions and 262 deletions.
398 changes: 232 additions & 166 deletions aiostream/core.py

Large diffs are not rendered by default.

37 changes: 18 additions & 19 deletions aiostream/stream/advanced.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,12 @@
"""Advanced operators (to deal with streams of higher order) ."""
from __future__ import annotations

from typing import AsyncIterator, AsyncIterable, TypeVar, Union, Callable
from typing import AsyncIterator, AsyncIterable, TypeVar, Union
from typing_extensions import ParamSpec

from . import combine

from ..core import operator, Streamer
from ..core import Streamer, pipable_operator
from ..manager import StreamerManager


Expand All @@ -21,7 +21,7 @@
# Helper to manage stream of higher order


@operator(pipable=True)
@pipable_operator
async def base_combine(
source: AsyncIterable[AsyncIterable[T]],
switch: bool = False,
Expand Down Expand Up @@ -113,7 +113,7 @@ async def base_combine(
# Advanced operators (for streams of higher order)


@operator(pipable=True)
@pipable_operator
def concat(
source: AsyncIterable[AsyncIterable[T]], task_limit: int | None = None
) -> AsyncIterator[T]:
Expand All @@ -128,7 +128,7 @@ def concat(
return base_combine.raw(source, task_limit=task_limit, switch=False, ordered=True)


@operator(pipable=True)
@pipable_operator
def flatten(
source: AsyncIterable[AsyncIterable[T]], task_limit: int | None = None
) -> AsyncIterator[T]:
Expand All @@ -143,7 +143,7 @@ def flatten(
return base_combine.raw(source, task_limit=task_limit, switch=False, ordered=False)


@operator(pipable=True)
@pipable_operator
def switch(source: AsyncIterable[AsyncIterable[T]]) -> AsyncIterator[T]:
"""Given an asynchronous sequence of sequences, generate the elements of
the most recent sequence.
Expand All @@ -161,10 +161,10 @@ def switch(source: AsyncIterable[AsyncIterable[T]]) -> AsyncIterator[T]:
# Advanced *-map operators


@operator(pipable=True)
@pipable_operator
def concatmap(
source: AsyncIterable[T],
func: Callable[P, AsyncIterable[U]],
func: combine.SmapCallable[T, AsyncIterable[U]],
*more_sources: AsyncIterable[T],
task_limit: int | None = None,
) -> AsyncIterator[U]:
Expand All @@ -177,15 +177,14 @@ def concatmap(
although it's possible to limit the amount of running sequences using
the `task_limit` argument.
"""
return concat.raw(
combine.smap.raw(source, func, *more_sources), task_limit=task_limit
)
mapped = combine.smap.raw(source, func, *more_sources)
return concat.raw(mapped, task_limit=task_limit)


@operator(pipable=True)
@pipable_operator
def flatmap(
source: AsyncIterable[T],
func: Callable[P, AsyncIterable[U]],
func: combine.SmapCallable[T, AsyncIterable[U]],
*more_sources: AsyncIterable[T],
task_limit: int | None = None,
) -> AsyncIterator[U]:
Expand All @@ -200,15 +199,14 @@ def flatmap(
Errors raised in a source or output sequence are propagated.
"""
return flatten.raw(
combine.smap.raw(source, func, *more_sources), task_limit=task_limit
)
mapped = combine.smap.raw(source, func, *more_sources)
return flatten.raw(mapped, task_limit=task_limit)


@operator(pipable=True)
@pipable_operator
def switchmap(
source: AsyncIterable[T],
func: Callable[P, AsyncIterable[U]],
func: combine.SmapCallable[T, AsyncIterable[U]],
*more_sources: AsyncIterable[T],
) -> AsyncIterator[U]:
"""Apply a given function that creates a sequence from the elements of one
Expand All @@ -219,4 +217,5 @@ def switchmap(
asynchronous sequence. Errors raised in a source or output sequence (that
was not already closed) are propagated.
"""
return switch.raw(combine.smap.raw(source, func, *more_sources))
mapped = combine.smap.raw(source, func, *more_sources)
return switch.raw(mapped)
8 changes: 4 additions & 4 deletions aiostream/stream/aggregate.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,14 +9,14 @@

from . import select
from ..aiter_utils import anext
from ..core import operator, streamcontext
from ..core import pipable_operator, streamcontext

__all__ = ["accumulate", "reduce", "list"]

T = TypeVar("T")


@operator(pipable=True)
@pipable_operator
async def accumulate(
source: AsyncIterable[T],
func: Callable[[T, T], Awaitable[T] | T] = op.add,
Expand Down Expand Up @@ -52,7 +52,7 @@ async def accumulate(
yield value


@operator(pipable=True)
@pipable_operator
def reduce(
source: AsyncIterable[T],
func: Callable[[T, T], Awaitable[T] | T],
Expand All @@ -69,7 +69,7 @@ def reduce(
return select.item.raw(acc, -1)


@operator(pipable=True)
@pipable_operator
async def list(source: AsyncIterable[T]) -> AsyncIterator[builtins.list[T]]:
"""Build a list from an asynchronous sequence.
Expand Down
102 changes: 75 additions & 27 deletions aiostream/stream/combine.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,19 @@
import asyncio
import builtins

from typing import Awaitable, TypeVar, AsyncIterable, AsyncIterator, Callable
from typing import (
Awaitable,
Protocol,
TypeVar,
AsyncIterable,
AsyncIterator,
Callable,
cast,
)
from typing_extensions import ParamSpec

from ..aiter_utils import AsyncExitStack, anext
from ..core import operator, streamcontext
from ..core import streamcontext, pipable_operator

from . import create
from . import select
Expand All @@ -23,22 +31,27 @@
P = ParamSpec("P")


@operator(pipable=True)
async def chain(*sources: AsyncIterable[T]) -> AsyncIterator[T]:
@pipable_operator
async def chain(
source: AsyncIterable[T], *more_sources: AsyncIterable[T]
) -> AsyncIterator[T]:
"""Chain asynchronous sequences together, in the order they are given.
Note: the sequences are not iterated until it is required,
so if the operation is interrupted, the remaining sequences
will be left untouched.
"""
sources = source, *more_sources
for source in sources:
async with streamcontext(source) as streamer:
async for item in streamer:
yield item


@operator(pipable=True)
async def zip(*sources: AsyncIterable[T]) -> AsyncIterator[tuple[T, ...]]:
@pipable_operator
async def zip(
source: AsyncIterable[T], *more_sources: AsyncIterable[T]
) -> AsyncIterator[tuple[T, ...]]:
"""Combine and forward the elements of several asynchronous sequences.
Each generated value is a tuple of elements, using the same order as
Expand All @@ -48,9 +61,7 @@ async def zip(*sources: AsyncIterable[T]) -> AsyncIterator[tuple[T, ...]]:
Note: the different sequences are awaited in parrallel, so that their
waiting times don't add up.
"""
# Zero sources
if len(sources) == 0:
return
sources = source, *more_sources

# One sources
if len(sources) == 1:
Expand All @@ -77,9 +88,30 @@ async def zip(*sources: AsyncIterable[T]) -> AsyncIterator[tuple[T, ...]]:
yield tuple(items)


@operator(pipable=True)
X = TypeVar("X", contravariant=True)
Y = TypeVar("Y", covariant=True)


class SmapCallable(Protocol[X, Y]):
def __call__(self, arg: X, /, *args: X) -> Y:
...


class AmapCallable(Protocol[X, Y]):
def __call__(self, arg: X, /, *args: X) -> Awaitable[Y]:
...


class MapCallable(Protocol[X, Y]):
def __call__(self, arg: X, /, *args: X) -> Awaitable[Y] | Y:
...


@pipable_operator
async def smap(
source: AsyncIterable[T], func: Callable[..., U], *more_sources: AsyncIterable[T]
source: AsyncIterable[T],
func: SmapCallable[T, U],
*more_sources: AsyncIterable[T],
) -> AsyncIterator[U]:
"""Apply a given function to the elements of one or several
asynchronous sequences.
Expand All @@ -97,10 +129,10 @@ async def smap(
yield func(*item)


@operator(pipable=True)
@pipable_operator
def amap(
source: AsyncIterable[T],
corofn: Callable[P, Awaitable[U]],
corofn: AmapCallable[T, U],
*more_sources: AsyncIterable[T],
ordered: bool = True,
task_limit: int | None = None,
Expand All @@ -123,8 +155,8 @@ def amap(
so that their waiting times don't add up.
"""

def func(*args: P.args, **kwargs: P.kwargs) -> AsyncIterable[U]:
return create.call(corofn, *args, **kwargs) # type: ignore
def func(arg: T, *args: T) -> AsyncIterable[U]:
return create.call(corofn, arg, *args) # type: ignore

if ordered:
return advanced.concatmap.raw(
Expand All @@ -133,10 +165,10 @@ def func(*args: P.args, **kwargs: P.kwargs) -> AsyncIterable[U]:
return advanced.flatmap.raw(source, func, *more_sources, task_limit=task_limit)


@operator(pipable=True)
@pipable_operator
def map(
source: AsyncIterable[T],
func: Callable[P, Awaitable[U] | U],
func: MapCallable[T, U],
*more_sources: AsyncIterable[T],
ordered: bool = True,
task_limit: int | None = None,
Expand Down Expand Up @@ -175,23 +207,31 @@ def map(
return amap.raw(
source, func, *more_sources, ordered=ordered, task_limit=task_limit
)
return smap.raw(source, func, *more_sources) # type: ignore
sync_func = cast("SmapCallable[T, U]", func)
return smap.raw(source, sync_func, *more_sources)


@operator(pipable=True)
def merge(*sources: AsyncIterable[T]) -> AsyncIterator[T]:
@pipable_operator
def merge(
source: AsyncIterable[T], *more_sources: AsyncIterable[T]
) -> AsyncIterator[T]:
"""Merge several asynchronous sequences together.
All the sequences are iterated simultaneously and their elements
are forwarded as soon as they're available. The generation continues
until all the sequences are exhausted.
"""
return advanced.flatten.raw(create.iterate.raw(sources))
sources = [source, *more_sources]
source_stream: AsyncIterable[AsyncIterable[T]] = create.iterate.raw(sources)
return advanced.flatten.raw(source_stream)


@operator(pipable=True)
@pipable_operator
def ziplatest(
*sources: AsyncIterable[T], partial: bool = True, default: T | None = None
source: AsyncIterable[T],
*more_sources: AsyncIterable[T],
partial: bool = True,
default: T | None = None,
) -> AsyncIterator[tuple[T | None, ...]]:
"""Combine several asynchronous sequences together, producing a tuple with
the lastest element of each sequence whenever a new element is received.
Expand All @@ -206,16 +246,21 @@ def ziplatest(
are forwarded as soon as they're available. The generation continues
until all the sequences are exhausted.
"""
sources = source, *more_sources
n = len(sources)

# Custom getter
def getter(dct: dict[int, T]) -> Callable[[int], T | None]:
return lambda key: dct.get(key, default)

# Add source index to the items
new_sources = [
smap.raw(source, lambda x, i=i: {i: x}) for i, source in enumerate(sources)
]
def make_func(i: int) -> SmapCallable[T, dict[int, T]]:
def func(x: T, *_: object) -> dict[int, T]:
return {i: x}

return func

new_sources = [smap.raw(source, make_func(i)) for i, source in enumerate(sources)]

# Merge the sources
merged = merge.raw(*new_sources)
Expand All @@ -231,4 +276,7 @@ def getter(dct: dict[int, T]) -> Callable[[int], T | None]:
)

# Convert the state dict to a tuple
return smap.raw(filtered, lambda x: tuple(builtins.map(getter(x), range(n))))
def dict_to_tuple(x: dict[int, T], *_: object) -> tuple[T | None, ...]:
return tuple(builtins.map(getter(x), range(n)))

return smap.raw(filtered, dict_to_tuple)
Loading

0 comments on commit b2e70bf

Please sign in to comment.