Skip to content

Commit

Permalink
Fix some MuJoCo env TypeIssues (#600)
Browse files Browse the repository at this point in the history
  • Loading branch information
Kallinteris-Andreas authored Jul 10, 2023
1 parent ce54ce5 commit 1d305e1
Showing 1 changed file with 28 additions and 22 deletions.
50 changes: 28 additions & 22 deletions gymnasium/envs/mujoco/mujoco_env.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
from os import path
from typing import Dict, Optional, Union
from typing import Any, Dict, Optional, Tuple, Union

import numpy as np
from numpy.typing import NDArray

import gymnasium as gym
from gymnasium import error, logger, spaces
Expand All @@ -26,7 +27,7 @@
DEFAULT_SIZE = 480


class BaseMujocoEnv(gym.Env):
class BaseMujocoEnv(gym.Env[np.float64, np.float32]):
"""Superclass for all MuJoCo environments."""

def __init__(
Expand Down Expand Up @@ -67,7 +68,8 @@ def __init__(

self.width = width
self.height = height
self._initialize_simulation() # may use width and height
# may use width and height
self.model, self.data = self._initialize_simulation()

self.init_qpos = self.data.qpos.ravel().copy()
self.init_qvel = self.data.qvel.ravel().copy()
Expand Down Expand Up @@ -100,39 +102,39 @@ def _set_action_space(self):
# methods to override:
# ----------------------------

def reset_model(self):
def reset_model(self) -> NDArray[np.float64]:
"""
Reset the robot degrees of freedom (qpos and qvel).
Implement this in each subclass.
"""
raise NotImplementedError

def _initialize_simulation(self):
def _initialize_simulation(self) -> Tuple[Any, Any]:
"""
Initialize MuJoCo simulation data structures mjModel and mjData.
"""
raise NotImplementedError

def _reset_simulation(self):
def _reset_simulation(self) -> None:
"""
Reset MuJoCo simulation data structures, mjModel and mjData.
"""
raise NotImplementedError

def _step_mujoco_simulation(self, ctrl, n_frames):
def _step_mujoco_simulation(self, ctrl, n_frames) -> None:
"""
Step over the MuJoCo simulation.
"""
raise NotImplementedError

def render(self):
def render(self) -> Union[NDArray[np.float64], None]:
"""
Render a frame from the MuJoCo simulation as specified by the render_mode.
"""
raise NotImplementedError

# -----------------------------
def _get_reset_info(self) -> Dict:
def _get_reset_info(self) -> Dict[str, float]:
"""Function that generates the `info` that is returned during a `reset()`."""
return {}

Expand All @@ -153,17 +155,17 @@ def reset(
self.render()
return ob, info

def set_state(self, qpos, qvel):
def set_state(self, qpos, qvel) -> None:
"""
Set the joints position qpos and velocity qvel of the model. Override this method depending on the MuJoCo bindings used.
"""
assert qpos.shape == (self.model.nq,) and qvel.shape == (self.model.nv,)

@property
def dt(self):
def dt(self) -> float:
return self.model.opt.timestep * self.frame_skip

def do_simulation(self, ctrl, n_frames):
def do_simulation(self, ctrl, n_frames) -> None:
"""
Step the simulation n number of frames and applying a control action.
"""
Expand All @@ -178,11 +180,11 @@ def close(self):
"""Close all processes like rendering contexts"""
raise NotImplementedError

def get_body_com(self, body_name):
def get_body_com(self, body_name) -> NDArray[np.float64]:
"""Return the cartesian position of a body frame"""
raise NotImplementedError

def state_vector(self):
def state_vector(self) -> NDArray[np.float64]:
"""Return the position and velocity joint states of the model"""
return np.concatenate([self.data.qpos.flat, self.data.qvel.flat])

Expand Down Expand Up @@ -229,9 +231,10 @@ def __init__(
)

def _initialize_simulation(self):
self.model = mujoco_py.load_model_from_path(self.fullpath)
self.sim = mujoco_py.MjSim(self.model)
self.data = self.sim.data
model = mujoco_py.load_model_from_path(self.fullpath)
self.sim = mujoco_py.MjSim(model)
data = self.sim.data
return model, data

def _reset_simulation(self):
self.sim.reset()
Expand Down Expand Up @@ -371,12 +374,15 @@ def __init__(
self.model, self.data, default_camera_config
)

def _initialize_simulation(self):
self.model = mujoco.MjModel.from_xml_path(self.fullpath)
def _initialize_simulation(
self,
):
model = mujoco.MjModel.from_xml_path(self.fullpath)
# MjrContext will copy model.vis.global_.off* to con.off*
self.model.vis.global_.offwidth = self.width
self.model.vis.global_.offheight = self.height
self.data = mujoco.MjData(self.model)
model.vis.global_.offwidth = self.width
model.vis.global_.offheight = self.height
data = mujoco.MjData(model)
return model, data

def _reset_simulation(self):
mujoco.mj_resetData(self.model, self.data)
Expand Down

0 comments on commit 1d305e1

Please sign in to comment.