Skip to content

Commit

Permalink
Resolved discounted_returns circular import (#83)
Browse files Browse the repository at this point in the history
  • Loading branch information
jamesbraza authored Oct 2, 2024
1 parent 33b629b commit 4bcdd1b
Show file tree
Hide file tree
Showing 10 changed files with 62 additions and 58 deletions.
3 changes: 1 addition & 2 deletions ldp/alg/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from .algorithms import discounted_returns, to_network
from .algorithms import to_network
from .beam_search import Beam, BeamSearchRollout
from .callbacks import (
Callback,
Expand Down Expand Up @@ -42,6 +42,5 @@
"TrajectoryMetricsCallback",
"TreeSearchRollout",
"WandBLoggingCallback",
"discounted_returns",
"to_network",
]
45 changes: 0 additions & 45 deletions ldp/alg/algorithms.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,51 +9,6 @@
from ldp.graph.ops import GradOutType, OpResult


def discounted_returns(
rewards: list[float], terminated: list[bool], discount: float = 1.0
) -> list[float]:
r"""
Calculate the discounted returns for a list of rewards, considering termination flags and a discount factor.
The discounted return represents the future discounted rewards from each time step onwards, taking into account
whether an episode has terminated at each step.
The discounted return \( G_t \) is given by:
.. math::
G_t = \sum_{k=1}^{\infty} \gamma^{k-1} R_{t+k}
where:
- \( G_t \) is the discounted return starting from time step \( t \).
- \( \gamma \) is the discount factor.
- \( R_{t+k} \) is the reward received at time step \( t+k \).
Args:
rewards: A list of rewards at each time step.
terminated: A list of boolean flags indicating whether the episode terminated at each time step.
discount: Discount factor to apply to future rewards. Defaults to 1.0 which means no discounting is applied.
Returns:
A list of discounted returns (rewards to go), with each element representing the
total discounted reward from that step onwards.
Example:
>>> rewards = [1.0, 2.0, 3.0]
>>> terminated = [False, False, True]
>>> discounted_returns(rewards, terminated, discount=0.9)
[5.23, 4.7, 3.0]
"""
returns = []
r = 0.0
for reward, term in zip(reversed(rewards), reversed(terminated), strict=False):
# 1 - term is 0 if the episode has terminated
r = reward + discount * r * (1 - term)
returns.append(r)
returns.reverse()
return returns


def to_network( # noqa: C901
op_result: OpResult,
max_label_height: int | None = None,
Expand Down
5 changes: 3 additions & 2 deletions ldp/alg/beam_search.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,11 @@
from aviary.env import Environment

from ldp.agent.agent import Agent, TAgentState
from ldp.alg.callbacks import Callback
from ldp.alg.rollout import AgentError, EnvError, TEnv, reraise_exc_as
from ldp.data_structures import Trajectory, Transition

from .callbacks import Callback
from .rollout import AgentError, EnvError, TEnv, reraise_exc_as


class Beam(NamedTuple):
# An ongoing beam contains two things: the trajectory up to now
Expand Down
2 changes: 1 addition & 1 deletion ldp/alg/datasets.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from aviary.env import TASK_DATASET_REGISTRY
from aviary.env import DummyTaskDataset as _DummyTaskDataset

from ldp.alg.callbacks import ComputeTrajectoryMetricsMixin
from .callbacks import ComputeTrajectoryMetricsMixin


class DummyTaskDataset(_DummyTaskDataset, ComputeTrajectoryMetricsMixin):
Expand Down
3 changes: 2 additions & 1 deletion ldp/alg/rollout.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,10 @@
from aviary.message import Message

from ldp.agent import Agent
from ldp.alg.callbacks import Callback
from ldp.data_structures import Trajectory, Transition

from .callbacks import Callback

logger = logging.getLogger(__name__)


Expand Down
5 changes: 3 additions & 2 deletions ldp/alg/runners.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,14 +10,15 @@
from pydantic import BaseModel, ConfigDict, Field

from ldp.agent import Agent
from ldp.alg.callbacks import Callback, ClearContextCallback
from ldp.alg.optimizer import Optimizer
from ldp.alg.rollout import RolloutManager
from ldp.data_structures import Trajectory
from ldp.graph.op_utils import eval_mode, train_mode
from ldp.graph.ops import OpResult
from ldp.shims import tqdm, trange

from .callbacks import Callback, ClearContextCallback
from .rollout import RolloutManager


async def _run_eval_loop(
dataset: TaskDataset,
Expand Down
7 changes: 4 additions & 3 deletions ldp/alg/tree_search.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,16 +8,17 @@
from aviary.utils import is_coroutine_callable

from ldp.agent import Agent
from ldp.alg.callbacks import Callback
from ldp.alg.rollout import (
from ldp.data_structures import TransitionTree

from .callbacks import Callback
from .rollout import (
AgentError,
CaughtError,
EnvError,
RolloutManager,
TEnv,
reraise_exc_as,
)
from ldp.data_structures import TransitionTree

logger = logging.getLogger(__name__)

Expand Down
2 changes: 1 addition & 1 deletion ldp/data_structures.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,8 @@
from aviary.tools import ToolRequestMessage, ToolResponseMessage
from pydantic import BaseModel, ConfigDict, Field, JsonValue, field_validator

from ldp.alg.algorithms import discounted_returns
from ldp.graph.ops import OpResult
from ldp.utils import discounted_returns

logger = logging.getLogger(__name__)

Expand Down
46 changes: 46 additions & 0 deletions ldp/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,3 +58,49 @@ def configure_stdout_logs(
},
"root": {"level": level, "handlers": ["stdout"]},
})


def discounted_returns(
rewards: list[float], terminated: list[bool], discount: float = 1.0
) -> list[float]:
r"""
Calculate the discounted returns for a list of rewards, considering termination flags and a discount factor.
The discounted return represents the future discounted rewards from each time step onwards, taking into account
whether an episode has terminated at each step.
The discounted return \( G_t \) is given by:
.. math::
G_t = \sum_{k=1}^{\infty} \gamma^{k-1} R_{t+k}
where:
- \( G_t \) is the discounted return starting from time step \( t \).
- \( \gamma \) is the discount factor.
- \( R_{t+k} \) is the reward received at time step \( t+k \).
NOTE: this could live in ldp.alg, but it's here to avoid circular imports.
Args:
rewards: A list of rewards at each time step.
terminated: A list of boolean flags indicating whether the episode terminated at each time step.
discount: Discount factor to apply to future rewards. Defaults to 1.0 which means no discounting is applied.
Returns:
A list of discounted returns (rewards to go), with each element representing the
total discounted reward from that step onwards.
Example:
>>> rewards = [1.0, 2.0, 3.0]
>>> terminated = [False, False, True]
>>> discounted_returns(rewards, terminated, discount=0.9)
[5.23, 4.7, 3.0]
"""
returns = []
r = 0.0
for reward, term in zip(reversed(rewards), reversed(terminated), strict=False):
# 1 - term is 0 if the episode has terminated
r = reward + discount * r * (1 - term)
returns.append(r)
returns.reverse()
return returns
2 changes: 1 addition & 1 deletion tests/test_algorithms.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from aviary.env import DummyEnv

from ldp.agent import SimpleAgent
from ldp.alg import discounted_returns
from ldp.utils import discounted_returns


@pytest.mark.asyncio
Expand Down

0 comments on commit 4bcdd1b

Please sign in to comment.