Skip to content

Commit

Permalink
Merge pull request #26 from MLShukai/feature/#18/models-frame-works
Browse files Browse the repository at this point in the history
行動と報酬も出力するForwardDynamicsWithActionRewardクラスとそのTrainerを実装
  • Loading branch information
Geson-anko authored Jul 25, 2024
2 parents 661f410 + 5111dcd commit 1aad514
Show file tree
Hide file tree
Showing 4 changed files with 407 additions and 8 deletions.
33 changes: 33 additions & 0 deletions ami/models/forward_dynamics.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,3 +29,36 @@ def forward(self, obs: Tensor, hidden: Tensor, action: Tensor) -> tuple[Distribu
x, next_hidden = self.core_model(x, hidden)
obs_hat_dist = self.obs_hat_dist_head(x)
return obs_hat_dist, next_hidden


class ForwardDynamcisWithActionReward(nn.Module):
def __init__(
self,
observation_flatten: nn.Module,
action_flatten: nn.Module,
obs_action_projection: nn.Module,
core_model: StackedHiddenState,
obs_hat_dist_head: nn.Module,
action_hat_dist_head: nn.Module,
reward_hat_dist_head: nn.Module,
) -> None:
super().__init__()
self.observation_flatten = observation_flatten
self.action_flatten = action_flatten
self.obs_action_projection = obs_action_projection
self.core_model = core_model
self.obs_hat_dist_head = obs_hat_dist_head
self.action_hat_dist_head = action_hat_dist_head
self.reward_hat_dist_head = reward_hat_dist_head

def forward(
self, obs: Tensor, hidden: Tensor, action: Tensor
) -> tuple[Distribution, Distribution, Distribution, Tensor]:
obs_flat = self.observation_flatten(obs)
action_flat = self.action_flatten(action)
x = self.obs_action_projection(torch.cat((obs_flat, action_flat), dim=-1))
x, next_hidden = self.core_model(x, hidden)
obs_hat_dist = self.obs_hat_dist_head(x)
action_hat_dist = self.action_hat_dist_head(x)
reward_hat_dist = self.reward_hat_dist_head(x)
return obs_hat_dist, action_hat_dist, reward_hat_dist, next_hidden
140 changes: 136 additions & 4 deletions ami/trainers/forward_dynamics_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,17 +2,15 @@
from pathlib import Path

import torch
from torch.distributions import kl_divergence
from torch.distributions.normal import Normal
from torch.nn.functional import mse_loss
from torch.distributions import Distribution
from torch.optim import Optimizer
from torch.utils.data import DataLoader
from typing_extensions import override

from ami.data.buffers.buffer_names import BufferNames
from ami.data.buffers.causal_data_buffer import CausalDataBuffer
from ami.data.interfaces import ThreadSafeDataUser
from ami.models.forward_dynamics import ForwardDynamics
from ami.models.forward_dynamics import ForwardDynamcisWithActionReward, ForwardDynamics
from ami.models.model_names import ModelNames
from ami.models.model_wrapper import ModelWrapper
from ami.tensorboard_loggers import StepIntervalLogger
Expand Down Expand Up @@ -119,3 +117,137 @@ def save_state(self, path: Path) -> None:
def load_state(self, path: Path) -> None:
self.optimizer_state = torch.load(path / "optimizer.pt")
self.logger.load_state_dict(torch.load(path / "logger.pt"))


class ForwardDynamicsWithActionRewardTrainer(BaseTrainer):
def __init__(
self,
partial_dataloader: partial[DataLoader[torch.Tensor]],
partial_optimizer: partial[Optimizer],
device: torch.device,
logger: StepIntervalLogger,
observation_encoder_name: ModelNames | None = None,
max_epochs: int = 1,
minimum_dataset_size: int = 2,
minimum_new_data_count: int = 0,
obs_loss_coef: float = 1.0,
action_loss_coef: float = 1.0,
reward_loss_coef: float = 1.0,
) -> None:
"""Initialization.
Args:
partial_dataloader: A partially instantiated dataloader lacking a provided dataset.
partial_optimizer: A partially instantiated optimizer lacking provided parameters.
device: The accelerator device (e.g., CPU, GPU) utilized for training the model.
minimum_new_data_count: Minimum number of new data count required to run the training.
"""
super().__init__()
self.partial_optimizer = partial_optimizer
self.partial_dataloader = partial_dataloader
self.device = device
self.logger = logger
self.observation_encoder_name = observation_encoder_name
self.max_epochs = max_epochs
assert minimum_dataset_size >= 2, "minimum_dataset_size must be at least 2"
self.minimum_dataset_size = minimum_dataset_size
self.minimum_new_data_count = minimum_new_data_count
self.obs_loss_coef = obs_loss_coef
self.action_loss_coef = action_loss_coef
self.reward_loss_coef = reward_loss_coef

def on_data_users_dict_attached(self) -> None:
self.trajectory_data_user: ThreadSafeDataUser[CausalDataBuffer] = self.get_data_user(
BufferNames.FORWARD_DYNAMICS_TRAJECTORY
)

def on_model_wrappers_dict_attached(self) -> None:
self.forward_dynamics: ModelWrapper[ForwardDynamcisWithActionReward] = self.get_training_model(
ModelNames.FORWARD_DYNAMICS
)
self.optimizer_state = self.partial_optimizer(self.forward_dynamics.parameters()).state_dict()
if self.observation_encoder_name is None:
self.observation_encoder = None
else:
self.observation_encoder = self.get_frozen_model(self.observation_encoder_name)

def is_trainable(self) -> bool:
self.trajectory_data_user.update()
return len(self.trajectory_data_user.buffer) >= self.minimum_dataset_size and self._is_new_data_available()

def _is_new_data_available(self) -> bool:
return self.trajectory_data_user.buffer.new_data_count >= self.minimum_new_data_count

def train(self) -> None:
self.forward_dynamics.to(self.device)
if self.observation_encoder is not None:
self.observation_encoder.to(self.device)

optimizer = self.partial_optimizer(self.forward_dynamics.parameters())
optimizer.load_state_dict(self.optimizer_state)
dataset = self.trajectory_data_user.get_dataset()
dataloader = self.partial_dataloader(dataset=dataset)

for _ in range(self.max_epochs):
for batch in dataloader:
observations, hiddens, actions, rewards = batch

if self.observation_encoder is not None:
with torch.no_grad():
observations = self.observation_encoder.infer(observations)

observations = observations.to(self.device)
actions = actions.to(self.device)

observations, hidden, actions, observations_next, actions_next, rewards = (
observations[:-1], # o_0:T-1
hiddens[0], # h_0
actions[:-1], # a_0:T-1
observations[1:], # o_1:T
actions[1:], # a_1:T
rewards[:-1], # r_1:T because rewards are always t+1.
)

hidden = hidden.to(self.device)
rewards = rewards.to(self.device)

optimizer.zero_grad()

observations_next_hat_dist: Distribution
actions_next_hat_dist: Distribution
reward_hat_dist: Distribution
observations_next_hat_dist, actions_next_hat_dist, reward_hat_dist, _ = self.forward_dynamics(
observations, hidden, actions
)

observation_loss = -observations_next_hat_dist.log_prob(observations_next).mean()
action_loss = -actions_next_hat_dist.log_prob(actions_next).mean()
reward_loss = -reward_hat_dist.log_prob(rewards).mean()

loss = (
self.obs_loss_coef * observation_loss
+ self.action_loss_coef * action_loss
+ self.reward_loss_coef * reward_loss
)
prefix = "forward_dynamics/"
self.logger.log(prefix + "loss", loss)
self.logger.log(prefix + "observation_loss", observation_loss)
self.logger.log(prefix + "action_loss", action_loss)
self.logger.log(prefix + "reward_loss", reward_loss)

loss.backward()
optimizer.step()
self.logger.update()

self.optimizer_state = optimizer.state_dict()

@override
def save_state(self, path: Path) -> None:
path.mkdir()
torch.save(self.optimizer_state, path / "optimizer.pt")
torch.save(self.logger.state_dict(), path / "logger.pt")

@override
def load_state(self, path: Path) -> None:
self.optimizer_state = torch.load(path / "optimizer.pt")
self.logger.load_state_dict(torch.load(path / "logger.pt"))
68 changes: 67 additions & 1 deletion tests/models/test_forward_dynamics.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
FullyConnectedFixedStdNormal,
)
from ami.models.components.sconv import SConv
from ami.models.forward_dynamics import ForwardDynamics
from ami.models.forward_dynamics import ForwardDynamcisWithActionReward, ForwardDynamics

