Skip to content

Commit

Permalink
Merge pull request #100 from mila-iqia/factory
Browse files Browse the repository at this point in the history
Factory
  • Loading branch information
mirkobronzi authored Nov 7, 2023
2 parents c00ec06 + e9cf628 commit cf75c1b
Show file tree
Hide file tree
Showing 8 changed files with 270 additions and 54 deletions.
4 changes: 4 additions & 0 deletions .flake8
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
[flake8]
max-line-length=100
ignore=W503,D104,D100,D401
docstring-convention=google
5 changes: 4 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -110,6 +110,9 @@ venv.bak/
# mypy
.mypy_cache/

# vim hipsters
# vim overlords
*.swp
*.swo

# vscode
.vscode
146 changes: 146 additions & 0 deletions amlrt_project/models/factory.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,146 @@
"""Optimizer and learning rate scheduler factory."""


from abc import ABC, abstractmethod
from dataclasses import MISSING, dataclass, field
from typing import Any, Dict, Iterable, Optional, Tuple

from torch.nn import Parameter
from torch.optim import SGD, Adam, Optimizer
from torch.optim.lr_scheduler import LambdaLR, ReduceLROnPlateau


class OptimFactory(ABC):
"""Base class for optimizer factories."""

@abstractmethod
def __call__(self, parameters: Iterable[Parameter]) -> Optimizer:
"""Create an optimizer."""
...


class SchedulerFactory(ABC):
"""Base class for learning rate scheduler factories."""

@abstractmethod
def __call__(self, optim: Optimizer) -> Dict[str, Any]:
"""Create a scheduler."""
...


@dataclass
class OptimizerConfigurationFactory:
"""Combine an optimizer factory and a scheduler factory.
Return the configuration Lightning requires.
Only support the usual case (one optim, one scheduler.)
"""
optim_factory: OptimFactory
scheduler_factory: Optional[SchedulerFactory] = None

def __call__(self, parameters: Iterable[Parameter]) -> Dict[str, Any]:
"""Create the optimizer and scheduler, for `parameters`."""
config = {}
optim = self.optim_factory(parameters)
config['optimizer'] = optim
if self.scheduler_factory is not None:
config['lr_scheduler'] = self.scheduler_factory(optim)
return config


@dataclass
class PlateauFactory(SchedulerFactory):
"""Reduce the learning rate when `metric` is no longer improving."""
metric: str
"""Metric to use, must be logged with Lightning."""
mode: str = "min"
"""Minimize or maximize."""
factor: float = 0.1
"""Multiply the learning rate by `factor`."""
patience: int = 10
"""Wait `patience` epoch before reducing the learning rate."""

def __call__(self, optimizer: Optimizer) -> Dict[str, Any]:
"""Create a scheduler."""
scheduler = ReduceLROnPlateau(
optimizer,
mode=self.mode,
factor=self.factor, patience=self.patience)
return dict(
scheduler=scheduler,
frequency=1,
interval='epoch',
monitor=self.metric)


@dataclass
class WarmupDecayFactory(SchedulerFactory):
r"""Increase the learning rate linearly from zero, then decay it.
With base learning rate $\tau$, step $s$, and `warmup` $w$, the linear warmup is:
$$\tau \frac{s}{w}.$$
The decay, following the warmup, is
$$\tau \gamma^{s-w},$$ where $\gamma$ is the hold rate.
"""
gamma: float
r"""Hold rate; higher value decay more slowly. Limited to $\eps \le \gamma \le 1.$"""
warmup: int
r"""Length of the linear warmup."""
eps: float = field(init=False, default=1e-16)
r"""Safety value: `gamma` must be larger than this."""

def __post_init__(self):
"""Finish initialization."""
# Clip gamma to something that make sense, just in case.
self.gamma = max(min(self.gamma, 1.0), self.eps)
# Same for warmup.
self.warmup = max(self.warmup, 0)

def __call__(self, optimizer: Optimizer) -> Dict[str, Any]:
"""Create scheduler."""

def fn(step: int) -> float:
"""Learning rate decay function."""
if step < self.warmup:
return step / self.warmup
elif step > self.warmup:
return self.gamma ** (step - self.warmup)
return 1.0

scheduler = LambdaLR(optimizer, fn)
return dict(scheduler=scheduler, frequency=1, interval='step')


@dataclass
class SGDFactory(OptimFactory):
"""Factory for SGD optimizers."""
lr: float = MISSING # Value is required.
momentum: float = 0
dampening: float = 0
weight_decay: float = 0
nesterov: bool = False

