Skip to content

Commit

Permalink
feat: model based bracket optimizers (wip)
Browse files Browse the repository at this point in the history
  • Loading branch information
eddiebergman committed Jan 24, 2025
1 parent ff08feb commit 33bfb9b
Show file tree
Hide file tree
Showing 6 changed files with 349 additions and 29 deletions.
68 changes: 65 additions & 3 deletions neps/optimizers/algorithms.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@

from neps.optimizers.ask_and_tell import AskAndTell # noqa: F401
from neps.optimizers.bayesian_optimization import BayesianOptimization
from neps.optimizers.bracket_optimizer import BracketOptimizer
from neps.optimizers.bracket_optimizer import BracketOptimizer, GPSampler
from neps.optimizers.grid_search import GridSearch
from neps.optimizers.ifbo import IFBO
from neps.optimizers.models.ftpfn import FTPFNSurrogate
Expand Down Expand Up @@ -117,17 +117,19 @@ def _bo(
)


def _bracket_optimizer( # noqa: C901
def _bracket_optimizer( # noqa: C901, PLR0912, PLR0915
pipeline_space: SearchSpace,
*,
bracket_type: Literal["successive_halving", "hyperband", "asha", "async_hb"],
eta: int,
sampler: Literal["uniform", "prior", "priorband"] | PriorBandArgs | Sampler,
bayesian_optimization: int | float | None,
sample_prior_first: bool | Literal["highest_fidelity"],
# NOTE: This is the only argument to get a default, since it
# is not required for hyperband style algorithms, only single bracket
# style ones.
early_stopping_rate: int | None = None,
early_stopping_rate: int | None,
device: torch.device | None,
) -> BracketOptimizer:
"""Initialise a bracket optimizer.
Expand All @@ -149,6 +151,19 @@ def _bracket_optimizer( # noqa: C901
This is only used for Successive Halving and Asha. If set
to not `None`, then the bracket type must be one of those.
bayesian_optimization:
* If `None`, no bayesian optimization is used at any point.
* If a number `N`, after `N` * `maximum_fidelity` worth of fidelity
has been evaluated, proceed with bayesian optimization when sampling
a new configuration.
!!! example
If `maximum_fidelity` is 100, and `bayesian_optimization` is `10`.
We will keep using the underlying bracket algorithm until the
threshold of `sum(config.fidelity >= 100 * 10)`, at which point we
will switch to using bayesian optimization.
sampler: The type of sampling procedure to use:
* If "uniform", samples uniformly from the space when it needs to sample
Expand All @@ -162,6 +177,7 @@ def _bracket_optimizer( # noqa: C901
* If a `Sampler` object, samples from the space using the sampler.
sample_prior_first: Whether to sample the prior configuration first.
device: If using Bayesian Optimization, the device to use for the optimization.
"""
assert pipeline_space.fidelity is not None
fidelity_name, fidelity = pipeline_space.fidelity
Expand Down Expand Up @@ -243,8 +259,36 @@ def _bracket_optimizer( # noqa: C901
case _:
raise ValueError(f"Unknown sampler: {sampler}")

# TODO: This should be lifted out of this function and have the caller
# pass in a `GPSampler`.
# TODO: Better name and parametrization of this if not going with above
gp_sampler: GPSampler | None = None
if bayesian_optimization is not None:
if bayesian_optimization <= 0:
raise ValueError("bayesian_optimization should be greater than 0")

# TODO: Parametrize?
modelling_strategy = "joint"
two_stage_batch_sample_size = 10
fidelity_max = fidelity.upper
threshold = bayesian_optimization

# Notably we include the fidelity into what we model here.
bo_encoder = ConfigEncoder.from_space(pipeline_space, include_fidelity=True)
gp_sampler = GPSampler(
space=pipeline_space,
encoder=bo_encoder,
threshold=threshold,
fidelity_name=fidelity_name,
fidelity_max=fidelity_max,
modelling_strategy=modelling_strategy,
two_stage_batch_sample_size=two_stage_batch_sample_size,
device=device,
)

return BracketOptimizer(
pipeline_space=pipeline_space,
gp_sampler=gp_sampler,
encoder=encoder,
eta=eta,
rung_to_fid=rung_to_fidelity,
Expand Down Expand Up @@ -491,6 +535,9 @@ def successive_halving(
early_stopping_rate=early_stopping_rate,
sampler=sampler,
sample_prior_first=sample_prior_first,
# TODO: Implement this
bayesian_optimization=None,
device=None,
)


Expand Down Expand Up @@ -551,6 +598,10 @@ def hyperband(
eta=eta,
sampler=sampler,
sample_prior_first=sample_prior_first,
early_stopping_rate=None,
# TODO: Implement this
bayesian_optimization=None,
device=None,
)


Expand Down Expand Up @@ -611,6 +662,9 @@ def asha(
early_stopping_rate=early_stopping_rate,
sampler=sampler,
sample_prior_first=sample_prior_first,
# TODO: Implement this
bayesian_optimization=None,
device=None,
)


Expand Down Expand Up @@ -667,6 +721,10 @@ def async_hb(
eta=eta,
sampler=sampler,
sample_prior_first=sample_prior_first,
early_stopping_rate=None,
# TODO: Implement this
bayesian_optimization=None,
device=None,
)


Expand Down Expand Up @@ -717,6 +775,10 @@ def priorband(
eta=eta,
sampler="priorband",
sample_prior_first=sample_prior_first,
# TODO: Implement this
early_stopping_rate=0 if base in ("successive_halving", "asha") else None,
bayesian_optimization=None,
device=None,
)


Expand Down
Loading

0 comments on commit 33bfb9b

Please sign in to comment.