Skip to content

Commit

Permalink
Reorganize optimizer tests
Browse files Browse the repository at this point in the history
  • Loading branch information
takuseno committed Nov 3, 2024
1 parent f3c5540 commit bb98403
Show file tree
Hide file tree
Showing 3 changed files with 44 additions and 0 deletions.
Empty file added tests/optimizers/__init__.py
Empty file.
44 changes: 44 additions & 0 deletions tests/optimizers/test_lr_schedulers.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
import numpy as np
import pytest
import torch
from torch.optim import SGD
from torch.optim.lr_scheduler import CosineAnnealingLR, LambdaLR

from d3rlpy.optimizers.lr_schedulers import (
CosineAnnealingLRFactory,
WarmupSchedulerFactory,
)


@pytest.mark.parametrize("warmup_steps", [100])
@pytest.mark.parametrize("lr", [1e-4])
@pytest.mark.parametrize("module", [torch.nn.Linear(2, 3)])
def test_warmup_scheduler_factory(
warmup_steps: int, lr: float, module: torch.nn.Module
) -> None:
factory = WarmupSchedulerFactory(warmup_steps)

lr_scheduler = factory.create(SGD(module.parameters(), lr=lr))

assert np.allclose(lr_scheduler.get_lr()[0], lr / warmup_steps)
for _ in range(warmup_steps):
lr_scheduler.step()
assert lr_scheduler.get_lr()[0] == lr

assert isinstance(lr_scheduler, LambdaLR)

# check serialization and deserialization
WarmupSchedulerFactory.deserialize(factory.serialize())


@pytest.mark.parametrize("T_max", [100])
@pytest.mark.parametrize("module", [torch.nn.Linear(2, 3)])
def test_cosine_annealing_factory(T_max: int, module: torch.nn.Module) -> None:
factory = CosineAnnealingLRFactory(T_max=T_max)

lr_scheduler = factory.create(SGD(module.parameters()))

assert isinstance(lr_scheduler, CosineAnnealingLR)

# check serialization and deserialization
CosineAnnealingLRFactory.deserialize(factory.serialize())
File renamed without changes.

0 comments on commit bb98403

Please sign in to comment.