Skip to content

Commit

Permalink
Remove private function _accept_greedy
Browse files Browse the repository at this point in the history
  • Loading branch information
timmens committed Jan 21, 2024
1 parent e3c1ac2 commit f63d502
Showing 1 changed file with 45 additions and 78 deletions.
123 changes: 45 additions & 78 deletions src/tranquilo/acceptance_decision.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,15 +60,52 @@ def accept_greedy(
AcceptanceResult
"""
out = _accept_greedy(
subproblem_solution=subproblem_solution,
state=state,
history=history,
wrapped_criterion=wrapped_criterion,
min_improvement=min_improvement,
n_evals=1,
candidate_x = subproblem_solution.x
candidate_index = history.add_xs(candidate_x)
wrapped_criterion({candidate_index: 1})

candidate_fval = np.mean(history.get_fvals(candidate_index))
actual_improvement = -(candidate_fval - state.fval)

rho = calculate_rho(
actual_improvement=actual_improvement,
expected_improvement=subproblem_solution.expected_improvement,
)
return out

best_x, best_fval, best_index = history.get_best()

if best_fval < candidate_fval:
candidate_x = best_x
candidate_fval = best_fval
candidate_index = best_index
overall_improvement = -(candidate_fval - state.fval)
else:
overall_improvement = actual_improvement

is_accepted = overall_improvement >= min_improvement

if np.isfinite(candidate_fval):
res = _get_acceptance_result(
candidate_x=candidate_x,
candidate_fval=candidate_fval,
candidate_index=candidate_index,
rho=rho,
is_accepted=is_accepted,
old_state=state,
n_evals=1,
)
else:
res = _get_acceptance_result(
candidate_x=state.x,
candidate_fval=state.fval,
candidate_index=state.index,
rho=-np.inf,
is_accepted=False,
old_state=state,
n_evals=1,
)

return res


def _accept_classic(
Expand Down Expand Up @@ -272,76 +309,6 @@ def accept_classic_line_search(
return res


def _accept_greedy(
subproblem_solution,
state,
history,
*,
wrapped_criterion,
min_improvement,
n_evals,
):
"""Do a simple greedy acceptance step for a trustregion algorithm.
Args:
subproblem_solution (SubproblemResult): Result of the subproblem solution.
state (State): Namedtuple containing the trustregion, criterion value of
previously accepted point, indices of model points, etc.
wrapped_criterion (callable): The criterion function.
min_improvement (float): Minimum improvement required to accept a point.
Returns:
AcceptanceResult
"""
candidate_x = subproblem_solution.x
candidate_index = history.add_xs(candidate_x)
wrapped_criterion({candidate_index: n_evals})

candidate_fval = np.mean(history.get_fvals(candidate_index))
actual_improvement = -(candidate_fval - state.fval)

rho = calculate_rho(
actual_improvement=actual_improvement,
expected_improvement=subproblem_solution.expected_improvement,
)

best_x, best_fval, best_index = history.get_best()

if best_fval < candidate_fval:
candidate_x = best_x
candidate_fval = best_fval
candidate_index = best_index
overall_improvement = -(candidate_fval - state.fval)
else:
overall_improvement = actual_improvement

is_accepted = overall_improvement >= min_improvement

if np.isfinite(candidate_fval):
res = _get_acceptance_result(
candidate_x=candidate_x,
candidate_fval=candidate_fval,
candidate_index=candidate_index,
rho=rho,
is_accepted=is_accepted,
old_state=state,
n_evals=n_evals,
)
else:
res = _get_acceptance_result(
candidate_x=state.x,
candidate_fval=state.fval,
candidate_index=state.index,
rho=-np.inf,
is_accepted=False,
old_state=state,
n_evals=n_evals,
)

return res


def _accept_simple(
subproblem_solution,
state,
Expand Down

0 comments on commit f63d502

Please sign in to comment.