Skip to content

Commit

Permalink
Add TimeAwareObservation support for environments without a spec (#…
Browse files Browse the repository at this point in the history
  • Loading branch information
pseudo-rnd-thoughts authored Jan 8, 2025
1 parent fc74bb8 commit e6e3521
Show file tree
Hide file tree
Showing 3 changed files with 39 additions and 7 deletions.
6 changes: 3 additions & 3 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
# See https://pre-commit.com/hooks.html for more hooks
repos:
- repo: https://github.com/pre-commit/pre-commit-hooks
rev: v4.6.0
rev: v5.0.0
hooks:
- id: check-symlinks
- id: destroyed-symlinks
Expand Down Expand Up @@ -35,7 +35,7 @@ repos:
- --show-source
- --statistics
- repo: https://github.com/asottile/pyupgrade
rev: v3.17.0
rev: v3.19.1
hooks:
- id: pyupgrade
args: ["--py38-plus"]
Expand All @@ -44,7 +44,7 @@ repos:
hooks:
- id: isort
- repo: https://github.com/python/black
rev: 24.8.0
rev: 24.10.0
hooks:
- id: black
- repo: https://github.com/pycqa/pydocstyle
Expand Down
15 changes: 12 additions & 3 deletions gymnasium/wrappers/stateful_observation.py
Original file line number Diff line number Diff line change
Expand Up @@ -198,9 +198,18 @@ def __init__(
if env.spec is not None and env.spec.max_episode_steps is not None:
self.max_timesteps = env.spec.max_episode_steps
else:
raise ValueError(
"The environment must be wrapped by a TimeLimit wrapper or the spec specify a `max_episode_steps`."
)
# else we need to loop through the environment stack to check if a `TimeLimit` wrapper exists
wrapped_env = env
while isinstance(wrapped_env, gym.Wrapper):
if isinstance(wrapped_env, gym.wrappers.TimeLimit):
self.max_timesteps = wrapped_env._max_episode_steps
break
wrapped_env = wrapped_env.env

if not isinstance(wrapped_env, gym.wrappers.TimeLimit):
raise ValueError(
"The environment must be wrapped by a TimeLimit wrapper or the spec specify a `max_episode_steps`."
)

self.timesteps: int = 0

Expand Down
25 changes: 24 additions & 1 deletion tests/wrappers/test_time_aware_observation.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,16 @@
"""Test suite for TimeAwareObservation wrapper."""

import re
import warnings

import numpy as np
import pytest

import gymnasium as gym
from gymnasium import spaces
from gymnasium.envs.classic_control import CartPoleEnv
from gymnasium.spaces import Box, Dict, Tuple
from gymnasium.wrappers import TimeAwareObservation
from gymnasium.wrappers import TimeAwareObservation, TimeLimit
from tests.testing_env import GenericTestEnv


Expand Down Expand Up @@ -41,6 +45,25 @@ def test_default(env_id):
assert wrapped_obs.shape[0] == obs.shape[0] + 1


def test_no_spec():
env = CartPoleEnv()

with pytest.raises(
ValueError,
match=re.escape(
"The environment must be wrapped by a TimeLimit wrapper or the spec specify a `max_episode_steps`."
),
):
TimeAwareObservation(env)

env = TimeLimit(env, 100)
with warnings.catch_warnings(record=True) as caught_warnings:
env = TimeAwareObservation(env)

assert env.max_timesteps == 100
assert len(caught_warnings) == 0


def test_no_flatten():
"""Test the TimeAwareObservation wrapper without flattening the space."""
env = GenericTestEnv(observation_space=Box(0, 1))
Expand Down

0 comments on commit e6e3521

Please sign in to comment.