Skip to content

Commit

Permalink
Catch closing exceptions (#162)
Browse files Browse the repository at this point in the history
  • Loading branch information
sidnarayanan authored Nov 27, 2024
1 parent 86f0b25 commit feb0222
Show file tree
Hide file tree
Showing 2 changed files with 23 additions and 3 deletions.
22 changes: 19 additions & 3 deletions ldp/alg/runners.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import math
import random
from collections.abc import Sequence
from contextlib import suppress
from typing import Any, cast

from aviary.core import Environment, TaskDataset
Expand All @@ -16,7 +17,7 @@
from ldp.shims import tqdm, trange

from .callbacks import Callback, ClearContextCallback
from .rollout import RolloutManager
from .rollout import EnvError, RolloutManager, reraise_exc_as


async def _run_eval_loop(
Expand Down Expand Up @@ -53,7 +54,7 @@ async def _run_eval_loop(

# Close the environment after we have sampled from it,
# in case it needs to tear down resources.
await asyncio.gather(*(env.close() for env in batch))
await _close_envs(batch, rollout_manager.catch_env_failures)

await asyncio.gather(*[
callback.after_eval_step(trajectories) for callback in callbacks
Expand Down Expand Up @@ -251,7 +252,7 @@ async def _training_step(self, i_iter: int, envs: Sequence[Environment]) -> None
)

# Close the environments after we have sampled from them, in case they need to tear down resources.
await asyncio.gather(*[env.close() for env in envs])
await _close_envs(envs, self.config.catch_env_failures)

training_batch.extend(traj for traj in trajectories if not traj.failed)

Expand Down Expand Up @@ -359,3 +360,18 @@ async def train(self) -> None:
await asyncio.gather(*[
callback.after_train_step(batch) for callback in self.callbacks
])


async def _close_envs(envs: Sequence[Environment], catch_env_failures: bool):
# Note that the reraise happens per-env, not over the whole batch, so one env failing
# to close won't affect the others
await asyncio.gather(*[safe_close_env(env, catch_env_failures) for env in envs])


async def safe_close_env(env: Environment, catch_env_failures: bool):
"""Close an environment.
If catch_env_failures is set, will not raise exceptions.
"""
with suppress(EnvError), reraise_exc_as(EnvError, enabled=catch_env_failures):
await env.close()
4 changes: 4 additions & 0 deletions ldp/alg/tree_search.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
TEnv,
reraise_exc_as,
)
from .runners import safe_close_env

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -89,6 +90,9 @@ async def sample_tree(self, env: TEnv, max_depth: int | None) -> TransitionTree:
])
except CaughtError:
return tree
finally:
# Tear down env resources if necessary
await safe_close_env(env, self.catch_env_failures)

await self._descend(
tree=tree,
Expand Down

0 comments on commit feb0222

Please sign in to comment.