diff --git a/kopf/_cogs/clients/auth.py b/kopf/_cogs/clients/auth.py index b92d101b..e5c13a44 100644 --- a/kopf/_cogs/clients/auth.py +++ b/kopf/_cogs/clients/auth.py @@ -4,7 +4,7 @@ import ssl import tempfile from contextvars import ContextVar -from typing import Any, Callable, Dict, Iterator, Mapping, Optional, TypeVar, cast +from typing import Any, Callable, Dict, Iterator, List, Mapping, Optional, TypeVar, cast import aiohttp @@ -36,13 +36,22 @@ async def wrapper(*args: Any, **kwargs: Any) -> Any: # If a context is explicitly passed, make it a simple call without re-auth. # Exceptions are escalated to a caller, which is probably wrapped itself. if 'context' in kwargs: - return await fn(*args, **kwargs) + context = kwargs['context'] + response = await fn(*args, **kwargs) + if isinstance(response, aiohttp.ClientResponse): + # Keep track of responses which are using this context. + context.add_response(response) + return response # Otherwise, attempt the execution with the vault credentials and re-authenticate on 401s. vault: credentials.Vault = vault_var.get() async for key, info, context in vault.extended(APIContext, 'contexts'): try: - return await fn(*args, **kwargs, context=context) + response = await fn(*args, **kwargs, context=context) + if isinstance(response, aiohttp.ClientResponse): + # Keep track of responses which are using this context. + context.add_response(response) + return response except errors.APIUnauthorizedError as e: await vault.invalidate(key, exc=e) @@ -74,6 +83,9 @@ class APIContext: server: str default_namespace: Optional[str] + # List of open responses. + responses: List[aiohttp.ClientResponse] + # Temporary caches of the information retrieved for and from the environment. _tempfiles: "_TempFiles" @@ -166,10 +178,32 @@ def __init__( self.server = info.server self.default_namespace = info.default_namespace + self.responses = [] + # For purging on garbage collection. self._tempfiles = tempfiles + def flush_closed_responses(self) -> None: + # There's no point keeping references to already closed responses. + self.responses[:] = [_response for _response in self.responses if not _response.closed] + + def add_response(self, response: aiohttp.ClientResponse) -> None: + # Keep track of responses so they can be closed later when the session + # is closed. + self.flush_closed_responses() + if not response.closed: + self.responses.append(response) + + def close_open_responses(self) -> None: + # Close all responses that are still open and are using this session. + for response in self.responses: + if not response.closed: + response.close() + self.responses.clear() + async def close(self) -> None: + # Close all open responses that use this session before closing the session itself. + self.close_open_responses() # Closing is triggered by `Vault._flush_caches()` -- forward it to the actual session. await self.session.close()