def __call__(self, parameters: Iterable[Parameter]) -> SGD:
"""Create and initialize a SGD optimizer."""
return SGD(
parameters, lr=self.lr,
momentum=self.momentum, dampening=self.dampening,
weight_decay=self.weight_decay, nesterov=self.nesterov)


@dataclass
class AdamFactory(OptimFactory):
"""Factory for ADAM optimizers."""
lr: float = 1e-3 # `MISSING` if we want to require an explicit value.
betas: Tuple[float, float] = (0.9, 0.999)
eps: float = 1e-8
weight_decay: float = 0
amsgrad: bool = True # NOTE: The pytorch default is False, for backward compatibility.

def __call__(self, parameters: Iterable[Parameter]) -> Adam:
"""Create and initialize an ADAM optimizer."""
return Adam(
parameters, lr=self.lr,
betas=self.betas, eps=self.eps,
weight_decay=self.weight_decay,
amsgrad=self.amsgrad)
67 changes: 65 additions & 2 deletions amlrt_project/models/model_loader.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,72 @@
import logging
from typing import Any, Dict, Optional, Type, Union

from amlrt_project.models.factory import (AdamFactory, OptimFactory,
PlateauFactory, SchedulerFactory,
SGDFactory, WarmupDecayFactory)
from amlrt_project.models.my_model import SimpleMLP

logger = logging.getLogger(__name__)


def load_model(hyper_params): # pragma: no cover
OPTS = {
'SGD': SGDFactory,
'sgd': SGDFactory,
'Adam': AdamFactory,
'adam': AdamFactory
}

SCHEDS = {
'Plateau': PlateauFactory,
'plateau': PlateauFactory,
'WarmupDecay': WarmupDecayFactory,
'warmupdecay': WarmupDecayFactory
}


def parse_opt_hp(hyper_params: Union[str, Dict[str, Any]]) -> OptimFactory:
"""Parse the optimizer part of the config."""
if isinstance(hyper_params, str):
algo = hyper_params
args = {}
elif isinstance(hyper_params, dict):
algo = hyper_params['algo']
args = {key: hyper_params[key] for key in hyper_params if key != 'algo'}
else:
raise TypeError(f"hyper_params should be a str or a dict, got {type(hyper_params)}")

if algo not in OPTS:
raise ValueError(f'Optimizer {algo} not supported')
else:
algo: Type[OptimFactory] = OPTS[algo]

return algo(**args)


def parse_sched_hp(
hyper_params: Optional[Union[str, Dict[str, Any]]]
) -> Optional[SchedulerFactory]:
"""Parse the scheduler part of the config."""
if hyper_params is None:
return None
elif isinstance(hyper_params, str):
algo = hyper_params
args = {}
elif isinstance(hyper_params, dict):
algo = hyper_params['algo']
args = {key: hyper_params[key] for key in hyper_params if key != 'algo'}
else:
raise TypeError(f"hyper_params should be a str or a dict, got {type(hyper_params)}")

if algo not in SCHEDS:
raise ValueError(f'Scheduler {algo} not supported')
else:
algo: Type[SchedulerFactory] = SCHEDS[algo]

return algo(**args)


