Skip to content

Commit

Permalink
PolicyOrValueNetworkを追加
Browse files Browse the repository at this point in the history
  • Loading branch information
Geson-anko committed Jul 25, 2024
1 parent 713c078 commit 6c57495
Show file tree
Hide file tree
Showing 3 changed files with 67 additions and 0 deletions.
2 changes: 2 additions & 0 deletions ami/models/model_names.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,3 +6,5 @@ class ModelNames(str, Enum):
IMAGE_DECODER = "image_decoder"
POLICY_VALUE = "policy_value"
FORWARD_DYNAMICS = "forward_dynamics"
POLICY = "policy"
VALUE = "value"
39 changes: 39 additions & 0 deletions ami/models/policy_or_value_network.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
import torch.nn as nn
from torch import Tensor
from torch.distributions import Distribution


class PolicyOrValueNetwork(nn.Module):
"""Module for policy or value network."""

def __init__(
self,
observation_projection: nn.Module,
forward_dynamics_hidden_projection: nn.Module,
observation_hidden_projection: nn.Module,
core_model: nn.Module,
dist_head: nn.Module,
) -> None:
"""Constructs the model with components.
Args:
observation_projection: Layer that processes observations only.
forward_dynamics_hidden_projection: Layer that processes hidden states of the Forward Dynamics model only.
observation_hidden_projection: Layer that receives and integrates observations and hidden states.
core_model: Layer that processes the integrated tensor.
dist_head: Layer that generates prediction distribution.
"""
super().__init__()
self.observation_projection = observation_projection
self.forward_dynamics_hidden_projection = forward_dynamics_hidden_projection
self.observation_hidden_projection = observation_hidden_projection
self.core_model = core_model
self.dist_head = dist_head

def forward(self, observation: Tensor, forward_dynamics_hidden: Tensor) -> Distribution:
"""Returns the prediction distribution."""
obs_embed = self.observation_projection(observation)
hidden_embed = self.forward_dynamics_hidden_projection(forward_dynamics_hidden)
x = self.observation_hidden_projection(obs_embed, hidden_embed)
h = self.core_model(x)
return self.dist_head(h)
26 changes: 26 additions & 0 deletions tests/models/test_policy_or_value_network.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
import pytest
import torch
import torch.nn as nn

from ami.models.components.fully_connected_normal import FullyConnectedNormal
from ami.models.policy_or_value_network import PolicyOrValueNetwork


class CatObsHidden(nn.Module):
def forward(self, obs, hidden):
return torch.cat([obs, hidden], -1)


class TestPolicyOrValueNetwork:
@pytest.fixture
def net(self) -> PolicyOrValueNetwork:
obs_layer = nn.Linear(128, 64)
hidden_layer = nn.Linear(256, 128)
obs_hidden_proj = CatObsHidden()
core_model = nn.Linear(128 + 64, 16)
head = FullyConnectedNormal(16, 8)
return PolicyOrValueNetwork(obs_layer, hidden_layer, obs_hidden_proj, core_model, head)

def test_forward(self, net: PolicyOrValueNetwork):
dist = net.forward(torch.randn(128), torch.randn(256))
assert dist.sample().shape == (8,)

0 comments on commit 6c57495

Please sign in to comment.