From 5f089c72a47a8ab2e805ca9afb29e798c43c2055 Mon Sep 17 00:00:00 2001 From: Marc-Antoine Rondeau Date: Wed, 18 Jan 2023 13:07:26 -0500 Subject: [PATCH 01/12] Added the optimizer and scheduler factories --- amlrt_project/models/factory.py | 149 ++++++++++++++++++++++++++++++++ 1 file changed, 149 insertions(+) create mode 100644 amlrt_project/models/factory.py diff --git a/amlrt_project/models/factory.py b/amlrt_project/models/factory.py new file mode 100644 index 0000000..98c8aec --- /dev/null +++ b/amlrt_project/models/factory.py @@ -0,0 +1,149 @@ +"""Optimizer and learning rate scheduler factory.""" + + +from abc import ABC, abstractmethod +from dataclasses import MISSING, dataclass, field +from enum import Enum, StrEnum +from typing import Any, ClassVar, Dict, Iterable, Optional, Tuple, Type + + +from torch.nn import Parameter +from torch.optim import SGD, Adam, Optimizer +from torch.optim.lr_scheduler import LambdaLR, ReduceLROnPlateau + + + +class OptimFactory(ABC): + """Base class for optimizer factories.""" + + @abstractmethod + def __call__(self, parameters: Iterable[Parameter]) -> Optimizer: + ... + + +class Interval(StrEnum): + """Interval supported by Lightning. + + Using an enum help avoiding invalid values. + """ + EPOCH = "epoch" + STEP = "step" + + +class SchedulerFactory(ABC): + """Base class for learning rate scheduler factories.""" + + @abstractmethod + def __call__(self, optim: Optimizer) -> Dict[str, Any]: + ... + + +@dataclass +class OptimizerConfigurationFactory: + """Combine an optimizer factory and a scheduler factory. + + Return the configuration Lightning requires. + Only support the usual case (one optim, one scheduler.) + """ + optim_factory: OptimFactory + scheduler_factory: Optional[SchedulerFactory] = None + + def __call__(self, parameters: Iterable[Parameter]) -> Dict[str, Any]: + """Create the optimizer and scheduler, for `parameters`.""" + config = {} + optim = self.optim_factory(parameters) + config['optimizer'] = optim + if self.scheduler_factory is not None: + config['lr_scheduler'] = self.scheduler_factory(optim) + return config + + +@dataclass +class PlateauFactory(SchedulerFactory): + """Reduce the learning rate when `metric` is no longer improving.""" + metric: str + """Metric to use, must be logged with Lightning.""" + mode: str = "min" + """Minimize or maximize.""" + factor: float = 0.1 + """Multiply the learning rate by `factor`.""" + patience: int = 10 + """Wait `patience` epoch before reducing the learning rate.""" + + def __call__(self, optimizer: Optimizer) -> Dict[str, Any]: + scheduler = ReduceLROnPlateau( + optimizer, + mode=self.mode, + factor=self.factor, patience=self.patience) + return dict( + scheduler=scheduler, + frequency=1, + interval=Interval.EPOCH, + monitor=self.metric) + + +@dataclass +class WarmupDecayFactory(SchedulerFactory): + r"""Increase the learning rate linearly from zero, then decay it. + + With base learning rate $\tau$, step $s$, and `warmup` $w$, the linear warmup is: + $$\tau \frac{s}{w}.$$ + The decay, following the warmup, is + $$\tau \gamma^{s-w},$$ where $\gamma$ is the hold rate. + """ + gamma: float + r"""Hold rate; higher value decay more slowly. Limited to $\eps \le \gamma \le 1.$""" + warmup: int + r"""Length of the linear warmup.""" + eps: float = field(init=False, default=1e-16) + r"""Safety value: `gamma` must be larger than this.""" + + def __post_init__(self): + self.gamma = max(min(self.gamma, 1.0), self.eps) + # Clip gamma to something that make sense, just in case. + self.warmup = max(self.warmup, 0) + # Same for warmup. + + def __call__(self, optimizer: Optimizer) -> Dict[str, Any]: + + def fn(step: int) -> float: + """Learning rate decay function.""" + if step < self.warmup: + return step / self.warmup + elif step > self.warmup: + return self.gamma ** (step - self.warmup) + return 1.0 + + scheduler = LambdaLR(optimizer, fn) + return dict(scheduler=scheduler, frequency=1, interval=Interval.STEP) + + +@dataclass +class SGDFactory(OptimFactory): + lr: float = MISSING + momentum: float = 0 + dampening: float = 0 + weight_decay: float = 0 + nesterov: bool = False + + def __call__(self, parameters: Iterable[Parameter]) -> SGD: + return SGD( + parameters, lr=self.lr, + momentum=self.momentum, dampening=self.dampening, + weight_decay=self.weight_decay, nesterov=self.nesterov) + + +@dataclass +class AdamFactory(OptimFactory): + lr: float = MISSING + betas: Tuple[float, float] = (0.9, 0.999) + eps: float = 1e-8 + weight_decay: float = 0 + amsgrad: bool = True # NOTE: The pytorch default is False, for backward compatibility. + + def __call__(self, parameters: Iterable[Parameter]) -> Adam: + return Adam( + parameters, lr=self.lr, + betas=self.betas, eps=self.eps, + weight_decay=self.weight_decay, + amsgrad=self.amsgrad) \ No newline at end of file From 4110b63d77a6810054feec69f1d79a099e08ef96 Mon Sep 17 00:00:00 2001 From: Marc-Antoine Rondeau Date: Wed, 18 Jan 2023 13:09:02 -0500 Subject: [PATCH 02/12] Backward compatible interface for the factories --- amlrt_project/models/model_loader.py | 57 +++++++++++++++++++++++++++- 1 file changed, 55 insertions(+), 2 deletions(-) diff --git a/amlrt_project/models/model_loader.py b/amlrt_project/models/model_loader.py index 2c55664..e1b7b04 100644 --- a/amlrt_project/models/model_loader.py +++ b/amlrt_project/models/model_loader.py @@ -1,11 +1,62 @@ +import dataclasses +from typing import Any, Dict, Optional, Tuple, Type, Union import logging from amlrt_project.models.my_model import SimpleMLP +from amlrt_project.models.factory import AdamFactory, OptimFactory, SGDFactory, PlateauFactory, SchedulerFactory, WarmupDecayFactory logger = logging.getLogger(__name__) -def load_model(hyper_params): # pragma: no cover +OPTS = { + 'SGD': SGDFactory, + 'Adam': AdamFactory +} + +SCHEDS = { + 'Plateau': PlateauFactory, + 'WarmupDecay': WarmupDecayFactory +} + + +def parse_opt_hp(hyper_params: Union[str, Dict[str, Any]]) -> OptimFactory: + """Parse the optimizer part of the config.""" + if isinstance(hyper_params, str): + algo = hyper_params + args = {} + else: + algo = hyper_params['algo'] + args = {key: hyper_params[key] for key in hyper_params if key != 'algo'} + + if algo not in OPTS: + raise ValueError(f'Optimizer {algo} not supported') + else: + algo: Type[OptimFactory] = OPTS[algo] + + return algo(**args) + + +def parse_sched_hp(hyper_params: Optional[Union[str, Dict[str, Any]]]) -> Optional[SchedulerFactory]: + """Parse the scheduler part of the config.""" + if hyper_params is None: + return None + elif isinstance(hyper_params, str): + algo = hyper_params + args = {} + else: + algo = hyper_params['algo'] + args = {key: hyper_params[key] for key in hyper_params if key != 'algo'} + + if algo not in SCHEDS: + raise ValueError(f'Scheduler {algo} not supported') + else: + algo: Type[SchedulerFactory] = SCHEDS[algo] + + return algo(**args) + + + +def load_model(hyper_params: Dict[str, Any]): # pragma: no cover """Instantiate a model. Args: @@ -22,7 +73,9 @@ def load_model(hyper_params): # pragma: no cover raise ValueError('architecture {} not supported'.format(architecture)) logger.info('selected architecture: {}'.format(architecture)) - model = model_class(hyper_params) + optim_fact = parse_opt_hp(hyper_params.get('optimizer', 'SGD')) + sched_fact = parse_sched_hp(hyper_params.get('scheduler', None)) + model = model_class(optim_fact, sched_fact, hyper_params) logger.info('model info:\n' + str(model) + '\n') return model From 3dc16b94003fdb382632c7b9954de8e1bbab3e4a Mon Sep 17 00:00:00 2001 From: Marc-Antoine Rondeau Date: Wed, 18 Jan 2023 13:10:24 -0500 Subject: [PATCH 03/12] Use the factories in the Lightning Module --- amlrt_project/models/my_model.py | 60 ++++++++++++++++++++------------ amlrt_project/models/optim.py | 21 ----------- 2 files changed, 38 insertions(+), 43 deletions(-) diff --git a/amlrt_project/models/my_model.py b/amlrt_project/models/my_model.py index 9a43b60..90cffb1 100644 --- a/amlrt_project/models/my_model.py +++ b/amlrt_project/models/my_model.py @@ -1,10 +1,11 @@ import logging import typing -from torch import nn +from torch import nn, FloatTensor, LongTensor import pytorch_lightning as pl -from amlrt_project.models.optim import load_loss, load_optimizer +from amlrt_project.models.optim import load_loss +from amlrt_project.models.factory import OptimFactory, OptimizerConfigurationFactory, SchedulerFactory from amlrt_project.utils.hp_utils import check_and_log_hp @@ -14,6 +15,19 @@ class BaseModel(pl.LightningModule): """Base class for Pytorch Lightning model - useful to reuse the same *_step methods.""" + def __init__( + self, + model: nn.Module, + loss_fn: nn.Module, + optim_fact: OptimFactory, + sched_fact: SchedulerFactory, + ): + """Initialize the LightningModule, with the actual model and loss""" + super().__init__() + self.model = model + self.loss_fn = loss_fn + self.opt_fact = OptimizerConfigurationFactory(optim_fact, sched_fact) + def configure_optimizers(self): """Returns the combination of optimizer(s) and learning rate scheduler(s) to train with. @@ -24,15 +38,16 @@ def configure_optimizers(self): See https://pytorch-lightning.readthedocs.io/en/latest/common/optimizers.html for more info on the expected returned elements. """ - # we use the generic loading function from the `model_loader` module, but it could be made - # a direct part of the model (useful if we want layer-dynamic optimization) - return load_optimizer(self.hparams, self) + return self.opt_fact(self.model.parameters()) + + def forward(self, input_data: FloatTensor) -> FloatTensor: + return self.model(input_data) def _generic_step( self, - batch: typing.Any, + batch: typing.Tuple[FloatTensor, LongTensor], batch_idx: int, - ) -> typing.Any: + ) -> FloatTensor: """Runs the prediction + evaluation step for training/validation/testing.""" input_data, targets = batch preds = self(input_data) # calls the forward pass of the model @@ -63,22 +78,23 @@ class SimpleMLP(BaseModel): # pragma: no cover Inherits from the given framework's model class. This is a simple MLP model. """ - def __init__(self, hyper_params: typing.Dict[typing.AnyStr, typing.Any]): + def __init__( + self, + optim_fact: OptimFactory, + sched_fact: SchedulerFactory, + hyper_params: typing.Dict[typing.AnyStr, typing.Any] + ): """__init__. Args: hyper_params (dict): hyper parameters from the config file. """ - super(SimpleMLP, self).__init__() - + # TODO: Place this in a factory. check_and_log_hp(['hidden_dim', 'num_classes'], hyper_params) - self.save_hyperparameters(hyper_params) # they will become available via model.hparams - num_classes = hyper_params['num_classes'] - hidden_dim = hyper_params['hidden_dim'] - self.loss_fn = load_loss(hyper_params) # 'load_loss' could be part of the model itself... - - self.flatten = nn.Flatten() - self.mlp_layers = nn.Sequential( + num_classes: int = hyper_params['num_classes'] + hidden_dim: int = hyper_params['hidden_dim'] + flatten = nn.Flatten() + mlp_layers = nn.Sequential( nn.Linear( 784, hidden_dim, ), # The input size for the linear layer is determined by the previous operations @@ -87,9 +103,9 @@ def __init__(self, hyper_params: typing.Dict[typing.AnyStr, typing.Any]): hidden_dim, num_classes ), # Here we get exactly num_classes logits at the output ) + model = nn.Sequential(flatten, mlp_layers) - def forward(self, x): - """Model forward.""" - x = self.flatten(x) # Flatten is necessary to pass from CNNs to MLP - x = self.mlp_layers(x) - return x + super().__init__( + model, load_loss(hyper_params), + optim_fact, sched_fact) + self.save_hyperparameters() # they will become available via model.hparams \ No newline at end of file diff --git a/amlrt_project/models/optim.py b/amlrt_project/models/optim.py index 3b5f709..7bd405d 100644 --- a/amlrt_project/models/optim.py +++ b/amlrt_project/models/optim.py @@ -7,27 +7,6 @@ logger = logging.getLogger(__name__) -def load_optimizer(hyper_params, model): # pragma: no cover - """Instantiate the optimizer. - - Args: - hyper_params (dict): hyper parameters from the config file - model (obj): A neural network model object. - - Returns: - optimizer (obj): The optimizer for the given model - """ - optimizer_name = hyper_params["optimizer"] - # __TODO__ fix optimizer list - if optimizer_name == "adam": - optimizer = optim.Adam(model.parameters()) - elif optimizer_name == "sgd": - optimizer = optim.SGD(model.parameters()) - else: - raise ValueError("optimizer {} not supported".format(optimizer_name)) - return optimizer - - def load_loss(hyper_params): # pragma: no cover r"""Instantiate the loss. From bc804b9354a2b1eda80929c577517c532b83d09d Mon Sep 17 00:00:00 2001 From: Marc-Antoine Rondeau Date: Fri, 20 Jan 2023 15:43:05 -0500 Subject: [PATCH 04/12] Linting, with configuration. --- .flake8 | 4 ++++ amlrt_project/models/factory.py | 17 ++++++++++++----- amlrt_project/models/model_loader.py | 12 +++++++----- amlrt_project/models/my_model.py | 15 +++++++++------ amlrt_project/models/optim.py | 1 - 5 files changed, 32 insertions(+), 17 deletions(-) create mode 100644 .flake8 diff --git a/.flake8 b/.flake8 new file mode 100644 index 0000000..469e681 --- /dev/null +++ b/.flake8 @@ -0,0 +1,4 @@ +[flake8] +max-line-length=100 +ignore=W503,D104,D100,D401 +docstring-convention=google \ No newline at end of file diff --git a/amlrt_project/models/factory.py b/amlrt_project/models/factory.py index 98c8aec..1c30fa7 100644 --- a/amlrt_project/models/factory.py +++ b/amlrt_project/models/factory.py @@ -3,21 +3,20 @@ from abc import ABC, abstractmethod from dataclasses import MISSING, dataclass, field -from enum import Enum, StrEnum -from typing import Any, ClassVar, Dict, Iterable, Optional, Tuple, Type - +from enum import StrEnum +from typing import Any, Dict, Iterable, Optional, Tuple from torch.nn import Parameter from torch.optim import SGD, Adam, Optimizer from torch.optim.lr_scheduler import LambdaLR, ReduceLROnPlateau - class OptimFactory(ABC): """Base class for optimizer factories.""" @abstractmethod def __call__(self, parameters: Iterable[Parameter]) -> Optimizer: + """Create an optimizer.""" ... @@ -35,6 +34,7 @@ class SchedulerFactory(ABC): @abstractmethod def __call__(self, optim: Optimizer) -> Dict[str, Any]: + """Create a scheduler.""" ... @@ -71,6 +71,7 @@ class PlateauFactory(SchedulerFactory): """Wait `patience` epoch before reducing the learning rate.""" def __call__(self, optimizer: Optimizer) -> Dict[str, Any]: + """Create a scheduler.""" scheduler = ReduceLROnPlateau( optimizer, mode=self.mode, @@ -99,12 +100,14 @@ class WarmupDecayFactory(SchedulerFactory): r"""Safety value: `gamma` must be larger than this.""" def __post_init__(self): + """Finish initialization.""" self.gamma = max(min(self.gamma, 1.0), self.eps) # Clip gamma to something that make sense, just in case. self.warmup = max(self.warmup, 0) # Same for warmup. def __call__(self, optimizer: Optimizer) -> Dict[str, Any]: + """Create scheduler.""" def fn(step: int) -> float: """Learning rate decay function.""" @@ -120,6 +123,7 @@ def fn(step: int) -> float: @dataclass class SGDFactory(OptimFactory): + """Factory for SGD optimizers.""" lr: float = MISSING momentum: float = 0 dampening: float = 0 @@ -127,6 +131,7 @@ class SGDFactory(OptimFactory): nesterov: bool = False def __call__(self, parameters: Iterable[Parameter]) -> SGD: + """Create and initialize a SGD optimizer.""" return SGD( parameters, lr=self.lr, momentum=self.momentum, dampening=self.dampening, @@ -135,6 +140,7 @@ def __call__(self, parameters: Iterable[Parameter]) -> SGD: @dataclass class AdamFactory(OptimFactory): + """Factory for ADAM optimizers.""" lr: float = MISSING betas: Tuple[float, float] = (0.9, 0.999) eps: float = 1e-8 @@ -142,8 +148,9 @@ class AdamFactory(OptimFactory): amsgrad: bool = True # NOTE: The pytorch default is False, for backward compatibility. def __call__(self, parameters: Iterable[Parameter]) -> Adam: + """Create and initialize an ADAM optimizer.""" return Adam( parameters, lr=self.lr, betas=self.betas, eps=self.eps, weight_decay=self.weight_decay, - amsgrad=self.amsgrad) \ No newline at end of file + amsgrad=self.amsgrad) diff --git a/amlrt_project/models/model_loader.py b/amlrt_project/models/model_loader.py index e1b7b04..2bffee2 100644 --- a/amlrt_project/models/model_loader.py +++ b/amlrt_project/models/model_loader.py @@ -1,9 +1,10 @@ -import dataclasses -from typing import Any, Dict, Optional, Tuple, Type, Union import logging +from typing import Any, Dict, Optional, Type, Union +from amlrt_project.models.factory import (AdamFactory, OptimFactory, + PlateauFactory, SchedulerFactory, + SGDFactory, WarmupDecayFactory) from amlrt_project.models.my_model import SimpleMLP -from amlrt_project.models.factory import AdamFactory, OptimFactory, SGDFactory, PlateauFactory, SchedulerFactory, WarmupDecayFactory logger = logging.getLogger(__name__) @@ -36,7 +37,9 @@ def parse_opt_hp(hyper_params: Union[str, Dict[str, Any]]) -> OptimFactory: return algo(**args) -def parse_sched_hp(hyper_params: Optional[Union[str, Dict[str, Any]]]) -> Optional[SchedulerFactory]: +def parse_sched_hp( + hyper_params: Optional[Union[str, Dict[str, Any]]] +) -> Optional[SchedulerFactory]: """Parse the scheduler part of the config.""" if hyper_params is None: return None @@ -55,7 +58,6 @@ def parse_sched_hp(hyper_params: Optional[Union[str, Dict[str, Any]]]) -> Option return algo(**args) - def load_model(hyper_params: Dict[str, Any]): # pragma: no cover """Instantiate a model. diff --git a/amlrt_project/models/my_model.py b/amlrt_project/models/my_model.py index 7516538..9aee523 100644 --- a/amlrt_project/models/my_model.py +++ b/amlrt_project/models/my_model.py @@ -1,13 +1,13 @@ import logging import typing - import pytorch_lightning as pl -from torch import nn, FloatTensor, LongTensor +from torch import FloatTensor, LongTensor, nn +from amlrt_project.models.factory import (OptimFactory, + OptimizerConfigurationFactory, + SchedulerFactory) from amlrt_project.models.optim import load_loss -from amlrt_project.models.factory import OptimFactory, OptimizerConfigurationFactory, SchedulerFactory - from amlrt_project.utils.hp_utils import check_and_log_hp logger = logging.getLogger(__name__) @@ -23,7 +23,7 @@ def __init__( optim_fact: OptimFactory, sched_fact: SchedulerFactory, ): - """Initialize the LightningModule, with the actual model and loss""" + """Initialize the LightningModule, with the actual model and loss.""" super().__init__() self.model = model self.loss_fn = loss_fn @@ -42,6 +42,7 @@ def configure_optimizers(self): return self.opt_fact(self.model.parameters()) def forward(self, input_data: FloatTensor) -> FloatTensor: + """Invoke the model.""" return self.model(input_data) def _generic_step( @@ -88,6 +89,8 @@ def __init__( """__init__. Args: + optim_fact (OptimFactory): factory for the optimizer. + sched_fact (SchedulerFactory): factory for the scheduler. hyper_params (dict): hyper parameters from the config file. """ # TODO: Place this in a factory. @@ -109,4 +112,4 @@ def __init__( super().__init__( model, load_loss(hyper_params), optim_fact, sched_fact) - self.save_hyperparameters() # they will become available via model.hparams \ No newline at end of file + self.save_hyperparameters() # they will become available via model.hparams diff --git a/amlrt_project/models/optim.py b/amlrt_project/models/optim.py index c8ccb48..b952da0 100644 --- a/amlrt_project/models/optim.py +++ b/amlrt_project/models/optim.py @@ -1,7 +1,6 @@ import logging import torch -from torch import optim logger = logging.getLogger(__name__) From 59a884f444354253f11d47a0703fc2e7651e2132 Mon Sep 17 00:00:00 2001 From: Marc-Antoine Rondeau Date: Mon, 23 Jan 2023 10:53:31 -0500 Subject: [PATCH 05/12] Remove StrEnum (Python 3.11) --- amlrt_project/models/factory.py | 14 ++------------ 1 file changed, 2 insertions(+), 12 deletions(-) diff --git a/amlrt_project/models/factory.py b/amlrt_project/models/factory.py index 1c30fa7..215fdad 100644 --- a/amlrt_project/models/factory.py +++ b/amlrt_project/models/factory.py @@ -3,7 +3,6 @@ from abc import ABC, abstractmethod from dataclasses import MISSING, dataclass, field -from enum import StrEnum from typing import Any, Dict, Iterable, Optional, Tuple from torch.nn import Parameter @@ -20,15 +19,6 @@ def __call__(self, parameters: Iterable[Parameter]) -> Optimizer: ... -class Interval(StrEnum): - """Interval supported by Lightning. - - Using an enum help avoiding invalid values. - """ - EPOCH = "epoch" - STEP = "step" - - class SchedulerFactory(ABC): """Base class for learning rate scheduler factories.""" @@ -79,7 +69,7 @@ def __call__(self, optimizer: Optimizer) -> Dict[str, Any]: return dict( scheduler=scheduler, frequency=1, - interval=Interval.EPOCH, + interval='epoch', monitor=self.metric) @@ -118,7 +108,7 @@ def fn(step: int) -> float: return 1.0 scheduler = LambdaLR(optimizer, fn) - return dict(scheduler=scheduler, frequency=1, interval=Interval.STEP) + return dict(scheduler=scheduler, frequency=1, interval='step') @dataclass From e05d66278cf3678a1a2d3e4973830cbee7c1fbc5 Mon Sep 17 00:00:00 2001 From: Marc-Antoine Rondeau Date: Mon, 23 Jan 2023 11:10:52 -0500 Subject: [PATCH 06/12] Added lower case names for the factories. --- amlrt_project/models/model_loader.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/amlrt_project/models/model_loader.py b/amlrt_project/models/model_loader.py index 2bffee2..d478b8f 100644 --- a/amlrt_project/models/model_loader.py +++ b/amlrt_project/models/model_loader.py @@ -11,12 +11,16 @@ OPTS = { 'SGD': SGDFactory, - 'Adam': AdamFactory + 'sgd': SGDFactory, + 'Adam': AdamFactory, + 'adam': AdamFactory } SCHEDS = { 'Plateau': PlateauFactory, - 'WarmupDecay': WarmupDecayFactory + 'plateau': PlateauFactory, + 'WarmupDecay': WarmupDecayFactory, + 'warmupdecay': WarmupDecayFactory } From 087fa47e2d34d72dc0c5a22f42b81b21db0cf39b Mon Sep 17 00:00:00 2001 From: Marc-Antoine Rondeau Date: Mon, 23 Jan 2023 11:30:49 -0500 Subject: [PATCH 07/12] Added default learning rate for Adam. --- .gitignore | 5 ++++- amlrt_project/models/factory.py | 4 ++-- 2 files changed, 6 insertions(+), 3 deletions(-) diff --git a/.gitignore b/.gitignore index 50e214b..cfe4705 100644 --- a/.gitignore +++ b/.gitignore @@ -110,6 +110,9 @@ venv.bak/ # mypy .mypy_cache/ -# vim hipsters +# vim overlords *.swp *.swo + +# vscode +.vscode \ No newline at end of file diff --git a/amlrt_project/models/factory.py b/amlrt_project/models/factory.py index 215fdad..5c53329 100644 --- a/amlrt_project/models/factory.py +++ b/amlrt_project/models/factory.py @@ -114,7 +114,7 @@ def fn(step: int) -> float: @dataclass class SGDFactory(OptimFactory): """Factory for SGD optimizers.""" - lr: float = MISSING + lr: float = MISSING # Value is required. momentum: float = 0 dampening: float = 0 weight_decay: float = 0 @@ -131,7 +131,7 @@ def __call__(self, parameters: Iterable[Parameter]) -> SGD: @dataclass class AdamFactory(OptimFactory): """Factory for ADAM optimizers.""" - lr: float = MISSING + lr: float = 1e-3 # `MISSING` if we want to require an explicit value. betas: Tuple[float, float] = (0.9, 0.999) eps: float = 1e-8 weight_decay: float = 0 From 40389d5a5a0a3f05dd4222ed29fb0c3ebd5b6df7 Mon Sep 17 00:00:00 2001 From: Marc-Antoine Rondeau Date: Mon, 23 Jan 2023 11:56:42 -0500 Subject: [PATCH 08/12] Explicit check for a dict when parsing HPs. --- amlrt_project/models/model_loader.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/amlrt_project/models/model_loader.py b/amlrt_project/models/model_loader.py index d478b8f..af25928 100644 --- a/amlrt_project/models/model_loader.py +++ b/amlrt_project/models/model_loader.py @@ -29,9 +29,11 @@ def parse_opt_hp(hyper_params: Union[str, Dict[str, Any]]) -> OptimFactory: if isinstance(hyper_params, str): algo = hyper_params args = {} - else: + elif isinstance(hyper_params, dict): algo = hyper_params['algo'] args = {key: hyper_params[key] for key in hyper_params if key != 'algo'} + else: + raise TypeError(f"hyper_params should be a str or a dict, got {type(hyper_params)}") if algo not in OPTS: raise ValueError(f'Optimizer {algo} not supported') @@ -50,9 +52,11 @@ def parse_sched_hp( elif isinstance(hyper_params, str): algo = hyper_params args = {} - else: + elif isinstance(hyper_params, dict): algo = hyper_params['algo'] args = {key: hyper_params[key] for key in hyper_params if key != 'algo'} + else: + raise TypeError(f"hyper_params should be a str or a dict, got {type(hyper_params)}") if algo not in SCHEDS: raise ValueError(f'Scheduler {algo} not supported') From 037d14a4dd49e58ba3dd10d2edd2861938c8aaed Mon Sep 17 00:00:00 2001 From: mirkobronzi Date: Tue, 7 Nov 2023 16:50:41 -0500 Subject: [PATCH 09/12] trying new version for sphinx --- setup.py | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/setup.py b/setup.py index bcb89f3..e77c5ad 100644 --- a/setup.py +++ b/setup.py @@ -10,19 +10,19 @@ 'flake8-docstrings==1.6.0', 'gitpython==3.1.27', 'jupyter==1.0.0', - 'jinja2==3.1.2', - 'myst-parser==0.18.0', + 'jinja2', + 'myst-parser', 'orion>=0.2.4.post1', 'pyyaml==6.0', 'pytest==7.1.2', 'pytest-cov==3.0.0', 'pytorch_lightning==1.8.3', 'pytype==2023.1.17', - 'sphinx==5.1.1', - 'sphinx-autoapi==1.9.0', - 'sphinx-rtd-theme==1.0.0', - 'sphinxcontrib-napoleon==0.7', - 'sphinxcontrib-katex==0.8.6', + 'sphinx', + 'sphinx-autoapi', + 'sphinx-rtd-theme', + 'sphinxcontrib-napoleon', + 'sphinxcontrib-katex', 'tqdm==4.64.0', 'torch==1.12.0', 'torchvision==0.13.0', From 9740d71eb30dfa40fc7f8f5e67f990b81ede9ba5 Mon Sep 17 00:00:00 2001 From: mirkobronzi Date: Tue, 7 Nov 2023 16:56:26 -0500 Subject: [PATCH 10/12] fixed orion by adding missing hyper-parameter --- examples/local_orion/config.yaml | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/examples/local_orion/config.yaml b/examples/local_orion/config.yaml index 6c9346c..42c8743 100644 --- a/examples/local_orion/config.yaml +++ b/examples/local_orion/config.yaml @@ -4,17 +4,18 @@ optimizer: adam loss: cross_entropy max_epoch: 5 exp_name: my_exp_1 +num_workers: 0 # set to null to avoid setting a seed (can speed up GPU computation, but # results will not be reproducible) seed: 1234 # architecture +hidden_dim: 'orion~uniform(32,256,discrete=True)' num_classes: 10 architecture: simple_mlp -hidden_dim: 'orion~uniform(32,256,discrete=True)' # early stopping early_stopping: metric: val_loss mode: min - patience: 3 \ No newline at end of file + patience: 3 From 785c898e875fd2941df111c82e19840ea669e5f8 Mon Sep 17 00:00:00 2001 From: mirkobronzi Date: Tue, 7 Nov 2023 17:01:38 -0500 Subject: [PATCH 11/12] fixed version of updated libraries --- setup.py | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/setup.py b/setup.py index e77c5ad..0c156fa 100644 --- a/setup.py +++ b/setup.py @@ -10,19 +10,19 @@ 'flake8-docstrings==1.6.0', 'gitpython==3.1.27', 'jupyter==1.0.0', - 'jinja2', - 'myst-parser', + 'jinja2==3.1.2', + 'myst-parser==2.0.0', 'orion>=0.2.4.post1', 'pyyaml==6.0', 'pytest==7.1.2', 'pytest-cov==3.0.0', 'pytorch_lightning==1.8.3', 'pytype==2023.1.17', - 'sphinx', - 'sphinx-autoapi', - 'sphinx-rtd-theme', - 'sphinxcontrib-napoleon', - 'sphinxcontrib-katex', + 'sphinx==7.2.6', + 'sphinx-autoapi==3.0.0', + 'sphinx-rtd-theme==1.3.0', + 'sphinxcontrib-napoleon==0.7', + 'sphinxcontrib-katex==0.9.9', 'tqdm==4.64.0', 'torch==1.12.0', 'torchvision==0.13.0', From e9cf628547d51d4bbdaee5bb7f4429fe61e4db77 Mon Sep 17 00:00:00 2001 From: mirkobronzi Date: Tue, 7 Nov 2023 17:14:39 -0500 Subject: [PATCH 12/12] sorted comments so that a comment is before the related line of code --- amlrt_project/models/factory.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/amlrt_project/models/factory.py b/amlrt_project/models/factory.py index 5c53329..7378ebe 100644 --- a/amlrt_project/models/factory.py +++ b/amlrt_project/models/factory.py @@ -91,10 +91,10 @@ class WarmupDecayFactory(SchedulerFactory): def __post_init__(self): """Finish initialization.""" - self.gamma = max(min(self.gamma, 1.0), self.eps) # Clip gamma to something that make sense, just in case. - self.warmup = max(self.warmup, 0) + self.gamma = max(min(self.gamma, 1.0), self.eps) # Same for warmup. + self.warmup = max(self.warmup, 0) def __call__(self, optimizer: Optimizer) -> Dict[str, Any]: """Create scheduler."""