From d324094fd0cd3bcc5b4a9f7ba0bf50d0a70b0d73 Mon Sep 17 00:00:00 2001 From: takuseno Date: Sat, 19 Oct 2024 18:05:46 +0900 Subject: [PATCH] Implement LRSchedulerFactory --- .../algos/transformer/decision_transformer.py | 7 -- .../torch/decision_transformer_impl.py | 19 ---- d3rlpy/models/__init__.py | 1 + d3rlpy/models/lr_schedulers.py | 94 +++++++++++++++++++ d3rlpy/models/optimizers.py | 30 +++++- examples/lr_scheduler.py | 59 ++++++++++++ reproductions/finetuning/iql_finetune.py | 20 ++-- reproductions/offline/decision_transformer.py | 7 +- reproductions/offline/iql.py | 19 +--- tests/models/test_lr_schedulers.py | 33 +++++++ 10 files changed, 232 insertions(+), 57 deletions(-) create mode 100644 d3rlpy/models/lr_schedulers.py create mode 100644 examples/lr_scheduler.py create mode 100644 tests/models/test_lr_schedulers.py diff --git a/d3rlpy/algos/transformer/decision_transformer.py b/d3rlpy/algos/transformer/decision_transformer.py index a6cec2ba..bbfc82e9 100644 --- a/d3rlpy/algos/transformer/decision_transformer.py +++ b/d3rlpy/algos/transformer/decision_transformer.py @@ -63,7 +63,6 @@ class DecisionTransformerConfig(TransformerConfig): activation_type (str): Type of activation function. position_encoding_type (d3rlpy.PositionEncodingType): Type of positional encoding (``SIMPLE`` or ``GLOBAL``). - warmup_steps (int): Warmup steps for learning rate scheduler. compile (bool): (experimental) Flag to enable JIT compilation. """ @@ -78,7 +77,6 @@ class DecisionTransformerConfig(TransformerConfig): embed_dropout: float = 0.1 activation_type: str = "relu" position_encoding_type: PositionEncodingType = PositionEncodingType.SIMPLE - warmup_steps: int = 10000 compile: bool = False def create( @@ -116,10 +114,6 @@ def inner_create_impl( optim = self._config.optim_factory.create( transformer.named_modules(), lr=self._config.learning_rate ) - scheduler = torch.optim.lr_scheduler.LambdaLR( - optim.optim, - lambda steps: min((steps + 1) / self._config.warmup_steps, 1), - ) # JIT compile if self._config.compile: @@ -134,7 +128,6 @@ def inner_create_impl( observation_shape=observation_shape, action_size=action_size, modules=modules, - scheduler=scheduler, device=self._device, ) diff --git a/d3rlpy/algos/transformer/torch/decision_transformer_impl.py b/d3rlpy/algos/transformer/torch/decision_transformer_impl.py index 23b7b5c0..0020d560 100644 --- a/d3rlpy/algos/transformer/torch/decision_transformer_impl.py +++ b/d3rlpy/algos/transformer/torch/decision_transformer_impl.py @@ -31,23 +31,6 @@ class DecisionTransformerModules(Modules): class DecisionTransformerImpl(TransformerAlgoImplBase): _modules: DecisionTransformerModules - _scheduler: torch.optim.lr_scheduler.LRScheduler - - def __init__( - self, - observation_shape: Shape, - action_size: int, - modules: DecisionTransformerModules, - scheduler: torch.optim.lr_scheduler.LRScheduler, - device: str, - ): - super().__init__( - observation_shape=observation_shape, - action_size=action_size, - modules=modules, - device=device, - ) - self._scheduler = scheduler def inner_predict(self, inpt: TorchTransformerInput) -> torch.Tensor: # (1, T, A) @@ -64,8 +47,6 @@ def inner_update( loss = self.compute_loss(batch) loss.backward() self._modules.optim.step(grad_step) - self._scheduler.step() - return {"loss": float(loss.cpu().detach().numpy())} def compute_loss(self, batch: TorchTrajectoryMiniBatch) -> torch.Tensor: diff --git a/d3rlpy/models/__init__.py b/d3rlpy/models/__init__.py index 2619768c..af03cc3c 100644 --- a/d3rlpy/models/__init__.py +++ b/d3rlpy/models/__init__.py @@ -1,4 +1,5 @@ from .builders import * from .encoders import * +from .lr_schedulers import * from .optimizers import * from .q_functions import * diff --git a/d3rlpy/models/lr_schedulers.py b/d3rlpy/models/lr_schedulers.py new file mode 100644 index 00000000..519ff212 --- /dev/null +++ b/d3rlpy/models/lr_schedulers.py @@ -0,0 +1,94 @@ +import dataclasses + +from torch.optim import Optimizer +from torch.optim.lr_scheduler import CosineAnnealingLR, LambdaLR, LRScheduler + +from ..serializable_config import ( + DynamicConfig, + generate_optional_config_generation, +) + +__all__ = [ + "LRSchedulerFactory", + "WarmupSchedulerFactory", + "CosineAnnealingLRFactory", + "make_lr_scheduler_field", +] + + +@dataclasses.dataclass() +class LRSchedulerFactory(DynamicConfig): + """A factory class that creates a learning rate scheduler a lazy way.""" + + def create(self, optim: Optimizer) -> LRScheduler: + """Returns a learning rate scheduler object. + + Args: + optim: PyTorch optimizer. + + Returns: + Learning rate scheduler. + """ + raise NotImplementedError + + +@dataclasses.dataclass() +class WarmupSchedulerFactory(LRSchedulerFactory): + r"""A warmup learning rate scheduler. + + .. math:: + + lr = \max((t + 1) / warmup_steps, 1) + + Args: + warmup_steps: Warmup steps. + """ + + warmup_steps: int + + def create(self, optim: Optimizer) -> LRScheduler: + return LambdaLR( + optim, + lambda steps: min((steps + 1) / self.warmup_steps, 1), + ) + + @staticmethod + def get_type() -> str: + return "warmup" + + +@dataclasses.dataclass() +class CosineAnnealingLRFactory(LRSchedulerFactory): + """A cosine annealing learning rate scheduler. + + Args: + T_max: Maximum time step. + eta_min: Minimum learning rate. + last_epoch: Last epoch. + """ + + T_max: int + eta_min: float = 0.0 + last_epoch: int = -1 + + def create(self, optim: Optimizer) -> LRScheduler: + return CosineAnnealingLR( + optim, + T_max=self.T_max, + eta_min=self.eta_min, + last_epoch=self.last_epoch, + ) + + @staticmethod + def get_type() -> str: + return "cosine_annealing" + + +register_lr_scheduler_factory, make_lr_scheduler_field = ( + generate_optional_config_generation( + LRSchedulerFactory, + ) +) + +register_lr_scheduler_factory(WarmupSchedulerFactory) +register_lr_scheduler_factory(CosineAnnealingLRFactory) diff --git a/d3rlpy/models/optimizers.py b/d3rlpy/models/optimizers.py index a020bb39..82a9bbf6 100644 --- a/d3rlpy/models/optimizers.py +++ b/d3rlpy/models/optimizers.py @@ -3,8 +3,10 @@ from torch import nn from torch.optim import SGD, Adam, AdamW, Optimizer, RMSprop +from torch.optim.lr_scheduler import LRScheduler from ..serializable_config import DynamicConfig, generate_config_registration +from .lr_schedulers import LRSchedulerFactory, make_lr_scheduler_field __all__ = [ "OptimizerWrapper", @@ -46,16 +48,19 @@ class OptimizerWrapper: _params: Sequence[nn.Parameter] _optim: Optimizer _clip_grad_norm: Optional[float] + _lr_scheduler: Optional[LRScheduler] def __init__( self, params: Sequence[nn.Parameter], optim: Optimizer, clip_grad_norm: Optional[float] = None, + lr_scheduler: Optional[LRScheduler] = None, ): self._params = params self._optim = optim self._clip_grad_norm = clip_grad_norm + self._lr_scheduler = lr_scheduler def zero_grad(self) -> None: self._optim.zero_grad() @@ -67,12 +72,19 @@ def step(self, grad_step: int) -> None: grad_step: Total gradient step. This can be used for learning rate schedulers. """ + # clip gradients if self._clip_grad_norm: nn.utils.clip_grad_norm_( self._params, max_norm=self._clip_grad_norm ) + + # update parameters self._optim.step() + # schedule learning rate + if self._lr_scheduler: + self._lr_scheduler.step() + @property def optim(self) -> Optimizer: return self._optim @@ -86,6 +98,9 @@ class OptimizerFactory(DynamicConfig): """ clip_grad_norm: Optional[float] = None + lr_scheduler_factory: Optional[LRSchedulerFactory] = ( + make_lr_scheduler_field() + ) def create( self, named_modules: Iterable[Tuple[str, nn.Module]], lr: float @@ -97,7 +112,7 @@ def create( lr (float): Learning rate. Returns: - Updater: Updater object. + OptimizerWrapper object. """ named_modules = list(named_modules) params = _get_parameters_from_named_modules(named_modules) @@ -106,6 +121,11 @@ def create( params=params, optim=optim, clip_grad_norm=self.clip_grad_norm, + lr_scheduler=( + self.lr_scheduler_factory.create(optim) + if self.lr_scheduler_factory + else None + ), ) def create_optimizer( @@ -126,6 +146,7 @@ class SGDFactory(OptimizerFactory): Args: clip_grad_norm: Maximum norm value of gradients to clip. + lr_scheduler_factory: LRSchedulerFactory. momentum: momentum factor. dampening: dampening for momentum. weight_decay: weight decay (L2 penalty). @@ -166,6 +187,7 @@ class AdamFactory(OptimizerFactory): Args: clip_grad_norm: Maximum norm value of gradients to clip. + lr_scheduler_factory: LRSchedulerFactory. betas: coefficients used for computing running averages of gradient and its square. eps: term added to the denominator to improve numerical stability. @@ -206,6 +228,8 @@ class AdamWFactory(OptimizerFactory): factory = AdamWFactory(weight_decay=1e-4) Args: + clip_grad_norm: Maximum norm value of gradients to clip. + lr_scheduler_factory: LRSchedulerFactory. betas: coefficients used for computing running averages of gradient and its square. eps: term added to the denominator to improve numerical stability. @@ -246,6 +270,8 @@ class RMSpropFactory(OptimizerFactory): factory = RMSpropFactory(weight_decay=1e-4) Args: + clip_grad_norm: Maximum norm value of gradients to clip. + lr_scheduler_factory: LRSchedulerFactory. alpha: smoothing constant. eps: term added to the denominator to improve numerical stability. weight_decay: weight decay (L2 penalty). @@ -289,6 +315,8 @@ class GPTAdamWFactory(OptimizerFactory): factory = GPTAdamWFactory(weight_decay=1e-4) Args: + clip_grad_norm: Maximum norm value of gradients to clip. + lr_scheduler_factory: LRSchedulerFactory. betas: coefficients used for computing running averages of gradient and its square. eps: term added to the denominator to improve numerical stability. diff --git a/examples/lr_scheduler.py b/examples/lr_scheduler.py new file mode 100644 index 00000000..10c52373 --- /dev/null +++ b/examples/lr_scheduler.py @@ -0,0 +1,59 @@ +import argparse + +import gymnasium + +import d3rlpy + + +def main() -> None: + parser = argparse.ArgumentParser() + parser.add_argument("--env", type=str, default="Hopper-v2") + parser.add_argument("--seed", type=int, default=1) + parser.add_argument("--gpu", action="store_true") + args = parser.parse_args() + + env = gymnasium.make(args.env) + eval_env = gymnasium.make(args.env) + + # fix seed + d3rlpy.seed(args.seed) + d3rlpy.envs.seed_env(env, args.seed) + d3rlpy.envs.seed_env(eval_env, args.seed) + + # setup algorithm + sac = d3rlpy.algos.SACConfig( + batch_size=256, + actor_learning_rate=3e-4, + critic_learning_rate=3e-4, + actor_optim_factory=d3rlpy.models.AdamFactory( + # setup learning rate scheduler + lr_scheduler_factory=d3rlpy.models.WarmupSchedulerFactory( + warmup_steps=10000 + ), + ), + critic_optim_factory=d3rlpy.models.AdamFactory( + # setup learning rate scheduler + lr_scheduler_factory=d3rlpy.models.WarmupSchedulerFactory( + warmup_steps=10000 + ), + ), + temp_learning_rate=3e-4, + ).create(device=args.gpu) + + # replay buffer for experience replay + buffer = d3rlpy.dataset.create_fifo_replay_buffer(limit=1000000, env=env) + + # start training + sac.fit_online( + env, + buffer, + eval_env=eval_env, + n_steps=1000000, + n_steps_per_epoch=10000, + update_interval=1, + update_start_step=1000, + ) + + +if __name__ == "__main__": + main() diff --git a/reproductions/finetuning/iql_finetune.py b/reproductions/finetuning/iql_finetune.py index caa70e09..b90309ed 100644 --- a/reproductions/finetuning/iql_finetune.py +++ b/reproductions/finetuning/iql_finetune.py @@ -2,8 +2,6 @@ import argparse import copy -from torch.optim.lr_scheduler import CosineAnnealingLR - import d3rlpy @@ -26,6 +24,11 @@ def main() -> None: iql = d3rlpy.algos.IQLConfig( actor_learning_rate=3e-4, critic_learning_rate=3e-4, + actor_optim_factory=d3rlpy.models.AdamFactory( + lr_scheduler_factory=d3rlpy.models.CosineAnnealingLRFactory( + T_max=1000000 + ), + ), batch_size=256, weight_temp=10.0, # hyperparameter for antmaze max_weight=100.0, @@ -33,29 +36,18 @@ def main() -> None: reward_scaler=reward_scaler, ).create(device=args.gpu) - # workaround for learning scheduler - iql.build_with_dataset(dataset) - assert iql.impl - scheduler = CosineAnnealingLR( - iql.impl._modules.actor_optim.optim, # pylint: disable=protected-access - 1000000, - ) - - def callback(algo: d3rlpy.algos.IQL, epoch: int, total_step: int) -> None: - scheduler.step() - # pretraining iql.fit( dataset, n_steps=1000000, n_steps_per_epoch=100000, save_interval=10, - callback=callback, evaluators={"environment": d3rlpy.metrics.EnvironmentEvaluator(env)}, experiment_name=f"IQL_pretraining_{args.dataset}_{args.seed}", ) # reset learning rate + assert iql.impl for g in iql.impl._modules.actor_optim.optim.param_groups: g["lr"] = iql.config.actor_learning_rate diff --git a/reproductions/offline/decision_transformer.py b/reproductions/offline/decision_transformer.py index d9b014ae..144f01a5 100644 --- a/reproductions/offline/decision_transformer.py +++ b/reproductions/offline/decision_transformer.py @@ -29,7 +29,11 @@ def main() -> None: batch_size=64, learning_rate=1e-4, optim_factory=d3rlpy.models.AdamWFactory( - weight_decay=1e-4, clip_grad_norm=0.25 + weight_decay=1e-4, + clip_grad_norm=0.25, + lr_scheduler_factory=d3rlpy.models.WarmupSchedulerFactory( + warmup_steps=10000 + ), ), encoder_factory=d3rlpy.models.VectorEncoderFactory( [128], @@ -41,7 +45,6 @@ def main() -> None: context_size=20, num_heads=1, num_layers=3, - warmup_steps=10000, max_timestep=1000, ).create(device=args.gpu) diff --git a/reproductions/offline/iql.py b/reproductions/offline/iql.py index e95d048e..63c18ad1 100644 --- a/reproductions/offline/iql.py +++ b/reproductions/offline/iql.py @@ -1,7 +1,5 @@ import argparse -from torch.optim.lr_scheduler import CosineAnnealingLR - import d3rlpy @@ -25,6 +23,11 @@ def main() -> None: iql = d3rlpy.algos.IQLConfig( actor_learning_rate=3e-4, critic_learning_rate=3e-4, + actor_optim_factory=d3rlpy.models.AdamFactory( + lr_scheduler_factory=d3rlpy.models.CosineAnnealingLRFactory( + T_max=500000 + ), + ), batch_size=256, weight_temp=3.0, max_weight=100.0, @@ -32,23 +35,11 @@ def main() -> None: reward_scaler=reward_scaler, ).create(device=args.gpu) - # workaround for learning scheduler - iql.build_with_dataset(dataset) - assert iql.impl - scheduler = CosineAnnealingLR( - iql.impl._modules.actor_optim, # pylint: disable=protected-access - 500000, - ) - - def callback(algo: d3rlpy.algos.IQL, epoch: int, total_step: int) -> None: - scheduler.step() - iql.fit( dataset, n_steps=500000, n_steps_per_epoch=1000, save_interval=10, - callback=callback, evaluators={"environment": d3rlpy.metrics.EnvironmentEvaluator(env)}, experiment_name=f"IQL_{args.dataset}_{args.seed}", ) diff --git a/tests/models/test_lr_schedulers.py b/tests/models/test_lr_schedulers.py new file mode 100644 index 00000000..0f2b32f0 --- /dev/null +++ b/tests/models/test_lr_schedulers.py @@ -0,0 +1,33 @@ +import pytest +import torch +from torch.optim import SGD +from torch.optim.lr_scheduler import CosineAnnealingLR, LambdaLR + +from d3rlpy.models.lr_schedulers import ( + CosineAnnealingLRFactory, + WarmupSchedulerFactory, +) + + +@pytest.mark.parametrize("module", [torch.nn.Linear(2, 3)]) +def test_warmup_scheduler_factory(module: torch.nn.Module) -> None: + factory = WarmupSchedulerFactory(warmup_steps=1000) + + lr_scheduler = factory.create(SGD(module.parameters(), 1e-4)) + + assert isinstance(lr_scheduler, LambdaLR) + + # check serialization and deserialization + WarmupSchedulerFactory.deserialize(factory.serialize()) + + +@pytest.mark.parametrize("module", [torch.nn.Linear(2, 3)]) +def test_cosine_annealing_lr_factory(module: torch.nn.Module) -> None: + factory = CosineAnnealingLRFactory(T_max=1000) + + lr_scheduler = factory.create(SGD(module.parameters(), 1e-4)) + + assert isinstance(lr_scheduler, CosineAnnealingLR) + + # check serialization and deserialization + CosineAnnealingLRFactory.deserialize(factory.serialize())