diff --git a/diambra/arena/stable_baselines/make_sb_env.py b/diambra/arena/stable_baselines/make_sb_env.py index 294a95a..d66b1ea 100644 --- a/diambra/arena/stable_baselines/make_sb_env.py +++ b/diambra/arena/stable_baselines/make_sb_env.py @@ -19,9 +19,9 @@ def make_sb_env(game_id: str, env_settings: EnvironmentSettings=EnvironmentSetti """ Create a wrapped, monitored VecEnv. :param game_id: (str) the game environment ID - :param env_settings: (dict) parameters for DIAMBRA Arena environment - :param wrappers_settings: (dict) parameters for environment wrapping function - :param episode_recording_settings: (dict) parameters for environment recording wrapping function + :param env_settings: (EnvironmentSettings) parameters for DIAMBRA Arena environment + :param wrappers_settings: (WrappersSettings) parameters for environment wrapping function + :param episode_recording_settings: (RecordingSettings) parameters for environment recording wrapping function :param start_index: (int) start rank index :param allow_early_resets: (bool) allows early reset of the environment :param start_method: (str) method used to start the subprocesses. See SubprocVecEnv doc for more information @@ -42,10 +42,9 @@ def make_sb_env(game_id: str, env_settings: EnvironmentSettings=EnvironmentSetti def _make_sb_env(rank, seed): # Seed management - if seed is None: - env_settings.seed = int(time.time()) + rank - else: - env_settings.seed = seed + rank + env_settings.seed = int(time.time()) if seed is None else seed + env_settings.seed += rank + def _init(): env = diambra.arena.make(game_id, env_settings, wrappers_settings, episode_recording_settings, render_mode, rank=rank) @@ -53,7 +52,7 @@ def _init(): env = Monitor(env, logger.get_dir() and os.path.join(logger.get_dir(), str(rank)), allow_early_resets=allow_early_resets) return env - set_global_seeds(env_settings.seed + rank) + set_global_seeds(env_settings.seed) return _init # If not wanting vectorized envs diff --git a/diambra/arena/stable_baselines3/make_sb3_env.py b/diambra/arena/stable_baselines3/make_sb3_env.py index 5f010ac..3c94a89 100644 --- a/diambra/arena/stable_baselines3/make_sb3_env.py +++ b/diambra/arena/stable_baselines3/make_sb3_env.py @@ -36,10 +36,9 @@ def make_sb3_env(game_id: str, env_settings: EnvironmentSettings=EnvironmentSett def _make_sb3_env(rank, seed): # Seed management - if seed is None: - env_settings.seed = int(time.time()) + rank - else: - env_settings.seed = seed + rank + env_settings.seed = int(time.time()) if seed is None else seed + env_settings.seed += rank + def _init(): env = diambra.arena.make(game_id, env_settings, wrappers_settings, episode_recording_settings, render_mode, rank=rank) @@ -49,7 +48,7 @@ def _init(): os.makedirs(log_dir, exist_ok=True) env = Monitor(env, log_dir, allow_early_resets=allow_early_resets) return env - set_random_seed(env_settings.seed + rank) + set_random_seed(env_settings.seed) return _init # If not wanting vectorized envs