Skip to content

Commit

Permalink
Created __init__.py shortcuts for ldp.alg (#81)
Browse files Browse the repository at this point in the history
  • Loading branch information
jamesbraza authored Oct 2, 2024
1 parent 21e90c6 commit 4adc48c
Show file tree
Hide file tree
Showing 5 changed files with 55 additions and 12 deletions.
45 changes: 45 additions & 0 deletions ldp/alg/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
from .algorithms import discounted_returns, to_network
from .beam_search import Beam, BeamSearchRollout
from .callbacks import (
Callback,
ClearContextCallback,
ComputeTrajectoryMetricsMixin,
LoggingCallback,
MeanMetricsCallback,
RolloutDebugDumpCallback,
TrajectoryMetricsCallback,
WandBLoggingCallback,
)
from .rollout import RolloutManager
from .runners import (
Evaluator,
EvaluatorConfig,
OfflineTrainer,
OfflineTrainerConfig,
OnlineTrainer,
OnlineTrainerConfig,
)
from .tree_search import TreeSearchRollout

__all__ = [
"Beam",
"BeamSearchRollout",
"Callback",
"ClearContextCallback",
"ComputeTrajectoryMetricsMixin",
"Evaluator",
"EvaluatorConfig",
"LoggingCallback",
"MeanMetricsCallback",
"OfflineTrainer",
"OfflineTrainerConfig",
"OnlineTrainer",
"OnlineTrainerConfig",
"RolloutDebugDumpCallback",
"RolloutManager",
"TrajectoryMetricsCallback",
"TreeSearchRollout",
"WandBLoggingCallback",
"discounted_returns",
"to_network",
]
2 changes: 1 addition & 1 deletion tests/test_agents.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
SimpleAgentState,
make_simple_agent_server,
)
from ldp.alg.algorithms import to_network
from ldp.alg import to_network
from ldp.graph.common_ops import LLMCallOp
from ldp.graph.gradient_estimators import llm_straight_through_estimator as llm_ste
from ldp.graph.gradient_estimators import straight_through_estimator as ste
Expand Down
2 changes: 1 addition & 1 deletion tests/test_algorithms.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from aviary.env import DummyEnv

from ldp.agent import SimpleAgent
from ldp.alg.algorithms import discounted_returns
from ldp.alg import discounted_returns


@pytest.mark.asyncio
Expand Down
5 changes: 1 addition & 4 deletions tests/test_rollouts.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,7 @@
from pydantic import BaseModel

from ldp.agent import Agent, SimpleAgent, SimpleAgentState
from ldp.alg.beam_search import BeamSearchRollout
from ldp.alg.callbacks import Callback
from ldp.alg.rollout import RolloutManager
from ldp.alg.tree_search import TreeSearchRollout
from ldp.alg import BeamSearchRollout, Callback, RolloutManager, TreeSearchRollout
from ldp.data_structures import Trajectory, Transition
from ldp.graph.common_ops import FxnOp
from ldp.graph.op_utils import compute_graph, set_training_mode
Expand Down
13 changes: 7 additions & 6 deletions tests/test_runners.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,19 +6,20 @@
from aviary.env import DummyEnv, TaskDataset

from ldp.agent import MemoryAgent, SimpleAgent
from ldp.alg.callbacks import Callback, MeanMetricsCallback
from ldp.alg.datasets import ( # noqa: F401 # Force TASK_DATASET_REGISTRY update
DummyTaskDataset,
)
from ldp.alg.optimizer import default_optimizer_factory
from ldp.alg.runners import (
from ldp.alg import (
Callback,
Evaluator,
EvaluatorConfig,
MeanMetricsCallback,
OfflineTrainer,
OfflineTrainerConfig,
OnlineTrainer,
OnlineTrainerConfig,
)
from ldp.alg.datasets import ( # noqa: F401 # Force TASK_DATASET_REGISTRY update
DummyTaskDataset,
)
from ldp.alg.optimizer import default_optimizer_factory
from ldp.data_structures import Trajectory
from ldp.graph.ops import OpCtx

Expand Down

0 comments on commit 4adc48c

Please sign in to comment.