diff --git a/README.md b/README.md index 544a611..db7c8e2 100644 --- a/README.md +++ b/README.md @@ -105,12 +105,12 @@ wandb offline import gym from grl.algorithms.qgpo import QGPOAlgorithm -from grl.datasets import QGPOCustomizedDataset +from grl.datasets import QGPOCustomizedTensorDictDataset from grl.utils.log import log from grl_pipelines.diffusion_model.configurations.lunarlander_continuous_qgpo import config def qgpo_pipeline(config): - qgpo = QGPOAlgorithm(config, dataset=QGPOCustomizedDataset(numpy_data_path="./data.npz", action_augment_num=config.train.parameter.action_augment_num)) + qgpo = QGPOAlgorithm(config, dataset=QGPOCustomizedTensorDictDataset(numpy_data_path="./data.npz", action_augment_num=config.train.parameter.action_augment_num)) qgpo.train() agent = qgpo.deploy() diff --git a/README.zh.md b/README.zh.md index f6c99d1..f5d8628 100644 --- a/README.zh.md +++ b/README.zh.md @@ -102,12 +102,12 @@ wandb offline import gym from grl.algorithms.qgpo import QGPOAlgorithm -from grl.datasets import QGPOCustomizedDataset +from grl.datasets import QGPOCustomizedTensorDictDataset from grl.utils.log import log from grl_pipelines.diffusion_model.configurations.lunarlander_continuous_qgpo import config def qgpo_pipeline(config): - qgpo = QGPOAlgorithm(config, dataset=QGPOCustomizedDataset(numpy_data_path="./data.npz", action_augment_num=config.train.parameter.action_augment_num)) + qgpo = QGPOAlgorithm(config, dataset=QGPOCustomizedTensorDictDataset(numpy_data_path="./data.npz", action_augment_num=config.train.parameter.action_augment_num)) qgpo.train() agent = qgpo.deploy()