diff --git a/ldp/alg/__init__.py b/ldp/alg/__init__.py index d0524022..e33b5bb9 100644 --- a/ldp/alg/__init__.py +++ b/ldp/alg/__init__.py @@ -1,4 +1,4 @@ -from .algorithms import discounted_returns, to_network +from .algorithms import to_network from .beam_search import Beam, BeamSearchRollout from .callbacks import ( Callback, @@ -42,6 +42,5 @@ "TrajectoryMetricsCallback", "TreeSearchRollout", "WandBLoggingCallback", - "discounted_returns", "to_network", ] diff --git a/ldp/alg/algorithms.py b/ldp/alg/algorithms.py index 567ded74..9e74845d 100644 --- a/ldp/alg/algorithms.py +++ b/ldp/alg/algorithms.py @@ -9,51 +9,6 @@ from ldp.graph.ops import GradOutType, OpResult -def discounted_returns( - rewards: list[float], terminated: list[bool], discount: float = 1.0 -) -> list[float]: - r""" - Calculate the discounted returns for a list of rewards, considering termination flags and a discount factor. - - The discounted return represents the future discounted rewards from each time step onwards, taking into account - whether an episode has terminated at each step. - - The discounted return \( G_t \) is given by: - - .. math:: - G_t = \sum_{k=1}^{\infty} \gamma^{k-1} R_{t+k} - - where: - - \( G_t \) is the discounted return starting from time step \( t \). - - \( \gamma \) is the discount factor. - - \( R_{t+k} \) is the reward received at time step \( t+k \). - - - Args: - rewards: A list of rewards at each time step. - terminated: A list of boolean flags indicating whether the episode terminated at each time step. - discount: Discount factor to apply to future rewards. Defaults to 1.0 which means no discounting is applied. - - Returns: - A list of discounted returns (rewards to go), with each element representing the - total discounted reward from that step onwards. - - Example: - >>> rewards = [1.0, 2.0, 3.0] - >>> terminated = [False, False, True] - >>> discounted_returns(rewards, terminated, discount=0.9) - [5.23, 4.7, 3.0] - """ - returns = [] - r = 0.0 - for reward, term in zip(reversed(rewards), reversed(terminated), strict=False): - # 1 - term is 0 if the episode has terminated - r = reward + discount * r * (1 - term) - returns.append(r) - returns.reverse() - return returns - - def to_network( # noqa: C901 op_result: OpResult, max_label_height: int | None = None, diff --git a/ldp/alg/beam_search.py b/ldp/alg/beam_search.py index a5fc2aaa..1c883e9f 100644 --- a/ldp/alg/beam_search.py +++ b/ldp/alg/beam_search.py @@ -7,10 +7,11 @@ from aviary.env import Environment from ldp.agent.agent import Agent, TAgentState -from ldp.alg.callbacks import Callback -from ldp.alg.rollout import AgentError, EnvError, TEnv, reraise_exc_as from ldp.data_structures import Trajectory, Transition +from .callbacks import Callback +from .rollout import AgentError, EnvError, TEnv, reraise_exc_as + class Beam(NamedTuple): # An ongoing beam contains two things: the trajectory up to now diff --git a/ldp/alg/datasets.py b/ldp/alg/datasets.py index 687587fc..a86892e7 100644 --- a/ldp/alg/datasets.py +++ b/ldp/alg/datasets.py @@ -1,7 +1,7 @@ from aviary.env import TASK_DATASET_REGISTRY from aviary.env import DummyTaskDataset as _DummyTaskDataset -from ldp.alg.callbacks import ComputeTrajectoryMetricsMixin +from .callbacks import ComputeTrajectoryMetricsMixin class DummyTaskDataset(_DummyTaskDataset, ComputeTrajectoryMetricsMixin): diff --git a/ldp/alg/rollout.py b/ldp/alg/rollout.py index ace16eb7..eec32ae5 100644 --- a/ldp/alg/rollout.py +++ b/ldp/alg/rollout.py @@ -10,9 +10,10 @@ from aviary.message import Message from ldp.agent import Agent -from ldp.alg.callbacks import Callback from ldp.data_structures import Trajectory, Transition +from .callbacks import Callback + logger = logging.getLogger(__name__) diff --git a/ldp/alg/runners.py b/ldp/alg/runners.py index 2a801997..010bfc45 100644 --- a/ldp/alg/runners.py +++ b/ldp/alg/runners.py @@ -10,14 +10,15 @@ from pydantic import BaseModel, ConfigDict, Field from ldp.agent import Agent -from ldp.alg.callbacks import Callback, ClearContextCallback from ldp.alg.optimizer import Optimizer -from ldp.alg.rollout import RolloutManager from ldp.data_structures import Trajectory from ldp.graph.op_utils import eval_mode, train_mode from ldp.graph.ops import OpResult from ldp.shims import tqdm, trange +from .callbacks import Callback, ClearContextCallback +from .rollout import RolloutManager + async def _run_eval_loop( dataset: TaskDataset, diff --git a/ldp/alg/tree_search.py b/ldp/alg/tree_search.py index f4621f54..6662c3f6 100644 --- a/ldp/alg/tree_search.py +++ b/ldp/alg/tree_search.py @@ -8,8 +8,10 @@ from aviary.utils import is_coroutine_callable from ldp.agent import Agent -from ldp.alg.callbacks import Callback -from ldp.alg.rollout import ( +from ldp.data_structures import TransitionTree + +from .callbacks import Callback +from .rollout import ( AgentError, CaughtError, EnvError, @@ -17,7 +19,6 @@ TEnv, reraise_exc_as, ) -from ldp.data_structures import TransitionTree logger = logging.getLogger(__name__) diff --git a/ldp/data_structures.py b/ldp/data_structures.py index c2814be6..00c923d2 100644 --- a/ldp/data_structures.py +++ b/ldp/data_structures.py @@ -12,8 +12,8 @@ from aviary.tools import ToolRequestMessage, ToolResponseMessage from pydantic import BaseModel, ConfigDict, Field, JsonValue, field_validator -from ldp.alg.algorithms import discounted_returns from ldp.graph.ops import OpResult +from ldp.utils import discounted_returns logger = logging.getLogger(__name__) diff --git a/ldp/utils.py b/ldp/utils.py index c6ebe2e7..78b44276 100644 --- a/ldp/utils.py +++ b/ldp/utils.py @@ -58,3 +58,49 @@ def configure_stdout_logs( }, "root": {"level": level, "handlers": ["stdout"]}, }) + + +def discounted_returns( + rewards: list[float], terminated: list[bool], discount: float = 1.0 +) -> list[float]: + r""" + Calculate the discounted returns for a list of rewards, considering termination flags and a discount factor. + + The discounted return represents the future discounted rewards from each time step onwards, taking into account + whether an episode has terminated at each step. + + The discounted return \( G_t \) is given by: + + .. math:: + G_t = \sum_{k=1}^{\infty} \gamma^{k-1} R_{t+k} + + where: + - \( G_t \) is the discounted return starting from time step \( t \). + - \( \gamma \) is the discount factor. + - \( R_{t+k} \) is the reward received at time step \( t+k \). + + NOTE: this could live in ldp.alg, but it's here to avoid circular imports. + + Args: + rewards: A list of rewards at each time step. + terminated: A list of boolean flags indicating whether the episode terminated at each time step. + discount: Discount factor to apply to future rewards. Defaults to 1.0 which means no discounting is applied. + + Returns: + A list of discounted returns (rewards to go), with each element representing the + total discounted reward from that step onwards. + + Example: + >>> rewards = [1.0, 2.0, 3.0] + >>> terminated = [False, False, True] + >>> discounted_returns(rewards, terminated, discount=0.9) + [5.23, 4.7, 3.0] + """ + returns = [] + r = 0.0 + for reward, term in zip(reversed(rewards), reversed(terminated), strict=False): + # 1 - term is 0 if the episode has terminated + r = reward + discount * r * (1 - term) + returns.append(r) + returns.reverse() + return returns diff --git a/tests/test_algorithms.py b/tests/test_algorithms.py index a953dec8..e1d45d64 100644 --- a/tests/test_algorithms.py +++ b/tests/test_algorithms.py @@ -2,7 +2,7 @@ from aviary.env import DummyEnv from ldp.agent import SimpleAgent -from ldp.alg import discounted_returns +from ldp.utils import discounted_returns @pytest.mark.asyncio