Skip to content

Commit

Permalink
Merge pull request #22 from MLShukai/myxy/lerp-stacked-hidden
Browse files Browse the repository at this point in the history
隠れ状態の線形補間
  • Loading branch information
myxyy authored Jul 23, 2024
2 parents 93a4e3c + 729d246 commit ced214e
Show file tree
Hide file tree
Showing 6 changed files with 221 additions and 0 deletions.
17 changes: 17 additions & 0 deletions ami/models/components/resnet.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
import torch.nn as nn
from torch import Tensor


class ResNetFF(nn.Module):
def __init__(self, dim: int, dim_hidden: int, depth: int, activation: nn.Module = nn.ReLU()):
super().__init__()
self.ff_list = nn.ModuleList(
[nn.Sequential(nn.Linear(dim, dim_hidden), activation, nn.Linear(dim_hidden, dim)) for _ in range(depth)]
)

def forward(self, x: Tensor) -> Tensor:
for ff in self.ff_list:
x_ = x
x = ff(x)
x = x + x_
return x
58 changes: 58 additions & 0 deletions ami/models/policy_value_common_net.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,3 +70,61 @@ def forward(self, flattened_obs: Tensor, stacked_hidden: Tensor) -> Tensor:
if self.transpose:
out = out.transpose(-2, -1)
return out


class LerpStackedHidden(nn.Module):
"""Linear interpolation along depth of stacked hidden.
Shape:
- stacked_hidden: (D, N) | (B, D, N)
Return shape: (N,) | (B, N)
"""

def __init__(self, dim: int, depth: int, num_head: int) -> None:
super().__init__()
self.hidden_linear_weight = nn.Parameter(torch.randn(depth, dim, dim) * (dim**-0.5))
self.hidden_linear_bias = nn.Parameter(torch.randn(depth, dim) * (dim**-0.5))
self.logit_coef_proj = nn.Linear(depth * dim, depth)
self.num_head = num_head
self.norm = nn.InstanceNorm1d(num_head)

def forward(self, stacked_hidden: Tensor) -> Tensor:
is_batch = len(stacked_hidden.shape) == 3
if not is_batch:
stacked_hidden = stacked_hidden.unsqueeze(0)

