Skip to content

Commit

Permalink
Implement LRSchedulerFactory
Browse files Browse the repository at this point in the history
  • Loading branch information
takuseno committed Oct 19, 2024
1 parent 677802d commit d324094
Show file tree
Hide file tree
Showing 10 changed files with 232 additions and 57 deletions.
7 changes: 0 additions & 7 deletions d3rlpy/algos/transformer/decision_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,6 @@ class DecisionTransformerConfig(TransformerConfig):
activation_type (str): Type of activation function.
position_encoding_type (d3rlpy.PositionEncodingType):
Type of positional encoding (``SIMPLE`` or ``GLOBAL``).
warmup_steps (int): Warmup steps for learning rate scheduler.
compile (bool): (experimental) Flag to enable JIT compilation.
"""

Expand All @@ -78,7 +77,6 @@ class DecisionTransformerConfig(TransformerConfig):
embed_dropout: float = 0.1
activation_type: str = "relu"
position_encoding_type: PositionEncodingType = PositionEncodingType.SIMPLE
warmup_steps: int = 10000
compile: bool = False

def create(
Expand Down Expand Up @@ -116,10 +114,6 @@ def inner_create_impl(
optim = self._config.optim_factory.create(
transformer.named_modules(), lr=self._config.learning_rate
)
scheduler = torch.optim.lr_scheduler.LambdaLR(
optim.optim,
lambda steps: min((steps + 1) / self._config.warmup_steps, 1),
)

# JIT compile
if self._config.compile:
Expand All @@ -134,7 +128,6 @@ def inner_create_impl(
observation_shape=observation_shape,
action_size=action_size,
modules=modules,
scheduler=scheduler,
device=self._device,
)

Expand Down
19 changes: 0 additions & 19 deletions d3rlpy/algos/transformer/torch/decision_transformer_impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,23 +31,6 @@ class DecisionTransformerModules(Modules):

class DecisionTransformerImpl(TransformerAlgoImplBase):
_modules: DecisionTransformerModules
_scheduler: torch.optim.lr_scheduler.LRScheduler

def __init__(
self,
observation_shape: Shape,
action_size: int,
modules: DecisionTransformerModules,
scheduler: torch.optim.lr_scheduler.LRScheduler,
device: str,
):
super().__init__(
observation_shape=observation_shape,
action_size=action_size,
modules=modules,
device=device,
)
self._scheduler = scheduler

def inner_predict(self, inpt: TorchTransformerInput) -> torch.Tensor:
# (1, T, A)
Expand All @@ -64,8 +47,6 @@ def inner_update(
loss = self.compute_loss(batch)
loss.backward()
self._modules.optim.step(grad_step)
self._scheduler.step()

return {"loss": float(loss.cpu().detach().numpy())}

def compute_loss(self, batch: TorchTrajectoryMiniBatch) -> torch.Tensor:
Expand Down
1 change: 1 addition & 0 deletions d3rlpy/models/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from .builders import *
from .encoders import *
from .lr_schedulers import *
from .optimizers import *
from .q_functions import *
94 changes: 94 additions & 0 deletions d3rlpy/models/lr_schedulers.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,94 @@
import dataclasses

from torch.optim import Optimizer
from torch.optim.lr_scheduler import CosineAnnealingLR, LambdaLR, LRScheduler

from ..serializable_config import (
DynamicConfig,
generate_optional_config_generation,
)

__all__ = [
"LRSchedulerFactory",
"WarmupSchedulerFactory",
"CosineAnnealingLRFactory",
"make_lr_scheduler_field",
]


@dataclasses.dataclass()
class LRSchedulerFactory(DynamicConfig):
"""A factory class that creates a learning rate scheduler a lazy way."""

def create(self, optim: Optimizer) -> LRScheduler:
"""Returns a learning rate scheduler object.
Args:
optim: PyTorch optimizer.
Returns:
Learning rate scheduler.
"""
raise NotImplementedError


@dataclasses.dataclass()
class WarmupSchedulerFactory(LRSchedulerFactory):
r"""A warmup learning rate scheduler.
.. math::
lr = \max((t + 1) / warmup_steps, 1)
Args:
warmup_steps: Warmup steps.
"""

warmup_steps: int

def create(self, optim: Optimizer) -> LRScheduler:
return LambdaLR(
optim,
lambda steps: min((steps + 1) / self.warmup_steps, 1),
)

@staticmethod
def get_type() -> str:
return "warmup"


@dataclasses.dataclass()
class CosineAnnealingLRFactory(LRSchedulerFactory):
"""A cosine annealing learning rate scheduler.
Args:
T_max: Maximum time step.
eta_min: Minimum learning rate.
last_epoch: Last epoch.
"""

T_max: int
eta_min: float = 0.0
last_epoch: int = -1

def create(self, optim: Optimizer) -> LRScheduler:
return CosineAnnealingLR(
optim,
T_max=self.T_max,
eta_min=self.eta_min,
last_epoch=self.last_epoch,
)

@staticmethod
def get_type() -> str:
return "cosine_annealing"


register_lr_scheduler_factory, make_lr_scheduler_field = (
generate_optional_config_generation(
LRSchedulerFactory,
)
)

register_lr_scheduler_factory(WarmupSchedulerFactory)
register_lr_scheduler_factory(CosineAnnealingLRFactory)
30 changes: 29 additions & 1 deletion d3rlpy/models/optimizers.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,10 @@

from torch import nn
from torch.optim import SGD, Adam, AdamW, Optimizer, RMSprop
from torch.optim.lr_scheduler import LRScheduler

from ..serializable_config import DynamicConfig, generate_config_registration
from .lr_schedulers import LRSchedulerFactory, make_lr_scheduler_field

__all__ = [
"OptimizerWrapper",
Expand Down Expand Up @@ -46,16 +48,19 @@ class OptimizerWrapper:
_params: Sequence[nn.Parameter]
_optim: Optimizer
_clip_grad_norm: Optional[float]
_lr_scheduler: Optional[LRScheduler]

def __init__(
self,
params: Sequence[nn.Parameter],
optim: Optimizer,
clip_grad_norm: Optional[float] = None,
lr_scheduler: Optional[LRScheduler] = None,
):
self._params = params
self._optim = optim
self._clip_grad_norm = clip_grad_norm
self._lr_scheduler = lr_scheduler

def zero_grad(self) -> None:
self._optim.zero_grad()
Expand All @@ -67,12 +72,19 @@ def step(self, grad_step: int) -> None:
grad_step: Total gradient step. This can be used for learning rate
schedulers.
"""
# clip gradients
if self._clip_grad_norm:
nn.utils.clip_grad_norm_(
self._params, max_norm=self._clip_grad_norm
)

