From da0c5373221b9aa3583d7f741eeb076823f40e11 Mon Sep 17 00:00:00 2001 From: Luis Pineda Date: Wed, 12 May 2021 15:06:15 -0400 Subject: [PATCH] [refactor] Replaced batch processing methods of the 1-D model wrapper with a single _process_batch() method. Base Model class also provides a default implementation of this. --- mbrl/models/basic_ensemble.py | 2 +- mbrl/models/model.py | 15 ++++++++- mbrl/models/one_dim_tr_model.py | 56 +++++++++++++-------------------- 3 files changed, 36 insertions(+), 37 deletions(-) diff --git a/mbrl/models/basic_ensemble.py b/mbrl/models/basic_ensemble.py index fd18ecdd..06c88764 100644 --- a/mbrl/models/basic_ensemble.py +++ b/mbrl/models/basic_ensemble.py @@ -92,7 +92,7 @@ def __iter__(self): # --------------------------------------------------------------------- # # These are customized for this class, to avoid unnecessary computation def _default_forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: - predictions = [model(x) for model in self.members] + predictions = [model.forward(x) for model in self.members] all_means = torch.stack([p[0] for p in predictions], dim=0) if predictions[0][1] is not None: all_logvars = torch.stack([p[1] for p in predictions], dim=0) diff --git a/mbrl/models/model.py b/mbrl/models/model.py index 25d71db1..a1e155ac 100644 --- a/mbrl/models/model.py +++ b/mbrl/models/model.py @@ -10,7 +10,8 @@ import torch from torch import nn as nn -from mbrl.types import ModelInput +from mbrl.models.util import to_tensor +from mbrl.types import ModelInput, TransitionBatch # TODO: these are temporary, eventually it will be tuple(tensor, dict), keeping this # for back-compatibility with v0.1.x, and will be removed in v0.2.0 @@ -56,6 +57,18 @@ def __init__( super().__init__() self.device = device + def _process_batch(self, batch: TransitionBatch) -> Tuple[torch.Tensor, ...]: + def _convert(x): + return to_tensor(x).to(self.device) + + return ( + _convert(batch.obs), + _convert(batch.act), + _convert(batch.next_obs), + _convert(batch.rewards).view(-1, 1), + _convert(batch.dones).view(-1, 1), + ) + def forward(self, x: ModelInput, **kwargs) -> Tuple[torch.Tensor, ...]: """Computes the output of the dynamics model. diff --git a/mbrl/models/one_dim_tr_model.py b/mbrl/models/one_dim_tr_model.py index c758aacc..57e11216 100644 --- a/mbrl/models/one_dim_tr_model.py +++ b/mbrl/models/one_dim_tr_model.py @@ -53,7 +53,8 @@ class OneDTransitionRewardModel(Model): Defaults to ``True``. Can be deactivated per dimension using ``no_delta_list``. normalize (bool): if true, the wrapper will create a normalizer for model inputs, which will be used every time the model is called using the methods in this - class. To update the normalizer statistics, the user needs to call + class. Assumes the given base model has an attributed ``in_size``. + To update the normalizer statistics, the user needs to call :meth:`update_normalizer` before using the model. Defaults to ``False``. normalize_double_precision (bool): if ``True``, the normalizer will work with double precision. @@ -109,26 +110,21 @@ def __init__( else None ) - def _get_model_input_from_np( - self, obs: np.ndarray, action: np.ndarray, device: torch.device - ) -> torch.Tensor: - if self.obs_process_fn: - obs = self.obs_process_fn(obs) - model_in_np = np.concatenate([obs, action], axis=obs.ndim - 1) - if self.input_normalizer: - # Normalizer lives on device - return self.input_normalizer.normalize(model_in_np).float().to(device) - return torch.from_numpy(model_in_np).to(device) - - def _get_model_input_from_tensors(self, obs: torch.Tensor, action: torch.Tensor): + def _get_model_input( + self, + obs: mbrl.types.TensorType, + action: mbrl.types.TensorType, + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: if self.obs_process_fn: obs = self.obs_process_fn(obs) - model_in = torch.cat([obs, action], axis=obs.ndim - 1) + obs = model_util.to_tensor(obs).to(self.device) + action = model_util.to_tensor(action).to(self.device) + model_in = torch.cat([obs, action], dim=obs.ndim - 1) if self.input_normalizer: model_in = self.input_normalizer.normalize(model_in).float() - return model_in + return model_in, obs, action - def _get_model_input_and_target_from_batch( + def _process_batch( self, batch: mbrl.types.TransitionBatch ) -> Tuple[torch.Tensor, torch.Tensor]: obs, action, next_obs, reward, _ = batch.astuple() @@ -138,21 +134,14 @@ def _get_model_input_and_target_from_batch( target_obs[..., dim] = next_obs[..., dim] else: target_obs = next_obs + target_obs = model_util.to_tensor(target_obs).to(self.device) - model_in = self._get_model_input_from_np(obs, action, self.device) + model_in, *_ = self._get_model_input(obs, action) if self.learned_rewards: - target = ( - torch.from_numpy( - np.concatenate( - [target_obs, np.expand_dims(reward, axis=reward.ndim)], - axis=obs.ndim - 1, - ) - ) - .float() - .to(self.device) - ) + reward = model_util.to_tensor(reward).to(self.device).unsqueeze(reward.ndim) + target = torch.cat([target_obs, reward], dim=obs.ndim - 1) else: - target = torch.from_numpy(target_obs).float().to(self.device) + target = target_obs return model_in, target def forward(self, x: torch.Tensor, **kwargs) -> Tuple[torch.Tensor, ...]: @@ -199,7 +188,7 @@ def loss( (tensor and optional dict): as returned by `model.loss().` """ assert target is None - model_in, target = self._get_model_input_and_target_from_batch(batch) + model_in, target = self._process_batch(batch) return self.model.loss(model_in, target=target) def update( @@ -239,7 +228,7 @@ def eval_score( """ assert target is None with torch.no_grad(): - model_in, target = self._get_model_input_and_target_from_batch(batch) + model_in, target = self._process_batch(batch) return self.model.eval_score(model_in, target=target) def get_output_and_targets( @@ -258,7 +247,7 @@ def get_output_and_targets( (tuple(tensor), tensor): the model outputs and the target for this batch. """ with torch.no_grad(): - model_in, target = self._get_model_input_and_target_from_batch(batch) + model_in, target = self._process_batch(batch) output = self.model.forward(model_in) return output, target @@ -289,10 +278,7 @@ def sample( Returns: (tuple of two tensors): predicted next_observation (o_{t+1}) and rewards (r_{t+1}). """ - obs = model_util.to_tensor(model_state["obs"]).to(self.device) - actions = model_util.to_tensor(act).to(self.device) - - model_in = self._get_model_input_from_tensors(obs, actions) + model_in, obs, action = self._get_model_input(model_state["obs"], act) if not hasattr(self.model, "sample_1d"): raise RuntimeError( "OneDTransitionRewardModel requires wrapped model to define method sample_1d"