Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Catch closing exceptions #162

Merged
merged 3 commits into from
Nov 27, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 19 additions & 3 deletions ldp/alg/runners.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import math
import random
from collections.abc import Sequence
from contextlib import suppress
from typing import Any, cast

from aviary.core import Environment, TaskDataset
Expand All @@ -16,7 +17,7 @@
from ldp.shims import tqdm, trange

from .callbacks import Callback, ClearContextCallback
from .rollout import RolloutManager
from .rollout import EnvError, RolloutManager, reraise_exc_as


async def _run_eval_loop(
Expand Down Expand Up @@ -53,7 +54,7 @@ async def _run_eval_loop(

# Close the environment after we have sampled from it,
# in case it needs to tear down resources.
await asyncio.gather(*(env.close() for env in batch))
await _close_envs(batch, rollout_manager.catch_env_failures)

await asyncio.gather(*[
callback.after_eval_step(trajectories) for callback in callbacks
Expand Down Expand Up @@ -251,7 +252,7 @@ async def _training_step(self, i_iter: int, envs: Sequence[Environment]) -> None
)

# Close the environments after we have sampled from them, in case they need to tear down resources.
await asyncio.gather(*[env.close() for env in envs])
await _close_envs(envs, self.config.catch_env_failures)

training_batch.extend(traj for traj in trajectories if not traj.failed)

Expand Down Expand Up @@ -359,3 +360,18 @@ async def train(self) -> None:
await asyncio.gather(*[
callback.after_train_step(batch) for callback in self.callbacks
])


async def _close_envs(envs: Sequence[Environment], catch_env_failures: bool):
# Note that the reraise happens per-env, not over the whole batch, so one env failing
# to close won't affect the others
await asyncio.gather(*[safe_close_env(env, catch_env_failures) for env in envs])


async def safe_close_env(env: Environment, catch_env_failures: bool):
"""Close an environment.

If catch_env_failures is set, will not raise exceptions.
"""
with suppress(EnvError), reraise_exc_as(EnvError, enabled=catch_env_failures):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Will this log the closing's Exception somewhere? Would be good to log it at least

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yep, reraise_exc_as has logging builtin. This is the same logic that we use to catch reset/step exceptions

await env.close()
4 changes: 4 additions & 0 deletions ldp/alg/tree_search.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
TEnv,
reraise_exc_as,
)
from .runners import safe_close_env

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -89,6 +90,9 @@ async def sample_tree(self, env: TEnv, max_depth: int | None) -> TransitionTree:
])
except CaughtError:
return tree
finally:
# Tear down env resources if necessary
await safe_close_env(env, self.catch_env_failures)

await self._descend(
tree=tree,
Expand Down