Skip to content

Commit

Permalink
feedback fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
albertbou92 committed Sep 20, 2024
1 parent ed6afa6 commit 9b2829d
Show file tree
Hide file tree
Showing 2 changed files with 34 additions and 15 deletions.
6 changes: 3 additions & 3 deletions ldp/alg/runners.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ class EvaluatorConfig(BaseModel):
max_rollout_steps: int | None = None
catch_agent_failures: bool = True
catch_env_failures: bool = True
clear_ctx_at_each_iter: bool = True
clear_ctx_at_each_iter: bool = False

def make_rollout_manager(
self, agent: Agent, callbacks: Sequence[Callback]
Expand Down Expand Up @@ -152,7 +152,7 @@ class OnlineTrainerConfig(EvaluatorConfig):
True, # noqa: FBT003
description="If True (default), run an evaluation loop before training.",
)
clear_ctx_at_each_iter: bool = True
clear_ctx_at_each_iter: bool = False


class OnlineTrainer:
Expand Down Expand Up @@ -273,7 +273,7 @@ class OfflineTrainerConfig(BaseModel):
1,
description="Number of training iterations to run before updating the model.",
)
clear_ctx_at_each_iter: bool = True
clear_ctx_at_each_iter: bool = False
# TODO: add some concept of eval loops


Expand Down
43 changes: 31 additions & 12 deletions tests/test_runners.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,8 @@


@pytest.mark.asyncio
async def test_online_trainer():
@pytest.mark.parametrize("clear_ctx_at_each_iter", [True, False])
async def test_online_trainer(clear_ctx_at_each_iter):
agent = MemoryAgent()
opt = default_optimizer_factory(agent)
dataset = TaskDataset.from_name("dummy")
Expand All @@ -35,6 +36,7 @@ async def test_online_trainer():
max_rollout_steps=1,
num_eval_iterations=1,
eval_every=1,
clear_ctx_at_each_iter=clear_ctx_at_each_iter,
)
trainer = OnlineTrainer(
config=train_conf,
Expand All @@ -50,18 +52,24 @@ async def test_online_trainer():
# eval is run 3 times: before training, during training, after training
assert v == (3 if "eval" in k else 1)

for ctx_data in OpCtx._CTX_REGISTRY.values():
assert not ctx_data.data
if clear_ctx_at_each_iter:
all(not ctx_data.data for ctx_data in OpCtx._CTX_REGISTRY.values())
else:
any(ctx_data.data for ctx_data in OpCtx._CTX_REGISTRY.values())


@pytest.mark.asyncio
async def test_evaluator() -> None:
@pytest.mark.parametrize("clear_ctx_at_each_iter", [True, False])
async def test_evaluator(clear_ctx_at_each_iter) -> None:
agent = SimpleAgent()
dataset = TaskDataset.from_name("dummy")
metrics_callback = MeanMetricsCallback(eval_dataset=dataset)
count_callback = DummyCallback()

eval_conf = EvaluatorConfig(num_eval_iterations=1)
eval_conf = EvaluatorConfig(
num_eval_iterations=1,
clear_ctx_at_each_iter=clear_ctx_at_each_iter,
)
evaluator = Evaluator(
config=eval_conf,
agent=agent,
Expand All @@ -77,12 +85,15 @@ async def test_evaluator() -> None:
for k, v in count_callback.fn_invocations.items():
assert v == (1 if "eval" in k else 0)

for ctx_data in OpCtx._CTX_REGISTRY.values():
assert not ctx_data.data
if clear_ctx_at_each_iter:
all(not ctx_data.data for ctx_data in OpCtx._CTX_REGISTRY.values())
else:
any(ctx_data.data for ctx_data in OpCtx._CTX_REGISTRY.values())


@pytest.mark.asyncio
async def test_offline_trainer():
@pytest.mark.parametrize("clear_ctx_at_each_iter", [True, False])
async def test_offline_trainer(clear_ctx_at_each_iter):
# This is kind of a system test of getting trajectories from the evaluator
# and then training on them "offline"
agent = MemoryAgent()
Expand All @@ -91,7 +102,10 @@ async def test_offline_trainer():
traj_callback = StoreTrajectoriesCallback()

evaluator = Evaluator(
config=EvaluatorConfig(num_eval_iterations=1),
config=EvaluatorConfig(
num_eval_iterations=1,
clear_ctx_at_each_iter=clear_ctx_at_each_iter,
),
agent=agent,
dataset=dataset,
callbacks=[traj_callback],
Expand All @@ -100,7 +114,10 @@ async def test_offline_trainer():
assert len(traj_callback.trajectories) == 1

count_callback = DummyCallback()
train_conf = OfflineTrainerConfig(batch_size=1)
train_conf = OfflineTrainerConfig(
batch_size=1,
clear_ctx_at_each_iter=clear_ctx_at_each_iter,
)
trainer = OfflineTrainer(
config=train_conf,
agent=agent,
Expand All @@ -117,8 +134,10 @@ async def test_offline_trainer():
"after_update": 1,
}

for ctx_data in OpCtx._CTX_REGISTRY.values():
assert not ctx_data.data
if clear_ctx_at_each_iter:
all(not ctx_data.data for ctx_data in OpCtx._CTX_REGISTRY.values())
else:
any(ctx_data.data for ctx_data in OpCtx._CTX_REGISTRY.values())


class StoreTrajectoriesCallback(Callback):
Expand Down

0 comments on commit 9b2829d

Please sign in to comment.