# update parameters
self._optim.step()

# schedule learning rate
if self._lr_scheduler:
self._lr_scheduler.step()

@property
def optim(self) -> Optimizer:
return self._optim
Expand All @@ -86,6 +98,9 @@ class OptimizerFactory(DynamicConfig):
"""

clip_grad_norm: Optional[float] = None
lr_scheduler_factory: Optional[LRSchedulerFactory] = (
make_lr_scheduler_field()
)

def create(
self, named_modules: Iterable[Tuple[str, nn.Module]], lr: float
Expand All @@ -97,7 +112,7 @@ def create(
lr (float): Learning rate.
Returns:
Updater: Updater object.
OptimizerWrapper object.
"""
named_modules = list(named_modules)
params = _get_parameters_from_named_modules(named_modules)
Expand All @@ -106,6 +121,11 @@ def create(
params=params,
optim=optim,
clip_grad_norm=self.clip_grad_norm,
lr_scheduler=(
self.lr_scheduler_factory.create(optim)
if self.lr_scheduler_factory
else None
),
)

def create_optimizer(
Expand All @@ -126,6 +146,7 @@ class SGDFactory(OptimizerFactory):
Args:
clip_grad_norm: Maximum norm value of gradients to clip.
lr_scheduler_factory: LRSchedulerFactory.
momentum: momentum factor.
dampening: dampening for momentum.
weight_decay: weight decay (L2 penalty).
Expand Down Expand Up @@ -166,6 +187,7 @@ class AdamFactory(OptimizerFactory):
Args:
clip_grad_norm: Maximum norm value of gradients to clip.
lr_scheduler_factory: LRSchedulerFactory.
betas: coefficients used for computing running averages of
gradient and its square.
eps: term added to the denominator to improve numerical stability.
Expand Down Expand Up @@ -206,6 +228,8 @@ class AdamWFactory(OptimizerFactory):
factory = AdamWFactory(weight_decay=1e-4)
Args:
clip_grad_norm: Maximum norm value of gradients to clip.
lr_scheduler_factory: LRSchedulerFactory.
betas: coefficients used for computing running averages of
gradient and its square.
eps: term added to the denominator to improve numerical stability.
Expand Down Expand Up @@ -246,6 +270,8 @@ class RMSpropFactory(OptimizerFactory):
factory = RMSpropFactory(weight_decay=1e-4)
Args:
clip_grad_norm: Maximum norm value of gradients to clip.
lr_scheduler_factory: LRSchedulerFactory.
alpha: smoothing constant.
eps: term added to the denominator to improve numerical stability.
weight_decay: weight decay (L2 penalty).
Expand Down Expand Up @@ -289,6 +315,8 @@ class GPTAdamWFactory(OptimizerFactory):
factory = GPTAdamWFactory(weight_decay=1e-4)
Args:
clip_grad_norm: Maximum norm value of gradients to clip.
lr_scheduler_factory: LRSchedulerFactory.
betas: coefficients used for computing running averages of
gradient and its square.
eps: term added to the denominator to improve numerical stability.
Expand Down
59 changes: 59 additions & 0 deletions examples/lr_scheduler.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
import argparse

import gymnasium

import d3rlpy


def main() -> None:
parser = argparse.ArgumentParser()
parser.add_argument("--env", type=str, default="Hopper-v2")
parser.add_argument("--seed", type=int, default=1)
parser.add_argument("--gpu", action="store_true")
args = parser.parse_args()

env = gymnasium.make(args.env)
eval_env = gymnasium.make(args.env)

# fix seed
d3rlpy.seed(args.seed)
d3rlpy.envs.seed_env(env, args.seed)
d3rlpy.envs.seed_env(eval_env, args.seed)

# setup algorithm
sac = d3rlpy.algos.SACConfig(
batch_size=256,
actor_learning_rate=3e-4,
critic_learning_rate=3e-4,
actor_optim_factory=d3rlpy.models.AdamFactory(
# setup learning rate scheduler
lr_scheduler_factory=d3rlpy.models.WarmupSchedulerFactory(
warmup_steps=10000
),
),
critic_optim_factory=d3rlpy.models.AdamFactory(
# setup learning rate scheduler
lr_scheduler_factory=d3rlpy.models.WarmupSchedulerFactory(
warmup_steps=10000
),
),
temp_learning_rate=3e-4,
).create(device=args.gpu)

# replay buffer for experience replay
buffer = d3rlpy.dataset.create_fifo_replay_buffer(limit=1000000, env=env)

# start training
sac.fit_online(
env,
buffer,
eval_env=eval_env,
n_steps=1000000,
n_steps_per_epoch=10000,
update_interval=1,
update_start_step=1000,
)


if __name__ == "__main__":
main()
Loading

0 comments on commit d324094

Please sign in to comment.