Skip to content

Commit

Permalink
Fixed unchecked None value in SubprocVecEnv (DLR-RM#808)
Browse files Browse the repository at this point in the history
* Fixed unchecked None value in SubprocVecEnv

* Fixed unchecked None value in DummyVecEnv

* Fix formatting

* Update test and changelog

* Improve test

Co-authored-by: Antonin RAFFIN <[email protected]>
  • Loading branch information
ScheiklP and araffin authored Apr 12, 2022
1 parent 39a4f93 commit ed308a7
Show file tree
Hide file tree
Showing 4 changed files with 41 additions and 3 deletions.
1 change: 1 addition & 0 deletions docs/misc/changelog.rst
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ Bug Fixes:
- Fixed saving and loading large policies greater than 2GB (@jkterry1, @ycheng517)
- Fixed final goal selection strategy that did not sample the final achieved goal (@qgallouedec)
- Fixed a bug with special characters in the tensorboard log name (@quantitative-technologies)
- Fixed a bug in ``DummyVecEnv``'s and ``SubprocVecEnv``'s seeding function. None value was unchecked (@ScheiklP)

Deprecations:
^^^^^^^^^^^^^
Expand Down
4 changes: 3 additions & 1 deletion stable_baselines3/common/vec_env/dummy_vec_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,9 @@ def step_wait(self) -> VecEnvStepReturn:
return (self._obs_from_buf(), np.copy(self.buf_rews), np.copy(self.buf_dones), deepcopy(self.buf_infos))

def seed(self, seed: Optional[int] = None) -> List[Union[None, int]]:
seeds = list()
if seed is None:
seed = np.random.randint(0, 2**32 - 1)
seeds = []
for idx, env in enumerate(self.envs):
seeds.append(env.seed(seed + idx))
return seeds
Expand Down
2 changes: 2 additions & 0 deletions stable_baselines3/common/vec_env/subproc_vec_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,6 +123,8 @@ def step_wait(self) -> VecEnvStepReturn:
return _flatten_obs(obs, self.observation_space), np.stack(rews), np.stack(dones), infos

def seed(self, seed: Optional[int] = None) -> List[Union[None, int]]:
if seed is None:
seed = np.random.randint(0, 2**32 - 1)
for idx, remote in enumerate(self.remotes):
remote.send(("seed", seed + idx))
return [remote.recv() for remote in self.remotes]
Expand Down
37 changes: 35 additions & 2 deletions tests/test_vec_envs.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ def reset(self):
return self.state

def step(self, action):
reward = 1
reward = float(np.random.rand())
self._choose_next_state()
self.current_step += 1
done = self.current_step >= self.ep_length
Expand All @@ -45,7 +45,9 @@ def render(self, mode="human"):
return np.zeros((4, 4, 3))

def seed(self, seed=None):
pass
if seed is not None:
np.random.seed(seed)
self.observation_space.seed(seed)

@staticmethod
def custom_method(dim_0=1, dim_1=1):
Expand Down Expand Up @@ -440,3 +442,34 @@ def make_monitored_env():

vec_env = VecFrameStack(vec_env, n_stack=2)
assert vec_env.env_is_wrapped(Monitor) == [False, True]


@pytest.mark.parametrize("vec_env_class", VEC_ENV_CLASSES)
def test_vec_seeding(vec_env_class):
def make_env():
return CustomGymEnv(gym.spaces.Box(low=np.zeros(2), high=np.ones(2)))

# For SubprocVecEnv check for all starting methods
start_methods = [None]
if vec_env_class != DummyVecEnv:
all_methods = {"forkserver", "spawn", "fork"}
available_methods = multiprocessing.get_all_start_methods()
start_methods = list(all_methods.intersection(available_methods))

for start_method in start_methods:
if start_method is not None:
vec_env_class = functools.partial(SubprocVecEnv, start_method=start_method)

n_envs = 3
vec_env = vec_env_class([make_env] * n_envs)
# Seed with no argument
vec_env.seed()
obs = vec_env.reset()
_, rewards, _, _ = vec_env.step(np.array([vec_env.action_space.sample() for _ in range(n_envs)]))
# Seed should be different per process
assert not np.allclose(obs[0], obs[1])
assert not np.allclose(rewards[0], rewards[1])
assert not np.allclose(obs[1], obs[2])
assert not np.allclose(rewards[1], rewards[2])

vec_env.close()

0 comments on commit ed308a7

Please sign in to comment.