Skip to content

Commit

Permalink
Return any modeling errors encountered in Scheduler.generate_candidat…
Browse files Browse the repository at this point in the history
…es (#2967)

Summary: Pull Request resolved: #2967

Reviewed By: ItsMrLin

Differential Revision: D64981673

fbshipit-source-id: ac851a876d296cf241fed3431ee0c3f10a071bd0
  • Loading branch information
Daniel Cohen authored and facebook-github-bot committed Oct 25, 2024
1 parent 5366cde commit 796637f
Showing 1 changed file with 20 additions and 12 deletions.
32 changes: 20 additions & 12 deletions ax/service/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -492,7 +492,7 @@ def generate_candidates(
self,
num_trials: int = 1,
reduce_state_generator_runs: bool = False,
) -> list[BaseTrial]:
) -> tuple[list[BaseTrial], Exception | None]:
"""Fetch the latest data and generate new candidate trials.
Args:
Expand All @@ -505,7 +505,7 @@ def generate_candidates(
Returns:
List of trials, empty if generation is not possible.
"""
new_trials = self._get_next_trials(
new_trials, err = self._get_next_trials(
num_trials=num_trials,
n=self.options.batch_size,
)
Expand All @@ -518,7 +518,7 @@ def generate_candidates(
new_generator_runs=new_generator_runs,
reduce_state_generator_runs=reduce_state_generator_runs,
)
return new_trials
return new_trials, err

def run_n_trials(
self,
Expand Down Expand Up @@ -1776,16 +1776,19 @@ def _prepare_trials(

existing_candidate_trials = self.candidate_trials[:n]
n_new = min(n - len(existing_candidate_trials), max_new_trials)
new_trials = (
new_trials, _err = (
self._get_next_trials(num_trials=n_new, n=self.options.batch_size)
if n_new > 0
else []
else (
[],
None,
)
)
return existing_candidate_trials, new_trials

def _get_next_trials(
self, num_trials: int = 1, n: int | None = None
) -> list[BaseTrial]:
) -> tuple[list[BaseTrial], Exception | None]:
"""Produce up to `num_trials` new generator runs from the underlying
generation strategy and create new trials with them. Logs errors
encountered during generation.
Expand All @@ -1805,7 +1808,7 @@ def _get_next_trials(
self.logger.info(completion_str)
self.markdown_messages["Optimization complete"] = completion_str
self._optimization_complete = True
return []
return [], err
except DataRequiredError as err:
# TODO[T62606107]: consider adding a `more_data_required` property to
# check to generation strategy to avoid running into this exception.
Expand All @@ -1815,7 +1818,7 @@ def _get_next_trials(
"Model requires more data to generate more trials."
)
self.logger.debug(f"Message from generation strategy: {err}")
return []
return [], err
except MaxParallelismReachedException as err:
# TODO[T62606107]: consider adding a `step_max_parallelism_reached`
# check to generation strategy to avoid running into this exception.
Expand All @@ -1825,7 +1828,7 @@ def _get_next_trials(
"Max parallelism currently reached."
)
self.logger.debug(f"Message from generation strategy: {err}")
return []
return [], err
except AxGenerationException as err:
if self._log_next_no_trials_reason:
self.logger.info(
Expand All @@ -1834,7 +1837,7 @@ def _get_next_trials(
f"{err}."
)
self.logger.debug(f"Message from generation strategy: {err}")
return []
return [], err
except OptimizationConfigRequired as err:
if self._log_next_no_trials_reason:
self.logger.info(
Expand All @@ -1843,7 +1846,12 @@ def _get_next_trials(
"to be set before generating more trials."
)
self.logger.debug(f"Message from generation strategy: {err}")
return []
return [], err
except Exception as err:
self.logger.exception(
f"An unexpected error occurred while generating trials. {err}"
)
return [], err

if self.options.trial_type == TrialType.TRIAL and any(
len(generator_run_list[0].arms) > 1 or len(generator_run_list) > 1
Expand Down Expand Up @@ -1873,7 +1881,7 @@ def _get_next_trials(
)

trials.append(trial)
return trials
return trials, None

def _choose_analyses(self) -> list[Analysis]:
"""
Expand Down

0 comments on commit 796637f

Please sign in to comment.