Skip to content

Commit

Permalink
added dataset_size attribute to minari datasets (#158)
Browse files Browse the repository at this point in the history
* added dataset_size attribute to minari datasets

* made changes as per review comments

* changed the way dataset_size is integrated with both the methods, changed unit tests

* corrected pre-commit failures

* transferred the get_dataset_size method in side MinariStorage class under the name get_size

* removed duplciate data_path attribute

* code cleanup in test file

* removed more comments

---------

Co-authored-by: Shreyans Jain <[email protected]>
  • Loading branch information
shreyansjainn and Shreyans Jain authored Nov 22, 2023
1 parent ff11da1 commit ee53b4c
Show file tree
Hide file tree
Showing 4 changed files with 191 additions and 2 deletions.
5 changes: 5 additions & 0 deletions minari/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,12 +31,16 @@ def _show_dataset_table(datasets, table_title):
table.add_column("Name", justify="left", style="cyan", no_wrap=True)
table.add_column("Total Episodes", justify="right", style="green")
table.add_column("Total Steps", justify="right", style="green")
table.add_column("Dataset Size", justify="left", style="green")
table.add_column("Description", justify="left", style="yellow")
table.add_column("Author", justify="left", style="magenta")
table.add_column("Email", justify="left", style="magenta")

for dst_metadata in datasets.values():
author = dst_metadata.get("author", "Unknown")
dataset_size = dst_metadata.get("dataset_size", "Unknown")
if dataset_size != "Unknown":
dataset_size = f"{str(dataset_size)} MB"
author_email = dst_metadata.get("author_email", "Unknown")

assert isinstance(dst_metadata["dataset_id"], str)
Expand All @@ -46,6 +50,7 @@ def _show_dataset_table(datasets, table_title):
dst_metadata["dataset_id"],
str(dst_metadata["total_episodes"]),
str(dst_metadata["total_steps"]),
dataset_size,
"Coming soon ...",
author,
author_email,
Expand Down
18 changes: 17 additions & 1 deletion minari/dataset/minari_storage.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,6 @@ def __init__(self, data_path: PathLike):
if not os.path.exists(file_path):
raise ValueError(f"No data found in data path {data_path}")
self._file_path = file_path

self._observation_space = None
self._action_space = None

Expand Down Expand Up @@ -261,6 +260,23 @@ def update_episodes(self, episodes: Iterable[dict]):
file.attrs.modify("total_episodes", total_episodes)
file.attrs.modify("total_steps", total_steps)

def get_size(self):
"""Returns the dataset size in MB.
Returns:
datasize (float): size of the dataset in MB
"""
datasize_list = []
if os.path.exists(self.data_path):

for filename in os.listdir(self.data_path):
datasize = os.path.getsize(os.path.join(self.data_path, filename))
datasize_list.append(datasize)

datasize = np.round(np.sum(datasize_list) / 1000000, 1)

return datasize

def update_from_storage(self, storage: MinariStorage):
"""Update the dataset using another MinariStorage.
Expand Down
13 changes: 12 additions & 1 deletion minari/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -555,8 +555,13 @@ def create_dataset_from_buffers(
env_spec=env_spec,
)

# adding `update_metadata` before hand too, as for small envs, the absence of metadata is causing a difference of some 10ths of MBs leading to errors in unit tests.
storage.update_metadata(metadata)
storage.update_episodes(buffer)

metadata['dataset_size'] = storage.get_size()
storage.update_metadata(metadata)

return MinariDataset(storage)


Expand Down Expand Up @@ -618,7 +623,13 @@ def create_dataset_from_collector_env(
)

collector_env.save_to_disk(dataset_path, metadata)
return MinariDataset(dataset_path)

# will be able to calculate dataset size only after saving the disk, so updating the dataset metadata post `save_to_disk` method

dataset = MinariDataset(dataset_path)
metadata['dataset_size'] = dataset.storage.get_size()
dataset.storage.update_metadata(metadata)
return dataset


def get_normalized_score(dataset: MinariDataset, returns: np.ndarray) -> np.ndarray:
Expand Down
157 changes: 157 additions & 0 deletions tests/dataset/test_minari_storage.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,24 @@
import copy
import os

import gymnasium as gym
import numpy as np
import pytest
from gymnasium import spaces

import minari
from minari import DataCollectorV0
from minari.dataset.minari_storage import MinariStorage
from tests.common import (
check_data_integrity,
check_load_and_delete_dataset,
register_dummy_envs,
)


register_dummy_envs()

file_path = os.path.join(os.path.expanduser("~"), ".minari", "datasets")


def _generate_episode_dict(
Expand Down Expand Up @@ -170,3 +186,144 @@ def test_episode_metadata(tmp_dataset_dir):

ep_indices = [1, 4, 5]
storage.update_episode_metadata(ep_metadatas, episode_indices=ep_indices)


@pytest.mark.parametrize(
"dataset_id,env_id",
[
("cartpole-test-v0", "CartPole-v1"),
("dummy-dict-test-v0", "DummyDictEnv-v0"),
("dummy-box-test-v0", "DummyBoxEnv-v0"),
("dummy-tuple-test-v0", "DummyTupleEnv-v0"),
("dummy-combo-test-v0", "DummyComboEnv-v0"),
("dummy-tuple-discrete-box-test-v0", "DummyTupleDiscreteBoxEnv-v0"),
],
)
def test_minari_get_dataset_size_from_collector_env(dataset_id, env_id):
"""Test get_dataset_size method for dataset made using create_dataset_from_collector_env method."""
# delete the test dataset if it already exists
local_datasets = minari.list_local_datasets()
if dataset_id in local_datasets:
minari.delete_dataset(dataset_id)

env = gym.make(env_id)

env = DataCollectorV0(env)
num_episodes = 100

# Step the environment, DataCollectorV0 wrapper will do the data collection job
env.reset(seed=42)

for episode in range(num_episodes):
done = False
while not done:
action = env.action_space.sample() # User-defined policy function
_, _, terminated, truncated, _ = env.step(action)
done = terminated or truncated

env.reset()

# Create Minari dataset and store locally
dataset = minari.create_dataset_from_collector_env(
dataset_id=dataset_id,
collector_env=env,
algorithm_name="random_policy",
code_permalink="https://github.com/Farama-Foundation/Minari/blob/f095bfe07f8dc6642082599e07779ec1dd9b2667/tutorials/LocalStorage/local_storage.py",
author="WillDudley",
author_email="[email protected]",
)

assert dataset.storage.metadata['dataset_size'] == dataset.storage.get_size()

check_data_integrity(dataset.storage, dataset.episode_indices)

env.close()

check_load_and_delete_dataset(dataset_id)


@pytest.mark.parametrize(
"dataset_id,env_id",
[
("cartpole-test-v0", "CartPole-v1"),
("dummy-dict-test-v0", "DummyDictEnv-v0"),
("dummy-box-test-v0", "DummyBoxEnv-v0"),
("dummy-tuple-test-v0", "DummyTupleEnv-v0"),
("dummy-text-test-v0", "DummyTextEnv-v0"),
("dummy-combo-test-v0", "DummyComboEnv-v0"),
("dummy-tuple-discrete-box-test-v0", "DummyTupleDiscreteBoxEnv-v0"),
],
)
def test_minari_get_dataset_size_from_buffer(dataset_id, env_id):
"""Test get_dataset_size method for dataset made using create_dataset_from_buffers method."""
buffer = []

# delete the test dataset if it already exists
local_datasets = minari.list_local_datasets()
if dataset_id in local_datasets:
minari.delete_dataset(dataset_id)

env = gym.make(env_id)

observations = []
actions = []
rewards = []
terminations = []
truncations = []

num_episodes = 10

observation, info = env.reset(seed=42)

# Step the environment, DataCollectorV0 wrapper will do the data collection job
observation, _ = env.reset()
observations.append(observation)
for episode in range(num_episodes):
terminated = False
truncated = False

while not terminated and not truncated:
action = env.action_space.sample() # User-defined policy function
observation, reward, terminated, truncated, _ = env.step(action)
observations.append(observation)
actions.append(action)
rewards.append(reward)
terminations.append(terminated)
truncations.append(truncated)

episode_buffer = {
"observations": copy.deepcopy(observations),
"actions": copy.deepcopy(actions),
"rewards": np.asarray(rewards),
"terminations": np.asarray(terminations),
"truncations": np.asarray(truncations),
}
buffer.append(episode_buffer)

observations.clear()
actions.clear()
rewards.clear()
terminations.clear()
truncations.clear()

observation, _ = env.reset()
observations.append(observation)

# Create Minari dataset and store locally
dataset = minari.create_dataset_from_buffers(
dataset_id=dataset_id,
env=env,
buffer=buffer,
algorithm_name="random_policy",
code_permalink="https://github.com/Farama-Foundation/Minari/blob/f095bfe07f8dc6642082599e07779ec1dd9b2667/tutorials/LocalStorage/local_storage.py",
author="WillDudley",
author_email="[email protected]",
)

assert dataset.storage.metadata['dataset_size'] == dataset.storage.get_size()

check_data_integrity(dataset.storage, dataset.episode_indices)

env.close()

check_load_and_delete_dataset(dataset_id)

0 comments on commit ee53b4c

Please sign in to comment.