Skip to content
This repository has been archived by the owner on Sep 1, 2024. It is now read-only.

Commit

Permalink
[refactor] Replaced batch processing methods of the 1-D model wrapper…
Browse files Browse the repository at this point in the history
… with a single _process_batch() method. Base Model class also provides a default implementation of this.
  • Loading branch information
luisenp committed Sep 27, 2021
1 parent 2d6a732 commit da0c537
Show file tree
Hide file tree
Showing 3 changed files with 36 additions and 37 deletions.
2 changes: 1 addition & 1 deletion mbrl/models/basic_ensemble.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,7 @@ def __iter__(self):
# --------------------------------------------------------------------- #
# These are customized for this class, to avoid unnecessary computation
def _default_forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
predictions = [model(x) for model in self.members]
predictions = [model.forward(x) for model in self.members]
all_means = torch.stack([p[0] for p in predictions], dim=0)
if predictions[0][1] is not None:
all_logvars = torch.stack([p[1] for p in predictions], dim=0)
Expand Down
15 changes: 14 additions & 1 deletion mbrl/models/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,8 @@
import torch
from torch import nn as nn

from mbrl.types import ModelInput
from mbrl.models.util import to_tensor
from mbrl.types import ModelInput, TransitionBatch

# TODO: these are temporary, eventually it will be tuple(tensor, dict), keeping this
# for back-compatibility with v0.1.x, and will be removed in v0.2.0
Expand Down Expand Up @@ -56,6 +57,18 @@ def __init__(
super().__init__()
self.device = device

def _process_batch(self, batch: TransitionBatch) -> Tuple[torch.Tensor, ...]:
def _convert(x):
return to_tensor(x).to(self.device)

return (
_convert(batch.obs),
_convert(batch.act),
_convert(batch.next_obs),
_convert(batch.rewards).view(-1, 1),
_convert(batch.dones).view(-1, 1),
)

def forward(self, x: ModelInput, **kwargs) -> Tuple[torch.Tensor, ...]:
"""Computes the output of the dynamics model.
Expand Down
56 changes: 21 additions & 35 deletions mbrl/models/one_dim_tr_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,8 @@ class OneDTransitionRewardModel(Model):
Defaults to ``True``. Can be deactivated per dimension using ``no_delta_list``.
normalize (bool): if true, the wrapper will create a normalizer for model inputs,
which will be used every time the model is called using the methods in this
class. To update the normalizer statistics, the user needs to call
class. Assumes the given base model has an attributed ``in_size``.
To update the normalizer statistics, the user needs to call
:meth:`update_normalizer` before using the model. Defaults to ``False``.
normalize_double_precision (bool): if ``True``, the normalizer will work with
double precision.
Expand Down Expand Up @@ -109,26 +110,21 @@ def __init__(
else None
)

def _get_model_input_from_np(
self, obs: np.ndarray, action: np.ndarray, device: torch.device
) -> torch.Tensor:
if self.obs_process_fn:
obs = self.obs_process_fn(obs)
model_in_np = np.concatenate([obs, action], axis=obs.ndim - 1)
if self.input_normalizer:
# Normalizer lives on device
return self.input_normalizer.normalize(model_in_np).float().to(device)
return torch.from_numpy(model_in_np).to(device)

def _get_model_input_from_tensors(self, obs: torch.Tensor, action: torch.Tensor):
def _get_model_input(
self,
obs: mbrl.types.TensorType,
action: mbrl.types.TensorType,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
if self.obs_process_fn:
obs = self.obs_process_fn(obs)
model_in = torch.cat([obs, action], axis=obs.ndim - 1)
obs = model_util.to_tensor(obs).to(self.device)
action = model_util.to_tensor(action).to(self.device)
model_in = torch.cat([obs, action], dim=obs.ndim - 1)
if self.input_normalizer:
model_in = self.input_normalizer.normalize(model_in).float()
return model_in
return model_in, obs, action

def _get_model_input_and_target_from_batch(
def _process_batch(
self, batch: mbrl.types.TransitionBatch
) -> Tuple[torch.Tensor, torch.Tensor]:
obs, action, next_obs, reward, _ = batch.astuple()
Expand All @@ -138,21 +134,14 @@ def _get_model_input_and_target_from_batch(
target_obs[..., dim] = next_obs[..., dim]
else:
target_obs = next_obs
target_obs = model_util.to_tensor(target_obs).to(self.device)

model_in = self._get_model_input_from_np(obs, action, self.device)
model_in, *_ = self._get_model_input(obs, action)
if self.learned_rewards:
target = (
torch.from_numpy(
np.concatenate(
[target_obs, np.expand_dims(reward, axis=reward.ndim)],
axis=obs.ndim - 1,
)
)
.float()
.to(self.device)
)
reward = model_util.to_tensor(reward).to(self.device).unsqueeze(reward.ndim)
target = torch.cat([target_obs, reward], dim=obs.ndim - 1)
else:
target = torch.from_numpy(target_obs).float().to(self.device)
target = target_obs
return model_in, target

def forward(self, x: torch.Tensor, **kwargs) -> Tuple[torch.Tensor, ...]:
Expand Down Expand Up @@ -199,7 +188,7 @@ def loss(
(tensor and optional dict): as returned by `model.loss().`
"""
assert target is None
model_in, target = self._get_model_input_and_target_from_batch(batch)
model_in, target = self._process_batch(batch)
return self.model.loss(model_in, target=target)

def update(
Expand Down Expand Up @@ -239,7 +228,7 @@ def eval_score(
"""
assert target is None
with torch.no_grad():
model_in, target = self._get_model_input_and_target_from_batch(batch)
model_in, target = self._process_batch(batch)
return self.model.eval_score(model_in, target=target)

def get_output_and_targets(
Expand All @@ -258,7 +247,7 @@ def get_output_and_targets(
(tuple(tensor), tensor): the model outputs and the target for this batch.
"""
with torch.no_grad():
model_in, target = self._get_model_input_and_target_from_batch(batch)
model_in, target = self._process_batch(batch)
output = self.model.forward(model_in)
return output, target

Expand Down Expand Up @@ -289,10 +278,7 @@ def sample(
Returns:
(tuple of two tensors): predicted next_observation (o_{t+1}) and rewards (r_{t+1}).
"""
obs = model_util.to_tensor(model_state["obs"]).to(self.device)
actions = model_util.to_tensor(act).to(self.device)

model_in = self._get_model_input_from_tensors(obs, actions)
model_in, obs, action = self._get_model_input(model_state["obs"], act)
if not hasattr(self.model, "sample_1d"):
raise RuntimeError(
"OneDTransitionRewardModel requires wrapped model to define method sample_1d"
Expand Down

0 comments on commit da0c537

Please sign in to comment.