Skip to content

Commit

Permalink
Add TorchRL replaybuffer; Polish QGPO/GMPO/GMPG APIs.
Browse files Browse the repository at this point in the history
  • Loading branch information
zjowowen committed Sep 2, 2024
1 parent d97a2fb commit 9c603df
Show file tree
Hide file tree
Showing 111 changed files with 7,748 additions and 329 deletions.
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,7 @@ from grl.utils.log import log
from grl_pipelines.diffusion_model.configurations.lunarlander_continuous_qgpo import config

def qgpo_pipeline(config):
qgpo = QGPOAlgorithm(config, dataset=QGPOCustomizedDataset(numpy_data_path="./data.npz",))
qgpo = QGPOAlgorithm(config, dataset=QGPOCustomizedDataset(numpy_data_path="./data.npz", action_augment_num=config.train.parameter.action_augment_num))
qgpo.train()

agent = qgpo.deploy()
Expand Down
2 changes: 1 addition & 1 deletion README.zh.md
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,7 @@ from grl.utils.log import log
from grl_pipelines.diffusion_model.configurations.lunarlander_continuous_qgpo import config

def qgpo_pipeline(config):
qgpo = QGPOAlgorithm(config, dataset=QGPOCustomizedDataset(numpy_data_path="./data.npz",))
qgpo = QGPOAlgorithm(config, dataset=QGPOCustomizedDataset(numpy_data_path="./data.npz", action_augment_num=config.train.parameter.action_augment_num))
qgpo.train()

agent = qgpo.deploy()
Expand Down
99 changes: 48 additions & 51 deletions grl/algorithms/gmpg.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@
from easydict import EasyDict
from rich.progress import track
from tensordict import TensorDict
from torchrl.data import TensorDictReplayBuffer
from torchrl.data.replay_buffers.samplers import SamplerWithoutReplacement

import wandb
from grl.agents.gm import GPAgent
Expand Down Expand Up @@ -759,7 +761,7 @@ def save_checkpoint(model, iteration=None, model_type=False):
else:
raise NotImplementedError

def generate_fake_action(model, states, sample_per_state):
def generate_fake_action(model, states, action_augment_num):

fake_actions_sampled = []
for states in track(
Expand All @@ -769,7 +771,7 @@ def generate_fake_action(model, states, sample_per_state):

fake_actions_ = model.behaviour_policy_sample(
state=states,
batch_size=sample_per_state,
batch_size=action_augment_num,
t_span=(
torch.linspace(0.0, 1.0, config.parameter.t_span).to(
states.device
Expand All @@ -788,11 +790,22 @@ def evaluate(model, train_epoch, repeat=1):
evaluation_results = dict()

def policy(obs: np.ndarray) -> np.ndarray:
obs = torch.tensor(
obs,
dtype=torch.float32,
device=config.model.GPPolicy.device,
).unsqueeze(0)
if isinstance(obs, torch.Tensor):
obs = torch.tensor(
obs,
dtype=torch.float32,
device=config.model.GPPolicy.device,
).unsqueeze(0)
elif isinstance(obs, dict):
for key in obs:
obs[key] = torch.tensor(
obs[key],
dtype=torch.float32,
device=config.model.GPPolicy.device
).unsqueeze(0)
if obs[key].dim() == 1 and obs[key].shape[0] == 1:
obs[key] = obs[key].unsqueeze(1)
obs = TensorDict(obs, batch_size=[1])
action = (
model.sample(
condition=obs,
Expand Down Expand Up @@ -855,13 +868,20 @@ def policy(obs: np.ndarray) -> np.ndarray:

# ---------------------------------------
# behavior training code ↓
# ---------------------------------------

# ---------------------------------------
behaviour_policy_optimizer = torch.optim.Adam(
self.model["GPPolicy"].base_model.model.parameters(),
lr=config.parameter.behaviour_policy.learning_rate,
)

replay_buffer=TensorDictReplayBuffer(
storage=self.dataset.storage,
batch_size=config.parameter.behaviour_policy.batch_size,
sampler=SamplerWithoutReplacement(),
prefetch=10,
pin_memory=True,
)

behaviour_policy_train_iter = 0
for epoch in track(
range(config.parameter.behaviour_policy.epochs),
Expand All @@ -870,22 +890,9 @@ def policy(obs: np.ndarray) -> np.ndarray:
if self.behaviour_policy_train_epoch >= epoch:
continue

sampler = torch.utils.data.RandomSampler(
self.dataset, replacement=False
)
data_loader = torch.utils.data.DataLoader(
self.dataset,
batch_size=config.parameter.behaviour_policy.batch_size,
shuffle=False,
sampler=sampler,
pin_memory=True,
drop_last=True,
num_workers=8,
)

counter = 1
behaviour_policy_loss_sum = 0
for data in data_loader:
for index, data in enumerate(replay_buffer):

behaviour_policy_loss = self.model[
"GPPolicy"
Expand Down Expand Up @@ -946,34 +953,29 @@ def policy(obs: np.ndarray) -> np.ndarray:
lr=config.parameter.critic.learning_rate,
)

replay_buffer=TensorDictReplayBuffer(
storage=self.dataset.storage,
batch_size=config.parameter.critic.batch_size,
sampler=SamplerWithoutReplacement(),
prefetch=10,
pin_memory=True,
)

critic_train_iter = 0
for epoch in track(
range(config.parameter.critic.epochs), description="Critic training"
):
if self.critic_train_epoch >= epoch:
continue

sampler = torch.utils.data.RandomSampler(
self.dataset, replacement=False
)
data_loader = torch.utils.data.DataLoader(
self.dataset,
batch_size=config.parameter.critic.batch_size,
shuffle=False,
sampler=sampler,
pin_memory=True,
drop_last=True,
num_workers=8,
)

counter = 1

v_loss_sum = 0.0
v_sum = 0.0
q_loss_sum = 0.0
q_sum = 0.0
q_target_sum = 0.0
for data in data_loader:
for index, data in enumerate(replay_buffer):

v_loss, next_v = self.model["GPPolicy"].critic.v_loss(
state=data["s"].to(config.model.GPPolicy.device),
Expand Down Expand Up @@ -1062,6 +1064,14 @@ def policy(obs: np.ndarray) -> np.ndarray:
lr=config.parameter.guided_policy.learning_rate,
)

replay_buffer=TensorDictReplayBuffer(
storage=self.dataset.storage,
batch_size=config.parameter.guided_policy.batch_size,
sampler=SamplerWithoutReplacement(),
prefetch=10,
pin_memory=True,
)

guided_policy_train_iter = 0
beta = config.parameter.guided_policy.beta
for epoch in track(
Expand All @@ -1072,22 +1082,9 @@ def policy(obs: np.ndarray) -> np.ndarray:
if self.guided_policy_train_epoch >= epoch:
continue

sampler = torch.utils.data.RandomSampler(
self.dataset, replacement=False
)
data_loader = torch.utils.data.DataLoader(
self.dataset,
batch_size=config.parameter.guided_policy.batch_size,
shuffle=False,
sampler=sampler,
pin_memory=True,
drop_last=True,
num_workers=8,
)

counter = 1
guided_policy_loss_sum = 0.0
for data in data_loader:
for index, data in enumerate(replay_buffer):
if config.parameter.algorithm_type == "GMPG":
(
guided_policy_loss,
Expand Down
Loading

0 comments on commit 9c603df

Please sign in to comment.