Skip to content

Commit

Permalink
Computing advantages on trees (#140)
Browse files Browse the repository at this point in the history
  • Loading branch information
sidnarayanan authored Nov 12, 2024
1 parent a7a6764 commit 0875041
Show file tree
Hide file tree
Showing 2 changed files with 47 additions and 2 deletions.
45 changes: 43 additions & 2 deletions ldp/data_structures.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import json
import logging
import os
from collections.abc import Callable, Hashable
from collections.abc import Callable, Hashable, Iterable
from typing import Any, ClassVar, Self, cast
from uuid import UUID

Expand Down Expand Up @@ -272,7 +272,7 @@ def assign_mc_value_estimates(self, discount_factor: float = 1.0) -> None:

if children := list(self.tree.successors(step_id)):
# V_{t+1}(s') = sum_{a'} p(a'|s') * Q_{t+1}(s', a')
# Here we assume p(a'|s') is uniform over the sampled actions..
# Here we assume p(a'|s') is uniform over the sampled actions.
# TODO: don't make that assumption where a logprob is available
weights = [self.get_weight(child_id) for child_id in children]
steps = [self.get_transition(child_id) for child_id in children]
Expand All @@ -286,6 +286,47 @@ def assign_mc_value_estimates(self, discount_factor: float = 1.0) -> None:
# (we are assuming the environment is deterministic)
step.value = step.reward + discount_factor * v_tp1

def compute_advantages(self) -> None:
"""Replace Transition.value with an advantage (in-place).
A(s, a) = Q(s, a) - V(s), where V(s) is estimated as the
average of Q(s, a') over all a' sampled at s.
TODO: put this in Transition.metadata['advantage']. Not doing
this right now due to implementation details in an optimizer.
"""
state_values: dict[str, float] = {}

for step_id in cast(Iterable[str], nx.topological_sort(self.tree)):
# topological sort means we will update a parent node in-place before
# descending to its children

step: Transition | None = self.tree.nodes[step_id]["transition"]
if step is None:
state_values[step_id] = 0.0
continue

# First, update V_t so that we can compute A_{t+1} for children
children = [
self.tree.nodes[child_id] for child_id in self.tree.successors(step_id)
]
if children:
state_action_values = [child["transition"].value for child in children]
weights = [child["weight"] for child in children]
state_values[step_id] = sum(
w * v for w, v in zip(weights, state_action_values, strict=True)
) / sum(weights)

# Now compute A_t and replace Q_t with it in-place
# Note that we are guaranteed at least one parent, since the `step is None`
# check above should have caught the root node.
parent_id, *extra = list(self.rev_tree.successors(step_id))
assert not extra, "self.tree is not a tree!"
step.value -= state_values[parent_id]
# TODO: switch to the following, instead of overwriting step.value.
# See docstring for explanation.
# step.metadata["advantage"] = step.value - state_values[parent_id]

def merge_identical_nodes(
self,
agent_state_hash_fn: Callable[[Any], Hashable],
Expand Down
4 changes: 4 additions & 0 deletions tests/test_data_structures.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,10 @@ def test_tree_mc_value():
0.0 + 0.9 * ((1.9 - 1) / 2), rel=0.001
)

# Check we can compute advantages w/o crashing for now. TODO: test the assigned
# advantages. Will do so after the TODO in compute_advantages() is resolved.
tree.compute_advantages()


def test_tree_node_merging() -> None:
root_id = "dummy"
Expand Down

0 comments on commit 0875041

Please sign in to comment.