From 1d305e1214c649bb679e352134081b8512dc4aaa Mon Sep 17 00:00:00 2001 From: Kallinteris Andreas <30759571+Kallinteris-Andreas@users.noreply.github.com> Date: Tue, 11 Jul 2023 00:25:15 +0300 Subject: [PATCH] Fix some MuJoCo env `TypeIssues` (#600) --- gymnasium/envs/mujoco/mujoco_env.py | 50 ++++++++++++++++------------- 1 file changed, 28 insertions(+), 22 deletions(-) diff --git a/gymnasium/envs/mujoco/mujoco_env.py b/gymnasium/envs/mujoco/mujoco_env.py index c3490bd1f..7018daa1c 100644 --- a/gymnasium/envs/mujoco/mujoco_env.py +++ b/gymnasium/envs/mujoco/mujoco_env.py @@ -1,7 +1,8 @@ from os import path -from typing import Dict, Optional, Union +from typing import Any, Dict, Optional, Tuple, Union import numpy as np +from numpy.typing import NDArray import gymnasium as gym from gymnasium import error, logger, spaces @@ -26,7 +27,7 @@ DEFAULT_SIZE = 480 -class BaseMujocoEnv(gym.Env): +class BaseMujocoEnv(gym.Env[np.float64, np.float32]): """Superclass for all MuJoCo environments.""" def __init__( @@ -67,7 +68,8 @@ def __init__( self.width = width self.height = height - self._initialize_simulation() # may use width and height + # may use width and height + self.model, self.data = self._initialize_simulation() self.init_qpos = self.data.qpos.ravel().copy() self.init_qvel = self.data.qvel.ravel().copy() @@ -100,39 +102,39 @@ def _set_action_space(self): # methods to override: # ---------------------------- - def reset_model(self): + def reset_model(self) -> NDArray[np.float64]: """ Reset the robot degrees of freedom (qpos and qvel). Implement this in each subclass. """ raise NotImplementedError - def _initialize_simulation(self): + def _initialize_simulation(self) -> Tuple[Any, Any]: """ Initialize MuJoCo simulation data structures mjModel and mjData. """ raise NotImplementedError - def _reset_simulation(self): + def _reset_simulation(self) -> None: """ Reset MuJoCo simulation data structures, mjModel and mjData. """ raise NotImplementedError - def _step_mujoco_simulation(self, ctrl, n_frames): + def _step_mujoco_simulation(self, ctrl, n_frames) -> None: """ Step over the MuJoCo simulation. """ raise NotImplementedError - def render(self): + def render(self) -> Union[NDArray[np.float64], None]: """ Render a frame from the MuJoCo simulation as specified by the render_mode. """ raise NotImplementedError # ----------------------------- - def _get_reset_info(self) -> Dict: + def _get_reset_info(self) -> Dict[str, float]: """Function that generates the `info` that is returned during a `reset()`.""" return {} @@ -153,17 +155,17 @@ def reset( self.render() return ob, info - def set_state(self, qpos, qvel): + def set_state(self, qpos, qvel) -> None: """ Set the joints position qpos and velocity qvel of the model. Override this method depending on the MuJoCo bindings used. """ assert qpos.shape == (self.model.nq,) and qvel.shape == (self.model.nv,) @property - def dt(self): + def dt(self) -> float: return self.model.opt.timestep * self.frame_skip - def do_simulation(self, ctrl, n_frames): + def do_simulation(self, ctrl, n_frames) -> None: """ Step the simulation n number of frames and applying a control action. """ @@ -178,11 +180,11 @@ def close(self): """Close all processes like rendering contexts""" raise NotImplementedError - def get_body_com(self, body_name): + def get_body_com(self, body_name) -> NDArray[np.float64]: """Return the cartesian position of a body frame""" raise NotImplementedError - def state_vector(self): + def state_vector(self) -> NDArray[np.float64]: """Return the position and velocity joint states of the model""" return np.concatenate([self.data.qpos.flat, self.data.qvel.flat]) @@ -229,9 +231,10 @@ def __init__( ) def _initialize_simulation(self): - self.model = mujoco_py.load_model_from_path(self.fullpath) - self.sim = mujoco_py.MjSim(self.model) - self.data = self.sim.data + model = mujoco_py.load_model_from_path(self.fullpath) + self.sim = mujoco_py.MjSim(model) + data = self.sim.data + return model, data def _reset_simulation(self): self.sim.reset() @@ -371,12 +374,15 @@ def __init__( self.model, self.data, default_camera_config ) - def _initialize_simulation(self): - self.model = mujoco.MjModel.from_xml_path(self.fullpath) + def _initialize_simulation( + self, + ): + model = mujoco.MjModel.from_xml_path(self.fullpath) # MjrContext will copy model.vis.global_.off* to con.off* - self.model.vis.global_.offwidth = self.width - self.model.vis.global_.offheight = self.height - self.data = mujoco.MjData(self.model) + model.vis.global_.offwidth = self.width + model.vis.global_.offheight = self.height + data = mujoco.MjData(model) + return model, data def _reset_simulation(self): mujoco.mj_resetData(self.model, self.data)