Skip to content

Commit

Permalink
Include lr_scheduler as state_dict
Browse files Browse the repository at this point in the history
  • Loading branch information
takuseno committed Nov 3, 2024
1 parent bb98403 commit 1f2a9dc
Show file tree
Hide file tree
Showing 5 changed files with 116 additions and 12 deletions.
15 changes: 14 additions & 1 deletion d3rlpy/optimizers/optimizers.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import dataclasses
from typing import Iterable, Optional, Sequence, Tuple
from typing import Any, Iterable, Mapping, Optional, Sequence, Tuple

from torch import nn
from torch.optim import SGD, Adam, AdamW, Optimizer, RMSprop
Expand Down Expand Up @@ -93,6 +93,19 @@ def step(self) -> None:
def optim(self) -> Optimizer:
return self._optim

def state_dict(self) -> Mapping[str, Any]:
return {
"optim": self._optim.state_dict(),
"lr_scheduler": (
self._lr_scheduler.state_dict() if self._lr_scheduler else None
),
}

def load_state_dict(self, state_dict: Mapping[str, Any]) -> None:
self._optim.load_state_dict(state_dict["optim"])
if self._lr_scheduler:
self._lr_scheduler.load_state_dict(state_dict["lr_scheduler"])


@dataclasses.dataclass()
class OptimizerFactory(DynamicConfig):
Expand Down
7 changes: 2 additions & 5 deletions d3rlpy/torch_utility.py
Original file line number Diff line number Diff line change
Expand Up @@ -384,7 +384,7 @@ def __init__(
def save(self, f: BinaryIO) -> None:
# unwrap DDP
modules = {
k: unwrap_ddp_model(v) if isinstance(v, nn.Module) else v.optim
k: unwrap_ddp_model(v) if isinstance(v, nn.Module) else v
for k, v in self._modules.items()
}
states = {k: v.state_dict() for k, v in modules.items()}
Expand All @@ -393,10 +393,7 @@ def save(self, f: BinaryIO) -> None:
def load(self, f: BinaryIO) -> None:
chkpt = torch.load(f, map_location=map_location(self._device))
for k, v in self._modules.items():
if isinstance(v, nn.Module):
v.load_state_dict(chkpt[k])
else:
v.optim.load_state_dict(chkpt[k])
v.load_state_dict(chkpt[k])

@property
def modules(self) -> Dict[str, Union[nn.Module, OptimizerWrapperProto]]:
Expand Down
8 changes: 7 additions & 1 deletion d3rlpy/types.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Any, Sequence, Union
from typing import Any, Mapping, Sequence, Union

import gym
import gymnasium
Expand Down Expand Up @@ -42,3 +42,9 @@ class OptimizerWrapperProto(Protocol):
@property
def optim(self) -> Optimizer:
raise NotImplementedError

def state_dict(self) -> Mapping[str, Any]:
raise NotImplementedError

def load_state_dict(self, state_dict: Mapping[str, Any]) -> None:
raise NotImplementedError
65 changes: 65 additions & 0 deletions tests/optimizers/test_optimizers.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,82 @@
from typing import Optional

import pytest
import torch
from torch import nn
from torch.optim import SGD, Adam, AdamW, RMSprop

from d3rlpy.optimizers.lr_schedulers import (
CosineAnnealingLRFactory,
LRSchedulerFactory,
)
from d3rlpy.optimizers.optimizers import (
AdamFactory,
AdamWFactory,
GPTAdamWFactory,
OptimizerWrapper,
RMSpropFactory,
SGDFactory,
)


@pytest.mark.parametrize(
"lr_scheduler_factory", [None, CosineAnnealingLRFactory(100)]
)
@pytest.mark.parametrize("compiled", [False, True])
@pytest.mark.parametrize("clip_grad_norm", [None, 1e-4])
def test_optimizer_wrapper(
lr_scheduler_factory: Optional[LRSchedulerFactory],
compiled: bool,
clip_grad_norm: Optional[float],
) -> None:
model = nn.Linear(100, 200)
optim = SGD(model.parameters(), lr=1)
lr_scheduler = (
lr_scheduler_factory.create(optim) if lr_scheduler_factory else None
)
wrapper = OptimizerWrapper(
params=list(model.parameters()),
optim=optim,
compiled=compiled,
clip_grad_norm=clip_grad_norm,
lr_scheduler=lr_scheduler,
)

loss = model(torch.rand(1, 100)).mean()
loss.backward()

# check zero grad
wrapper.zero_grad()
if compiled:
assert model.weight.grad is None
assert model.bias.grad is None
else:
assert torch.all(model.weight.grad == 0)
assert torch.all(model.bias.grad == 0)

# check step
before_weight = torch.zeros_like(model.weight)
before_weight.copy_(model.weight)
before_bias = torch.zeros_like(model.bias)
before_bias.copy_(model.bias)
loss = model(torch.rand(1, 100)).mean()
loss.backward()
model.weight.grad.add_(1)
model.weight.grad.mul_(10000)
model.bias.grad.add_(1)
model.bias.grad.mul_(10000)

wrapper.step()
assert torch.all(model.weight != before_weight)
assert torch.all(model.bias != before_bias)

# check clip_grad_norm
if clip_grad_norm:
assert torch.norm(model.weight.grad) < 1e-4
else:
assert torch.norm(model.weight.grad) > 1e-4


@pytest.mark.parametrize("lr", [1e-4])
@pytest.mark.parametrize("module", [torch.nn.Linear(2, 3)])
def test_sgd_factory(lr: float, module: torch.nn.Module) -> None:
Expand Down
33 changes: 28 additions & 5 deletions tests/test_torch_utility.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import copy
import dataclasses
from io import BytesIO
from typing import Any, Dict, Sequence
from typing import Any, Dict, Optional, Sequence
from unittest.mock import Mock

import numpy as np
Expand All @@ -10,6 +10,10 @@

from d3rlpy.dataset import TrajectoryMiniBatch, Transition, TransitionMiniBatch
from d3rlpy.optimizers import OptimizerWrapper
from d3rlpy.optimizers.lr_schedulers import (
CosineAnnealingLRFactory,
LRSchedulerFactory,
)
from d3rlpy.torch_utility import (
GEGLU,
Checkpointer,
Expand Down Expand Up @@ -454,11 +458,22 @@ def test_torch_trajectory_mini_batch(
assert torch.all(torch_batch2.masks == torch_batch.masks)


def test_checkpointer() -> None:
@pytest.mark.parametrize(
"lr_scheduler_factory", [None, CosineAnnealingLRFactory(100)]
)
def test_checkpointer(
lr_scheduler_factory: Optional[LRSchedulerFactory],
) -> None:
fc1 = torch.nn.Linear(100, 100)
fc2 = torch.nn.Linear(100, 100)
params = list(fc1.parameters())
optim = OptimizerWrapper(params, torch.optim.Adam(params), False)
raw_optim = torch.optim.Adam(params)
lr_scheduler = (
lr_scheduler_factory.create(raw_optim) if lr_scheduler_factory else None
)
optim = OptimizerWrapper(
params, raw_optim, lr_scheduler=lr_scheduler, compiled=False
)
checkpointer = Checkpointer(
modules={"fc1": fc1, "fc2": fc2, "optim": optim}, device="cpu:0"
)
Expand All @@ -468,7 +483,7 @@ def test_checkpointer() -> None:
states = {
"fc1": fc1.state_dict(),
"fc2": fc2.state_dict(),
"optim": optim.optim.state_dict(),
"optim": optim.state_dict(),
}
torch.save(states, ref_bytes)

Expand All @@ -480,7 +495,15 @@ def test_checkpointer() -> None:
fc1_2 = torch.nn.Linear(100, 100)
fc2_2 = torch.nn.Linear(100, 100)
params_2 = list(fc1_2.parameters())
optim_2 = OptimizerWrapper(params_2, torch.optim.Adam(params_2), False)
raw_optim_2 = torch.optim.Adam(params_2)
lr_scheduler_2 = (
lr_scheduler_factory.create(raw_optim_2)
if lr_scheduler_factory
else None
)
optim_2 = OptimizerWrapper(
params_2, raw_optim_2, lr_scheduler=lr_scheduler_2, compiled=False
)
checkpointer = Checkpointer(
modules={"fc1": fc1_2, "fc2": fc2_2, "optim": optim_2}, device="cpu:0"
)
Expand Down

0 comments on commit 1f2a9dc

Please sign in to comment.