From 3e2dcc4c2feeaf892aca6f3d3719a94b5990346e Mon Sep 17 00:00:00 2001 From: John Balis Date: Thu, 25 Jan 2024 11:50:47 -0600 Subject: [PATCH] Adds infos to `EpisodeData` (#132) * initial commit for info's fix * tentative draft of info support for EpisodeData * typing fix * typing fixed, removed print * added some information to dataset stabdards started work on test for varying infos * added tests for error in response to infos with timestep variant structure, and test of using StepDataCallback to fix it * added explicit np.array dtype support, documentation, and tests * updated doc page * table syntax change * remove print * DataCollectorV0 -> DataCollector * rename test * move info shape check to add to buffer * add _get_info in dummy test envs * _get_info_at_step_index * fix pre-commit * fix tests * fix docs * fix pre-commit * refactor * simplify tests * fix pre-commit * remove redundant comments * fix pre-commit * fixes * fix basic_usage * fix episode_data repr * fix common * improe tests --------- Co-authored-by: rodrigodelazcano Co-authored-by: Omar Younis --- docs/content/basic_usage.md | 12 +- docs/content/dataset_standards.md | 7 + minari/data_collector/data_collector.py | 86 ++++++++---- minari/dataset/episode_data.py | 4 +- minari/dataset/minari_storage.py | 21 ++- tests/common.py | 132 ++++++++++++++++-- .../callbacks/test_step_data_callback.py | 92 ++++++++++-- tests/data_collector/test_data_collector.py | 35 +++-- tests/dataset/test_minari_dataset.py | 2 + tests/utils/test_dataset_creation.py | 69 ++++++++- 10 files changed, 387 insertions(+), 73 deletions(-) diff --git a/docs/content/basic_usage.md b/docs/content/basic_usage.md index 6af27742..a6f9f558 100644 --- a/docs/content/basic_usage.md +++ b/docs/content/basic_usage.md @@ -77,7 +77,7 @@ for _ in range(total_episodes): if terminated or truncated: break -dataset = env.create_dataset(dataset_id="CartPole-v1-test-v0", +dataset = env.create_dataset(dataset_id="cartpole-test-v0", algorithm_name="Random-Policy", code_permalink="https://github.com/Farama-Foundation/Minari", author="Farama", @@ -96,7 +96,7 @@ Once the dataset has been created we can check if the Minari dataset id appears >>> import minari >>> local_datasets = minari.list_local_datasets() >>> local_datasets.keys() -dict_keys(['CartPole-v1-test-v0']) +dict_keys(['cartpole-test-v0']) ``` ```{eval-rst} @@ -125,7 +125,7 @@ env = gym.make('CartPole-v1') env = DataCollector(env, record_infos=True, max_buffer_steps=100000) total_episodes = 100 -dataset_name = "CartPole-v1-test-v0" +dataset_name = "cartpole-test-v0" dataset = None if dataset_name in minari.list_local_datasets(): dataset = minari.load_dataset(dataset_name) @@ -161,9 +161,9 @@ Minari will only be able to load datasets that are stored in your `local root di ```python >>> import minari ->>> dataset = minari.load_dataset('CartPole-v1-test-v0') +>>> dataset = minari.load_dataset('cartpole-test-v0') >>> dataset.name -'CartPole-v1-test-v0' +'cartpole-test-v0' ``` ### Download Remote Datasets @@ -323,7 +323,7 @@ From a :class:`minari.MinariDataset` object we can also recover the Gymnasium en ```python import minari -dataset = minari.load_dataset('CartPole-v1-test-v0') +dataset = minari.load_dataset('cartpole-test-v0') env = dataset.recover_environment() env.reset() diff --git a/docs/content/dataset_standards.md b/docs/content/dataset_standards.md index 280f5d80..a2af8ee5 100644 --- a/docs/content/dataset_standards.md +++ b/docs/content/dataset_standards.md @@ -554,5 +554,12 @@ The `sampled_episodes` variable will be a list of 10 `EpisodeData` elements, eac | `rewards` | `np.ndarray` | Rewards for each timestep. | | `terminations` | `np.ndarray` | Terminations for each timestep. | | `truncations` | `np.ndarray` | Truncations for each timestep. | +| `infos` | `dict` | A dictionary containing additional information. | As mentioned in the `Supported Spaces` section, many different observation and action spaces are supported so the data type for these fields are dependent on the environment being used. + +## Additional Information Formatting + +When creating a dataset with `DataCollector`, if the `DataCollector` is initialized with `record_infos=True`, an info dict must be provided from every call to the environment's `step` and `reset` function. The structure of the info dictionary must be the same across timesteps. + +Given that it is not guaranteed that all Gymnasium environments provide infos at every timestep, we provide the `StepDataCallback` which can modify the infos from a non-compliant environment so they have the same structure at every timestep. An example of this pattern is available in our test `test_data_collector_step_data_callback_info_correction` in test_step_data_callback.py. diff --git a/minari/data_collector/data_collector.py b/minari/data_collector/data_collector.py index 80f893f4..28310e91 100644 --- a/minari/data_collector/data_collector.py +++ b/minari/data_collector/data_collector.py @@ -128,6 +128,7 @@ def __init__( ) self._record_infos = record_infos + self._reference_info = None self.max_buffer_steps = max_buffer_steps # Initialzie empty buffer @@ -136,11 +137,11 @@ def __init__( self._step_id = -1 self._episode_id = -1 - def _add_to_episode_buffer( + def _add_step_data( self, episode_buffer: EpisodeBuffer, - step_data: Union[StepData, Dict[str, StepData]], - ) -> EpisodeBuffer: + step_data: Union[StepData, Dict], + ): """Add step data dictionary to episode buffer. Args: @@ -150,31 +151,43 @@ def _add_to_episode_buffer( Returns: Dict: new dictionary episode buffer with added values from step_data """ + dict_data = dict(step_data) + if not self._record_infos: + dict_data = {k: v for k, v in step_data.items() if k != "infos"} + else: + assert self._reference_info is not None + if not _check_infos_same_shape( + self._reference_info, step_data["infos"] + ): + raise ValueError( + "Info structure inconsistent with info structure returned by original reset." + ) + + self._add_to_episode_buffer(episode_buffer, dict_data) + + def _add_to_episode_buffer( + self, + episode_buffer: EpisodeBuffer, + step_data: Dict[str, Any], + ): for key, value in step_data.items(): - if (not self._record_infos and key == "infos") or (value is None): + if value is None: continue if key not in episode_buffer: - if isinstance(value, dict): - episode_buffer[key] = self._add_to_episode_buffer({}, value) - else: - episode_buffer[key] = [value] + episode_buffer[key] = {} if isinstance(value, dict) else [] + + if isinstance(value, dict): + assert isinstance( + episode_buffer[key], dict + ), f"Element to be inserted is type 'dict', but buffer accepts type {type(episode_buffer[key])}" + + self._add_to_episode_buffer(episode_buffer[key], value) else: - if isinstance(value, dict): - assert isinstance( - episode_buffer[key], dict - ), f"Element to be inserted is type 'dict', but buffer accepts type {type(episode_buffer[key])}" - - episode_buffer[key] = self._add_to_episode_buffer( - episode_buffer[key], value - ) - else: - assert isinstance( - episode_buffer[key], list - ), f"Element to be inserted is type 'list', but buffer accepts type {type(episode_buffer[key])}" - episode_buffer[key].append(value) - - return episode_buffer + assert isinstance( + episode_buffer[key], list + ), f"Element to be inserted is type 'list', but buffer accepts type {type(episode_buffer[key])}" + episode_buffer[key].append(value) def step( self, action: ActType @@ -191,6 +204,9 @@ def step( terminated=terminated, truncated=truncated, ) + + # Force step data dictionary to include keys corresponding to Gymnasium step returns: + # actions, observations, rewards, terminations, truncations, and infos assert STEP_DATA_KEYS.issubset( step_data.keys() ), "One or more required keys is missing from 'step-data'." @@ -203,7 +219,7 @@ def step( ), "Actions are not in action space." self._step_id += 1 - self._buffer[-1] = self._add_to_episode_buffer(self._buffer[-1], step_data) + self._add_step_data(self._buffer[-1], step_data) if ( self.max_buffer_steps is not None @@ -219,7 +235,7 @@ def step( "observations": step_data["observations"], "infos": step_data["infos"], } - eps_buff = self._add_to_episode_buffer(eps_buff, previous_data) + self._add_step_data(eps_buff, previous_data) self._buffer.append(eps_buff) return obs, rew, terminated, truncated, info @@ -245,7 +261,7 @@ def reset( observation (ObsType): Observation of the initial state. info (dictionary): Auxiliary information complementing ``observation``. """ - autoseed_enabled = (not options) or options.get("minari_autoseed", False) + autoseed_enabled = (not options) or options.get("minari_autoseed", True) if seed is None and autoseed_enabled: seed = secrets.randbits(AUTOSEED_BIT_SIZE) @@ -253,6 +269,9 @@ def reset( step_data = self._step_data_callback(env=self.env, obs=obs, info=info) self._episode_id += 1 + if self._record_infos and self._reference_info is None: + self._reference_info = step_data["infos"] + assert STEP_DATA_KEYS.issubset( step_data.keys() ), "One or more required keys is missing from 'step-data'" @@ -262,7 +281,7 @@ def reset( "seed": str(None) if seed is None else seed, "id": self._episode_id } - episode_buffer = self._add_to_episode_buffer(episode_buffer, step_data) + self._add_step_data(episode_buffer, step_data) self._buffer.append(episode_buffer) return obs, info @@ -418,3 +437,16 @@ def close(self): self._buffer.clear() shutil.rmtree(self._tmp_dir.name) + + +def _check_infos_same_shape(info_1: dict, info_2: dict): + if info_1.keys() != info_2.keys(): + return False + for key in info_1.keys(): + if type(info_1[key]) is not type(info_2[key]): + return False + if isinstance(info_1[key], dict): + return _check_infos_same_shape(info_1[key], info_2[key]) + elif isinstance(info_1[key], np.ndarray): + return (info_1[key].shape == info_2[key].shape) and (info_1[key].dtype == info_2[key].dtype) + return True diff --git a/minari/dataset/episode_data.py b/minari/dataset/episode_data.py index 53786144..8d32d754 100644 --- a/minari/dataset/episode_data.py +++ b/minari/dataset/episode_data.py @@ -19,6 +19,7 @@ class EpisodeData: rewards: np.ndarray terminations: np.ndarray truncations: np.ndarray + infos: dict def __repr__(self) -> str: return ( @@ -30,7 +31,8 @@ def __repr__(self) -> str: f"actions={EpisodeData._repr_space_values(self.actions)}, " f"rewards=ndarray of {len(self.rewards)} floats, " f"terminations=ndarray of {len(self.terminations)} bools, " - f"truncations=ndarray of {len(self.truncations)} bools" + f"truncations=ndarray of {len(self.truncations)} bools, " + f"infos=dict with the following keys: {list(self.infos.keys())}" ")" ) diff --git a/minari/dataset/minari_storage.py b/minari/dataset/minari_storage.py index 860c2035..b076d0e4 100644 --- a/minari/dataset/minari_storage.py +++ b/minari/dataset/minari_storage.py @@ -88,7 +88,10 @@ def new( obj._action_space = action_space if env_spec is not None: - metadata["env_spec"] = env_spec.to_json() + try: + metadata["env_spec"] = env_spec.to_json() + except TypeError: + pass with h5py.File(obj._file_path, "a") as file: file.attrs.update(metadata) return obj @@ -161,6 +164,19 @@ def apply( ep_dicts = self.get_episodes(episode_indices) return map(function, ep_dicts) + def _decode_infos(self, infos: h5py.Group): + result = {} + for key in infos.keys(): + if isinstance(infos[key], h5py.Group): + result[key] = self._decode_infos(infos[key]) + elif isinstance(infos[key], h5py.Dataset): + result[key] = infos[key][()] + else: + raise ValueError( + "Infos are in an unsupported format; see Minari documentation for supported formats." + ) + return result + def _decode_space( self, hdf_ref: Union[h5py.Group, h5py.Dataset, h5py.Datatype], @@ -219,6 +235,9 @@ def get_episodes(self, episode_indices: Iterable[int]) -> List[dict]: "actions": self._decode_space( ep_group["actions"], self.action_space ), + "infos": self._decode_infos(ep_group["infos"]) + if "infos" in ep_group + else {}, } for key in {"rewards", "terminations", "truncations"}: group_value = ep_group[key] diff --git a/tests/common.py b/tests/common.py index 65f013de..8c4a99c7 100644 --- a/tests/common.py +++ b/tests/common.py @@ -28,16 +28,42 @@ def __init__(self): low=-1, high=4, shape=(3,), dtype=np.float32 ) + def _get_info(self): + return {"timestep": np.array([self.timestep])} + def step(self, action): terminated = self.timestep > 5 self.timestep += 1 - return self.observation_space.sample(), 0, terminated, False, {} + return ( + self.observation_space.sample(), + 0, + terminated, + False, + self._get_info(), + ) def reset(self, seed=None, options=None): self.timestep = 0 self.observation_space.seed(seed) - return self.observation_space.sample(), {} + return self.observation_space.sample(), self._get_info() + + +class DummyInfoEnv(DummyBoxEnv): + def __init__(self, info=None): + super().__init__() + self.info = info if info is not None else {} + + def _get_info(self): + return self.info + + +class DummyInconsistentInfoEnv(DummyBoxEnv): + def __init__(self): + super().__init__() + + def _get_info(self): + return super()._get_info() if self.timestep % 2 == 0 else {} class DummyMultiDimensionalBoxEnv(gym.Env): @@ -76,16 +102,28 @@ def __init__(self): ) ) + def _get_info(self): + return {"timestep": np.array([self.timestep])} + def step(self, action): terminated = self.timestep > 5 self.timestep += 1 - return self.observation_space.sample(), 0, terminated, False, {} + return ( + self.observation_space.sample(), + 0, + terminated, + False, + self._get_info(), + ) def reset(self, seed=None, options=None): self.timestep = 0 self.observation_space.seed(seed) - return self.observation_space.sample(), {} + return ( + self.observation_space.sample(), + self._get_info(), + ) class DummyDictEnv(gym.Env): @@ -113,16 +151,29 @@ def __init__(self): } ) + def _get_info(self): + return { + "timestep": np.array([self.timestep]), + "component_1": {"next_timestep": np.array([self.timestep + 1])}, + } + def step(self, action): terminated = self.timestep > 5 self.timestep += 1 - return self.observation_space.sample(), 0, terminated, False, {} + return ( + self.observation_space.sample(), + 0, + terminated, + False, + self._get_info(), + ) def reset(self, seed=None, options=None): self.timestep = 0 self.observation_space.seed(seed) - return self.observation_space.sample(), {} + + return self.observation_space.sample(), self._get_info() class DummyTupleEnv(gym.Env): @@ -146,16 +197,23 @@ def __init__(self): ) ) + def _get_info(self): + return { + "info_1": np.ones((2, 2)), + "component_1": {"component_1_info_1": np.ones((2,))}, + } + def step(self, action): terminated = self.timestep > 5 self.timestep += 1 - return self.observation_space.sample(), 0, terminated, False, {} + return self.observation_space.sample(), 0, terminated, False, self._get_info() def reset(self, seed=None, options=None): self.timestep = 0 self.observation_space.seed(seed) - return self.observation_space.sample(), {} + + return self.observation_space.sample(), self._get_info() class DummyTextEnv(gym.Env): @@ -237,6 +295,18 @@ def register_dummy_envs(): max_episode_steps=5, ) + register( + id="DummyInfoEnv-v0", + entry_point="tests.common:DummyInfoEnv", + max_episode_steps=5, + ) + + register( + id="DummyInconsistentInfoEnv-v0", + entry_point="tests.common:DummyInconsistentInfoEnv", + max_episode_steps=5, + ) + register( id="DummyMultiDimensionalBoxEnv-v0", entry_point="tests.common:DummyMultiDimensionalBoxEnv", @@ -512,6 +582,20 @@ def check_data_integrity(data: MinariStorage, episode_indices: Iterable[int]): assert total_steps == data.total_steps +def get_info_at_step_index(infos: Dict, step_index: int) -> Dict: + result = {} + for key in infos.keys(): + if isinstance(infos[key], dict): + result[key] = get_info_at_step_index(infos[key], step_index) + elif isinstance(infos[key], np.ndarray): + result[key] = infos[key][step_index] + else: + raise ValueError( + "Infos are in an unsupported format; see Minari documentation for supported formats." + ) + return result + + def _reconstuct_obs_or_action_at_index_recursive( data: Union[dict, tuple, np.ndarray], index: int ) -> Union[np.ndarray, dict, tuple]: @@ -604,16 +688,18 @@ def create_dummy_dataset_with_collecter_env_helper( def check_episode_data_integrity( - episode_data_list: List[EpisodeData], + episode_data_list: Union[List[EpisodeData], MinariDataset], observation_space: gym.spaces.Space, action_space: gym.spaces.Space, + info_sample: Optional[dict] = None, ): """Checks to see if a list of EpisodeData instances has consistent data and that the observations and actions are in the appropriate spaces. Args: episode_data_list (List[EpisodeData]): A list of EpisodeData instances representing episodes. - observation_space(gym.spaces.Space): The environment's observation space. - action_space(gym.spaces.Space): The environment's action space. + observation_space (gym.spaces.Space): The environment's observation space. + action_space (gym.spaces.Space): The environment's action space. + info_sample (dict): An info returned by the environment used to build the dataset. """ # verify the actions and observations are in the appropriate action space and observation space, and that the episode lengths are correct @@ -627,7 +713,14 @@ def check_episode_data_integrity( for i in range(episode.total_timesteps + 1): obs = _reconstuct_obs_or_action_at_index_recursive(episode.observations, i) + if info_sample is not None: + assert check_infos_equal( + get_info_at_step_index(episode.infos, i), + info_sample + ) + assert observation_space.contains(obs) + for i in range(episode.total_timesteps): action = _reconstuct_obs_or_action_at_index_recursive(episode.actions, i) assert action_space.contains(action) @@ -637,6 +730,19 @@ def check_episode_data_integrity( assert episode.total_timesteps == len(episode.truncations) +def check_infos_equal(info_1: Dict, info_2: Dict) -> bool: + if info_1.keys() != info_2.keys(): + return False + for key in info_1.keys(): + if isinstance(info_1[key], dict): + return check_infos_equal(info_1[key], info_2[key]) + elif isinstance(info_1[key], np.ndarray): + return bool(np.all(info_1[key] == info_2[key])) + else: + return info_1[key] == info_2[key] + return True + + def _space_subset_helper(entry: Dict): return OrderedDict( @@ -648,7 +754,7 @@ def _space_subset_helper(entry: Dict): ) -def get_sample_buffer_for_dataset_from_env(env, num_episodes=10): +def get_sample_buffer_for_dataset_from_env(env: gym.Env, num_episodes: int = 10): buffer = [] observations = [] @@ -667,7 +773,7 @@ def get_sample_buffer_for_dataset_from_env(env, num_episodes=10): truncated = False while not terminated and not truncated: - action = env.action_space.sample() # User-defined policy function + action = env.action_space.sample() observation, reward, terminated, truncated, _ = env.step(action) observations.append(_space_subset_helper(observation)) actions.append(_space_subset_helper(action)) diff --git a/tests/data_collector/callbacks/test_step_data_callback.py b/tests/data_collector/callbacks/test_step_data_callback.py index c16df008..d26cda5a 100644 --- a/tests/data_collector/callbacks/test_step_data_callback.py +++ b/tests/data_collector/callbacks/test_step_data_callback.py @@ -1,5 +1,6 @@ import gymnasium as gym import numpy as np +import pytest from gymnasium import spaces import minari @@ -7,6 +8,7 @@ from minari.data_collector.callbacks import StepDataCallback from tests.common import ( check_data_integrity, + check_env_recovery, check_env_recovery_with_subset_spaces, check_load_and_delete_dataset, register_dummy_envs, @@ -37,13 +39,17 @@ def __call__(self, env, **kwargs): return step_data +class CustomSubsetInfoPadStepDataCallback(StepDataCallback): + def __call__(self, env, **kwargs): + step_data = super().__call__(env, **kwargs) + if step_data["infos"] == {}: + step_data["infos"] = {"timestep": np.array([-1])} + return step_data + + def test_data_collector_step_data_callback(): """Test DataCollector wrapper and Minari dataset creation.""" dataset_id = "dummy-dict-test-v0" - # delete the test dataset if it already exists - local_datasets = minari.list_local_datasets() - if dataset_id in local_datasets: - minari.delete_dataset(dataset_id) env = gym.make("DummyDictEnv-v0") @@ -74,23 +80,20 @@ def test_data_collector_step_data_callback(): ) num_episodes = 10 - # Step the environment, DataCollector wrapper will do the data collection job env.reset(seed=42) - for episode in range(num_episodes): terminated = False truncated = False while not terminated and not truncated: - action = env.action_space.sample() # User-defined policy function + action = env.action_space.sample() _, _, terminated, truncated, _ = env.step(action) env.reset() - # Create Minari dataset and store locally dataset = env.create_dataset( dataset_id=dataset_id, algorithm_name="random_policy", - code_permalink="https://github.com/Farama-Foundation/Minari/blob/f095bfe07f8dc6642082599e07779ec1dd9b2667/tutorials/LocalStorage/local_storage.py", + code_permalink=str(__file__), author="WillDudley", author_email="wdudley@farama.org", ) @@ -102,11 +105,78 @@ def test_data_collector_step_data_callback(): check_data_integrity(dataset.storage, dataset.episode_indices) - # check that the environment can be recovered from the dataset check_env_recovery_with_subset_spaces( env.env, dataset, action_space_subset, observation_space_subset ) env.close() - # check load and delete local dataset check_load_and_delete_dataset(dataset_id) + + +def test_data_collector_step_data_callback_info_correction(): + """Test DataCollector wrapper and Minari dataset creation.""" + dataset_id = "dummy-inconsistent-info-v0" + env = gym.make("DummyInconsistentInfoEnv-v0") + + env = DataCollector( + env, + record_infos=True, + step_data_callback=CustomSubsetInfoPadStepDataCallback, + ) + num_episodes = 10 + + env.reset(seed=42) + for episode in range(num_episodes): + terminated = False + truncated = False + while not terminated and not truncated: + action = env.action_space.sample() + _, _, terminated, truncated, _ = env.step(action) + + env.reset() + + dataset = minari.create_dataset_from_collector_env( + dataset_id=dataset_id, + collector_env=env, + algorithm_name="random_policy", + code_permalink=str(__file__), + author="WillDudley", + author_email="wdudley@farama.org", + ) + + assert isinstance(dataset, MinariDataset) + assert dataset.total_episodes == num_episodes + assert dataset.spec.total_episodes == num_episodes + assert len(dataset.episode_indices) == num_episodes + + check_data_integrity(dataset.storage, dataset.episode_indices) + + check_env_recovery(env.env, dataset) + + env.close() + check_load_and_delete_dataset(dataset_id) + + env = gym.make("DummyInconsistentInfoEnv-v0") + + env = DataCollector( + env, + record_infos=True, + ) + # here we are checking to make sure that if we have an environment changing its info + # structure across timesteps, it is caught by the data_collector + with pytest.raises( + ValueError, + match=r"Info structure inconsistent with info structure returned by original reset." + ): + + num_episodes = 10 + env.reset(seed=42) + for _ in range(num_episodes): + terminated = False + truncated = False + while not terminated and not truncated: + action = env.action_space.sample() + _, _, terminated, truncated, _ = env.step(action) + + env.reset() + env.close() diff --git a/tests/data_collector/test_data_collector.py b/tests/data_collector/test_data_collector.py index e183842f..ee932cdc 100644 --- a/tests/data_collector/test_data_collector.py +++ b/tests/data_collector/test_data_collector.py @@ -3,7 +3,12 @@ import pytest from minari import DataCollector, EpisodeData, MinariDataset, StepDataCallback -from tests.common import check_load_and_delete_dataset, register_dummy_envs +from tests.common import ( + check_infos_equal, + check_load_and_delete_dataset, + get_info_at_step_index, + register_dummy_envs, +) MAX_UINT64 = np.iinfo(np.uint64).max @@ -71,6 +76,8 @@ def get_single_step_from_episode(episode: EpisodeData, index: int) -> EpisodeDat else: action = episode.actions[index] + infos = get_info_at_step_index(episode.infos, index) + step_data = { "id": episode.id, "total_timesteps": 1, @@ -80,6 +87,7 @@ def get_single_step_from_episode(episode: EpisodeData, index: int) -> EpisodeDat "rewards": episode.rewards[index], "terminations": episode.terminations[index], "truncations": episode.truncations[index], + "infos": infos, } return EpisodeData(**step_data) @@ -103,10 +111,10 @@ def test_truncation_without_reset(dataset_id, env_id): env = DataCollector( env, step_data_callback=ForceTruncateStepDataCallback, + record_infos=True, ) env.reset() - for _ in range(num_steps): env.step(env.action_space.sample()) @@ -125,19 +133,20 @@ def test_truncation_without_reset(dataset_id, env_id): assert len(dataset.episode_indices) == num_episodes episodes_generator = dataset.iterate_episodes() - last_step = None + last_step = get_single_step_from_episode(next(episodes_generator), -1) for episode in episodes_generator: assert episode.total_timesteps == ForceTruncateStepDataCallback.episode_steps - if last_step is not None: - first_step = get_single_step_from_episode(episode, 0) - # Check that the last observation of the previous episode is carried over to the next episode - # as the reset observation. - if isinstance(first_step.observations, dict) or isinstance( - first_step.observations, tuple - ): - assert first_step.observations == last_step.observations - else: - assert np.array_equal(first_step.observations, last_step.observations) + first_step = get_single_step_from_episode(episode, 0) + # Check that the last observation of the previous episode is carried over to the next episode + # as the reset observation. + if isinstance(first_step.observations, dict) or isinstance( + first_step.observations, tuple + ): + assert first_step.observations == last_step.observations + else: + assert np.array_equal(first_step.observations, last_step.observations) + + check_infos_equal(last_step.infos, first_step.infos) last_step = get_single_step_from_episode(episode, -1) assert bool(last_step.truncations) is True diff --git a/tests/dataset/test_minari_dataset.py b/tests/dataset/test_minari_dataset.py index db7ba70e..26099c27 100644 --- a/tests/dataset/test_minari_dataset.py +++ b/tests/dataset/test_minari_dataset.py @@ -44,6 +44,7 @@ def test_episode_data(space: gym.Space): rewards=rewards, terminations=terminations, truncations=truncations, + infos={}, ) pattern = r"EpisodeData\(" @@ -55,6 +56,7 @@ def test_episode_data(space: gym.Space): pattern += r"rewards=.+, " pattern += r"terminations=.+, " pattern += r"truncations=.+" + pattern += r"infos=.+" pattern += r"\)" assert re.fullmatch(pattern, repr(episode_data)) diff --git a/tests/utils/test_dataset_creation.py b/tests/utils/test_dataset_creation.py index b1f878c0..6dfc7afc 100644 --- a/tests/utils/test_dataset_creation.py +++ b/tests/utils/test_dataset_creation.py @@ -11,6 +11,7 @@ check_data_integrity, check_env_recovery, check_env_recovery_with_subset_spaces, + check_episode_data_integrity, check_load_and_delete_dataset, get_sample_buffer_for_dataset_from_env, register_dummy_envs, @@ -37,7 +38,7 @@ def test_generate_dataset_with_collector_env(dataset_id, env_id): """Test DataCollector wrapper and Minari dataset creation.""" env = gym.make(env_id) - env = DataCollector(env) + env = DataCollector(env, record_infos=True) num_episodes = 10 # Step the environment, DataCollector wrapper will do the data collection job @@ -83,6 +84,9 @@ def test_generate_dataset_with_collector_env(dataset_id, env_id): assert len(dataset.episode_indices) == num_episodes check_data_integrity(dataset.storage, dataset.episode_indices) + check_episode_data_integrity( + dataset, dataset.spec.observation_space, dataset.spec.action_space + ) # check that the environment can be recovered from the dataset check_env_recovery(env.env, dataset, eval_env) @@ -93,6 +97,67 @@ def test_generate_dataset_with_collector_env(dataset_id, env_id): check_load_and_delete_dataset(dataset_id) +@pytest.mark.parametrize( + "info_override", + [ + None, {}, {"foo": np.ones((10, 10), dtype=np.float32)}, + {"int": 1}, {"bool": False}, + { + "value1": True, + "value2": 5, + "value3": { + "nested1": False, + "nested2": np.empty(10) + } + }, + ], +) +def test_record_infos_collector_env(info_override): + """Test DataCollector wrapper and Minari dataset creation including infos.""" + dataset_id = "dummy-mutable-info-box-test-v0" + env = gym.make("DummyInfoEnv-v0", info=info_override) + + env = DataCollector(env, record_infos=True) + num_episodes = 10 + + _, info_sample = env.reset(seed=42) + + for episode in range(num_episodes): + terminated = False + truncated = False + while not terminated and not truncated: + action = env.action_space.sample() + _, _, terminated, truncated, _ = env.step(action) + + env.reset() + + dataset = minari.create_dataset_from_collector_env( + dataset_id=dataset_id, + collector_env=env, + algorithm_name="random_policy", + code_permalink=CODELINK, + author="WillDudley", + author_email="wdudley@farama.org", + ) + + assert isinstance(dataset, MinariDataset) + assert dataset.total_episodes == num_episodes + assert dataset.spec.total_episodes == num_episodes + assert len(dataset.episode_indices) == num_episodes + + check_data_integrity(dataset.storage, dataset.episode_indices) + check_episode_data_integrity( + dataset, + dataset.spec.observation_space, + dataset.spec.action_space, + info_sample=info_sample, + ) + + env.close() + + check_load_and_delete_dataset(dataset_id) + + @pytest.mark.parametrize( "dataset_id,env_id", [ @@ -186,6 +251,7 @@ def test_generate_dataset_with_external_buffer(dataset_id, env_id): assert len(dataset.episode_indices) == num_episodes check_data_integrity(dataset.storage, dataset.episode_indices) + check_episode_data_integrity(dataset, dataset.spec.observation_space, dataset.spec.action_space) check_env_recovery(env, dataset, eval_env) check_load_and_delete_dataset(dataset_id) @@ -254,6 +320,7 @@ def test_generate_dataset_with_space_subset_external_buffer(is_env_needed): assert len(dataset.episode_indices) == num_episodes check_data_integrity(dataset.storage, dataset.episode_indices) + check_episode_data_integrity(dataset, dataset.spec.observation_space, dataset.spec.action_space) if is_env_needed: check_env_recovery_with_subset_spaces( env, dataset, action_space_subset, observation_space_subset