BATCH = 4
DEPTH = 8
Expand Down Expand Up @@ -64,3 +64,69 @@ def test_forward_dynamycs(self, forward_dynamics):
obs_hat_dist, hidden = forward_dynamics(obs, hidden[:, -1, :], action)
assert obs_hat_dist.sample().shape == obs_shape
assert hidden.shape == hidden_shape


class TestForwardDynamicsWithActionReward:
@pytest.fixture
def core_model(self):
sconv = SConv(DEPTH, DIM, DIM_FF_HIDDEN, DROPOUT)
return sconv

@pytest.fixture
def observation_flatten(self):
return nn.Identity()

@pytest.fixture
def action_flatten(self):
return nn.Identity()

@pytest.fixture
def obs_action_projection(self):
return nn.Linear(DIM_OBS + DIM_ACTION, DIM)

@pytest.fixture
def obs_hat_dist_head(self):
return FullyConnectedFixedStdNormal(DIM, DIM_OBS)

@pytest.fixture
def action_hat_dist_head(self):
return FullyConnectedFixedStdNormal(DIM, DIM_ACTION)

@pytest.fixture
def reward_head(self):
return FullyConnectedFixedStdNormal(DIM, 1)

@pytest.fixture
def forward_dynamics(
self,
observation_flatten,
action_flatten,
obs_action_projection,
core_model,
obs_hat_dist_head,
action_hat_dist_head,
reward_head,
):
return ForwardDynamcisWithActionReward(
observation_flatten,
action_flatten,
obs_action_projection,
core_model,
obs_hat_dist_head,
action_hat_dist_head,
reward_head,
)

def test_forward_dynamycs(self, forward_dynamics):
obs_shape = (LEN, DIM_OBS)
obs = torch.randn(*obs_shape)
action_shape = (LEN, DIM_ACTION)
action = torch.randn(*action_shape)
hidden_shape = (DEPTH, LEN, DIM)
hidden = torch.randn(*hidden_shape)

obs_hat_dist, action_hat_dist, reward_dist, hidden = forward_dynamics(obs, hidden[:, -1, :], action)
assert obs_hat_dist.sample().shape == obs_shape
assert action_hat_dist.sample().shape == action_shape
assert reward_dist.sample().shape == (LEN, 1)
assert hidden.shape == hidden_shape
Loading

0 comments on commit 1aad514

Please sign in to comment.