diff --git a/docs/versionhistory.rst b/docs/versionhistory.rst index 245f9eee..35fd7b8a 100644 --- a/docs/versionhistory.rst +++ b/docs/versionhistory.rst @@ -13,6 +13,8 @@ This library adheres to `Semantic Versioning 2.0 `_. (`#836 `_; PR by @graingert) - Fixed ``AssertionError`` when using ``nest-asyncio`` (`#840 `_) +- Fixed return type annotation of various context managers' ``__exit__`` method + (`#847 `_; PR by @Enegg) **4.7.0** diff --git a/src/anyio/_backends/_asyncio.py b/src/anyio/_backends/_asyncio.py index 5a0aa936..11582529 100644 --- a/src/anyio/_backends/_asyncio.py +++ b/src/anyio/_backends/_asyncio.py @@ -449,7 +449,7 @@ def __exit__( exc_type: type[BaseException] | None, exc_val: BaseException | None, exc_tb: TracebackType | None, - ) -> bool | None: + ) -> bool: del exc_tb if not self._active: @@ -2116,10 +2116,9 @@ def __exit__( exc_type: type[BaseException] | None, exc_val: BaseException | None, exc_tb: TracebackType | None, - ) -> bool | None: + ) -> None: for sig in self._handled_signals: self._loop.remove_signal_handler(sig) - return None def __aiter__(self) -> _SignalReceiver: return self @@ -2448,7 +2447,7 @@ def create_capacity_limiter(cls, total_tokens: float) -> abc.CapacityLimiter: return CapacityLimiter(total_tokens) @classmethod - async def run_sync_in_worker_thread( + async def run_sync_in_worker_thread( # type: ignore[return] cls, func: Callable[[Unpack[PosArgsT]], T_Retval], args: tuple[Unpack[PosArgsT]], @@ -2470,7 +2469,7 @@ async def run_sync_in_worker_thread( async with limiter or cls.current_default_thread_limiter(): with CancelScope(shield=not abandon_on_cancel) as scope: - future: asyncio.Future = asyncio.Future() + future = asyncio.Future[T_Retval]() root_task = find_root_task() if not idle_workers: worker = WorkerThread(root_task, workers, idle_workers) diff --git a/src/anyio/_backends/_trio.py b/src/anyio/_backends/_trio.py index 70a0a605..32ae8ace 100644 --- a/src/anyio/_backends/_trio.py +++ b/src/anyio/_backends/_trio.py @@ -132,8 +132,7 @@ def __exit__( exc_type: type[BaseException] | None, exc_val: BaseException | None, exc_tb: TracebackType | None, - ) -> bool | None: - # https://github.com/python-trio/trio-typing/pull/79 + ) -> bool: return self.__original.__exit__(exc_type, exc_val, exc_tb) def cancel(self) -> None: @@ -186,9 +185,10 @@ async def __aexit__( exc_type: type[BaseException] | None, exc_val: BaseException | None, exc_tb: TracebackType | None, - ) -> bool | None: + ) -> bool: try: - return await self._nursery_manager.__aexit__(exc_type, exc_val, exc_tb) + # trio.Nursery.__exit__ returns bool; .open_nursery has wrong type + return await self._nursery_manager.__aexit__(exc_type, exc_val, exc_tb) # type: ignore[return-value] except BaseExceptionGroup as exc: if not exc.split(trio.Cancelled)[1]: raise trio.Cancelled._create() from exc diff --git a/src/anyio/_core/_synchronization.py b/src/anyio/_core/_synchronization.py index 7878ba66..a6331328 100644 --- a/src/anyio/_core/_synchronization.py +++ b/src/anyio/_core/_synchronization.py @@ -728,6 +728,5 @@ def __exit__( exc_type: type[BaseException] | None, exc_val: BaseException | None, exc_tb: TracebackType | None, - ) -> bool | None: + ) -> None: self._guarded = False - return None diff --git a/src/anyio/_core/_tasks.py b/src/anyio/_core/_tasks.py index 2f21ea20..fe490151 100644 --- a/src/anyio/_core/_tasks.py +++ b/src/anyio/_core/_tasks.py @@ -88,7 +88,7 @@ def __exit__( exc_type: type[BaseException] | None, exc_val: BaseException | None, exc_tb: TracebackType | None, - ) -> bool | None: + ) -> bool: raise NotImplementedError diff --git a/src/anyio/to_process.py b/src/anyio/to_process.py index 5050dee2..495de2ae 100644 --- a/src/anyio/to_process.py +++ b/src/anyio/to_process.py @@ -35,7 +35,7 @@ _default_process_limiter: RunVar[CapacityLimiter] = RunVar("_default_process_limiter") -async def run_sync( +async def run_sync( # type: ignore[return] func: Callable[[Unpack[PosArgsT]], T_Retval], *args: Unpack[PosArgsT], cancellable: bool = False, diff --git a/tests/test_taskgroups.py b/tests/test_taskgroups.py index 6410f5e3..1c1a654c 100644 --- a/tests/test_taskgroups.py +++ b/tests/test_taskgroups.py @@ -865,7 +865,8 @@ async def task(task_status: TaskStatus) -> NoReturn: completed = True scope.shield = False await sleep(1) - pytest.fail("Execution should not reach this point") + + pytest.fail("Execution should not reach this point") async with create_task_group() as tg: await tg.start(task)