Skip to content
This repository has been archived by the owner on May 6, 2024. It is now read-only.

Gymnasium migration #373

Open
wants to merge 18 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 5 additions & 5 deletions .github/workflows/test_and_deploy.yml
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ jobs:
runs-on: ${{ matrix.os }}
strategy:
matrix:
python-version: ["3.6", "3.7", "3.8", "3.9"]
python-version: ["3.8", "3.9", "3.10", "3.11"]
os: [ubuntu-20.04, macos-latest]
fail-fast: false
steps:
Expand Down Expand Up @@ -62,14 +62,14 @@ jobs:
path: nle_test_ci_${{ github.sha }}.tar.gz

test_sdist:
name: Test sdist on MacOS w/ Py3.8
name: Test sdist on MacOS w/ Py3.11
needs: test_repo
runs-on: macos-latest
steps:
- name: Setup Python 3.8 env
- name: Setup Python 3.11 env
uses: actions/setup-python@v2
with:
python-version: 3.8
python-version: 3.11
- name: Ensure latest pip & wheel
run: "python -m pip install -q --upgrade pip wheel"
- name: Install dependencies
Expand Down Expand Up @@ -138,7 +138,7 @@ jobs:
# NOTE: we assume that dist/ contains a built sdist (and only that).
# Yes, we could be more defensively, but What Could Go Wrong?™
- name: Publish package to PyPI
uses: pypa/gh-action-pypi-publish@master
uses: pypa/gh-action-pypi-publish@release/v1
with:
user: __token__
password: ${{ secrets.PYPI_TOKEN }}
4 changes: 2 additions & 2 deletions .github/workflows/test_package.yml
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ jobs:
runs-on: macos-latest
strategy:
matrix:
python-version: ["3.5", "3.6", "3.7", "3.8"]
python-version: ["3.8", "3.9", "3.10", "3.11"]
fail-fast: false
steps:
- name: Setup Python ${{ matrix.python-version }} env
Expand All @@ -27,4 +27,4 @@ jobs:
run: "pip install nle"
- name: Check nethack is installed
run: |
python -c 'import nle, gym; e = gym.make("NetHack-v0"); e.reset(); e.step(0)'
python -c 'import nle; import gymnasium as gym; e = gym.make("NetHack-v0"); e.reset(); e.step(0)'
4 changes: 4 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -181,6 +181,7 @@ venv.bak/

# IDE
.idea/
.vscode/

# Rope project settings
.ropeproject
Expand All @@ -205,3 +206,6 @@ nle/version.py
nle_data/
nle/fbs/
nle/nethackdir

# Generated during tests
nle.ttyrec3.bz2
6 changes: 3 additions & 3 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
repos:
- repo: https://github.com/ambv/black
rev: 22.3.0
rev: 24.3.0
hooks:
- id: black
language_version: python3.8
- repo: https://github.com/pycqa/flake8
rev: '3.9.2'
rev: '7.0.0'
hooks:
- id: flake8
additional_dependencies: [flake8-bugbear]
Expand All @@ -23,7 +23,7 @@ repos:
language: system
files: ^(src\/nle|include\/nle|win\/rl|sys\/unix\/nle).*\.(c|cc|cxx|cpp|cu|h|hpp|hxx|cuh|proto)$
- repo: https://github.com/pycqa/isort
rev: 5.8.0
rev: 5.13.2
hooks:
- id: isort
name: isort (python)
13 changes: 7 additions & 6 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -76,12 +76,12 @@ to add papers.
# Getting started

Starting with NLE environments is extremely simple, provided one is familiar
with other gym / RL environments.
with other Gymnasium / RL environments.


## Installation

NLE requires `python>=3.5`, `cmake>=3.15` to be installed and available both when building the
NLE requires `python>=3.8`, `cmake>=3.15` to be installed and available both when building the
package, and at runtime.

On **MacOS**, one can use `Homebrew` as follows:
Expand Down Expand Up @@ -136,7 +136,7 @@ README](docker/README.md).
After installation, one can try out any of the provided tasks as follows:

```python
>>> import gym
>>> import gymnasium as gym
>>> import nle
>>> env = gym.make("NetHackScore-v0")
>>> env.reset() # each reset generates a new dungeon
Expand Down Expand Up @@ -174,8 +174,9 @@ $ pip install "nle[agent]"
$ python -m nle.agent.agent --num_actors 80 --batch_size 32 --unroll_length 80 --learning_rate 0.0001 --entropy_cost 0.0001 --use_lstm --total_steps 1000000000
```

