Skip to content

Commit

Permalink
Fix mypy issues for AtariWrappers
Browse files Browse the repository at this point in the history
  * Use type annotations of Gym API > 0.26 ver
  * Rename some vars to prevent confusion from mypy
  * add missing type annotations
  * add explicit type to some vars to make mypy happy
  * assert that shape obs space is not None before using it
  • Loading branch information
dantp-ai committed Mar 26, 2024
1 parent 56608da commit 10352d4
Showing 1 changed file with 24 additions and 17 deletions.
41 changes: 24 additions & 17 deletions examples/atari/atari_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@
import cv2
import gymnasium as gym
import numpy as np
import numpy.typing as npt
from gymnasium import Env

from tianshou.env import BaseVectorEnv
Expand Down Expand Up @@ -108,7 +107,7 @@ def step(self, action: Any) -> tuple[Any, float, bool, bool, dict[str, Any]]:
if new_step_api:
return max_frame, total_reward, term, trunc, info

return max_frame, total_reward, done, info
return max_frame, total_reward, done, info.get("TimeLimit.truncated", False), info


class EpisodicLifeEnv(gym.Wrapper):
Expand All @@ -134,7 +133,7 @@ def step(self, action: Any) -> tuple[Any, float, bool, bool, dict[str, Any]]:
obs, reward, term, trunc, info = step_result
done = term or trunc
new_step_api = True

reward = float(reward)
self.was_real_done = done
# check current lives, make loss of life terminal, then update lives to
# handle bonus lives
Expand All @@ -149,7 +148,7 @@ def step(self, action: Any) -> tuple[Any, float, bool, bool, dict[str, Any]]:
self.lives = lives
if new_step_api:
return obs, reward, term, trunc, info
return obs, reward, done, info
return obs, reward, done, info.get("TimeLimit.truncated", False), info

def reset(self, **kwargs: Any) -> tuple[Any, dict[str, Any]]:
"""Calls the Gym environment reset, only when lives are exhausted.
Expand Down Expand Up @@ -199,11 +198,13 @@ class WarpFrame(gym.ObservationWrapper):
def __init__(self, env: gym.Env) -> None:
super().__init__(env)
self.size = 84
obs_space = env.observation_space
assert isinstance(obs_space, gym.spaces.Box)
self.observation_space = gym.spaces.Box(
low=np.min(env.observation_space.low),
high=np.max(env.observation_space.high),
low=np.min(obs_space.low),
high=np.max(obs_space.high),
shape=(self.size, self.size),
dtype=env.observation_space.dtype,
dtype=obs_space.dtype,
)

def observation(self, frame: np.ndarray) -> np.ndarray:
Expand All @@ -220,14 +221,16 @@ class ScaledFloatFrame(gym.ObservationWrapper):

def __init__(self, env: gym.Env) -> None:
super().__init__(env)
low = np.min(env.observation_space.low)
high = np.max(env.observation_space.high)
obs_space = env.observation_space
assert isinstance(obs_space, gym.spaces.Box)
low = np.min(obs_space.low)
high = np.max(obs_space.high)
self.bias = low
self.scale = high - low
self.observation_space = gym.spaces.Box(
low=0.0,
high=1.0,
shape=env.observation_space.shape,
shape=obs_space.shape,
dtype=np.float32,
)

Expand Down Expand Up @@ -261,21 +264,24 @@ def __init__(self, env: gym.Env, n_frames: int) -> None:
super().__init__(env)
self.n_frames: int = n_frames
self.frames: deque[tuple[Any, ...]] = deque([], maxlen=n_frames)
shape = (n_frames, *env.observation_space.shape)
obs_space_shape = env.observation_space.shape
assert obs_space_shape is not None
shape = (n_frames, *obs_space_shape)
assert isinstance(env.observation_space, gym.spaces.Box)
self.observation_space = gym.spaces.Box(
low=np.min(env.observation_space.low),
high=np.max(env.observation_space.high),
shape=shape,
dtype=env.observation_space.dtype,
)

def reset(self, **kwargs: Any) -> tuple[npt.NDArray, dict]:
def reset(self, **kwargs: Any) -> tuple[np.ndarray, dict]:
obs, info, return_info = _parse_reset_result(self.env.reset(**kwargs))
for _ in range(self.n_frames):
self.frames.append(obs)
return (self._get_ob(), info) if return_info else (self._get_ob(), {})

def step(self, action):
def step(self, action: Any) -> tuple[np.ndarray, float, bool, bool, dict[str, Any]]:
step_result = self.env.step(action)
done: bool
if len(step_result) == 4:
Expand All @@ -285,11 +291,12 @@ def step(self, action):
obs, reward, term, trunc, info = step_result
new_step_api = True
self.frames.append(obs)
reward = float(reward)
if new_step_api:
return self._get_ob(), reward, term, trunc, info
return self._get_ob(), reward, done, info
return self._get_ob(), reward, done, info.get("TimeLimit.truncated", False), info

def _get_ob(self) -> npt.NDArray:
def _get_ob(self) -> np.ndarray:
# the original wrapper use `LazyFrames` but since we use np buffer,
# it has no effect
return np.stack(self.frames, axis=0)
Expand Down Expand Up @@ -379,7 +386,7 @@ def __init__(
envpool_factory = None
if use_envpool_if_available:
if envpool_is_available:
envpool_factory = self.EnvPoolFactory(self)
envpool_factory = self.EnvPoolFactoryAtari(self)
log.info("Using envpool, because it available")
else:
log.info("Not using envpool, because it is not available")
Expand All @@ -401,7 +408,7 @@ def create_env(self, mode: EnvMode) -> gym.Env:
scale=self.scale,
)

class EnvPoolFactory(EnvPoolFactory):
class EnvPoolFactoryAtari(EnvPoolFactory):
"""Atari-specific envpool creation.
Since envpool internally handles the functions that are implemented through the wrappers in `wrap_deepmind`,
it sets the creation keyword arguments accordingly.
Expand Down

0 comments on commit 10352d4

Please sign in to comment.