diff --git a/.flake8 b/.flake8 new file mode 100644 index 0000000..469e681 --- /dev/null +++ b/.flake8 @@ -0,0 +1,4 @@ +[flake8] +max-line-length=100 +ignore=W503,D104,D100,D401 +docstring-convention=google \ No newline at end of file diff --git a/.gitignore b/.gitignore index 50e214b..cfe4705 100644 --- a/.gitignore +++ b/.gitignore @@ -110,6 +110,9 @@ venv.bak/ # mypy .mypy_cache/ -# vim hipsters +# vim overlords *.swp *.swo + +# vscode +.vscode \ No newline at end of file diff --git a/amlrt_project/models/factory.py b/amlrt_project/models/factory.py new file mode 100644 index 0000000..7378ebe --- /dev/null +++ b/amlrt_project/models/factory.py @@ -0,0 +1,146 @@ +"""Optimizer and learning rate scheduler factory.""" + + +from abc import ABC, abstractmethod +from dataclasses import MISSING, dataclass, field +from typing import Any, Dict, Iterable, Optional, Tuple + +from torch.nn import Parameter +from torch.optim import SGD, Adam, Optimizer +from torch.optim.lr_scheduler import LambdaLR, ReduceLROnPlateau + + +class OptimFactory(ABC): + """Base class for optimizer factories.""" + + @abstractmethod + def __call__(self, parameters: Iterable[Parameter]) -> Optimizer: + """Create an optimizer.""" + ... + + +class SchedulerFactory(ABC): + """Base class for learning rate scheduler factories.""" + + @abstractmethod + def __call__(self, optim: Optimizer) -> Dict[str, Any]: + """Create a scheduler.""" + ... + + +@dataclass +class OptimizerConfigurationFactory: + """Combine an optimizer factory and a scheduler factory. + + Return the configuration Lightning requires. + Only support the usual case (one optim, one scheduler.) + """ + optim_factory: OptimFactory + scheduler_factory: Optional[SchedulerFactory] = None + + def __call__(self, parameters: Iterable[Parameter]) -> Dict[str, Any]: + """Create the optimizer and scheduler, for `parameters`.""" + config = {} + optim = self.optim_factory(parameters) + config['optimizer'] = optim + if self.scheduler_factory is not None: + config['lr_scheduler'] = self.scheduler_factory(optim) + return config + + +@dataclass +class PlateauFactory(SchedulerFactory): + """Reduce the learning rate when `metric` is no longer improving.""" + metric: str + """Metric to use, must be logged with Lightning.""" + mode: str = "min" + """Minimize or maximize.""" + factor: float = 0.1 + """Multiply the learning rate by `factor`.""" + patience: int = 10 + """Wait `patience` epoch before reducing the learning rate.""" + + def __call__(self, optimizer: Optimizer) -> Dict[str, Any]: + """Create a scheduler.""" + scheduler = ReduceLROnPlateau( + optimizer, + mode=self.mode, + factor=self.factor, patience=self.patience) + return dict( + scheduler=scheduler, + frequency=1, + interval='epoch', + monitor=self.metric) + + +@dataclass +class WarmupDecayFactory(SchedulerFactory): + r"""Increase the learning rate linearly from zero, then decay it. + + With base learning rate $\tau$, step $s$, and `warmup` $w$, the linear warmup is: + $$\tau \frac{s}{w}.$$ + The decay, following the warmup, is + $$\tau \gamma^{s-w},$$ where $\gamma$ is the hold rate. + """ + gamma: float + r"""Hold rate; higher value decay more slowly. Limited to $\eps \le \gamma \le 1.$""" + warmup: int + r"""Length of the linear warmup.""" + eps: float = field(init=False, default=1e-16) + r"""Safety value: `gamma` must be larger than this.""" + + def __post_init__(self): + """Finish initialization.""" + # Clip gamma to something that make sense, just in case. + self.gamma = max(min(self.gamma, 1.0), self.eps) + # Same for warmup. + self.warmup = max(self.warmup, 0) + + def __call__(self, optimizer: Optimizer) -> Dict[str, Any]: + """Create scheduler.""" + + def fn(step: int) -> float: + """Learning rate decay function.""" + if step < self.warmup: + return step / self.warmup + elif step > self.warmup: + return self.gamma ** (step - self.warmup) + return 1.0 + + scheduler = LambdaLR(optimizer, fn) + return dict(scheduler=scheduler, frequency=1, interval='step') + + +@dataclass +class SGDFactory(OptimFactory): + """Factory for SGD optimizers.""" + lr: float = MISSING # Value is required. + momentum: float = 0 + dampening: float = 0 + weight_decay: float = 0 + nesterov: bool = False + + def __call__(self, parameters: Iterable[Parameter]) -> SGD: + """Create and initialize a SGD optimizer.""" + return SGD( + parameters, lr=self.lr, + momentum=self.momentum, dampening=self.dampening, + weight_decay=self.weight_decay, nesterov=self.nesterov) + + +@dataclass +class AdamFactory(OptimFactory): + """Factory for ADAM optimizers.""" + lr: float = 1e-3 # `MISSING` if we want to require an explicit value. + betas: Tuple[float, float] = (0.9, 0.999) + eps: float = 1e-8 + weight_decay: float = 0 + amsgrad: bool = True # NOTE: The pytorch default is False, for backward compatibility. + + def __call__(self, parameters: Iterable[Parameter]) -> Adam: + """Create and initialize an ADAM optimizer.""" + return Adam( + parameters, lr=self.lr, + betas=self.betas, eps=self.eps, + weight_decay=self.weight_decay, + amsgrad=self.amsgrad) diff --git a/amlrt_project/models/model_loader.py b/amlrt_project/models/model_loader.py index 2c55664..af25928 100644 --- a/amlrt_project/models/model_loader.py +++ b/amlrt_project/models/model_loader.py @@ -1,11 +1,72 @@ import logging +from typing import Any, Dict, Optional, Type, Union +from amlrt_project.models.factory import (AdamFactory, OptimFactory, + PlateauFactory, SchedulerFactory, + SGDFactory, WarmupDecayFactory) from amlrt_project.models.my_model import SimpleMLP logger = logging.getLogger(__name__) -def load_model(hyper_params): # pragma: no cover +OPTS = { + 'SGD': SGDFactory, + 'sgd': SGDFactory, + 'Adam': AdamFactory, + 'adam': AdamFactory +} + +SCHEDS = { + 'Plateau': PlateauFactory, + 'plateau': PlateauFactory, + 'WarmupDecay': WarmupDecayFactory, + 'warmupdecay': WarmupDecayFactory +} + + +def parse_opt_hp(hyper_params: Union[str, Dict[str, Any]]) -> OptimFactory: + """Parse the optimizer part of the config.""" + if isinstance(hyper_params, str): + algo = hyper_params + args = {} + elif isinstance(hyper_params, dict): + algo = hyper_params['algo'] + args = {key: hyper_params[key] for key in hyper_params if key != 'algo'} + else: + raise TypeError(f"hyper_params should be a str or a dict, got {type(hyper_params)}") + + if algo not in OPTS: + raise ValueError(f'Optimizer {algo} not supported') + else: + algo: Type[OptimFactory] = OPTS[algo] + + return algo(**args) + + +def parse_sched_hp( + hyper_params: Optional[Union[str, Dict[str, Any]]] +) -> Optional[SchedulerFactory]: + """Parse the scheduler part of the config.""" + if hyper_params is None: + return None + elif isinstance(hyper_params, str): + algo = hyper_params + args = {} + elif isinstance(hyper_params, dict): + algo = hyper_params['algo'] + args = {key: hyper_params[key] for key in hyper_params if key != 'algo'} + else: + raise TypeError(f"hyper_params should be a str or a dict, got {type(hyper_params)}") + + if algo not in SCHEDS: + raise ValueError(f'Scheduler {algo} not supported') + else: + algo: Type[SchedulerFactory] = SCHEDS[algo] + + return algo(**args) + + +def load_model(hyper_params: Dict[str, Any]): # pragma: no cover """Instantiate a model. Args: @@ -22,7 +83,9 @@ def load_model(hyper_params): # pragma: no cover raise ValueError('architecture {} not supported'.format(architecture)) logger.info('selected architecture: {}'.format(architecture)) - model = model_class(hyper_params) + optim_fact = parse_opt_hp(hyper_params.get('optimizer', 'SGD')) + sched_fact = parse_sched_hp(hyper_params.get('scheduler', None)) + model = model_class(optim_fact, sched_fact, hyper_params) logger.info('model info:\n' + str(model) + '\n') return model diff --git a/amlrt_project/models/my_model.py b/amlrt_project/models/my_model.py index e44cdfe..9aee523 100644 --- a/amlrt_project/models/my_model.py +++ b/amlrt_project/models/my_model.py @@ -2,9 +2,12 @@ import typing import pytorch_lightning as pl -from torch import nn +from torch import FloatTensor, LongTensor, nn -from amlrt_project.models.optim import load_loss, load_optimizer +from amlrt_project.models.factory import (OptimFactory, + OptimizerConfigurationFactory, + SchedulerFactory) +from amlrt_project.models.optim import load_loss from amlrt_project.utils.hp_utils import check_and_log_hp logger = logging.getLogger(__name__) @@ -13,6 +16,19 @@ class BaseModel(pl.LightningModule): """Base class for Pytorch Lightning model - useful to reuse the same *_step methods.""" + def __init__( + self, + model: nn.Module, + loss_fn: nn.Module, + optim_fact: OptimFactory, + sched_fact: SchedulerFactory, + ): + """Initialize the LightningModule, with the actual model and loss.""" + super().__init__() + self.model = model + self.loss_fn = loss_fn + self.opt_fact = OptimizerConfigurationFactory(optim_fact, sched_fact) + def configure_optimizers(self): """Returns the combination of optimizer(s) and learning rate scheduler(s) to train with. @@ -23,15 +39,17 @@ def configure_optimizers(self): See https://pytorch-lightning.readthedocs.io/en/latest/common/optimizers.html for more info on the expected returned elements. """ - # we use the generic loading function from the `model_loader` module, but it could be made - # a direct part of the model (useful if we want layer-dynamic optimization) - return load_optimizer(self.hparams, self) + return self.opt_fact(self.model.parameters()) + + def forward(self, input_data: FloatTensor) -> FloatTensor: + """Invoke the model.""" + return self.model(input_data) def _generic_step( self, - batch: typing.Any, + batch: typing.Tuple[FloatTensor, LongTensor], batch_idx: int, - ) -> typing.Any: + ) -> FloatTensor: """Runs the prediction + evaluation step for training/validation/testing.""" input_data, targets = batch preds = self(input_data) # calls the forward pass of the model @@ -62,22 +80,25 @@ class SimpleMLP(BaseModel): # pragma: no cover Inherits from the given framework's model class. This is a simple MLP model. """ - def __init__(self, hyper_params: typing.Dict[typing.AnyStr, typing.Any]): + def __init__( + self, + optim_fact: OptimFactory, + sched_fact: SchedulerFactory, + hyper_params: typing.Dict[typing.AnyStr, typing.Any] + ): """__init__. Args: + optim_fact (OptimFactory): factory for the optimizer. + sched_fact (SchedulerFactory): factory for the scheduler. hyper_params (dict): hyper parameters from the config file. """ - super(SimpleMLP, self).__init__() - + # TODO: Place this in a factory. check_and_log_hp(['hidden_dim', 'num_classes'], hyper_params) - self.save_hyperparameters(hyper_params) # they will become available via model.hparams - num_classes = hyper_params['num_classes'] - hidden_dim = hyper_params['hidden_dim'] - self.loss_fn = load_loss(hyper_params) # 'load_loss' could be part of the model itself... - - self.flatten = nn.Flatten() - self.mlp_layers = nn.Sequential( + num_classes: int = hyper_params['num_classes'] + hidden_dim: int = hyper_params['hidden_dim'] + flatten = nn.Flatten() + mlp_layers = nn.Sequential( nn.Linear( 784, hidden_dim, ), # The input size for the linear layer is determined by the previous operations @@ -86,9 +107,9 @@ def __init__(self, hyper_params: typing.Dict[typing.AnyStr, typing.Any]): hidden_dim, num_classes ), # Here we get exactly num_classes logits at the output ) + model = nn.Sequential(flatten, mlp_layers) - def forward(self, x): - """Model forward.""" - x = self.flatten(x) # Flatten is necessary to pass from CNNs to MLP - x = self.mlp_layers(x) - return x + super().__init__( + model, load_loss(hyper_params), + optim_fact, sched_fact) + self.save_hyperparameters() # they will become available via model.hparams diff --git a/amlrt_project/models/optim.py b/amlrt_project/models/optim.py index a124684..b952da0 100644 --- a/amlrt_project/models/optim.py +++ b/amlrt_project/models/optim.py @@ -1,32 +1,10 @@ import logging import torch -from torch import optim logger = logging.getLogger(__name__) -def load_optimizer(hyper_params, model): # pragma: no cover - """Instantiate the optimizer. - - Args: - hyper_params (dict): hyper parameters from the config file - model (obj): A neural network model object. - - Returns: - optimizer (obj): The optimizer for the given model - """ - optimizer_name = hyper_params["optimizer"] - # __TODO__ fix optimizer list - if optimizer_name == "adam": - optimizer = optim.Adam(model.parameters()) - elif optimizer_name == "sgd": - optimizer = optim.SGD(model.parameters()) - else: - raise ValueError("optimizer {} not supported".format(optimizer_name)) - return optimizer - - def load_loss(hyper_params): # pragma: no cover r"""Instantiate the loss. diff --git a/examples/local_orion/config.yaml b/examples/local_orion/config.yaml index 6c9346c..42c8743 100644 --- a/examples/local_orion/config.yaml +++ b/examples/local_orion/config.yaml @@ -4,17 +4,18 @@ optimizer: adam loss: cross_entropy max_epoch: 5 exp_name: my_exp_1 +num_workers: 0 # set to null to avoid setting a seed (can speed up GPU computation, but # results will not be reproducible) seed: 1234 # architecture +hidden_dim: 'orion~uniform(32,256,discrete=True)' num_classes: 10 architecture: simple_mlp -hidden_dim: 'orion~uniform(32,256,discrete=True)' # early stopping early_stopping: metric: val_loss mode: min - patience: 3 \ No newline at end of file + patience: 3 diff --git a/setup.py b/setup.py index bcb89f3..0c156fa 100644 --- a/setup.py +++ b/setup.py @@ -11,18 +11,18 @@ 'gitpython==3.1.27', 'jupyter==1.0.0', 'jinja2==3.1.2', - 'myst-parser==0.18.0', + 'myst-parser==2.0.0', 'orion>=0.2.4.post1', 'pyyaml==6.0', 'pytest==7.1.2', 'pytest-cov==3.0.0', 'pytorch_lightning==1.8.3', 'pytype==2023.1.17', - 'sphinx==5.1.1', - 'sphinx-autoapi==1.9.0', - 'sphinx-rtd-theme==1.0.0', + 'sphinx==7.2.6', + 'sphinx-autoapi==3.0.0', + 'sphinx-rtd-theme==1.3.0', 'sphinxcontrib-napoleon==0.7', - 'sphinxcontrib-katex==0.8.6', + 'sphinxcontrib-katex==0.9.9', 'tqdm==4.64.0', 'torch==1.12.0', 'torchvision==0.13.0',