Skip to content
This repository has been archived by the owner on Sep 1, 2024. It is now read-only.

Commit

Permalink
[bug-fix] Fixed incorrect normalization in OneDTransitionRewardModel.…
Browse files Browse the repository at this point in the history
…sample() which was causing PETS to break in some problems.
  • Loading branch information
luisenp committed Jan 10, 2022
1 parent da0c537 commit 6cd4b2b
Show file tree
Hide file tree
Showing 3 changed files with 16 additions and 14 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -23,3 +23,4 @@ wheels/
.installed.cfg
*.egg
MANIFEST
exp/
16 changes: 8 additions & 8 deletions mbrl/examples/conf/overrides/pets_halfcheetah.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -2,22 +2,22 @@
env: "pets_halfcheetah"
term_fn: "no_termination"
obs_process_fn: mbrl.env.pets_halfcheetah.HalfCheetahEnv.preprocess_fn
learned_rewards: true
learned_rewards: false
num_steps: 300000
trial_length: 1000

num_elites: 5
model_lr: 2e-4
model_wd: 3e-5
model_lr: 0.00028
model_wd: 0.00010
model_batch_size: 32
validation_ratio: 0
no_delta_list: [ 0 ]
freq_train_model: 1000
patience: 25
num_epochs_train_model: 25
patience: 12
num_epochs_train_model: 12

planning_horizon: 30
cem_num_iters: 5
cem_elite_ratio: 0.1
cem_population_size: 350
cem_alpha: 0.1
cem_elite_ratio: 0.16
cem_population_size: 400
cem_alpha: 0.12
13 changes: 7 additions & 6 deletions mbrl/models/one_dim_tr_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,15 +114,15 @@ def _get_model_input(
self,
obs: mbrl.types.TensorType,
action: mbrl.types.TensorType,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
) -> torch.Tensor:
if self.obs_process_fn:
obs = self.obs_process_fn(obs)
obs = model_util.to_tensor(obs).to(self.device)
action = model_util.to_tensor(action).to(self.device)
model_in = torch.cat([obs, action], dim=obs.ndim - 1)
if self.input_normalizer:
model_in = self.input_normalizer.normalize(model_in).float()
return model_in, obs, action
return model_in

def _process_batch(
self, batch: mbrl.types.TransitionBatch
Expand All @@ -136,13 +136,13 @@ def _process_batch(
target_obs = next_obs
target_obs = model_util.to_tensor(target_obs).to(self.device)

model_in, *_ = self._get_model_input(obs, action)
model_in = self._get_model_input(obs, action)
if self.learned_rewards:
reward = model_util.to_tensor(reward).to(self.device).unsqueeze(reward.ndim)
target = torch.cat([target_obs, reward], dim=obs.ndim - 1)
else:
target = target_obs
return model_in, target
return model_in.float(), target.float()

def forward(self, x: torch.Tensor, **kwargs) -> Tuple[torch.Tensor, ...]:
"""Calls forward method of base model with the given input and args."""
Expand Down Expand Up @@ -207,7 +207,7 @@ def update(
(tensor and optional dict): as returned by `model.loss().`
"""
assert target is None
model_in, target = self._get_model_input_and_target_from_batch(batch)
model_in, target = self._process_batch(batch)
return self.model.update(model_in, optimizer, target=target)

def eval_score(
Expand Down Expand Up @@ -278,7 +278,8 @@ def sample(
Returns:
(tuple of two tensors): predicted next_observation (o_{t+1}) and rewards (r_{t+1}).
"""
model_in, obs, action = self._get_model_input(model_state["obs"], act)
obs = model_util.to_tensor(model_state["obs"]).to(self.device)
model_in = self._get_model_input(model_state["obs"], act)
if not hasattr(self.model, "sample_1d"):
raise RuntimeError(
"OneDTransitionRewardModel requires wrapped model to define method sample_1d"
Expand Down

0 comments on commit 6cd4b2b

Please sign in to comment.