From 1c1a096c10ee80d0bcda8ad3c5357a4ae2eddbbe Mon Sep 17 00:00:00 2001 From: albertbou92 Date: Fri, 20 Sep 2024 13:52:26 -0700 Subject: [PATCH] feedback fixes --- ldp/alg/callbacks.py | 10 +++++++++- ldp/alg/runners.py | 22 ++++++++++++++-------- ldp/graph/ops.py | 2 +- tests/test_ops.py | 4 ++-- 4 files changed, 26 insertions(+), 12 deletions(-) diff --git a/ldp/alg/callbacks.py b/ldp/alg/callbacks.py index 4e473f81..846612ff 100644 --- a/ldp/alg/callbacks.py +++ b/ldp/alg/callbacks.py @@ -15,7 +15,7 @@ from ldp.agent import Agent from ldp.data_structures import Trajectory, Transition -from ldp.graph.ops import OpResult +from ldp.graph.ops import OpCtx, OpResult try: import wandb @@ -375,3 +375,11 @@ def _eval_log(metrics: list[dict[str, list[float]]]) -> None: f"eval/{key}_mean": sum(vals) / len(vals) for key, vals in flattened_metrics.items() }) + + +class ClearContextCallback(Callback): + async def after_eval_step(self, trajectories: Sequence[Trajectory]) -> None: + OpCtx.clear_data() + + async def after_train_step(self, trajectories: Sequence[Trajectory]) -> None: + OpCtx.clear_data() diff --git a/ldp/alg/runners.py b/ldp/alg/runners.py index 59a4f5bf..6679645c 100644 --- a/ldp/alg/runners.py +++ b/ldp/alg/runners.py @@ -11,12 +11,12 @@ from tqdm import tqdm, trange from ldp.agent import Agent -from ldp.alg.callbacks import Callback +from ldp.alg.callbacks import Callback, ClearContextCallback from ldp.alg.optimizer import Optimizer from ldp.alg.rollout import RolloutManager from ldp.data_structures import Trajectory from ldp.graph.op_utils import eval_mode, train_mode -from ldp.graph.ops import OpCtx, OpResult +from ldp.graph.ops import OpResult async def _run_eval_loop( @@ -96,11 +96,15 @@ def __init__( agent: Agent, dataset: TaskDataset, callbacks: Sequence[Callback] | None = None, + clear_ctx_at_each_iter: bool = True, ): self.config = config self.agent = agent self.dataset = dataset self.callbacks = callbacks or [] + if clear_ctx_at_each_iter: + clear_cb = ClearContextCallback() + self.callbacks = [*self.callbacks, clear_cb] if callbacks else [clear_cb] self.rollout_manager = self.config.make_rollout_manager(agent, self.callbacks) @eval_mode() @@ -159,6 +163,7 @@ def __init__( train_dataset: TaskDataset, eval_dataset: TaskDataset | None = None, callbacks: Sequence[Callback] | None = None, + clear_ctx_at_each_iter: bool = True, ): if config.eval_every is not None and eval_dataset is None: raise ValueError("Must specify eval_dataset if eval_every is set") @@ -169,6 +174,9 @@ def __init__( self.eval_dataset = eval_dataset self.optimizer = optimizer self.callbacks = callbacks or [] + if clear_ctx_at_each_iter: + clear_cb = ClearContextCallback() + self.callbacks = [*self.callbacks, clear_cb] if callbacks else [clear_cb] self.rollout_manager = self.config.make_rollout_manager( agent=agent, callbacks=self.callbacks ) @@ -197,9 +205,6 @@ async def train(self) -> None: if pbar.n == self.config.num_train_iterations: break - # Clear all op contexts - OpCtx.clear_registry() - pbar.close() await self._eval_loop() @@ -279,6 +284,7 @@ def __init__( optimizer: Optimizer, train_trajectories: list[Trajectory], callbacks: Sequence[Callback] | None = None, + clear_ctx_at_each_iter: bool = True, ): self.config = config self.agent = agent @@ -286,6 +292,9 @@ def __init__( # copy so we can shuffle self.train_trajectories = train_trajectories.copy() self.callbacks = callbacks or [] + if clear_ctx_at_each_iter: + clear_cb = ClearContextCallback() + self.callbacks = [*self.callbacks, clear_cb] if callbacks else [clear_cb] async def train(self) -> None: random.shuffle(self.train_trajectories) @@ -317,6 +326,3 @@ async def train(self) -> None: await asyncio.gather(*[ callback.after_train_step(batch) for callback in self.callbacks ]) - - # Clear all op contexts - OpCtx.clear_registry() diff --git a/ldp/graph/ops.py b/ldp/graph/ops.py index c122807c..6084683f 100644 --- a/ldp/graph/ops.py +++ b/ldp/graph/ops.py @@ -337,7 +337,7 @@ def get_or_create(cls, op_name: str) -> OpCtx: return cls(op_name=op_name) # Create @classmethod - def clear_registry(cls, op_names: Collection[str] | None = None) -> None: + def clear_data(cls, op_names: Collection[str] | None = None) -> None: """Clear the context registry. If op_names is provided, only clear those contexts.""" if op_names is None: op_names = cls._CTX_REGISTRY.keys() diff --git a/tests/test_ops.py b/tests/test_ops.py index f0cd0b15..bb9c766c 100644 --- a/tests/test_ops.py +++ b/tests/test_ops.py @@ -347,13 +347,13 @@ async def test_clear_contexts(): # Test global clear await op2(await op1(1)) assert len(op1.ctx.data) == len(op2.ctx.data) == 1 - OpCtx.clear_registry() + OpCtx.clear_data() assert not op1.ctx.data assert not op2.ctx.data # Test global clear by op name await op2(await op1(1)) - OpCtx.clear_registry(op_names=["op1"]) + OpCtx.clear_data(op_names=["op1"]) assert not op1.ctx.data assert len(op2.ctx.data) == 1