def load_model(hyper_params: Dict[str, Any]): # pragma: no cover
"""Instantiate a model.
Args:
Expand All @@ -22,7 +83,9 @@ def load_model(hyper_params): # pragma: no cover
raise ValueError('architecture {} not supported'.format(architecture))
logger.info('selected architecture: {}'.format(architecture))

model = model_class(hyper_params)
optim_fact = parse_opt_hp(hyper_params.get('optimizer', 'SGD'))
sched_fact = parse_sched_hp(hyper_params.get('scheduler', None))
model = model_class(optim_fact, sched_fact, hyper_params)
logger.info('model info:\n' + str(model) + '\n')

return model
65 changes: 43 additions & 22 deletions amlrt_project/models/my_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,12 @@
import typing

import pytorch_lightning as pl
from torch import nn
from torch import FloatTensor, LongTensor, nn

from amlrt_project.models.optim import load_loss, load_optimizer
from amlrt_project.models.factory import (OptimFactory,
OptimizerConfigurationFactory,
SchedulerFactory)
from amlrt_project.models.optim import load_loss
from amlrt_project.utils.hp_utils import check_and_log_hp

logger = logging.getLogger(__name__)
Expand All @@ -13,6 +16,19 @@
class BaseModel(pl.LightningModule):
"""Base class for Pytorch Lightning model - useful to reuse the same *_step methods."""

def __init__(
self,
model: nn.Module,
loss_fn: nn.Module,
optim_fact: OptimFactory,
sched_fact: SchedulerFactory,
):
"""Initialize the LightningModule, with the actual model and loss."""
super().__init__()
self.model = model
self.loss_fn = loss_fn
self.opt_fact = OptimizerConfigurationFactory(optim_fact, sched_fact)

def configure_optimizers(self):
"""Returns the combination of optimizer(s) and learning rate scheduler(s) to train with.
Expand All @@ -23,15 +39,17 @@ def configure_optimizers(self):
See https://pytorch-lightning.readthedocs.io/en/latest/common/optimizers.html for more info
on the expected returned elements.
"""
# we use the generic loading function from the `model_loader` module, but it could be made
# a direct part of the model (useful if we want layer-dynamic optimization)
return load_optimizer(self.hparams, self)
return self.opt_fact(self.model.parameters())

def forward(self, input_data: FloatTensor) -> FloatTensor:
"""Invoke the model."""
return self.model(input_data)

def _generic_step(
self,
batch: typing.Any,
batch: typing.Tuple[FloatTensor, LongTensor],
batch_idx: int,
) -> typing.Any:
) -> FloatTensor:
"""Runs the prediction + evaluation step for training/validation/testing."""
input_data, targets = batch
preds = self(input_data) # calls the forward pass of the model
Expand Down Expand Up @@ -62,22 +80,25 @@ class SimpleMLP(BaseModel): # pragma: no cover
Inherits from the given framework's model class. This is a simple MLP model.
"""
def __init__(self, hyper_params: typing.Dict[typing.AnyStr, typing.Any]):
def __init__(
self,
optim_fact: OptimFactory,
sched_fact: SchedulerFactory,
hyper_params: typing.Dict[typing.AnyStr, typing.Any]
):
"""__init__.
Args:
optim_fact (OptimFactory): factory for the optimizer.
sched_fact (SchedulerFactory): factory for the scheduler.
hyper_params (dict): hyper parameters from the config file.
"""
super(SimpleMLP, self).__init__()

# TODO: Place this in a factory.
check_and_log_hp(['hidden_dim', 'num_classes'], hyper_params)
self.save_hyperparameters(hyper_params) # they will become available via model.hparams
num_classes = hyper_params['num_classes']
hidden_dim = hyper_params['hidden_dim']
self.loss_fn = load_loss(hyper_params) # 'load_loss' could be part of the model itself...

self.flatten = nn.Flatten()
self.mlp_layers = nn.Sequential(
num_classes: int = hyper_params['num_classes']
hidden_dim: int = hyper_params['hidden_dim']
flatten = nn.Flatten()
mlp_layers = nn.Sequential(
nn.Linear(
784, hidden_dim,
), # The input size for the linear layer is determined by the previous operations
Expand All @@ -86,9 +107,9 @@ def __init__(self, hyper_params: typing.Dict[typing.AnyStr, typing.Any]):
hidden_dim, num_classes
), # Here we get exactly num_classes logits at the output
)
model = nn.Sequential(flatten, mlp_layers)

def forward(self, x):
"""Model forward."""
x = self.flatten(x) # Flatten is necessary to pass from CNNs to MLP
x = self.mlp_layers(x)
return x
super().__init__(
model, load_loss(hyper_params),
optim_fact, sched_fact)
self.save_hyperparameters() # they will become available via model.hparams
22 changes: 0 additions & 22 deletions amlrt_project/models/optim.py
Original file line number Diff line number Diff line change
@@ -1,32 +1,10 @@
import logging

import torch
from torch import optim

logger = logging.getLogger(__name__)


def load_optimizer(hyper_params, model): # pragma: no cover
"""Instantiate the optimizer.
Args:
hyper_params (dict): hyper parameters from the config file
model (obj): A neural network model object.
Returns:
optimizer (obj): The optimizer for the given model
"""
optimizer_name = hyper_params["optimizer"]
# __TODO__ fix optimizer list
if optimizer_name == "adam":
optimizer = optim.Adam(model.parameters())
elif optimizer_name == "sgd":
optimizer = optim.SGD(model.parameters())
else:
raise ValueError("optimizer {} not supported".format(optimizer_name))
return optimizer


def load_loss(hyper_params): # pragma: no cover
r"""Instantiate the loss.
Expand Down
Loading

0 comments on commit cf75c1b

Please sign in to comment.