-
Notifications
You must be signed in to change notification settings - Fork 35
/
utils.py
52 lines (38 loc) · 1.34 KB
/
utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
import gym
import numpy as np
import torch
from gym.wrappers import AtariPreprocessing, TransformReward
from gym.wrappers import FrameStack as FrameStack_
from fourrooms import Fourrooms
class LazyFrames(object):
def __init__(self, frames):
self._frames = frames
def __array__(self, dtype=None):
out = np.concatenate(self._frames, axis=0)
if dtype is not None:
out = out.astype(dtype)
return out
def __len__(self):
return len(self.__array__())
def __getitem__(self, i):
return self.__array__()[i]
class FrameStack(FrameStack_):
def __init__(self, env, k):
FrameStack_.__init__(self, env, k)
def _get_ob(self):
assert len(self.frames) == self.k
return LazyFrames(list(self.frames))
def make_env(env_name):
if env_name == 'fourrooms':
return Fourrooms(), False
env = gym.make(env_name)
is_atari = hasattr(gym.envs, 'atari') and isinstance(env.unwrapped, gym.envs.atari.atari_env.AtariEnv)
if is_atari:
env = AtariPreprocessing(env, grayscale_obs=True, scale_obs=True, terminal_on_life_loss=True)
env = TransformReward(env, lambda r: np.clip(r, -1, 1))
env = FrameStack(env, 4)
return env, is_atari
def to_tensor(obs):
obs = np.asarray(obs)
obs = torch.from_numpy(obs).float()
return obs