Skip to content

Commit

Permalink
add: random tie break for expected error reduction
Browse files Browse the repository at this point in the history
  • Loading branch information
cosmic-cortex committed Dec 5, 2018
1 parent b7815d8 commit fad91e8
Show file tree
Hide file tree
Showing 2 changed files with 10 additions and 3 deletions.
12 changes: 9 additions & 3 deletions modAL/expected_error.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,12 +11,13 @@

from modAL.models import ActiveLearner
from modAL.utils.data import modALinput, data_vstack
from modAL.utils.selection import multi_argmax
from modAL.utils.selection import multi_argmax, shuffled_argmax
from modAL.uncertainty import _proba_uncertainty, _proba_entropy


def expected_error_reduction(learner: ActiveLearner, X: modALinput, loss: str = 'binary',
p_subsample: np.float = 1.0, n_instances: int = 1) -> Tuple[np.ndarray, modALinput]:
p_subsample: np.float = 1.0, n_instances: int = 1,
random_tie_break: bool = False) -> Tuple[np.ndarray, modALinput]:
"""
Expected error reduction query strategy.
Expand All @@ -32,6 +33,8 @@ def expected_error_reduction(learner: ActiveLearner, X: modALinput, loss: str =
calculating expected error. Significantly improves runtime
for large sample pools.
n_instances: The number of instances to be sampled.
random_tie_break: If True, shuffles utility scores to randomize the order. This
can be used to break the tie when the highest utility score is not unique.
Returns:
Expand Down Expand Up @@ -73,6 +76,9 @@ def expected_error_reduction(learner: ActiveLearner, X: modALinput, loss: str =
else:
expected_error[x_idx] = np.inf

query_idx = multi_argmax(expected_error, n_instances)
if not random_tie_break:
query_idx = multi_argmax(expected_error, n_instances)
else:
query_idx = shuffled_argmax(expected_error, n_instances)

return query_idx, X[query_idx]
1 change: 1 addition & 0 deletions tests/core_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -462,6 +462,7 @@ def test_eer(self):
X_training=X_training, y_training=y_training)

modAL.expected_error.expected_error_reduction(learner, X_pool)
modAL.expected_error.expected_error_reduction(learner, X_pool, random_tie_break=True)
modAL.expected_error.expected_error_reduction(learner, X_pool, p_subsample=0.1)
modAL.expected_error.expected_error_reduction(learner, X_pool, loss='binary')
modAL.expected_error.expected_error_reduction(learner, X_pool, p_subsample=0.1, loss='log')
Expand Down

0 comments on commit fad91e8

Please sign in to comment.