From 10352d4a97c19c771696eeefe69e0fb6d03972b0 Mon Sep 17 00:00:00 2001 From: daniel <1534513+dantp-ai@users.noreply.github.com> Date: Tue, 26 Mar 2024 17:11:53 +0100 Subject: [PATCH] Fix mypy issues for AtariWrappers * Use type annotations of Gym API > 0.26 ver * Rename some vars to prevent confusion from mypy * add missing type annotations * add explicit type to some vars to make mypy happy * assert that shape obs space is not None before using it --- examples/atari/atari_wrapper.py | 41 +++++++++++++++++++-------------- 1 file changed, 24 insertions(+), 17 deletions(-) diff --git a/examples/atari/atari_wrapper.py b/examples/atari/atari_wrapper.py index de10d5eb7..557ca7cd4 100644 --- a/examples/atari/atari_wrapper.py +++ b/examples/atari/atari_wrapper.py @@ -8,7 +8,6 @@ import cv2 import gymnasium as gym import numpy as np -import numpy.typing as npt from gymnasium import Env from tianshou.env import BaseVectorEnv @@ -108,7 +107,7 @@ def step(self, action: Any) -> tuple[Any, float, bool, bool, dict[str, Any]]: if new_step_api: return max_frame, total_reward, term, trunc, info - return max_frame, total_reward, done, info + return max_frame, total_reward, done, info.get("TimeLimit.truncated", False), info class EpisodicLifeEnv(gym.Wrapper): @@ -134,7 +133,7 @@ def step(self, action: Any) -> tuple[Any, float, bool, bool, dict[str, Any]]: obs, reward, term, trunc, info = step_result done = term or trunc new_step_api = True - + reward = float(reward) self.was_real_done = done # check current lives, make loss of life terminal, then update lives to # handle bonus lives @@ -149,7 +148,7 @@ def step(self, action: Any) -> tuple[Any, float, bool, bool, dict[str, Any]]: self.lives = lives if new_step_api: return obs, reward, term, trunc, info - return obs, reward, done, info + return obs, reward, done, info.get("TimeLimit.truncated", False), info def reset(self, **kwargs: Any) -> tuple[Any, dict[str, Any]]: """Calls the Gym environment reset, only when lives are exhausted. @@ -199,11 +198,13 @@ class WarpFrame(gym.ObservationWrapper): def __init__(self, env: gym.Env) -> None: super().__init__(env) self.size = 84 + obs_space = env.observation_space + assert isinstance(obs_space, gym.spaces.Box) self.observation_space = gym.spaces.Box( - low=np.min(env.observation_space.low), - high=np.max(env.observation_space.high), + low=np.min(obs_space.low), + high=np.max(obs_space.high), shape=(self.size, self.size), - dtype=env.observation_space.dtype, + dtype=obs_space.dtype, ) def observation(self, frame: np.ndarray) -> np.ndarray: @@ -220,14 +221,16 @@ class ScaledFloatFrame(gym.ObservationWrapper): def __init__(self, env: gym.Env) -> None: super().__init__(env) - low = np.min(env.observation_space.low) - high = np.max(env.observation_space.high) + obs_space = env.observation_space + assert isinstance(obs_space, gym.spaces.Box) + low = np.min(obs_space.low) + high = np.max(obs_space.high) self.bias = low self.scale = high - low self.observation_space = gym.spaces.Box( low=0.0, high=1.0, - shape=env.observation_space.shape, + shape=obs_space.shape, dtype=np.float32, ) @@ -261,7 +264,10 @@ def __init__(self, env: gym.Env, n_frames: int) -> None: super().__init__(env) self.n_frames: int = n_frames self.frames: deque[tuple[Any, ...]] = deque([], maxlen=n_frames) - shape = (n_frames, *env.observation_space.shape) + obs_space_shape = env.observation_space.shape + assert obs_space_shape is not None + shape = (n_frames, *obs_space_shape) + assert isinstance(env.observation_space, gym.spaces.Box) self.observation_space = gym.spaces.Box( low=np.min(env.observation_space.low), high=np.max(env.observation_space.high), @@ -269,13 +275,13 @@ def __init__(self, env: gym.Env, n_frames: int) -> None: dtype=env.observation_space.dtype, ) - def reset(self, **kwargs: Any) -> tuple[npt.NDArray, dict]: + def reset(self, **kwargs: Any) -> tuple[np.ndarray, dict]: obs, info, return_info = _parse_reset_result(self.env.reset(**kwargs)) for _ in range(self.n_frames): self.frames.append(obs) return (self._get_ob(), info) if return_info else (self._get_ob(), {}) - def step(self, action): + def step(self, action: Any) -> tuple[np.ndarray, float, bool, bool, dict[str, Any]]: step_result = self.env.step(action) done: bool if len(step_result) == 4: @@ -285,11 +291,12 @@ def step(self, action): obs, reward, term, trunc, info = step_result new_step_api = True self.frames.append(obs) + reward = float(reward) if new_step_api: return self._get_ob(), reward, term, trunc, info - return self._get_ob(), reward, done, info + return self._get_ob(), reward, done, info.get("TimeLimit.truncated", False), info - def _get_ob(self) -> npt.NDArray: + def _get_ob(self) -> np.ndarray: # the original wrapper use `LazyFrames` but since we use np buffer, # it has no effect return np.stack(self.frames, axis=0) @@ -379,7 +386,7 @@ def __init__( envpool_factory = None if use_envpool_if_available: if envpool_is_available: - envpool_factory = self.EnvPoolFactory(self) + envpool_factory = self.EnvPoolFactoryAtari(self) log.info("Using envpool, because it available") else: log.info("Not using envpool, because it is not available") @@ -401,7 +408,7 @@ def create_env(self, mode: EnvMode) -> gym.Env: scale=self.scale, ) - class EnvPoolFactory(EnvPoolFactory): + class EnvPoolFactoryAtari(EnvPoolFactory): """Atari-specific envpool creation. Since envpool internally handles the functions that are implemented through the wrappers in `wrap_deepmind`, it sets the creation keyword arguments accordingly.