Skip to content

Commit

Permalink
Tidy StreamedStr type hints. Improve README example. (#7)
Browse files Browse the repository at this point in the history
  • Loading branch information
jackmpcollins authored Sep 6, 2023
1 parent 49b1aed commit 08ad9c0
Show file tree
Hide file tree
Showing 3 changed files with 66 additions and 33 deletions.
36 changes: 31 additions & 5 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -130,7 +130,7 @@ See the [examples directory](examples/) for more.

### Streaming

The `StreamedStr` (and `AsyncStreamedStr`) class can be used to stream the output of the LLM. This allows you to process the text while it is being generated, rather than receiving the whole output at once. Multiple `StreamedStr` can be created at the same time to stream LLM outputs concurrently. In the below example, generating the description for multiple countries takes approximately the same amount of time as for a single country.
The `StreamedStr` (and `AsyncStreamedStr`) class can be used to stream the output of the LLM. This allows you to process the text while it is being generated, rather than receiving the whole output at once.

```python
from magentic import prompt, StreamedStr
Expand All @@ -145,12 +145,38 @@ def describe_country(country: str) -> StreamedStr:
for chunk in describe_country("Brazil"):
print(chunk, end="")
# 'Brazil, officially known as the Federative Republic of Brazil, is ...'
```

Multiple `StreamedStr` can be created at the same time to stream LLM outputs concurrently. In the below example, generating the description for multiple countries takes approximately the same amount of time as for a single country.

```python
from time import time

countries = ["Australia", "Brazil", "Chile"]


# Generate the descriptions one at a time
start_time = time()
for country in countries:
# Converting `StreamedStr` to `str` blocks until the LLM output is fully generated
description = str(describe_country(country))
print(f"{time() - start_time:.2f}s : {country} - {len(description)} chars")

# 22.72s : Australia - 2130 chars
# 41.63s : Brazil - 1884 chars
# 74.31s : Chile - 2968 chars


# Generate the descriptions concurrently by creating the StreamedStrs at the same time
start_time = time()
streamed_strs = [describe_country(country) for country in countries]
for country, streamed_str in zip(countries, streamed_strs):
description = str(streamed_str)
print(f"{time() - start_time:.2f}s : {country} - {len(description)} chars")

# Generate text concurrently by creating the streams before consuming them
streamed_strs = [describe_country(c) for c in ["Australia", "Brazil", "Chile"]]
[str(s) for s in streamed_strs]
# ["Australia is a country ...", "Brazil, officially known as ...", "Chile, officially known as ..."]
# 22.79s : Australia - 2147 chars
# 23.64s : Brazil - 2202 chars
# 24.67s : Chile - 2186 chars
```

### Additional Features
Expand Down
27 changes: 18 additions & 9 deletions src/magentic/streamed_str.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,25 @@
from typing import AsyncIterator, Iterator
from collections.abc import AsyncIterable, Iterable
from typing import AsyncIterator, Iterator, TypeVar

T = TypeVar("T")

class StreamedStr:

async def async_iter(iterable: Iterable[T]) -> AsyncIterator[T]:
"""Get an AsyncIterator for an Iterable."""
for item in iterable:
yield item


class StreamedStr(Iterable[str]):
"""A string that is generated in chunks."""

def __init__(self, generator: Iterator[str]):
self._generator = generator
def __init__(self, chunks: Iterable[str]):
self._chunks = chunks
self._cached_chunks: list[str] = []

def __iter__(self) -> Iterator[str]:
yield from self._cached_chunks
for chunk in self._generator:
for chunk in self._chunks:
self._cached_chunks.append(chunk)
yield chunk

Expand All @@ -22,19 +31,19 @@ def to_string(self) -> str:
return str(self)


class AsyncStreamedStr:
class AsyncStreamedStr(AsyncIterable[str]):
"""Async version of `StreamedStr`."""

def __init__(self, generator: AsyncIterator[str]):
self._generator = generator
def __init__(self, chunks: AsyncIterable[str]):
self._chunks = chunks
self._cached_chunks: list[str] = []

async def __aiter__(self) -> AsyncIterator[str]:
# Cannot use `yield from` inside an async function
# https://peps.python.org/pep-0525/#asynchronous-yield-from
for chunk in self._cached_chunks:
yield chunk
async for chunk in self._generator:
async for chunk in self._chunks:
self._cached_chunks.append(chunk)
yield chunk

Expand Down
36 changes: 17 additions & 19 deletions tests/test_streamed_str.py
Original file line number Diff line number Diff line change
@@ -1,43 +1,41 @@
from typing import AsyncIterator, Iterator
from typing import AsyncIterator

import pytest

from magentic import AsyncStreamedStr, StreamedStr
from magentic.streamed_str import async_iter


def test_streamed_str_iter():
def generator() -> Iterator[str]:
yield from ["Hello", " World"]
@pytest.mark.asyncio
async def test_async_iter():
output = async_iter(["Hello", " World"])
assert isinstance(output, AsyncIterator)
assert [chunk async for chunk in output] == ["Hello", " World"]

streamed_str = StreamedStr(generator())

def test_streamed_str_iter():
iter_chunks = iter(["Hello", " World"])
streamed_str = StreamedStr(iter_chunks)
assert list(streamed_str) == ["Hello", " World"]
assert list(iter_chunks) == [] # iterator is exhausted
assert list(streamed_str) == ["Hello", " World"]


def test_streamed_str_str():
def generator() -> Iterator[str]:
yield from ["Hello", " World"]

streamed_str = StreamedStr(generator())
streamed_str = StreamedStr(["Hello", " World"])
assert str(streamed_str) == "Hello World"


@pytest.mark.asyncio
async def test_async_streamed_str_iter():
async def generator() -> AsyncIterator[str]:
for chunk in ["Hello", " World"]:
yield chunk

async_streamed_str = AsyncStreamedStr(generator())
aiter_chunks = async_iter(["Hello", " World"])
async_streamed_str = AsyncStreamedStr(aiter_chunks)
assert [chunk async for chunk in async_streamed_str] == ["Hello", " World"]
assert [chunk async for chunk in aiter_chunks] == [] # iterator is exhausted
assert [chunk async for chunk in async_streamed_str] == ["Hello", " World"]


@pytest.mark.asyncio
async def test_async_streamed_str_to_string():
async def generator() -> AsyncIterator[str]:
for chunk in ["Hello", " World"]:
yield chunk

async_streamed_str = AsyncStreamedStr(generator())
async_streamed_str = AsyncStreamedStr(async_iter(["Hello", " World"]))
assert await async_streamed_str.to_string() == "Hello World"

0 comments on commit 08ad9c0

Please sign in to comment.