batch, depth, dim = stacked_hidden.shape
stacked_hidden = self.norm(stacked_hidden.reshape(batch * depth, self.num_head, dim // self.num_head)).reshape(
batch, depth, dim
)

logit_coef = self.logit_coef_proj(stacked_hidden.reshape(batch, depth * dim))

hidden_linear = torch.einsum(
"dij,bdj->bdi", self.hidden_linear_weight, stacked_hidden
) + self.hidden_linear_bias.unsqueeze(0)

out = torch.einsum("bd,bdi->bi", nn.functional.softmax(logit_coef, dim=-1), hidden_linear)

if not is_batch:
out = out.squeeze(0)
return out


class ConcatFlattenedObservationAndLerpedHidden(nn.Module):
"""Concatenates the flattened observation and stacked hidden states.
Shape:
- flattened_obs: (*, N_OBS)
- lerped_hidden: (*, N_HIDDEN)
Return shape: (*, N_OUT)
"""

def __init__(self, dim_obs: int, dim_hidden: int, dim_out: int):
super().__init__()
self.fc = nn.Linear(dim_obs + dim_hidden, dim_out)

def forward(self, flattened_obs: Tensor, lerped_hidden: Tensor) -> Tensor:
return self.fc(torch.cat([flattened_obs, lerped_hidden], dim=-1))
7 changes: 7 additions & 0 deletions configs/experiment/world_models_sioconv_lerp_hidden.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
# @package _global_

defaults:
- world_models_sioconv
- override /models: world_models_sioconv_lerp_hidden

task_name: world_models_sioconv_lerp_hidden
94 changes: 94 additions & 0 deletions configs/models/world_models_sioconv_lerp_hidden.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,94 @@
image_encoder:
_target_: ami.models.vae.EncoderWrapper
default_device: ${devices.0}
has_inference: True
model:
_target_: ami.models.vae.Conv2dEncoder
height: ${shared.image_height}
width: ${shared.image_width}
channels: ${shared.image_channels}
latent_dim: 512 # From primitive AMI.
do_batchnorm: True

image_decoder:
_target_: ami.models.model_wrapper.ModelWrapper
default_device: ${devices.0}
has_inference: False
model:
_target_: ami.models.vae.Conv2dDecoder
height: ${models.image_encoder.model.height}
width: ${models.image_encoder.model.width}
channels: ${models.image_encoder.model.channels}
latent_dim: ${models.image_encoder.model.latent_dim}
do_batchnorm: True

forward_dynamics:
_target_: ami.models.model_wrapper.ModelWrapper
default_device: ${devices.0}
has_inference: True
model:
_target_: ami.models.forward_dynamics.ForwardDynamics
observation_flatten:
_target_: torch.nn.Identity
action_flatten:
_target_: ami.models.components.multi_embeddings.MultiEmbeddings
choices_per_category:
_target_: hydra.utils.get_object
path: ami.interactions.environments.actuators.vrchat_osc_discrete_actuator.ACTION_CHOICES_PER_CATEGORY
embedding_dim: 8
do_flatten: True
obs_action_projection:
_target_: torch.nn.Linear
# action_embedding_dim * num_action_choices + obs_embedding_dim
in_features: ${python.eval:"${..action_flatten.embedding_dim} * 5 + ${models.image_encoder.model.latent_dim}"}
out_features: ${..core_model.dim}
core_model:
_target_: ami.models.components.sioconv.SioConv
depth: 8
dim: 512
num_head: 8
dim_ff_hidden: 512
chunk_size: 512
dropout: 0.1
obs_hat_dist_head:
_target_: ami.models.components.fully_connected_fixed_std_normal.FullyConnectedFixedStdNormal
dim_in: ${..core_model.dim}
dim_out: ${models.image_encoder.model.latent_dim}
normal_cls:
_target_: hydra.utils.get_class
path: ami.models.components.fully_connected_fixed_std_normal.DeterministicNormal

policy_value:
_target_: ami.models.model_wrapper.ModelWrapper
default_device: ${devices.0}
has_inference: True
model:
_target_: ami.models.policy_value_common_net.PolicyValueCommonNet
observation_projection:
_target_: torch.nn.Linear
in_features: ${models.image_encoder.model.latent_dim}
out_features: 512
forward_dynamics_hidden_projection:
_target_: ami.models.policy_value_common_net.LerpStackedHidden
dim: ${models.forward_dynamics.model.core_model.dim}
depth: ${models.forward_dynamics.model.core_model.depth}
num_head: ${models.forward_dynamics.model.core_model.num_head}
observation_hidden_projection:
_target_: ami.models.policy_value_common_net.ConcatFlattenedObservationAndLerpedHidden
dim_obs: ${..observation_projection.out_features}
dim_hidden: ${models.forward_dynamics.model.core_model.dim}
dim_out: 512
core_model:
_target_: ami.models.components.resnet.ResNetFF
dim: ${..observation_hidden_projection.dim_out}
dim_hidden: 1024
depth: 4
policy_head:
_target_: ami.models.components.discrete_policy_head.DiscretePolicyHead
dim_in: ${..observation_hidden_projection.dim_out}
action_choices_per_category:
_target_: hydra.utils.get_object
path: ami.interactions.environments.actuators.vrchat_osc_discrete_actuator.ACTION_CHOICES_PER_CATEGORY
value_head:
_target_: ami.models.components.fully_connected_value_head.FullyConnectedValueHead
dim_in: ${..observation_hidden_projection.dim_out}
24 changes: 24 additions & 0 deletions tests/models/components/test_resnet.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
import pytest
import torch

from ami.models.components.resnet import ResNetFF


class TestResNetFF:
@pytest.mark.parametrize(
"""
batch,
dim,
dim_hidden,
depth,
""",
[
(3, 128, 256, 4),
(6, 28, 56, 2),
],
)
def test_resnet_ff(self, batch, dim, dim_hidden, depth):
mod = ResNetFF(dim, dim_hidden, depth)
x = torch.randn(batch, dim)
x = mod(x)
assert x.shape == (batch, dim)
21 changes: 21 additions & 0 deletions tests/models/test_policy_value_common_net.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,9 @@
from ami.models.components.discrete_policy_head import DiscretePolicyHead
from ami.models.components.fully_connected_value_head import FullyConnectedValueHead
from ami.models.policy_value_common_net import (
ConcatFlattenedObservationAndLerpedHidden,
ConcatFlattenedObservationAndStackedHidden,
LerpStackedHidden,
PolicyValueCommonNet,
)

Expand Down Expand Up @@ -49,3 +51,22 @@ def test_forward(self):
out = mod.forward(obs, hidden)
assert out.shape == (3, 10, 5)
assert torch.equal(obs, out[:, :, 0])


class TestLerpedStackedHidden:
def test_forward(self):
mod = LerpStackedHidden(128, 8, 4)

hidden = torch.randn(4, 8, 128)
out = mod.forward(hidden)
assert out.shape == (4, 128)


class TestConcatFlattenedObservationAndLerpedHidden:
def test_forward(self):
mod = ConcatFlattenedObservationAndLerpedHidden(32, 64, 128)

obs = torch.randn(4, 32)
hidden = torch.randn(4, 64)
out = mod.forward(obs, hidden)
assert out.shape == (4, 128)

0 comments on commit ced214e

Please sign in to comment.