From 7b04cda582b32cba34a7bfb0eb1dd535f0ea88a5 Mon Sep 17 00:00:00 2001 From: Tom Christie Date: Mon, 12 Feb 2024 11:20:30 +0000 Subject: [PATCH] Safe async cancellations. (#880) * Connection pool work * Connection pool work * Connection pool work * Connection pool work * Comments * Comments * Connection pool work * Reraise * Lookin sharp * nocover directive * Safe cancellations * Update CHANGELOG --- CHANGELOG.md | 1 + httpcore/_async/connection.py | 16 +- httpcore/_async/connection_pool.py | 364 ++++++++++++++------------- httpcore/_sync/connection.py | 16 +- httpcore/_sync/connection_pool.py | 364 ++++++++++++++------------- httpcore/_synchronization.py | 58 +++++ requirements.txt | 2 +- tests/_async/test_connection_pool.py | 35 ++- tests/_sync/test_connection_pool.py | 35 ++- 9 files changed, 512 insertions(+), 379 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 061358f4..3d537c8e 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -6,6 +6,7 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/). ## Unreleased +- Fix support for async cancellations. (#880) - Fix trace extension when used with socks proxy. (#849) - Fix SSL context for connections using the "wss" scheme (#869) diff --git a/httpcore/_async/connection.py b/httpcore/_async/connection.py index 3aeb8ed9..2f439cf0 100644 --- a/httpcore/_async/connection.py +++ b/httpcore/_async/connection.py @@ -6,7 +6,7 @@ from .._backends.auto import AutoBackend from .._backends.base import SOCKET_OPTION, AsyncNetworkBackend, AsyncNetworkStream -from .._exceptions import ConnectError, ConnectionNotAvailable, ConnectTimeout +from .._exceptions import ConnectError, ConnectTimeout from .._models import Origin, Request, Response from .._ssl import default_ssl_context from .._synchronization import AsyncLock @@ -70,9 +70,9 @@ async def handle_async_request(self, request: Request) -> Response: f"Attempted to send request to {request.url.origin} on connection to {self._origin}" ) - async with self._request_lock: - if self._connection is None: - try: + try: + async with self._request_lock: + if self._connection is None: stream = await self._connect(request) ssl_object = stream.get_extra_info("ssl_object") @@ -94,11 +94,9 @@ async def handle_async_request(self, request: Request) -> Response: stream=stream, keepalive_expiry=self._keepalive_expiry, ) - except Exception as exc: - self._connect_failed = True - raise exc - elif not self._connection.is_available(): - raise ConnectionNotAvailable() + except BaseException as exc: + self._connect_failed = True + raise exc return await self._connection.handle_async_request(request) diff --git a/httpcore/_async/connection_pool.py b/httpcore/_async/connection_pool.py index 0320c6d8..018b0ba2 100644 --- a/httpcore/_async/connection_pool.py +++ b/httpcore/_async/connection_pool.py @@ -1,31 +1,30 @@ import ssl import sys -import time from types import TracebackType from typing import AsyncIterable, AsyncIterator, Iterable, List, Optional, Type from .._backends.auto import AutoBackend from .._backends.base import SOCKET_OPTION, AsyncNetworkBackend -from .._exceptions import ConnectionNotAvailable, PoolTimeout, UnsupportedProtocol +from .._exceptions import ConnectionNotAvailable, UnsupportedProtocol from .._models import Origin, Request, Response -from .._synchronization import AsyncEvent, AsyncLock, AsyncShieldCancellation +from .._synchronization import AsyncEvent, AsyncShieldCancellation, AsyncThreadLock from .connection import AsyncHTTPConnection from .interfaces import AsyncConnectionInterface, AsyncRequestInterface -class RequestStatus: - def __init__(self, request: Request): +class AsyncPoolRequest: + def __init__(self, request: Request) -> None: self.request = request self.connection: Optional[AsyncConnectionInterface] = None self._connection_acquired = AsyncEvent() - def set_connection(self, connection: AsyncConnectionInterface) -> None: - assert self.connection is None + def assign_to_connection( + self, connection: Optional[AsyncConnectionInterface] + ) -> None: self.connection = connection self._connection_acquired.set() - def unset_connection(self) -> None: - assert self.connection is not None + def clear_connection(self) -> None: self.connection = None self._connection_acquired = AsyncEvent() @@ -37,6 +36,9 @@ async def wait_for_connection( assert self.connection is not None return self.connection + def is_queued(self) -> bool: + return self.connection is None + class AsyncConnectionPool(AsyncRequestInterface): """ @@ -107,14 +109,21 @@ def __init__( self._local_address = local_address self._uds = uds - self._pool: List[AsyncConnectionInterface] = [] - self._requests: List[RequestStatus] = [] - self._pool_lock = AsyncLock() self._network_backend = ( AutoBackend() if network_backend is None else network_backend ) self._socket_options = socket_options + # The mutable state on a connection pool is the queue of incoming requests, + # and the set of connections that are servicing those requests. + self._connections: List[AsyncConnectionInterface] = [] + self._requests: List[AsyncPoolRequest] = [] + + # We only mutate the state of the connection pool within an 'optional_thread_lock' + # context. This holds a threading lock unless we're running in async mode, + # in which case it is a no-op. + self._optional_thread_lock = AsyncThreadLock() + def create_connection(self, origin: Origin) -> AsyncConnectionInterface: return AsyncHTTPConnection( origin=origin, @@ -145,64 +154,7 @@ def connections(self) -> List[AsyncConnectionInterface]: ] ``` """ - return list(self._pool) - - async def _attempt_to_acquire_connection(self, status: RequestStatus) -> bool: - """ - Attempt to provide a connection that can handle the given origin. - """ - origin = status.request.url.origin - - # If there are queued requests in front of us, then don't acquire a - # connection. We handle requests strictly in order. - waiting = [s for s in self._requests if s.connection is None] - if waiting and waiting[0] is not status: - return False - - # Reuse an existing connection if one is currently available. - for idx, connection in enumerate(self._pool): - if connection.can_handle_request(origin) and connection.is_available(): - self._pool.pop(idx) - self._pool.insert(0, connection) - status.set_connection(connection) - return True - - # If the pool is currently full, attempt to close one idle connection. - if len(self._pool) >= self._max_connections: - for idx, connection in reversed(list(enumerate(self._pool))): - if connection.is_idle(): - await connection.aclose() - self._pool.pop(idx) - break - - # If the pool is still full, then we cannot acquire a connection. - if len(self._pool) >= self._max_connections: - return False - - # Otherwise create a new connection. - connection = self.create_connection(origin) - self._pool.insert(0, connection) - status.set_connection(connection) - return True - - async def _close_expired_connections(self) -> None: - """ - Clean up the connection pool by closing off any connections that have expired. - """ - # Close any connections that have expired their keep-alive time. - for idx, connection in reversed(list(enumerate(self._pool))): - if connection.has_expired(): - await connection.aclose() - self._pool.pop(idx) - - # If the pool size exceeds the maximum number of allowed keep-alive connections, - # then close off idle connections as required. - pool_size = len(self._pool) - for idx, connection in reversed(list(enumerate(self._pool))): - if connection.is_idle() and pool_size > self._max_keepalive_connections: - await connection.aclose() - self._pool.pop(idx) - pool_size -= 1 + return list(self._connections) async def handle_async_request(self, request: Request) -> Response: """ @@ -220,116 +172,147 @@ async def handle_async_request(self, request: Request) -> Response: f"Request URL has an unsupported protocol '{scheme}://'." ) - status = RequestStatus(request) timeouts = request.extensions.get("timeout", {}) timeout = timeouts.get("pool", None) - if timeout is not None: - deadline = time.monotonic() + timeout - else: - deadline = float("inf") - - async with self._pool_lock: - self._requests.append(status) - await self._close_expired_connections() - await self._attempt_to_acquire_connection(status) - - while True: - try: - connection = await status.wait_for_connection(timeout=timeout) - except BaseException as exc: - # If we timeout here, or if the task is cancelled, then make - # sure to remove the request from the queue before bubbling - # up the exception. - async with self._pool_lock: - # Ensure only remove when task exists. - if status in self._requests: - self._requests.remove(status) - raise exc - - try: - response = await connection.handle_async_request(request) - except ConnectionNotAvailable: - # The ConnectionNotAvailable exception is a special case, that - # indicates we need to retry the request on a new connection. - # - # The most common case where this can occur is when multiple - # requests are queued waiting for a single connection, which - # might end up as an HTTP/2 connection, but which actually ends - # up as HTTP/1.1. - async with self._pool_lock: - # Maintain our position in the request queue, but reset the - # status so that the request becomes queued again. - status.unset_connection() - await self._attempt_to_acquire_connection(status) - except BaseException as exc: - with AsyncShieldCancellation(): - await self.response_closed(status) - raise exc - else: - break - - timeout = deadline - time.monotonic() - if timeout < 0: - raise PoolTimeout # pragma: nocover - - # When we return the response, we wrap the stream in a special class - # that handles notifying the connection pool once the response - # has been released. + with self._optional_thread_lock: + # Add the incoming request to our request queue. + pool_request = AsyncPoolRequest(request) + self._requests.append(pool_request) + + try: + while True: + with self._optional_thread_lock: + # Assign incoming requests to available connections, + # closing or creating new connections as required. + closing = self._assign_requests_to_connections() + await self._close_connections(closing) + + # Wait until this request has an assigned connection. + connection = await pool_request.wait_for_connection(timeout=timeout) + + try: + # Send the request on the assigned connection. + response = await connection.handle_async_request( + pool_request.request + ) + except ConnectionNotAvailable: + # In some cases a connection may initially be available to + # handle a request, but then become unavailable. + # + # In this case we clear the connection and try again. + pool_request.clear_connection() + else: + break # pragma: nocover + + except BaseException as exc: + with self._optional_thread_lock: + # For any exception or cancellation we remove the request from + # the queue, and then re-assign requests to connections. + self._requests.remove(pool_request) + closing = self._assign_requests_to_connections() + + await self._close_connections(closing) + raise exc from None + + # Return the response. Note that in this case we still have to manage + # the point at which the response is closed. assert isinstance(response.stream, AsyncIterable) return Response( status=response.status, headers=response.headers, - content=ConnectionPoolByteStream(response.stream, self, status), + content=PoolByteStream( + stream=response.stream, pool_request=pool_request, pool=self + ), extensions=response.extensions, ) - async def response_closed(self, status: RequestStatus) -> None: + def _assign_requests_to_connections(self) -> List[AsyncConnectionInterface]: """ - This method acts as a callback once the request/response cycle is complete. + Manage the state of the connection pool, assigning incoming + requests to connections as available. - It is called into from the `ConnectionPoolByteStream.aclose()` method. - """ - assert status.connection is not None - connection = status.connection - - async with self._pool_lock: - # Update the state of the connection pool. - if status in self._requests: - self._requests.remove(status) - - if connection.is_closed() and connection in self._pool: - self._pool.remove(connection) - - # Since we've had a response closed, it's possible we'll now be able - # to service one or more requests that are currently pending. - for status in self._requests: - if status.connection is None: - acquired = await self._attempt_to_acquire_connection(status) - # If we could not acquire a connection for a queued request - # then we don't need to check anymore requests that are - # queued later behind it. - if not acquired: - break - - # Housekeeping. - await self._close_expired_connections() + Called whenever a new request is added or removed from the pool. - async def aclose(self) -> None: + Any closing connections are returned, allowing the I/O for closing + those connections to be handled seperately. """ - Close any connections in the pool. - """ - async with self._pool_lock: - for connection in self._pool: + closing_connections = [] + + # First we handle cleaning up any connections that are closed, + # have expired their keep-alive, or surplus idle connections. + for connection in list(self._connections): + if connection.is_closed(): + # log: "removing closed connection" + self._connections.remove(connection) + elif connection.has_expired(): + # log: "closing expired connection" + self._connections.remove(connection) + closing_connections.append(connection) + elif ( + connection.is_idle() + and len([connection.is_idle() for connection in self._connections]) + > self._max_keepalive_connections + ): + # log: "closing idle connection" + self._connections.remove(connection) + closing_connections.append(connection) + + # Assign queued requests to connections. + queued_requests = [request for request in self._requests if request.is_queued()] + for pool_request in queued_requests: + origin = pool_request.request.url.origin + avilable_connections = [ + connection + for connection in self._connections + if connection.can_handle_request(origin) and connection.is_available() + ] + idle_connections = [ + connection for connection in self._connections if connection.is_idle() + ] + + # There are three cases for how we may be able to handle the request: + # + # 1. There is an existing connection that can handle the request. + # 2. We can create a new connection to handle the request. + # 3. We can close an idle connection and then create a new connection + # to handle the request. + if avilable_connections: + # log: "reusing existing connection" + connection = avilable_connections[0] + pool_request.assign_to_connection(connection) + elif len(self._connections) < self._max_connections: + # log: "creating new connection" + connection = self.create_connection(origin) + self._connections.append(connection) + pool_request.assign_to_connection(connection) + elif idle_connections: + # log: "closing idle connection" + connection = idle_connections[0] + self._connections.remove(connection) + closing_connections.append(connection) + # log: "creating new connection" + connection = self.create_connection(origin) + self._connections.append(connection) + pool_request.assign_to_connection(connection) + + return closing_connections + + async def _close_connections(self, closing: List[AsyncConnectionInterface]) -> None: + # Close connections which have been removed from the pool. + with AsyncShieldCancellation(): + for connection in closing: await connection.aclose() - self._pool = [] - self._requests = [] + + async def aclose(self) -> None: + # Explicitly close the connection pool. + # Clears all existing requests and connections. + with self._optional_thread_lock: + closing_connections = list(self._connections) + self._connections = [] + await self._close_connections(closing_connections) async def __aenter__(self) -> "AsyncConnectionPool": - # Acquiring the pool lock here ensures that we have the - # correct dependencies installed as early as possible. - async with self._pool_lock: - pass return self async def __aexit__( @@ -340,31 +323,58 @@ async def __aexit__( ) -> None: await self.aclose() + def __repr__(self) -> str: + class_name = self.__class__.__name__ + with self._optional_thread_lock: + request_is_queued = [request.is_queued() for request in self._requests] + connection_is_idle = [ + connection.is_idle() for connection in self._connections + ] + + num_active_requests = request_is_queued.count(False) + num_queued_requests = request_is_queued.count(True) + num_active_connections = connection_is_idle.count(False) + num_idle_connections = connection_is_idle.count(True) + + requests_info = ( + f"Requests: {num_active_requests} active, {num_queued_requests} queued" + ) + connection_info = ( + f"Connections: {num_active_connections} active, {num_idle_connections} idle" + ) + + return f"<{class_name} [{requests_info} | {connection_info}]>" -class ConnectionPoolByteStream: - """ - A wrapper around the response byte stream, that additionally handles - notifying the connection pool when the response has been closed. - """ +class PoolByteStream: def __init__( self, stream: AsyncIterable[bytes], + pool_request: AsyncPoolRequest, pool: AsyncConnectionPool, - status: RequestStatus, ) -> None: self._stream = stream + self._pool_request = pool_request self._pool = pool - self._status = status + self._closed = False async def __aiter__(self) -> AsyncIterator[bytes]: - async for part in self._stream: - yield part + try: + async for part in self._stream: + yield part + except BaseException as exc: + await self.aclose() + raise exc from None async def aclose(self) -> None: - try: - if hasattr(self._stream, "aclose"): - await self._stream.aclose() - finally: + if not self._closed: + self._closed = True with AsyncShieldCancellation(): - await self._pool.response_closed(self._status) + if hasattr(self._stream, "aclose"): + await self._stream.aclose() + + with self._pool._optional_thread_lock: + self._pool._requests.remove(self._pool_request) + closing = self._pool._assign_requests_to_connections() + + await self._pool._close_connections(closing) diff --git a/httpcore/_sync/connection.py b/httpcore/_sync/connection.py index f6b99f1b..c3890f34 100644 --- a/httpcore/_sync/connection.py +++ b/httpcore/_sync/connection.py @@ -6,7 +6,7 @@ from .._backends.sync import SyncBackend from .._backends.base import SOCKET_OPTION, NetworkBackend, NetworkStream -from .._exceptions import ConnectError, ConnectionNotAvailable, ConnectTimeout +from .._exceptions import ConnectError, ConnectTimeout from .._models import Origin, Request, Response from .._ssl import default_ssl_context from .._synchronization import Lock @@ -70,9 +70,9 @@ def handle_request(self, request: Request) -> Response: f"Attempted to send request to {request.url.origin} on connection to {self._origin}" ) - with self._request_lock: - if self._connection is None: - try: + try: + with self._request_lock: + if self._connection is None: stream = self._connect(request) ssl_object = stream.get_extra_info("ssl_object") @@ -94,11 +94,9 @@ def handle_request(self, request: Request) -> Response: stream=stream, keepalive_expiry=self._keepalive_expiry, ) - except Exception as exc: - self._connect_failed = True - raise exc - elif not self._connection.is_available(): - raise ConnectionNotAvailable() + except BaseException as exc: + self._connect_failed = True + raise exc return self._connection.handle_request(request) diff --git a/httpcore/_sync/connection_pool.py b/httpcore/_sync/connection_pool.py index ccfb8d22..8dcf348c 100644 --- a/httpcore/_sync/connection_pool.py +++ b/httpcore/_sync/connection_pool.py @@ -1,31 +1,30 @@ import ssl import sys -import time from types import TracebackType from typing import Iterable, Iterator, Iterable, List, Optional, Type from .._backends.sync import SyncBackend from .._backends.base import SOCKET_OPTION, NetworkBackend -from .._exceptions import ConnectionNotAvailable, PoolTimeout, UnsupportedProtocol +from .._exceptions import ConnectionNotAvailable, UnsupportedProtocol from .._models import Origin, Request, Response -from .._synchronization import Event, Lock, ShieldCancellation +from .._synchronization import Event, ShieldCancellation, ThreadLock from .connection import HTTPConnection from .interfaces import ConnectionInterface, RequestInterface -class RequestStatus: - def __init__(self, request: Request): +class PoolRequest: + def __init__(self, request: Request) -> None: self.request = request self.connection: Optional[ConnectionInterface] = None self._connection_acquired = Event() - def set_connection(self, connection: ConnectionInterface) -> None: - assert self.connection is None + def assign_to_connection( + self, connection: Optional[ConnectionInterface] + ) -> None: self.connection = connection self._connection_acquired.set() - def unset_connection(self) -> None: - assert self.connection is not None + def clear_connection(self) -> None: self.connection = None self._connection_acquired = Event() @@ -37,6 +36,9 @@ def wait_for_connection( assert self.connection is not None return self.connection + def is_queued(self) -> bool: + return self.connection is None + class ConnectionPool(RequestInterface): """ @@ -107,14 +109,21 @@ def __init__( self._local_address = local_address self._uds = uds - self._pool: List[ConnectionInterface] = [] - self._requests: List[RequestStatus] = [] - self._pool_lock = Lock() self._network_backend = ( SyncBackend() if network_backend is None else network_backend ) self._socket_options = socket_options + # The mutable state on a connection pool is the queue of incoming requests, + # and the set of connections that are servicing those requests. + self._connections: List[ConnectionInterface] = [] + self._requests: List[PoolRequest] = [] + + # We only mutate the state of the connection pool within an 'optional_thread_lock' + # context. This holds a threading lock unless we're running in async mode, + # in which case it is a no-op. + self._optional_thread_lock = ThreadLock() + def create_connection(self, origin: Origin) -> ConnectionInterface: return HTTPConnection( origin=origin, @@ -145,64 +154,7 @@ def connections(self) -> List[ConnectionInterface]: ] ``` """ - return list(self._pool) - - def _attempt_to_acquire_connection(self, status: RequestStatus) -> bool: - """ - Attempt to provide a connection that can handle the given origin. - """ - origin = status.request.url.origin - - # If there are queued requests in front of us, then don't acquire a - # connection. We handle requests strictly in order. - waiting = [s for s in self._requests if s.connection is None] - if waiting and waiting[0] is not status: - return False - - # Reuse an existing connection if one is currently available. - for idx, connection in enumerate(self._pool): - if connection.can_handle_request(origin) and connection.is_available(): - self._pool.pop(idx) - self._pool.insert(0, connection) - status.set_connection(connection) - return True - - # If the pool is currently full, attempt to close one idle connection. - if len(self._pool) >= self._max_connections: - for idx, connection in reversed(list(enumerate(self._pool))): - if connection.is_idle(): - connection.close() - self._pool.pop(idx) - break - - # If the pool is still full, then we cannot acquire a connection. - if len(self._pool) >= self._max_connections: - return False - - # Otherwise create a new connection. - connection = self.create_connection(origin) - self._pool.insert(0, connection) - status.set_connection(connection) - return True - - def _close_expired_connections(self) -> None: - """ - Clean up the connection pool by closing off any connections that have expired. - """ - # Close any connections that have expired their keep-alive time. - for idx, connection in reversed(list(enumerate(self._pool))): - if connection.has_expired(): - connection.close() - self._pool.pop(idx) - - # If the pool size exceeds the maximum number of allowed keep-alive connections, - # then close off idle connections as required. - pool_size = len(self._pool) - for idx, connection in reversed(list(enumerate(self._pool))): - if connection.is_idle() and pool_size > self._max_keepalive_connections: - connection.close() - self._pool.pop(idx) - pool_size -= 1 + return list(self._connections) def handle_request(self, request: Request) -> Response: """ @@ -220,116 +172,147 @@ def handle_request(self, request: Request) -> Response: f"Request URL has an unsupported protocol '{scheme}://'." ) - status = RequestStatus(request) timeouts = request.extensions.get("timeout", {}) timeout = timeouts.get("pool", None) - if timeout is not None: - deadline = time.monotonic() + timeout - else: - deadline = float("inf") - - with self._pool_lock: - self._requests.append(status) - self._close_expired_connections() - self._attempt_to_acquire_connection(status) - - while True: - try: - connection = status.wait_for_connection(timeout=timeout) - except BaseException as exc: - # If we timeout here, or if the task is cancelled, then make - # sure to remove the request from the queue before bubbling - # up the exception. - with self._pool_lock: - # Ensure only remove when task exists. - if status in self._requests: - self._requests.remove(status) - raise exc - - try: - response = connection.handle_request(request) - except ConnectionNotAvailable: - # The ConnectionNotAvailable exception is a special case, that - # indicates we need to retry the request on a new connection. - # - # The most common case where this can occur is when multiple - # requests are queued waiting for a single connection, which - # might end up as an HTTP/2 connection, but which actually ends - # up as HTTP/1.1. - with self._pool_lock: - # Maintain our position in the request queue, but reset the - # status so that the request becomes queued again. - status.unset_connection() - self._attempt_to_acquire_connection(status) - except BaseException as exc: - with ShieldCancellation(): - self.response_closed(status) - raise exc - else: - break - - timeout = deadline - time.monotonic() - if timeout < 0: - raise PoolTimeout # pragma: nocover - - # When we return the response, we wrap the stream in a special class - # that handles notifying the connection pool once the response - # has been released. + with self._optional_thread_lock: + # Add the incoming request to our request queue. + pool_request = PoolRequest(request) + self._requests.append(pool_request) + + try: + while True: + with self._optional_thread_lock: + # Assign incoming requests to available connections, + # closing or creating new connections as required. + closing = self._assign_requests_to_connections() + self._close_connections(closing) + + # Wait until this request has an assigned connection. + connection = pool_request.wait_for_connection(timeout=timeout) + + try: + # Send the request on the assigned connection. + response = connection.handle_request( + pool_request.request + ) + except ConnectionNotAvailable: + # In some cases a connection may initially be available to + # handle a request, but then become unavailable. + # + # In this case we clear the connection and try again. + pool_request.clear_connection() + else: + break # pragma: nocover + + except BaseException as exc: + with self._optional_thread_lock: + # For any exception or cancellation we remove the request from + # the queue, and then re-assign requests to connections. + self._requests.remove(pool_request) + closing = self._assign_requests_to_connections() + + self._close_connections(closing) + raise exc from None + + # Return the response. Note that in this case we still have to manage + # the point at which the response is closed. assert isinstance(response.stream, Iterable) return Response( status=response.status, headers=response.headers, - content=ConnectionPoolByteStream(response.stream, self, status), + content=PoolByteStream( + stream=response.stream, pool_request=pool_request, pool=self + ), extensions=response.extensions, ) - def response_closed(self, status: RequestStatus) -> None: + def _assign_requests_to_connections(self) -> List[ConnectionInterface]: """ - This method acts as a callback once the request/response cycle is complete. + Manage the state of the connection pool, assigning incoming + requests to connections as available. - It is called into from the `ConnectionPoolByteStream.close()` method. - """ - assert status.connection is not None - connection = status.connection - - with self._pool_lock: - # Update the state of the connection pool. - if status in self._requests: - self._requests.remove(status) - - if connection.is_closed() and connection in self._pool: - self._pool.remove(connection) - - # Since we've had a response closed, it's possible we'll now be able - # to service one or more requests that are currently pending. - for status in self._requests: - if status.connection is None: - acquired = self._attempt_to_acquire_connection(status) - # If we could not acquire a connection for a queued request - # then we don't need to check anymore requests that are - # queued later behind it. - if not acquired: - break - - # Housekeeping. - self._close_expired_connections() + Called whenever a new request is added or removed from the pool. - def close(self) -> None: + Any closing connections are returned, allowing the I/O for closing + those connections to be handled seperately. """ - Close any connections in the pool. - """ - with self._pool_lock: - for connection in self._pool: + closing_connections = [] + + # First we handle cleaning up any connections that are closed, + # have expired their keep-alive, or surplus idle connections. + for connection in list(self._connections): + if connection.is_closed(): + # log: "removing closed connection" + self._connections.remove(connection) + elif connection.has_expired(): + # log: "closing expired connection" + self._connections.remove(connection) + closing_connections.append(connection) + elif ( + connection.is_idle() + and len([connection.is_idle() for connection in self._connections]) + > self._max_keepalive_connections + ): + # log: "closing idle connection" + self._connections.remove(connection) + closing_connections.append(connection) + + # Assign queued requests to connections. + queued_requests = [request for request in self._requests if request.is_queued()] + for pool_request in queued_requests: + origin = pool_request.request.url.origin + avilable_connections = [ + connection + for connection in self._connections + if connection.can_handle_request(origin) and connection.is_available() + ] + idle_connections = [ + connection for connection in self._connections if connection.is_idle() + ] + + # There are three cases for how we may be able to handle the request: + # + # 1. There is an existing connection that can handle the request. + # 2. We can create a new connection to handle the request. + # 3. We can close an idle connection and then create a new connection + # to handle the request. + if avilable_connections: + # log: "reusing existing connection" + connection = avilable_connections[0] + pool_request.assign_to_connection(connection) + elif len(self._connections) < self._max_connections: + # log: "creating new connection" + connection = self.create_connection(origin) + self._connections.append(connection) + pool_request.assign_to_connection(connection) + elif idle_connections: + # log: "closing idle connection" + connection = idle_connections[0] + self._connections.remove(connection) + closing_connections.append(connection) + # log: "creating new connection" + connection = self.create_connection(origin) + self._connections.append(connection) + pool_request.assign_to_connection(connection) + + return closing_connections + + def _close_connections(self, closing: List[ConnectionInterface]) -> None: + # Close connections which have been removed from the pool. + with ShieldCancellation(): + for connection in closing: connection.close() - self._pool = [] - self._requests = [] + + def close(self) -> None: + # Explicitly close the connection pool. + # Clears all existing requests and connections. + with self._optional_thread_lock: + closing_connections = list(self._connections) + self._connections = [] + self._close_connections(closing_connections) def __enter__(self) -> "ConnectionPool": - # Acquiring the pool lock here ensures that we have the - # correct dependencies installed as early as possible. - with self._pool_lock: - pass return self def __exit__( @@ -340,31 +323,58 @@ def __exit__( ) -> None: self.close() + def __repr__(self) -> str: + class_name = self.__class__.__name__ + with self._optional_thread_lock: + request_is_queued = [request.is_queued() for request in self._requests] + connection_is_idle = [ + connection.is_idle() for connection in self._connections + ] + + num_active_requests = request_is_queued.count(False) + num_queued_requests = request_is_queued.count(True) + num_active_connections = connection_is_idle.count(False) + num_idle_connections = connection_is_idle.count(True) + + requests_info = ( + f"Requests: {num_active_requests} active, {num_queued_requests} queued" + ) + connection_info = ( + f"Connections: {num_active_connections} active, {num_idle_connections} idle" + ) + + return f"<{class_name} [{requests_info} | {connection_info}]>" -class ConnectionPoolByteStream: - """ - A wrapper around the response byte stream, that additionally handles - notifying the connection pool when the response has been closed. - """ +class PoolByteStream: def __init__( self, stream: Iterable[bytes], + pool_request: PoolRequest, pool: ConnectionPool, - status: RequestStatus, ) -> None: self._stream = stream + self._pool_request = pool_request self._pool = pool - self._status = status + self._closed = False def __iter__(self) -> Iterator[bytes]: - for part in self._stream: - yield part + try: + for part in self._stream: + yield part + except BaseException as exc: + self.close() + raise exc from None def close(self) -> None: - try: - if hasattr(self._stream, "close"): - self._stream.close() - finally: + if not self._closed: + self._closed = True with ShieldCancellation(): - self._pool.response_closed(self._status) + if hasattr(self._stream, "close"): + self._stream.close() + + with self._pool._optional_thread_lock: + self._pool._requests.remove(self._pool_request) + closing = self._pool._assign_requests_to_connections() + + self._pool._close_connections(closing) diff --git a/httpcore/_synchronization.py b/httpcore/_synchronization.py index 119d89fc..9619a398 100644 --- a/httpcore/_synchronization.py +++ b/httpcore/_synchronization.py @@ -45,6 +45,13 @@ def current_async_library() -> str: class AsyncLock: + """ + This is a standard lock. + + In the sync case `Lock` provides thread locking. + In the async case `AsyncLock` provides async locking. + """ + def __init__(self) -> None: self._backend = "" @@ -82,6 +89,26 @@ async def __aexit__( self._anyio_lock.release() +class AsyncThreadLock: + """ + This is a threading-only lock for no-I/O contexts. + + In the sync case `ThreadLock` provides thread locking. + In the async case `AsyncThreadLock` is a no-op. + """ + + def __enter__(self) -> "AsyncThreadLock": + return self + + def __exit__( + self, + exc_type: Optional[Type[BaseException]] = None, + exc_value: Optional[BaseException] = None, + traceback: Optional[TracebackType] = None, + ) -> None: + pass + + class AsyncEvent: def __init__(self) -> None: self._backend = "" @@ -202,6 +229,13 @@ def __exit__( class Lock: + """ + This is a standard lock. + + In the sync case `Lock` provides thread locking. + In the async case `AsyncLock` provides async locking. + """ + def __init__(self) -> None: self._lock = threading.Lock() @@ -218,6 +252,30 @@ def __exit__( self._lock.release() +class ThreadLock: + """ + This is a threading-only lock for no-I/O contexts. + + In the sync case `ThreadLock` provides thread locking. + In the async case `AsyncThreadLock` is a no-op. + """ + + def __init__(self) -> None: + self._lock = threading.Lock() + + def __enter__(self) -> "ThreadLock": + self._lock.acquire() + return self + + def __exit__( + self, + exc_type: Optional[Type[BaseException]] = None, + exc_value: Optional[BaseException] = None, + traceback: Optional[TracebackType] = None, + ) -> None: + self._lock.release() + + class Event: def __init__(self) -> None: self._event = threading.Event() diff --git a/requirements.txt b/requirements.txt index 71cf2164..d125b321 100644 --- a/requirements.txt +++ b/requirements.txt @@ -20,5 +20,5 @@ trio-typing==0.10.0 types-certifi==2021.10.8.3 pytest==8.0.0 pytest-httpbin==2.0.0 -pytest-trio==0.7.0 +pytest-trio==0.8.0 werkzeug<2.1 # See: https://github.com/postmanlabs/httpbin/issues/673 diff --git a/tests/_async/test_connection_pool.py b/tests/_async/test_connection_pool.py index 61ee1e54..2fc27204 100644 --- a/tests/_async/test_connection_pool.py +++ b/tests/_async/test_connection_pool.py @@ -38,6 +38,10 @@ async def test_connection_pool_with_keepalive(): assert info == [ "" ] + assert ( + repr(pool) + == "" + ) await response.aread() assert response.status == 200 @@ -46,6 +50,10 @@ async def test_connection_pool_with_keepalive(): assert info == [ "" ] + assert ( + repr(pool) + == "" + ) # Sending a second request to the same origin will reuse the existing IDLE connection. async with pool.stream("GET", "https://example.com/") as response: @@ -53,6 +61,10 @@ async def test_connection_pool_with_keepalive(): assert info == [ "" ] + assert ( + repr(pool) + == "" + ) await response.aread() assert response.status == 200 @@ -61,23 +73,35 @@ async def test_connection_pool_with_keepalive(): assert info == [ "" ] + assert ( + repr(pool) + == "" + ) # Sending a request to a different origin will not reuse the existing IDLE connection. async with pool.stream("GET", "http://example.com/") as response: info = [repr(c) for c in pool.connections] assert info == [ - "", "", + "", ] + assert ( + repr(pool) + == "" + ) await response.aread() assert response.status == 200 assert response.content == b"Hello, world!" info = [repr(c) for c in pool.connections] assert info == [ - "", "", + "", ] + assert ( + repr(pool) + == "" + ) @pytest.mark.anyio @@ -219,6 +243,7 @@ async def test_connection_pool_with_http2_goaway(): ] # Sending a second request to the same origin will require a new connection. + # The original connection has now been closed. response = await pool.request("GET", "https://example.com/") assert response.status == 200 assert response.content == b"Hello, world!" @@ -226,7 +251,6 @@ async def test_connection_pool_with_http2_goaway(): info = [repr(c) for c in pool.connections] assert info == [ "", - "", ] @@ -620,6 +644,11 @@ async def fetch(pool, domain, info_list): "", ] + assert ( + repr(pool) + == "" + ) + @pytest.mark.anyio async def test_unsupported_protocol(): diff --git a/tests/_sync/test_connection_pool.py b/tests/_sync/test_connection_pool.py index c9621c7b..ee303e5c 100644 --- a/tests/_sync/test_connection_pool.py +++ b/tests/_sync/test_connection_pool.py @@ -38,6 +38,10 @@ def test_connection_pool_with_keepalive(): assert info == [ "" ] + assert ( + repr(pool) + == "" + ) response.read() assert response.status == 200 @@ -46,6 +50,10 @@ def test_connection_pool_with_keepalive(): assert info == [ "" ] + assert ( + repr(pool) + == "" + ) # Sending a second request to the same origin will reuse the existing IDLE connection. with pool.stream("GET", "https://example.com/") as response: @@ -53,6 +61,10 @@ def test_connection_pool_with_keepalive(): assert info == [ "" ] + assert ( + repr(pool) + == "" + ) response.read() assert response.status == 200 @@ -61,23 +73,35 @@ def test_connection_pool_with_keepalive(): assert info == [ "" ] + assert ( + repr(pool) + == "" + ) # Sending a request to a different origin will not reuse the existing IDLE connection. with pool.stream("GET", "http://example.com/") as response: info = [repr(c) for c in pool.connections] assert info == [ - "", "", + "", ] + assert ( + repr(pool) + == "" + ) response.read() assert response.status == 200 assert response.content == b"Hello, world!" info = [repr(c) for c in pool.connections] assert info == [ - "", "", + "", ] + assert ( + repr(pool) + == "" + ) @@ -219,6 +243,7 @@ def test_connection_pool_with_http2_goaway(): ] # Sending a second request to the same origin will require a new connection. + # The original connection has now been closed. response = pool.request("GET", "https://example.com/") assert response.status == 200 assert response.content == b"Hello, world!" @@ -226,7 +251,6 @@ def test_connection_pool_with_http2_goaway(): info = [repr(c) for c in pool.connections] assert info == [ "", - "", ] @@ -620,6 +644,11 @@ def fetch(pool, domain, info_list): "", ] + assert ( + repr(pool) + == "" + ) + def test_unsupported_protocol():