diff --git a/.gitignore b/.gitignore index f106bb6b..6f3f15a7 100644 --- a/.gitignore +++ b/.gitignore @@ -90,4 +90,7 @@ venv.bak/ # mypy .mypy_cache/ .dmypy.json -dmypy.json \ No newline at end of file +dmypy.json + +# gym results +results/ diff --git a/games/abstract_game.py b/games/abstract_game.py index 1be8875c..bd551de3 100644 --- a/games/abstract_game.py +++ b/games/abstract_game.py @@ -7,7 +7,7 @@ class AbstractGame(ABC): """ @abstractmethod - def __init__(self, seed=None): + def __init__(self, seed=None, render_mode=None): pass @abstractmethod diff --git a/games/atari.py b/games/atari.py index 67e13bdc..87ddf8b8 100644 --- a/games/atari.py +++ b/games/atari.py @@ -138,10 +138,10 @@ class Game(AbstractGame): Game wrapper. """ - def __init__(self, seed=None): - self.env = gym.make("Breakout-v4") + def __init__(self, seed=None, render_mode=None): + self.env = gym.make("Breakout-v4", render_mode=render_mode) if seed is not None: - self.env.seed(seed) + self.env.reset(seed=seed) def step(self, action): """ @@ -153,7 +153,9 @@ def step(self, action): Returns: The new observation, the reward and a boolean if the game has ended. """ - observation, reward, done, _ = self.env.step(action) + observation, reward, terminated, truncated, _ = self.env.step(action) + done = terminated or truncated + observation = cv2.resize(observation, (96, 96), interpolation=cv2.INTER_AREA) observation = numpy.asarray(observation, dtype="float32") / 255.0 observation = numpy.moveaxis(observation, -1, 0) @@ -179,7 +181,7 @@ def reset(self): Returns: Initial observation of the game. """ - observation = self.env.reset() + observation, _ = self.env.reset() observation = cv2.resize(observation, (96, 96), interpolation=cv2.INTER_AREA) observation = numpy.asarray(observation, dtype="float32") / 255.0 observation = numpy.moveaxis(observation, -1, 0) diff --git a/games/breakout.py b/games/breakout.py index 8a078d90..9bb2dc19 100644 --- a/games/breakout.py +++ b/games/breakout.py @@ -138,10 +138,10 @@ class Game(AbstractGame): Game wrapper. """ - def __init__(self, seed=None): - self.env = gym.make("Breakout-v4") + def __init__(self, seed=None, render_mode=None): + self.env = gym.make("Breakout-v4", render_mode=render_mode) if seed is not None: - self.env.seed(seed) + self.env.reset(seed=seed) def step(self, action): """ @@ -153,7 +153,9 @@ def step(self, action): Returns: The new observation, the reward and a boolean if the game has ended. """ - observation, reward, done, _ = self.env.step(action) + observation, reward, terminated, truncated, _ = self.env.step(action) + done = terminated or truncated + observation = cv2.resize(observation, (96, 96), interpolation=cv2.INTER_AREA) observation = numpy.asarray(observation, dtype="float32") / 255.0 observation = numpy.moveaxis(observation, -1, 0) @@ -179,7 +181,7 @@ def reset(self): Returns: Initial observation of the game. """ - observation = self.env.reset() + observation, _ = self.env.reset() observation = cv2.resize(observation, (96, 96), interpolation=cv2.INTER_AREA) observation = numpy.asarray(observation, dtype="float32") / 255.0 observation = numpy.moveaxis(observation, -1, 0) diff --git a/games/cartpole.py b/games/cartpole.py index fa1e8bbf..f22e4f7c 100644 --- a/games/cartpole.py +++ b/games/cartpole.py @@ -133,10 +133,10 @@ class Game(AbstractGame): Game wrapper. """ - def __init__(self, seed=None): - self.env = gym.make("CartPole-v1") + def __init__(self, seed=None, render_mode=None): + self.env = gym.make("CartPole-v1", render_mode=render_mode) if seed is not None: - self.env.seed(seed) + self.env.reset(seed=seed) def step(self, action): """ @@ -148,7 +148,8 @@ def step(self, action): Returns: The new observation, the reward and a boolean if the game has ended. """ - observation, reward, done, _ = self.env.step(action) + observation, reward, terminated, truncated, _ = self.env.step(action) + done = terminated or truncated return numpy.array([[observation]]), reward, done def legal_actions(self): @@ -171,7 +172,8 @@ def reset(self): Returns: Initial observation of the game. """ - return numpy.array([[self.env.reset()]]) + observation, _ = self.env.reset() + return numpy.array([[observation]]) def close(self): """ diff --git a/games/connect4.py b/games/connect4.py index a01e6551..51d28e0d 100644 --- a/games/connect4.py +++ b/games/connect4.py @@ -127,7 +127,7 @@ class Game(AbstractGame): Game wrapper. """ - def __init__(self, seed=None): + def __init__(self, seed=None, render_mode=None): self.env = Connect4() def step(self, action): @@ -140,7 +140,8 @@ def step(self, action): Returns: The new observation, the reward and a boolean if the game has ended. """ - observation, reward, done = self.env.step(action) + observation, reward, terminated, truncated, _ = self.env.step(action) + done = terminated or truncated return observation, reward * 10, done def to_play(self): @@ -172,7 +173,8 @@ def reset(self): Returns: Initial observation of the game. """ - return self.env.reset() + observation, _ = self.env.reset() + return observation def render(self): """ diff --git a/games/gomoku.py b/games/gomoku.py index 58fa89bb..11b459ef 100644 --- a/games/gomoku.py +++ b/games/gomoku.py @@ -133,7 +133,7 @@ class Game(AbstractGame): Game wrapper. """ - def __init__(self, seed=None): + def __init__(self, seed=None, render_mode=None): self.env = Gomoku() def step(self, action): @@ -146,7 +146,8 @@ def step(self, action): Returns: The new observation, the reward and a boolean if the game has ended. """ - observation, reward, done = self.env.step(action) + observation, reward, terminated, truncated, _ = self.env.step(action) + done = terminated or truncated return observation, reward, done def to_play(self): @@ -178,7 +179,8 @@ def reset(self): Returns: Initial observation of the game. """ - return self.env.reset() + observation, _ = self.env.reset() + return observation def close(self): """ diff --git a/games/gridworld.py b/games/gridworld.py index e6805db9..08a5a61a 100644 --- a/games/gridworld.py +++ b/games/gridworld.py @@ -138,11 +138,11 @@ class Game(AbstractGame): Game wrapper. """ - def __init__(self, seed=None): - self.env = gym.make("MiniGrid-Empty-Random-6x6-v0") + def __init__(self, seed=None, render_mode=None): + self.env = gym.make("MiniGrid-Empty-Random-6x6-v0", render_mode=render_mode) self.env = gym_minigrid.wrappers.ImgObsWrapper(self.env) if seed is not None: - self.env.seed(seed) + self.env.reset(seed=seed) def step(self, action): """ @@ -154,7 +154,8 @@ def step(self, action): Returns: The new observation, the reward and a boolean if the game has ended. """ - observation, reward, done, _ = self.env.step(action) + observation, reward, terminated, truncated, _ = self.env.step(action) + done = terminated or truncated return numpy.array(observation), reward, done def legal_actions(self): @@ -177,7 +178,8 @@ def reset(self): Returns: Initial observation of the game. """ - return numpy.array(self.env.reset()) + observation, _ = self.env.reset() + return numpy.array(observation) def close(self): """ diff --git a/games/lunarlander.py b/games/lunarlander.py index bdb1f09f..2ae4871c 100644 --- a/games/lunarlander.py +++ b/games/lunarlander.py @@ -129,11 +129,11 @@ class Game(AbstractGame): Game wrapper. """ - def __init__(self, seed=None): + def __init__(self, seed=None, render_mode=None): self.env = DeterministicLunarLander() # self.env = gym.make("LunarLander-v2") if seed is not None: - self.env.seed(seed) + self.env.reset(seed=seed) def step(self, action): """ @@ -145,7 +145,8 @@ def step(self, action): Returns: The new observation, the reward and a boolean if the game has ended. """ - observation, reward, done, _ = self.env.step(action) + observation, reward, terminated, truncated, _ = self.env.step(action) + done = terminated or truncated return numpy.array([[observation]]), reward / 3, done def legal_actions(self): @@ -168,7 +169,8 @@ def reset(self): Returns: Initial observation of the game. """ - return numpy.array([[self.env.reset()]]) + observation, _ = self.env.reset() + return numpy.array([[observation]]) def close(self): """ diff --git a/games/simple_grid.py b/games/simple_grid.py index f26ae429..5f7851a6 100644 --- a/games/simple_grid.py +++ b/games/simple_grid.py @@ -127,7 +127,7 @@ class Game(AbstractGame): Game wrapper. """ - def __init__(self, seed=None): + def __init__(self, seed=None, render_mode=None): self.env = GridEnv() def step(self, action): @@ -140,7 +140,8 @@ def step(self, action): Returns: The new observation, the reward and a boolean if the game has ended. """ - observation, reward, done = self.env.step(action) + observation, reward, terminated, truncated, _ = self.env.step(action) + done = terminated or truncated return [[observation]], reward * 10, done def legal_actions(self): @@ -163,7 +164,8 @@ def reset(self): Returns: Initial observation of the game. """ - return [[self.env.reset()]] + observation, _ = self.env.reset() + return [[observation]] def render(self): """ diff --git a/games/spiel.py b/games/spiel.py index 7bb3a194..a01a9f9f 100644 --- a/games/spiel.py +++ b/games/spiel.py @@ -145,7 +145,7 @@ class Game(AbstractGame): Game wrapper. """ - def __init__(self, seed=None): + def __init__(self, seed=None, render_mode=None): self.env = Spiel() def step(self, action): @@ -158,7 +158,8 @@ def step(self, action): Returns: The new observation, the reward and a boolean if the game has ended. """ - observation, reward, done = self.env.step(action) + observation, reward, terminated, truncated, _ = self.env.step(action) + done = terminated or truncated return observation, reward * 20, done def to_play(self): @@ -190,7 +191,8 @@ def reset(self): Returns: Initial observation of the game. """ - return self.env.reset() + observation, _ = self.env.reset() + return observation def render(self): """ diff --git a/games/tictactoe.py b/games/tictactoe.py index f331a9ae..b6722aee 100644 --- a/games/tictactoe.py +++ b/games/tictactoe.py @@ -127,7 +127,7 @@ class Game(AbstractGame): Game wrapper. """ - def __init__(self, seed=None): + def __init__(self, seed=None, render_mode=None): self.env = TicTacToe() def step(self, action): @@ -140,7 +140,8 @@ def step(self, action): Returns: The new observation, the reward and a boolean if the game has ended. """ - observation, reward, done = self.env.step(action) + observation, reward, terminated, truncated, _ = self.env.step(action) + done = terminated or truncated return observation, reward * 20, done def to_play(self): @@ -172,7 +173,8 @@ def reset(self): Returns: Initial observation of the game. """ - return self.env.reset() + observation, _ = self.env.reset() + return observation def render(self): """ diff --git a/games/twentyone.py b/games/twentyone.py index ea39af68..413440a6 100644 --- a/games/twentyone.py +++ b/games/twentyone.py @@ -139,7 +139,7 @@ class Game(AbstractGame): Game wrapper. """ - def __init__(self, seed=None): + def __init__(self, seed=None, render_mode=None): self.env = TwentyOne(seed) def step(self, action): @@ -152,7 +152,8 @@ def step(self, action): Returns: The new observation, the reward and a boolean if the game has ended. """ - observation, reward, done = self.env.step(action) + observation, reward, terminated, truncated, _ = self.env.step(action) + done = terminated or truncated return observation, reward * 10, done def to_play(self): @@ -184,7 +185,8 @@ def reset(self): Returns: Initial observation of the game. """ - return self.env.reset() + observation, _ = self.env.reset() + return observation def render(self): """ diff --git a/muzero.py b/muzero.py index f7601c9b..bae78eaf 100644 --- a/muzero.py +++ b/muzero.py @@ -434,7 +434,7 @@ def load_model(self, checkpoint_path=None, replay_buffer_path=None): """ # Load checkpoint if checkpoint_path: - checkpoint_path = pathlib.Path(checkpoint_path) + checkpoint_path = pathlib.Path(checkpoint_path).absolute() self.checkpoint = torch.load(checkpoint_path) print(f"\nUsing checkpoint from {checkpoint_path}")