Plot the mean return over the last 100 episodes:
Plot the mean return over the last 100 episodes (requires gnuplotlib):
```bash
$ pip install gnuplotlib
$ python -m nle.scripts.plot
```
```
Expand Down Expand Up @@ -222,8 +223,8 @@ see [this document](./CONTRIBUTING.md).
NLE is direct fork of [NetHack](https://github.com/nethack/nethack) and
therefore contains code that operates on many different levels of abstraction.
This ranges from low-level game logic, to the higher-level administration of
repeated nethack games, and finally to binding of these games to Python `gym`
environment.
repeated nethack games, and finally to binding of these games to Python
`gymbasium` environment.

If you want to learn more about the architecture of `nle` and how it works
under the hood, checkout the [architecture document](./doc/nle/ARCHITECTURE.md).
Expand Down
14 changes: 8 additions & 6 deletions nle/agent/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@
'`pip install "nle[agent]"`'
)

import gym # noqa: E402
import gymnasium as gym # noqa: E402

import nle # noqa: F401, E402
from nle import nethack # noqa: E402
Expand Down Expand Up @@ -141,7 +141,7 @@ def compute_policy_gradient_loss(logits, actions, advantages):


def create_env(name, *args, **kwargs):
return gym.make(name, observation_keys=("glyphs", "blstats"), *args, **kwargs)
return gym.make(name, *args, observation_keys=("glyphs", "blstats"), **kwargs)


def act(
Expand Down Expand Up @@ -350,8 +350,8 @@ def initial(self):
self.episode_return = torch.zeros(1, 1)
self.episode_step = torch.zeros(1, 1, dtype=torch.int32)
initial_done = torch.ones(1, 1, dtype=torch.uint8)

result = _format_observations(self.gym_env.reset())
obs, reset_info = self.gym_env.reset()
result = _format_observations(obs)
result.update(
reward=initial_reward,
done=initial_done,
Expand All @@ -362,13 +362,15 @@ def initial(self):
return result

def step(self, action):
observation, reward, done, unused_info = self.gym_env.step(action.item())
observation, reward, done, truncated, unused_info = self.gym_env.step(
action.item()
)
self.episode_step += 1
self.episode_return += reward
episode_step = self.episode_step
episode_return = self.episode_return
if done:
observation = self.gym_env.reset()
observation, reset_info = self.gym_env.reset()
self.episode_return = torch.zeros(1, 1)
self.episode_step = torch.zeros(1, 1, dtype=torch.int32)

Expand Down
4 changes: 2 additions & 2 deletions nle/env/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# Copyright (c) Facebook, Inc. and its affiliates.
import gym
from gym.envs import registration
import gymnasium as gym
from gymnasium.envs import registration

from nle.env.base import NLE, DUNGEON_SHAPE

Expand Down
74 changes: 50 additions & 24 deletions nle/env/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
import warnings
import weakref

import gym
import gymnasium as gym
import numpy as np

from nle import nethack
Expand Down Expand Up @@ -142,12 +142,16 @@ class NLE(gym.Env):

Examples:
>>> env = NLE()
>>> obs = env.reset()
>>> obs, reward, done, info = env.step(0)
>>> obs, reset_info = env.reset()
>>> obs, reward, done, truncation, info = env.step(0)
>>> env.render()
"""

metadata = {"render.modes": ["human", "ansi"]}
# Gym expects an fps rate > 0 for render checks, but
# NetHack doesn't have any. Setting it to 42 because
# that's always the answer to life, the universe and
# everything.
metadata = {"render_modes": ["human", "ansi", "full"], "render_fps": 42}

class StepStatus(enum.IntEnum):
"""Specifies the status of the terminal state.
Expand Down Expand Up @@ -193,6 +197,7 @@ def __init__(
allow_all_yn_questions=False,
allow_all_modes=False,
spawn_monsters=True,
render_mode="human",
):
"""Constructs a new NLE environment.

Expand Down Expand Up @@ -224,12 +229,16 @@ def __init__(
If set to False, only skip click through 'MORE' on death.
spawn_monsters: If False, disables normal NetHack behavior to randomly
create monsters.
render_mode (str): mode used to render the screen. One of
"human" | "ansi" | "full".
Defaults to "human", i.e. what a human would see playing the game.
"""
self.character = character
self._max_episode_steps = max_episode_steps
self._allow_all_yn_questions = allow_all_yn_questions
self._allow_all_modes = allow_all_modes
self._save_ttyrec_every = save_ttyrec_every
self.render_mode = render_mode

if actions is None:
actions = FULL_ACTIONS
Expand Down Expand Up @@ -329,6 +338,19 @@ def _get_observation(self, observation):
for key, i in zip(self._original_observation_keys, self._original_indices)
}

def _get_end_status(self, observation, done):
if self._check_abort(observation):
end_status = self.StepStatus.ABORTED
else:
end_status = self._is_episode_end(observation)
return self.StepStatus(done or end_status)

def _get_information(self, end_status):
info = {}
info["end_status"] = end_status
info["is_ascended"] = self.nethack.how_done() == nethack.ASCENDED
return info

def print_action_meanings(self):
for a_idx, a in enumerate(self.actions):
print(a_idx, a)
Expand All @@ -349,6 +371,7 @@ def step(self, action: int):
- (*float*): a reward; see ``self._reward_fn`` to see how it is
specified.
- (*bool*): True if the state is terminal, False otherwise.
- (*bool*): True if the episode is truncated, False otherwise.
- (*dict*): a dictionary of extra information (such as
`end_status`, i.e. a status info -- death, task win, etc. --
for the terminal state).
Expand All @@ -367,11 +390,7 @@ def step(self, action: int):

self.last_observation = observation

if self._check_abort(observation):
end_status = self.StepStatus.ABORTED
else:
end_status = self._is_episode_end(observation)
end_status = self.StepStatus(done or end_status)
end_status = self._get_end_status(observation, done)

reward = float(
self._reward_fn(last_observation, action, observation, end_status)
Expand All @@ -382,17 +401,21 @@ def step(self, action: int):
self._quit_game(observation, done)
done = True

info = {}
info["end_status"] = end_status
info["is_ascended"] = self.nethack.how_done() == nethack.ASCENDED
truncated = False

return self._get_observation(observation), reward, done, info
return (
self._get_observation(observation),
reward,
done,
truncated,
self._get_information(end_status),
)

def _in_moveloop(self, observation):
program_state = observation[self._program_state_index]
return program_state[3] # in_moveloop

def reset(self, wizkit_items=None):
def reset(self, seed=None, options=None):
"""Resets the environment.

Note:
Expand All @@ -401,19 +424,21 @@ def reset(self, wizkit_items=None):
fail in case Nethack is initialized with some uncommon options.

Returns:
[dict] Observation of the state as defined by
`self.observation_space`.
(tuple) (Observation of the state as defined by
`self.observation_space`,
Extra game state information)
"""
super().reset(seed=seed)

self._episode += 1
if self.savedir and self._episode % self._save_ttyrec_every == 0:
new_ttyrec = self._ttyrec_pattern % self._episode
else:
new_ttyrec = None
self.last_observation = self.nethack.reset(
new_ttyrec, wizkit_items=wizkit_items
)
self.last_observation = self.nethack.reset(new_ttyrec, options=options)

self._steps = 0
done = False

for _ in range(1000):
# Get past initial phase of game. This should make sure
Expand All @@ -430,9 +455,11 @@ def reset(self, wizkit_items=None):
warnings.warn(
"Not in moveloop after 1000 tries, aborting (ttyrec: %s)." % new_ttyrec
)
return self.reset(wizkit_items=wizkit_items)
return self.reset(seed=seed, options=options)

return self._get_observation(self.last_observation)
return self._get_observation(self.last_observation), self._get_information(
self._get_end_status(self.last_observation, done)
)

def close(self):
self._close_nethack()
Expand Down Expand Up @@ -475,8 +502,9 @@ def get_seeds(self):
"""
return self.nethack.get_current_seeds()

def render(self, mode="human"):
def render(self):
"""Renders the state of the environment."""
mode = self.render_mode

if mode == "human":
obs = self.last_observation
Expand Down Expand Up @@ -515,8 +543,6 @@ def render(self, mode="human"):
# TODO: Why return a string here but print in the other branches?
return "\n".join([line.tobytes().decode("utf-8") for line in chars])

return super().render(mode=mode)

def __repr__(self):
return "<%s>" % self.__class__.__name__

Expand Down
Loading