Skip to content

Commit

Permalink
Polish dataset
Browse files Browse the repository at this point in the history
  • Loading branch information
zjowowen committed Nov 29, 2024
1 parent a78ab24 commit 4949f6d
Show file tree
Hide file tree
Showing 3 changed files with 73 additions and 4 deletions.
4 changes: 2 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -98,12 +98,12 @@ wandb offline
import gym

from grl.algorithms.qgpo import QGPOAlgorithm
from grl.datasets import QGPOCustomizedDataset
from grl.datasets import QGPOCustomizedTensorDictDataset
from grl.utils.log import log
from grl_pipelines.diffusion_model.configurations.lunarlander_continuous_qgpo import config

def qgpo_pipeline(config):
qgpo = QGPOAlgorithm(config, dataset=QGPOCustomizedDataset(numpy_data_path="./data.npz", action_augment_num=config.train.parameter.action_augment_num))
qgpo = QGPOAlgorithm(config, dataset=QGPOCustomizedTensorDictDataset(numpy_data_path="./data.npz", action_augment_num=config.train.parameter.action_augment_num))
qgpo.train()

agent = qgpo.deploy()
Expand Down
4 changes: 2 additions & 2 deletions README.zh.md
Original file line number Diff line number Diff line change
Expand Up @@ -97,12 +97,12 @@ wandb offline
import gym

from grl.algorithms.qgpo import QGPOAlgorithm
from grl.datasets import QGPOCustomizedDataset
from grl.datasets import QGPOCustomizedTensorDictDataset
from grl.utils.log import log
from grl_pipelines.diffusion_model.configurations.lunarlander_continuous_qgpo import config

def qgpo_pipeline(config):
qgpo = QGPOAlgorithm(config, dataset=QGPOCustomizedDataset(numpy_data_path="./data.npz", action_augment_num=config.train.parameter.action_augment_num))
qgpo = QGPOAlgorithm(config, dataset=QGPOCustomizedTensorDictDataset(numpy_data_path="./data.npz", action_augment_num=config.train.parameter.action_augment_num))
qgpo.train()

agent = qgpo.deploy()
Expand Down
69 changes: 69 additions & 0 deletions grl/datasets/gp.py
Original file line number Diff line number Diff line change
Expand Up @@ -870,3 +870,72 @@ def __getitem__(self, index):

def __len__(self):
return self.len

class GPAtariVisualTensorDictDataset(torch.utils.data.Dataset):
def __init__(
self,
max_size: int = 1000000,
):

self.len = 0
self.storage = LazyMemmapStorage(max_size=max_size)
self.episode_counter = 0

def extend_data(self, episode_data: List):
# concatenate the data into the dataset

# collate the data by sorting the keys

keys = ["obs", "action", "done", "next_obs", "reward"]

collated_data = {
keys[i]: episode_data[i] for i in range(len(keys))
}

len_after_extend = self.len + collated_data["obs"].shape[0]
self.storage.set(
range(self.len, len_after_extend),
TensorDict(
{
"s": collated_data["obs"],
"a": collated_data["action"],
"r": collated_data["reward"],
"s_": collated_data["next_obs"],
"d": collated_data["done"],
"episode": torch.tensor([self.episode_counter] * collated_data["obs"].shape[0]),
"step": torch.arange(collated_data["obs"].shape[0]),
},
batch_size=[collated_data["obs"].shape[0]],
),
)
self.len = len_after_extend
self.episode_counter += 1
log.debug(f"{collated_data['obs'].shape[0]} data loaded in GPOnlineDataset")


def __getitem__(self, index):
"""
Overview:
Get data by index
Arguments:
index (:obj:`int`): Index of data
Returns:
data (:obj:`dict`): Data dict
.. note::
The data dict contains the following keys:
s (:obj:`torch.Tensor`): State
a (:obj:`torch.Tensor`): Action
r (:obj:`torch.Tensor`): Reward
s_ (:obj:`torch.Tensor`): Next state
d (:obj:`torch.Tensor`): Is finished
episode (:obj:`torch.Tensor`): Episode index
"""

data = self.storage.get(index=index)
return data

def __len__(self):
return self.len

0 comments on commit 4949f6d

Please sign in to comment.