Skip to content

Commit

Permalink
Merge pull request #5972 from markotoplak/fix-prediction-classless
Browse files Browse the repository at this point in the history
[FIX] Predictions: allow predicting probabilities for classless data
  • Loading branch information
janezd authored May 17, 2022
2 parents 369f2c3 + f86ed47 commit 1b044ca
Show file tree
Hide file tree
Showing 2 changed files with 58 additions and 6 deletions.
11 changes: 8 additions & 3 deletions Orange/widgets/evaluate/owpredictions.py
Original file line number Diff line number Diff line change
Expand Up @@ -215,8 +215,7 @@ def set_data(self, data):
self.predictionsview))

self._set_target_combos()
if self.is_discrete_class:
self.openContext(self.class_var.values)
self.openContext(self.class_var.values if self.is_discrete_class else ())
self._invalidate_predictions()

def _store_selection(self):
Expand Down Expand Up @@ -267,10 +266,16 @@ def _set_target_combos(self):
self.target_class = self.TARGET_AVERAGE
else:
self.shown_probs = self.NO_PROBS
model = prob_combo.model()
for v in (self.DATA_PROBS, self.BOTH_PROBS):
item = model.item(v)
item.setFlags(item.flags() & ~Qt.ItemIsEnabled)

def _update_control_visibility(self):
for widget in self._prob_controls:
widget.setVisible(self.is_discrete_class)
widget.setVisible(self.is_discrete_class
or any(slot.predictor.domain.has_discrete_class
for slot in self.predictors))

for widget in self._target_controls:
widget.setVisible(self.is_discrete_class and self.show_scores)
Expand Down
53 changes: 50 additions & 3 deletions Orange/widgets/evaluate/tests/test_owpredictions.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ class TestOWPredictions(WidgetTest):
def setUp(self):
self.widget = self.create_widget(OWPredictions) # type: OWPredictions
self.iris = Table("iris")
self.iris_classless = self.iris.transform(Domain(self.iris.domain.attributes, []))
self.housing = Table("housing")

def test_rowCount_from_model(self):
Expand Down Expand Up @@ -624,12 +625,12 @@ def test_missing_target_reg(self):
def _mock_predictors(self):
def pred(values):
slot = Mock()
slot.predictor.domain.class_var = DiscreteVariable("c", tuple(values))
slot.predictor.domain = Domain([], DiscreteVariable("c", tuple(values)))
return slot

def predc():
slot = Mock()
slot.predictor.domain.class_var = ContinuousVariable("c")
slot.predictor.domain = Domain([], ContinuousVariable("c"))
return slot

widget = self.widget
Expand Down Expand Up @@ -746,16 +747,32 @@ def test_update_delegates_continuous(self):

widget.data = Table.from_list(Domain([], ContinuousVariable("c")), [])

# only regression
all_predictors = widget.predictors
widget.predictors = [widget.predictors[-1]]
widget._update_control_visibility()
self.assertTrue(widget.controls.shown_probs.isHidden())
self.assertTrue(widget.controls.target_class.isHidden())

# regression and classification
widget.predictors = all_predictors
widget._update_control_visibility()
self.assertFalse(widget.controls.shown_probs.isHidden())
self.assertTrue(widget.controls.target_class.isHidden())

widget._set_class_values()
self.assertEqual(widget.class_values, list("abcde"))

widget._set_target_combos()
self.assertEqual(widget.shown_probs, widget.NO_PROBS)

def is_enabled(prob_item):
return widget.controls.shown_probs.model().item(prob_item).flags() & Qt.ItemIsEnabled
self.assertTrue(is_enabled(widget.NO_PROBS))
self.assertTrue(is_enabled(widget.MODEL_PROBS))
self.assertFalse(is_enabled(widget.DATA_PROBS))
self.assertFalse(is_enabled(widget.BOTH_PROBS))

def test_delegate_ranges(self):
widget = self.widget

Expand Down Expand Up @@ -816,7 +833,6 @@ def predict(self, X):
delegate = widget.predictionsview.itemDelegateForColumn(2)
self.assertIsInstance(delegate, ClassificationItemDelegate)


class _Scorer(TargetScore):
# pylint: disable=arguments-differ
def compute_score(self, _, target, **__):
Expand Down Expand Up @@ -940,6 +956,37 @@ def test_output_regression(self):
out.metas,
np.hstack([pred.results.predicted.T for pred in widget.predictors]))

def test_classless(self):
widget = self.widget
iris012 = self.iris
purge = Remove(class_flags=Remove.RemoveUnusedValues)
iris01 = purge(iris012[:100])
iris12 = purge(iris012[50:])

bayes01 = NaiveBayesLearner()(iris01)
bayes12 = NaiveBayesLearner()(iris12)
bayes012 = NaiveBayesLearner()(iris012)

self.send_signal(widget.Inputs.data, self.iris_classless)
self.send_signal(widget.Inputs.predictors, bayes01, 0)
self.send_signal(widget.Inputs.predictors, bayes12, 1)
self.send_signal(widget.Inputs.predictors, bayes012, 2)

for i, pred in enumerate(widget.predictors):
p = pred.results.unmapped_probabilities
p[0] = 10 + 100 * i + np.arange(p.shape[1])
pred.results.unmapped_predicted[:] = i

widget.shown_probs = widget.NO_PROBS
widget._commit_predictions()
out = self.get_output(widget.Outputs.predictions)
self.assertEqual(list(out.metas[0]), [0, 1, 2])

widget.shown_probs = widget.MODEL_PROBS
widget._commit_predictions()
out = self.get_output(widget.Outputs.predictions)
self.assertEqual(list(out.metas[0]), [0, 10, 11, 1, 110, 111, 2, 210, 211, 212])

@patch("Orange.widgets.evaluate.owpredictions.usable_scorers",
Mock(return_value=[_Scorer]))
def test_change_target(self):
Expand Down

0 comments on commit 1b044ca

Please sign in to comment.