Skip to content

Commit

Permalink
Add test for accept_greedy
Browse files Browse the repository at this point in the history
  • Loading branch information
timmens committed Jan 21, 2024
1 parent f63d502 commit cdbd65e
Showing 1 changed file with 29 additions and 12 deletions.
41 changes: 29 additions & 12 deletions tests/test_acceptance_decision.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
from collections import namedtuple

from functools import partial
import numpy as np
import pytest
from tranquilo.sample_points import get_sampler
from tranquilo.acceptance_decision import (
_accept_greedy,
accept_greedy,
_accept_simple,
_get_acceptance_result,
calculate_rho,
Expand Down Expand Up @@ -93,30 +94,46 @@ def test_accept_greedy(
state,
subproblem_solution,
):
"""Test accept greedy.
Tests that the best point is chosen in the acceptance step, even though it is added
to the history before the acceptance step.
"""
history = History(functype="scalar")

idxs = history.add_xs(np.arange(10).reshape(5, 2))
def criterion(x):
return np.sum(x**2)

history.add_evals(idxs.repeat(2), np.arange(10))
def _wrapped_criterion(eval_info, history):
for x_index, _ in eval_info.items():
xs = history.get_xs(x_index)
crit_value = criterion(xs)
history.add_evals(np.array([x_index]), crit_value)

def wrapped_criterion(eval_info):
indices = np.array(list(eval_info)).repeat(np.array(list(eval_info.values())))
history.add_evals(indices, -indices)
wrapped_criterion = partial(_wrapped_criterion, history=history)

res_got = _accept_greedy(
# Add existing xs to history and evaluate wrapped criterion
existing_xs = np.zeros((1, 2))
existing_xs_indices = history.add_xs(existing_xs)

eval_info = {x_index: 1 for x_index in existing_xs_indices}
wrapped_criterion(eval_info)

res_got = accept_greedy(
subproblem_solution=subproblem_solution,
state=state,
history=history,
wrapped_criterion=wrapped_criterion,
min_improvement=0.0,
n_evals=2,
)

assert res_got.accepted
assert res_got.index == 5
assert res_got.candidate_index == 5
assert_array_equal(res_got.x, subproblem_solution.x)
assert_array_equal(res_got.candidate_x, 1.0 + np.arange(2))
assert res_got.index == 0
assert res_got.candidate_index == 0
assert res_got.fval == 0.0
assert_array_equal(res_got.x, np.zeros(2))
assert_array_equal(res_got.candidate_x, np.zeros(2))


# ======================================================================================
Expand Down

0 comments on commit cdbd65e

Please sign in to comment.