diff --git a/ldp/alg/runners.py b/ldp/alg/runners.py index 31baada..97bbd39 100644 --- a/ldp/alg/runners.py +++ b/ldp/alg/runners.py @@ -4,6 +4,7 @@ import math import random from collections.abc import Sequence +from contextlib import suppress from typing import Any, cast from aviary.core import Environment, TaskDataset @@ -16,7 +17,7 @@ from ldp.shims import tqdm, trange from .callbacks import Callback, ClearContextCallback -from .rollout import RolloutManager +from .rollout import EnvError, RolloutManager, reraise_exc_as async def _run_eval_loop( @@ -53,7 +54,7 @@ async def _run_eval_loop( # Close the environment after we have sampled from it, # in case it needs to tear down resources. - await asyncio.gather(*(env.close() for env in batch)) + await _close_envs(batch, rollout_manager.catch_env_failures) await asyncio.gather(*[ callback.after_eval_step(trajectories) for callback in callbacks @@ -251,7 +252,7 @@ async def _training_step(self, i_iter: int, envs: Sequence[Environment]) -> None ) # Close the environments after we have sampled from them, in case they need to tear down resources. - await asyncio.gather(*[env.close() for env in envs]) + await _close_envs(envs, self.config.catch_env_failures) training_batch.extend(traj for traj in trajectories if not traj.failed) @@ -359,3 +360,18 @@ async def train(self) -> None: await asyncio.gather(*[ callback.after_train_step(batch) for callback in self.callbacks ]) + + +async def _close_envs(envs: Sequence[Environment], catch_env_failures: bool): + # Note that the reraise happens per-env, not over the whole batch, so one env failing + # to close won't affect the others + await asyncio.gather(*[safe_close_env(env, catch_env_failures) for env in envs]) + + +async def safe_close_env(env: Environment, catch_env_failures: bool): + """Close an environment. + + If catch_env_failures is set, will not raise exceptions. + """ + with suppress(EnvError), reraise_exc_as(EnvError, enabled=catch_env_failures): + await env.close() diff --git a/ldp/alg/tree_search.py b/ldp/alg/tree_search.py index 1cb6fac..9d74ba0 100644 --- a/ldp/alg/tree_search.py +++ b/ldp/alg/tree_search.py @@ -19,6 +19,7 @@ TEnv, reraise_exc_as, ) +from .runners import safe_close_env logger = logging.getLogger(__name__) @@ -89,6 +90,9 @@ async def sample_tree(self, env: TEnv, max_depth: int | None) -> TransitionTree: ]) except CaughtError: return tree + finally: + # Tear down env resources if necessary + await safe_close_env(env, self.catch_env_failures) await self._descend( tree=tree,