From ee53b4cb8d8586d5939f3c808672252a7bf4c34a Mon Sep 17 00:00:00 2001 From: Shreyans Jain Date: Wed, 22 Nov 2023 10:02:08 +0530 Subject: [PATCH] added dataset_size attribute to minari datasets (#158) * added dataset_size attribute to minari datasets * made changes as per review comments * changed the way dataset_size is integrated with both the methods, changed unit tests * corrected pre-commit failures * transferred the get_dataset_size method in side MinariStorage class under the name get_size * removed duplciate data_path attribute * code cleanup in test file * removed more comments --------- Co-authored-by: Shreyans Jain --- minari/cli.py | 5 + minari/dataset/minari_storage.py | 18 ++- minari/utils.py | 13 ++- tests/dataset/test_minari_storage.py | 157 +++++++++++++++++++++++++++ 4 files changed, 191 insertions(+), 2 deletions(-) diff --git a/minari/cli.py b/minari/cli.py index 1f1da582..57c57e1a 100644 --- a/minari/cli.py +++ b/minari/cli.py @@ -31,12 +31,16 @@ def _show_dataset_table(datasets, table_title): table.add_column("Name", justify="left", style="cyan", no_wrap=True) table.add_column("Total Episodes", justify="right", style="green") table.add_column("Total Steps", justify="right", style="green") + table.add_column("Dataset Size", justify="left", style="green") table.add_column("Description", justify="left", style="yellow") table.add_column("Author", justify="left", style="magenta") table.add_column("Email", justify="left", style="magenta") for dst_metadata in datasets.values(): author = dst_metadata.get("author", "Unknown") + dataset_size = dst_metadata.get("dataset_size", "Unknown") + if dataset_size != "Unknown": + dataset_size = f"{str(dataset_size)} MB" author_email = dst_metadata.get("author_email", "Unknown") assert isinstance(dst_metadata["dataset_id"], str) @@ -46,6 +50,7 @@ def _show_dataset_table(datasets, table_title): dst_metadata["dataset_id"], str(dst_metadata["total_episodes"]), str(dst_metadata["total_steps"]), + dataset_size, "Coming soon ...", author, author_email, diff --git a/minari/dataset/minari_storage.py b/minari/dataset/minari_storage.py index 02493bae..258f18cb 100644 --- a/minari/dataset/minari_storage.py +++ b/minari/dataset/minari_storage.py @@ -37,7 +37,6 @@ def __init__(self, data_path: PathLike): if not os.path.exists(file_path): raise ValueError(f"No data found in data path {data_path}") self._file_path = file_path - self._observation_space = None self._action_space = None @@ -261,6 +260,23 @@ def update_episodes(self, episodes: Iterable[dict]): file.attrs.modify("total_episodes", total_episodes) file.attrs.modify("total_steps", total_steps) + def get_size(self): + """Returns the dataset size in MB. + + Returns: + datasize (float): size of the dataset in MB + """ + datasize_list = [] + if os.path.exists(self.data_path): + + for filename in os.listdir(self.data_path): + datasize = os.path.getsize(os.path.join(self.data_path, filename)) + datasize_list.append(datasize) + + datasize = np.round(np.sum(datasize_list) / 1000000, 1) + + return datasize + def update_from_storage(self, storage: MinariStorage): """Update the dataset using another MinariStorage. diff --git a/minari/utils.py b/minari/utils.py index 7d89cdf1..40e9c942 100644 --- a/minari/utils.py +++ b/minari/utils.py @@ -555,8 +555,13 @@ def create_dataset_from_buffers( env_spec=env_spec, ) + # adding `update_metadata` before hand too, as for small envs, the absence of metadata is causing a difference of some 10ths of MBs leading to errors in unit tests. storage.update_metadata(metadata) storage.update_episodes(buffer) + + metadata['dataset_size'] = storage.get_size() + storage.update_metadata(metadata) + return MinariDataset(storage) @@ -618,7 +623,13 @@ def create_dataset_from_collector_env( ) collector_env.save_to_disk(dataset_path, metadata) - return MinariDataset(dataset_path) + + # will be able to calculate dataset size only after saving the disk, so updating the dataset metadata post `save_to_disk` method + + dataset = MinariDataset(dataset_path) + metadata['dataset_size'] = dataset.storage.get_size() + dataset.storage.update_metadata(metadata) + return dataset def get_normalized_score(dataset: MinariDataset, returns: np.ndarray) -> np.ndarray: diff --git a/tests/dataset/test_minari_storage.py b/tests/dataset/test_minari_storage.py index ed059197..4690c0ac 100644 --- a/tests/dataset/test_minari_storage.py +++ b/tests/dataset/test_minari_storage.py @@ -1,8 +1,24 @@ +import copy +import os + +import gymnasium as gym import numpy as np import pytest from gymnasium import spaces +import minari +from minari import DataCollectorV0 from minari.dataset.minari_storage import MinariStorage +from tests.common import ( + check_data_integrity, + check_load_and_delete_dataset, + register_dummy_envs, +) + + +register_dummy_envs() + +file_path = os.path.join(os.path.expanduser("~"), ".minari", "datasets") def _generate_episode_dict( @@ -170,3 +186,144 @@ def test_episode_metadata(tmp_dataset_dir): ep_indices = [1, 4, 5] storage.update_episode_metadata(ep_metadatas, episode_indices=ep_indices) + + +@pytest.mark.parametrize( + "dataset_id,env_id", + [ + ("cartpole-test-v0", "CartPole-v1"), + ("dummy-dict-test-v0", "DummyDictEnv-v0"), + ("dummy-box-test-v0", "DummyBoxEnv-v0"), + ("dummy-tuple-test-v0", "DummyTupleEnv-v0"), + ("dummy-combo-test-v0", "DummyComboEnv-v0"), + ("dummy-tuple-discrete-box-test-v0", "DummyTupleDiscreteBoxEnv-v0"), + ], +) +def test_minari_get_dataset_size_from_collector_env(dataset_id, env_id): + """Test get_dataset_size method for dataset made using create_dataset_from_collector_env method.""" + # delete the test dataset if it already exists + local_datasets = minari.list_local_datasets() + if dataset_id in local_datasets: + minari.delete_dataset(dataset_id) + + env = gym.make(env_id) + + env = DataCollectorV0(env) + num_episodes = 100 + + # Step the environment, DataCollectorV0 wrapper will do the data collection job + env.reset(seed=42) + + for episode in range(num_episodes): + done = False + while not done: + action = env.action_space.sample() # User-defined policy function + _, _, terminated, truncated, _ = env.step(action) + done = terminated or truncated + + env.reset() + + # Create Minari dataset and store locally + dataset = minari.create_dataset_from_collector_env( + dataset_id=dataset_id, + collector_env=env, + algorithm_name="random_policy", + code_permalink="https://github.com/Farama-Foundation/Minari/blob/f095bfe07f8dc6642082599e07779ec1dd9b2667/tutorials/LocalStorage/local_storage.py", + author="WillDudley", + author_email="wdudley@farama.org", + ) + + assert dataset.storage.metadata['dataset_size'] == dataset.storage.get_size() + + check_data_integrity(dataset.storage, dataset.episode_indices) + + env.close() + + check_load_and_delete_dataset(dataset_id) + + +@pytest.mark.parametrize( + "dataset_id,env_id", + [ + ("cartpole-test-v0", "CartPole-v1"), + ("dummy-dict-test-v0", "DummyDictEnv-v0"), + ("dummy-box-test-v0", "DummyBoxEnv-v0"), + ("dummy-tuple-test-v0", "DummyTupleEnv-v0"), + ("dummy-text-test-v0", "DummyTextEnv-v0"), + ("dummy-combo-test-v0", "DummyComboEnv-v0"), + ("dummy-tuple-discrete-box-test-v0", "DummyTupleDiscreteBoxEnv-v0"), + ], +) +def test_minari_get_dataset_size_from_buffer(dataset_id, env_id): + """Test get_dataset_size method for dataset made using create_dataset_from_buffers method.""" + buffer = [] + + # delete the test dataset if it already exists + local_datasets = minari.list_local_datasets() + if dataset_id in local_datasets: + minari.delete_dataset(dataset_id) + + env = gym.make(env_id) + + observations = [] + actions = [] + rewards = [] + terminations = [] + truncations = [] + + num_episodes = 10 + + observation, info = env.reset(seed=42) + + # Step the environment, DataCollectorV0 wrapper will do the data collection job + observation, _ = env.reset() + observations.append(observation) + for episode in range(num_episodes): + terminated = False + truncated = False + + while not terminated and not truncated: + action = env.action_space.sample() # User-defined policy function + observation, reward, terminated, truncated, _ = env.step(action) + observations.append(observation) + actions.append(action) + rewards.append(reward) + terminations.append(terminated) + truncations.append(truncated) + + episode_buffer = { + "observations": copy.deepcopy(observations), + "actions": copy.deepcopy(actions), + "rewards": np.asarray(rewards), + "terminations": np.asarray(terminations), + "truncations": np.asarray(truncations), + } + buffer.append(episode_buffer) + + observations.clear() + actions.clear() + rewards.clear() + terminations.clear() + truncations.clear() + + observation, _ = env.reset() + observations.append(observation) + + # Create Minari dataset and store locally + dataset = minari.create_dataset_from_buffers( + dataset_id=dataset_id, + env=env, + buffer=buffer, + algorithm_name="random_policy", + code_permalink="https://github.com/Farama-Foundation/Minari/blob/f095bfe07f8dc6642082599e07779ec1dd9b2667/tutorials/LocalStorage/local_storage.py", + author="WillDudley", + author_email="wdudley@farama.org", + ) + + assert dataset.storage.metadata['dataset_size'] == dataset.storage.get_size() + + check_data_integrity(dataset.storage, dataset.episode_indices) + + env.close() + + check_load_and_delete_dataset(dataset_id)