diff --git a/src/anthropic/lib/streaming/_messages.py b/src/anthropic/lib/streaming/_messages.py index d39993a8..590a99d5 100644 --- a/src/anthropic/lib/streaming/_messages.py +++ b/src/anthropic/lib/streaming/_messages.py @@ -18,9 +18,6 @@ from ..._models import build, construct_type from ..._streaming import Stream, AsyncStream -if TYPE_CHECKING: - from ..._client import Anthropic, AsyncAnthropic - class MessageStream: text_stream: Iterator[str] @@ -33,24 +30,15 @@ class MessageStream: ``` """ - response: httpx.Response - - def __init__( - self, - *, - cast_to: type[RawMessageStreamEvent], - response: httpx.Response, - client: Anthropic, - ) -> None: - self.response = response - self._cast_to = cast_to - self._client = client - + def __init__(self, raw_stream: Stream[RawMessageStreamEvent]) -> None: + self._raw_stream = raw_stream self.text_stream = self.__stream_text__() + self._iterator = self.__stream__() self.__final_message_snapshot: Message | None = None - self._iterator = self.__stream__() - self._raw_stream: Stream[RawMessageStreamEvent] = Stream(cast_to=cast_to, response=response, client=client) + @property + def response(self) -> httpx.Response: + return self._raw_stream.response def __next__(self) -> MessageStreamEvent: return self._iterator.__next__() @@ -76,7 +64,7 @@ def close(self) -> None: Automatically called if the response body is read to completion. """ - self.response.close() + self._raw_stream.close() def get_final_message(self) -> Message: """Waits until the stream has been read to completion and returns @@ -151,13 +139,7 @@ def __init__( def __enter__(self) -> MessageStream: raw_stream = self.__api_request() - - self.__stream = MessageStream( - cast_to=raw_stream._cast_to, - response=raw_stream.response, - client=raw_stream._client, - ) - + self.__stream = MessageStream(raw_stream) return self.__stream def __exit__( @@ -181,26 +163,15 @@ class AsyncMessageStream: ``` """ - response: httpx.Response - - def __init__( - self, - *, - cast_to: type[RawMessageStreamEvent], - response: httpx.Response, - client: AsyncAnthropic, - ) -> None: - self.response = response - self._cast_to = cast_to - self._client = client - + def __init__(self, raw_stream: AsyncStream[RawMessageStreamEvent]) -> None: + self._raw_stream = raw_stream self.text_stream = self.__stream_text__() + self._iterator = self.__stream__() self.__final_message_snapshot: Message | None = None - self._iterator = self.__stream__() - self._raw_stream: AsyncStream[RawMessageStreamEvent] = AsyncStream( - cast_to=cast_to, response=response, client=client - ) + @property + def response(self) -> httpx.Response: + return self._raw_stream.response async def __anext__(self) -> MessageStreamEvent: return await self._iterator.__anext__() @@ -226,7 +197,7 @@ async def close(self) -> None: Automatically called if the response body is read to completion. """ - await self.response.aclose() + await self._raw_stream.close() async def get_final_message(self) -> Message: """Waits until the stream has been read to completion and returns @@ -303,13 +274,7 @@ def __init__( async def __aenter__(self) -> AsyncMessageStream: raw_stream = await self.__api_request - - self.__stream = AsyncMessageStream( - cast_to=raw_stream._cast_to, - response=raw_stream.response, - client=raw_stream._client, - ) - + self.__stream = AsyncMessageStream(raw_stream) return self.__stream async def __aexit__(