From 3d7fbffbc6cccfd349d7635433cfa136e2d41498 Mon Sep 17 00:00:00 2001 From: Daniel Weindl Date: Tue, 11 Jul 2023 09:09:17 +0200 Subject: [PATCH] Add startpoint_method to Problem It's not intuitive that `x_guesses` is part of `Problem`, but `startpoint_method`s are handled separately. See also discussion in #1017. This patch * adds `startpoint_method` to `Problem` * during PEtab import, sets `Problem.startpoint_method` based on the PEtab problem To be discussed: Do we want to keep the existing `startpoint_method` argument in the long term or should it be removed (deprecated)? --- pypesto/optimize/ess/cess.py | 10 +++++++++- pypesto/optimize/ess/ess.py | 18 +++++++++++------- pypesto/optimize/ess/function_evaluator.py | 20 ++++++++++++++++---- pypesto/optimize/ess/sacess.py | 9 ++++++++- pypesto/optimize/optimize.py | 13 ++++++++++++- pypesto/petab/importer.py | 1 + pypesto/problem/base.py | 21 ++++++++++++++++++++- pypesto/startpoint/base.py | 13 ++++++++----- 8 files changed, 85 insertions(+), 20 deletions(-) diff --git a/pypesto/optimize/ess/cess.py b/pypesto/optimize/ess/cess.py index 1d5ef9798..225d8bf6f 100644 --- a/pypesto/optimize/ess/cess.py +++ b/pypesto/optimize/ess/cess.py @@ -4,6 +4,7 @@ import os import time from typing import Dict, List, Optional +from warnings import warn import numpy as np @@ -107,7 +108,7 @@ def _initialize(self): def minimize( self, problem: Problem, - startpoint_method: StartpointMethod, + startpoint_method: StartpointMethod = None, ) -> pypesto.Result: """Minimize the given objective using CESS. @@ -117,7 +118,14 @@ def minimize( Problem to run ESS on. startpoint_method: Method for choosing starting points. + **Deprecated. Use ``problem.startpoint_method`` instead.** """ + if startpoint_method is not None: + warn( + "Passing `startpoint_method` directly is deprecated, use `problem.startpoint_method` instead.", + DeprecationWarning, + ) + self._initialize() evaluator = FunctionEvaluator( diff --git a/pypesto/optimize/ess/ess.py b/pypesto/optimize/ess/ess.py index caa7bfb4d..130ad2a9d 100644 --- a/pypesto/optimize/ess/ess.py +++ b/pypesto/optimize/ess/ess.py @@ -28,6 +28,7 @@ import logging import time from typing import List, Optional, Tuple +from warnings import warn import numpy as np @@ -183,21 +184,24 @@ def minimize( Problem to run ESS on. startpoint_method: Method for choosing starting points. + **Deprecated. Use ``problem.startpoint_method`` instead.** refset: The initial RefSet or ``None`` to auto-generate. """ + if startpoint_method is not None: + warn( + "Passing `startpoint_method` directly is deprecated, use `problem.startpoint_method` instead.", + DeprecationWarning, + ) + self._initialize() self.starttime = time.time() - if ( - refset is None and (problem is None or startpoint_method is None) - ) or ( - refset is not None - and (problem is not None or startpoint_method is not None) + if (refset is None and problem is None) or ( + refset is not None and problem is not None ): raise ValueError( - "Either `refset` or `problem` and `startpoint_method` " - "has to be provided." + "Either `refset` or `problem` has to be provided." ) # generate initial RefSet if not provided if refset is None: diff --git a/pypesto/optimize/ess/function_evaluator.py b/pypesto/optimize/ess/function_evaluator.py index ebb7191a5..fe6412a3e 100644 --- a/pypesto/optimize/ess/function_evaluator.py +++ b/pypesto/optimize/ess/function_evaluator.py @@ -5,6 +5,7 @@ from concurrent.futures import ThreadPoolExecutor from copy import deepcopy from typing import Optional, Sequence, Tuple +from warnings import warn import numpy as np @@ -31,17 +32,28 @@ class FunctionEvaluator: def __init__( self, problem: Problem, - startpoint_method: StartpointMethod, + startpoint_method: StartpointMethod = None, ): """Construct. Parameters ---------- - problem: The problem - startpoint_method: Method for choosing feasible parameters + problem: + The problem + startpoint_method: + Method for choosing feasible parameters + **Deprecated. Use ``problem.startpoint_method`` instead.** """ + if startpoint_method is not None: + warn( + "Passing `startpoint_method` directly is deprecated, use `problem.startpoint_method` instead.", + DeprecationWarning, + ) + self.problem: Problem = problem - self.startpoint_method: StartpointMethod = startpoint_method + self.startpoint_method: StartpointMethod = ( + startpoint_method or problem.startpoint_method + ) self.n_eval: int = 0 self.n_eval_round: int = 0 diff --git a/pypesto/optimize/ess/sacess.py b/pypesto/optimize/ess/sacess.py index b477e2c5f..773f5b055 100644 --- a/pypesto/optimize/ess/sacess.py +++ b/pypesto/optimize/ess/sacess.py @@ -6,6 +6,7 @@ from multiprocessing import Manager, Process from multiprocessing.managers import SyncManager from typing import Any, Dict, List, Optional, Tuple +from warnings import warn import numpy as np @@ -96,9 +97,15 @@ def __init__( def minimize( self, problem: Problem, - startpoint_method: StartpointMethod, + startpoint_method: StartpointMethod = None, ): """Solve the given optimization problem.""" + if startpoint_method is not None: + warn( + "Passing `startpoint_method` directly is deprecated, use `problem.startpoint_method` instead.", + DeprecationWarning, + ) + start_time = time.time() logger.debug( f"Running sacess with {self.num_workers} " diff --git a/pypesto/optimize/optimize.py b/pypesto/optimize/optimize.py index 70456e6a9..eb0a74e14 100644 --- a/pypesto/optimize/optimize.py +++ b/pypesto/optimize/optimize.py @@ -1,5 +1,6 @@ import logging from typing import Callable, Iterable, Union +from warnings import warn from ..engine import Engine, SingleCoreEngine from ..history import HistoryOptions @@ -50,6 +51,7 @@ def minimize( startpoint_method: Method for how to choose start points. False means the optimizer does not require start points, e.g. for the 'PyswarmOptimizer'. + **Deprecated. Use ``problem.startpoint_method`` instead.** result: A result object to append the optimization results to. For example, one might append more runs to a previous optimization. If None, @@ -88,7 +90,16 @@ def minimize( # startpoint method if startpoint_method is None: - startpoint_method = uniform + if problem.startpoint_method is None: + startpoint_method = uniform + else: + startpoint_method = problem.startpoint_method + else: + warn( + "Passing `startpoint_method` directly is deprecated, use `problem.startpoint_method` instead.", + DeprecationWarning, + ) + # convert startpoint method to class instance startpoint_method = to_startpoint_method(startpoint_method) diff --git a/pypesto/petab/importer.py b/pypesto/petab/importer.py index 9c496318e..cdc6b19cf 100644 --- a/pypesto/petab/importer.py +++ b/pypesto/petab/importer.py @@ -740,6 +740,7 @@ def create_problem( x_names=x_ids, x_scales=x_scales, x_priors_defs=prior, + startpoint_method=self.create_startpoint_method(), **problem_kwargs, ) diff --git a/pypesto/problem/base.py b/pypesto/problem/base.py index ae3e2004a..ab802eee8 100644 --- a/pypesto/problem/base.py +++ b/pypesto/problem/base.py @@ -1,12 +1,21 @@ import copy import logging -from typing import Iterable, List, Optional, SupportsFloat, SupportsInt, Union +from typing import ( + Callable, + Iterable, + List, + Optional, + SupportsFloat, + SupportsInt, + Union, +) import numpy as np import pandas as pd from ..objective import ObjectiveBase from ..objective.priors import NegLogParameterPriors +from ..startpoint import StartpointMethod, to_startpoint_method, uniform SupportsFloatIterableOrValue = Union[Iterable[SupportsFloat], SupportsFloat] SupportsIntIterableOrValue = Union[Iterable[SupportsInt], SupportsInt] @@ -60,6 +69,9 @@ class Problem: copy_objective: Whethter to generate a deep copy of the objective function before potential modification the problem class performs on it. + startpoint_method: + Method for how to choose start points. ``False`` means the optimizer + does not require start points, e.g. for the ``PyswarmOptimizer``. Notes ----- @@ -90,6 +102,7 @@ def __init__( lb_init: Union[np.ndarray, List[float], None] = None, ub_init: Union[np.ndarray, List[float], None] = None, copy_objective: bool = True, + startpoint_method: Union[StartpointMethod, Callable, bool] = None, ): if copy_objective: objective = copy.deepcopy(objective) @@ -147,6 +160,12 @@ def __init__( self.normalize() self._check_x_guesses() + # startpoint method + if startpoint_method is None: + startpoint_method = uniform + # convert startpoint method to class instance + self.startpoint_method = to_startpoint_method(startpoint_method) + @property def lb(self) -> np.ndarray: """Return lower bounds of free parameters.""" diff --git a/pypesto/startpoint/base.py b/pypesto/startpoint/base.py index 038232915..d8ecde27f 100644 --- a/pypesto/startpoint/base.py +++ b/pypesto/startpoint/base.py @@ -1,13 +1,16 @@ """Startpoint base classes.""" +from __future__ import annotations from abc import ABC, abstractmethod -from typing import Callable, Union +from typing import TYPE_CHECKING, Callable, Union import numpy as np from ..C import FVAL, GRAD from ..objective import ObjectiveBase -from ..problem import Problem + +if TYPE_CHECKING: + import pypesto class StartpointMethod(ABC): @@ -21,7 +24,7 @@ class StartpointMethod(ABC): def __call__( self, n_starts: int, - problem: Problem, + problem: pypesto.problem.Problem, ) -> np.ndarray: """Generate startpoints. @@ -42,7 +45,7 @@ class NoStartpoints(StartpointMethod): def __call__( self, n_starts: int, - problem: Problem, + problem: pypesto.problem.Problem, ) -> np.ndarray: """Generate a (n_starts, dim) nan matrix.""" startpoints = np.full(shape=(n_starts, problem.dim), fill_value=np.nan) @@ -78,7 +81,7 @@ def __init__( def __call__( self, n_starts: int, - problem: Problem, + problem: pypesto.problem.Problem, ) -> np.ndarray: """Generate checked startpoints.""" # shape: (n_guesses, dim)