Skip to content

Commit

Permalink
feedback fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
albertbou92 committed Sep 20, 2024
1 parent a6612bc commit 1c1a096
Show file tree
Hide file tree
Showing 4 changed files with 26 additions and 12 deletions.
10 changes: 9 additions & 1 deletion ldp/alg/callbacks.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@

from ldp.agent import Agent
from ldp.data_structures import Trajectory, Transition
from ldp.graph.ops import OpResult
from ldp.graph.ops import OpCtx, OpResult

try:
import wandb
Expand Down Expand Up @@ -375,3 +375,11 @@ def _eval_log(metrics: list[dict[str, list[float]]]) -> None:
f"eval/{key}_mean": sum(vals) / len(vals)
for key, vals in flattened_metrics.items()
})


class ClearContextCallback(Callback):
async def after_eval_step(self, trajectories: Sequence[Trajectory]) -> None:
OpCtx.clear_data()

async def after_train_step(self, trajectories: Sequence[Trajectory]) -> None:
OpCtx.clear_data()
22 changes: 14 additions & 8 deletions ldp/alg/runners.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,12 +11,12 @@
from tqdm import tqdm, trange

from ldp.agent import Agent
from ldp.alg.callbacks import Callback
from ldp.alg.callbacks import Callback, ClearContextCallback
from ldp.alg.optimizer import Optimizer
from ldp.alg.rollout import RolloutManager
from ldp.data_structures import Trajectory
from ldp.graph.op_utils import eval_mode, train_mode
from ldp.graph.ops import OpCtx, OpResult
from ldp.graph.ops import OpResult


async def _run_eval_loop(
Expand Down Expand Up @@ -96,11 +96,15 @@ def __init__(
agent: Agent,
dataset: TaskDataset,
callbacks: Sequence[Callback] | None = None,
clear_ctx_at_each_iter: bool = True,
):
self.config = config
self.agent = agent
self.dataset = dataset
self.callbacks = callbacks or []
if clear_ctx_at_each_iter:
clear_cb = ClearContextCallback()
self.callbacks = [*self.callbacks, clear_cb] if callbacks else [clear_cb]
self.rollout_manager = self.config.make_rollout_manager(agent, self.callbacks)

@eval_mode()
Expand Down Expand Up @@ -159,6 +163,7 @@ def __init__(
train_dataset: TaskDataset,
eval_dataset: TaskDataset | None = None,
callbacks: Sequence[Callback] | None = None,
clear_ctx_at_each_iter: bool = True,
):
if config.eval_every is not None and eval_dataset is None:
raise ValueError("Must specify eval_dataset if eval_every is set")
Expand All @@ -169,6 +174,9 @@ def __init__(
self.eval_dataset = eval_dataset
self.optimizer = optimizer
self.callbacks = callbacks or []
if clear_ctx_at_each_iter:
clear_cb = ClearContextCallback()
self.callbacks = [*self.callbacks, clear_cb] if callbacks else [clear_cb]
self.rollout_manager = self.config.make_rollout_manager(
agent=agent, callbacks=self.callbacks
)
Expand Down Expand Up @@ -197,9 +205,6 @@ async def train(self) -> None:
if pbar.n == self.config.num_train_iterations:
break

# Clear all op contexts
OpCtx.clear_registry()

pbar.close()

await self._eval_loop()
Expand Down Expand Up @@ -279,13 +284,17 @@ def __init__(
optimizer: Optimizer,
train_trajectories: list[Trajectory],
callbacks: Sequence[Callback] | None = None,
clear_ctx_at_each_iter: bool = True,
):
self.config = config
self.agent = agent
self.optimizer = optimizer
# copy so we can shuffle
self.train_trajectories = train_trajectories.copy()
self.callbacks = callbacks or []
if clear_ctx_at_each_iter:
clear_cb = ClearContextCallback()
self.callbacks = [*self.callbacks, clear_cb] if callbacks else [clear_cb]

async def train(self) -> None:
random.shuffle(self.train_trajectories)
Expand Down Expand Up @@ -317,6 +326,3 @@ async def train(self) -> None:
await asyncio.gather(*[
callback.after_train_step(batch) for callback in self.callbacks
])

# Clear all op contexts
OpCtx.clear_registry()
2 changes: 1 addition & 1 deletion ldp/graph/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -337,7 +337,7 @@ def get_or_create(cls, op_name: str) -> OpCtx:
return cls(op_name=op_name) # Create

@classmethod
def clear_registry(cls, op_names: Collection[str] | None = None) -> None:
def clear_data(cls, op_names: Collection[str] | None = None) -> None:
"""Clear the context registry. If op_names is provided, only clear those contexts."""
if op_names is None:
op_names = cls._CTX_REGISTRY.keys()
Expand Down
4 changes: 2 additions & 2 deletions tests/test_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -347,13 +347,13 @@ async def test_clear_contexts():
# Test global clear
await op2(await op1(1))
assert len(op1.ctx.data) == len(op2.ctx.data) == 1
OpCtx.clear_registry()
OpCtx.clear_data()
assert not op1.ctx.data
assert not op2.ctx.data

# Test global clear by op name
await op2(await op1(1))
OpCtx.clear_registry(op_names=["op1"])
OpCtx.clear_data(op_names=["op1"])
assert not op1.ctx.data
assert len(op2.ctx.data) == 1

Expand Down

0 comments on commit 1c1a096

Please sign in to comment.