From 4adc48c9105f3e4b4b65aa37e539993336e0d6dd Mon Sep 17 00:00:00 2001 From: James Braza Date: Wed, 2 Oct 2024 15:11:49 -0700 Subject: [PATCH] Created `__init__.py` shortcuts for `ldp.alg` (#81) --- ldp/alg/__init__.py | 45 ++++++++++++++++++++++++++++++++++++++++ tests/test_agents.py | 2 +- tests/test_algorithms.py | 2 +- tests/test_rollouts.py | 5 +---- tests/test_runners.py | 13 ++++++------ 5 files changed, 55 insertions(+), 12 deletions(-) diff --git a/ldp/alg/__init__.py b/ldp/alg/__init__.py index e69de29b..046766a4 100644 --- a/ldp/alg/__init__.py +++ b/ldp/alg/__init__.py @@ -0,0 +1,45 @@ +from .algorithms import discounted_returns, to_network +from .beam_search import Beam, BeamSearchRollout +from .callbacks import ( + Callback, + ClearContextCallback, + ComputeTrajectoryMetricsMixin, + LoggingCallback, + MeanMetricsCallback, + RolloutDebugDumpCallback, + TrajectoryMetricsCallback, + WandBLoggingCallback, +) +from .rollout import RolloutManager +from .runners import ( + Evaluator, + EvaluatorConfig, + OfflineTrainer, + OfflineTrainerConfig, + OnlineTrainer, + OnlineTrainerConfig, +) +from .tree_search import TreeSearchRollout + +__all__ = [ + "Beam", + "BeamSearchRollout", + "Callback", + "ClearContextCallback", + "ComputeTrajectoryMetricsMixin", + "Evaluator", + "EvaluatorConfig", + "LoggingCallback", + "MeanMetricsCallback", + "OfflineTrainer", + "OfflineTrainerConfig", + "OnlineTrainer", + "OnlineTrainerConfig", + "RolloutDebugDumpCallback", + "RolloutManager", + "TrajectoryMetricsCallback", + "TreeSearchRollout", + "WandBLoggingCallback", + "discounted_returns", + "to_network", +] diff --git a/tests/test_agents.py b/tests/test_agents.py index 0f68f191..a15110e0 100644 --- a/tests/test_agents.py +++ b/tests/test_agents.py @@ -24,7 +24,7 @@ SimpleAgentState, make_simple_agent_server, ) -from ldp.alg.algorithms import to_network +from ldp.alg import to_network from ldp.graph.common_ops import LLMCallOp from ldp.graph.gradient_estimators import llm_straight_through_estimator as llm_ste from ldp.graph.gradient_estimators import straight_through_estimator as ste diff --git a/tests/test_algorithms.py b/tests/test_algorithms.py index 97cb753f..a953dec8 100644 --- a/tests/test_algorithms.py +++ b/tests/test_algorithms.py @@ -2,7 +2,7 @@ from aviary.env import DummyEnv from ldp.agent import SimpleAgent -from ldp.alg.algorithms import discounted_returns +from ldp.alg import discounted_returns @pytest.mark.asyncio diff --git a/tests/test_rollouts.py b/tests/test_rollouts.py index fc3ea53e..d52e8525 100644 --- a/tests/test_rollouts.py +++ b/tests/test_rollouts.py @@ -11,10 +11,7 @@ from pydantic import BaseModel from ldp.agent import Agent, SimpleAgent, SimpleAgentState -from ldp.alg.beam_search import BeamSearchRollout -from ldp.alg.callbacks import Callback -from ldp.alg.rollout import RolloutManager -from ldp.alg.tree_search import TreeSearchRollout +from ldp.alg import BeamSearchRollout, Callback, RolloutManager, TreeSearchRollout from ldp.data_structures import Trajectory, Transition from ldp.graph.common_ops import FxnOp from ldp.graph.op_utils import compute_graph, set_training_mode diff --git a/tests/test_runners.py b/tests/test_runners.py index 4e010e50..fbbaf57a 100644 --- a/tests/test_runners.py +++ b/tests/test_runners.py @@ -6,19 +6,20 @@ from aviary.env import DummyEnv, TaskDataset from ldp.agent import MemoryAgent, SimpleAgent -from ldp.alg.callbacks import Callback, MeanMetricsCallback -from ldp.alg.datasets import ( # noqa: F401 # Force TASK_DATASET_REGISTRY update - DummyTaskDataset, -) -from ldp.alg.optimizer import default_optimizer_factory -from ldp.alg.runners import ( +from ldp.alg import ( + Callback, Evaluator, EvaluatorConfig, + MeanMetricsCallback, OfflineTrainer, OfflineTrainerConfig, OnlineTrainer, OnlineTrainerConfig, ) +from ldp.alg.datasets import ( # noqa: F401 # Force TASK_DATASET_REGISTRY update + DummyTaskDataset, +) +from ldp.alg.optimizer import default_optimizer_factory from ldp.data_structures import Trajectory from ldp.graph.ops import OpCtx