diff --git a/ldp/data_structures.py b/ldp/data_structures.py index cc01e22..ddaf6bb 100644 --- a/ldp/data_structures.py +++ b/ldp/data_structures.py @@ -3,7 +3,7 @@ import json import logging import os -from collections.abc import Callable, Hashable +from collections.abc import Callable, Hashable, Iterable from typing import Any, ClassVar, Self, cast from uuid import UUID @@ -272,7 +272,7 @@ def assign_mc_value_estimates(self, discount_factor: float = 1.0) -> None: if children := list(self.tree.successors(step_id)): # V_{t+1}(s') = sum_{a'} p(a'|s') * Q_{t+1}(s', a') - # Here we assume p(a'|s') is uniform over the sampled actions.. + # Here we assume p(a'|s') is uniform over the sampled actions. # TODO: don't make that assumption where a logprob is available weights = [self.get_weight(child_id) for child_id in children] steps = [self.get_transition(child_id) for child_id in children] @@ -286,6 +286,47 @@ def assign_mc_value_estimates(self, discount_factor: float = 1.0) -> None: # (we are assuming the environment is deterministic) step.value = step.reward + discount_factor * v_tp1 + def compute_advantages(self) -> None: + """Replace Transition.value with an advantage (in-place). + + A(s, a) = Q(s, a) - V(s), where V(s) is estimated as the + average of Q(s, a') over all a' sampled at s. + + TODO: put this in Transition.metadata['advantage']. Not doing + this right now due to implementation details in an optimizer. + """ + state_values: dict[str, float] = {} + + for step_id in cast(Iterable[str], nx.topological_sort(self.tree)): + # topological sort means we will update a parent node in-place before + # descending to its children + + step: Transition | None = self.tree.nodes[step_id]["transition"] + if step is None: + state_values[step_id] = 0.0 + continue + + # First, update V_t so that we can compute A_{t+1} for children + children = [ + self.tree.nodes[child_id] for child_id in self.tree.successors(step_id) + ] + if children: + state_action_values = [child["transition"].value for child in children] + weights = [child["weight"] for child in children] + state_values[step_id] = sum( + w * v for w, v in zip(weights, state_action_values, strict=True) + ) / sum(weights) + + # Now compute A_t and replace Q_t with it in-place + # Note that we are guaranteed at least one parent, since the `step is None` + # check above should have caught the root node. + parent_id, *extra = list(self.rev_tree.successors(step_id)) + assert not extra, "self.tree is not a tree!" + step.value -= state_values[parent_id] + # TODO: switch to the following, instead of overwriting step.value. + # See docstring for explanation. + # step.metadata["advantage"] = step.value - state_values[parent_id] + def merge_identical_nodes( self, agent_state_hash_fn: Callable[[Any], Hashable], diff --git a/tests/test_data_structures.py b/tests/test_data_structures.py index 267b4c5..7a72fb4 100644 --- a/tests/test_data_structures.py +++ b/tests/test_data_structures.py @@ -47,6 +47,10 @@ def test_tree_mc_value(): 0.0 + 0.9 * ((1.9 - 1) / 2), rel=0.001 ) + # Check we can compute advantages w/o crashing for now. TODO: test the assigned + # advantages. Will do so after the TODO in compute_advantages() is resolved. + tree.compute_advantages() + def test_tree_node_merging() -> None: root_id = "dummy"