Skip to content

Commit

Permalink
OWTreeLearner: report error instead of crashing when can't binarize
Browse files Browse the repository at this point in the history
  • Loading branch information
janezd committed Dec 23, 2016
1 parent 8bc04e1 commit dba233e
Show file tree
Hide file tree
Showing 2 changed files with 79 additions and 0 deletions.
25 changes: 25 additions & 0 deletions Orange/widgets/classify/owclassificationtree.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,10 @@

from Orange.data import Table
from Orange.modelling.tree import TreeLearner
from Orange.classification.tree import TreeLearner as ClassificationTreeLearner
from Orange.widgets.model.owtree import OWTreeLearner
from Orange.widgets.utils.owlearnerwidget import OWBaseLearner
from Orange.widgets.widget import Msg


class OWTreeLearner(OWTreeLearner):
Expand All @@ -21,6 +23,29 @@ class OWTreeLearner(OWTreeLearner):
"limit_majority", "sufficient_majority", 51, 100),) + \
OWTreeLearner.spin_boxes[-1:]

class Error(OWTreeLearner.Error):
cannot_binarize = Msg("Binarization cannot handle '{}'\n"
"because it has {} values. "
"Binarization can handle up to {}.\n"
"Disable 'Induce binary tree' to proceed.")

def check_data(self):
self.Error.cannot_binarize.clear()
if not super().check_data():
return False
if not self.binary_trees:
return True
max_values, max_attr = max(
((len(attr.values), attr)
for attr in self.data.domain.attributes if attr.is_discrete),
default=(0, None))
MAX_BINARIZATION = ClassificationTreeLearner.MAX_BINARIZATION
if max_values > MAX_BINARIZATION:
self.Error.cannot_binarize(
max_attr.name, max_values, MAX_BINARIZATION)
return False
return True

def learner_kwargs(self):
opts = super().learner_kwargs()
opts['sufficient_majority'] = \
Expand Down
54 changes: 54 additions & 0 deletions Orange/widgets/classify/tests/test_owclassificationtree.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,19 @@
# Test methods with long descriptive names can omit docstrings
# pylint: disable=missing-docstring
from Orange.data import Table, Domain, DiscreteVariable
from Orange.classification.tree import TreeLearner as ClassificationTreeLearner
from Orange.base import Model
from Orange.widgets.classify.owclassificationtree import OWTreeLearner
from Orange.widgets.tests.base import (WidgetTest, DefaultParameterMapping,
ParameterMapping, WidgetLearnerTestMixin)


class TestOWClassificationTree(WidgetTest, WidgetLearnerTestMixin):
@classmethod
def setUpClass(cls):
super().setUpClass()
cls.iris = Table("iris")

def setUp(self):
self.widget = self.create_widget(
OWTreeLearner, stored_settings={"auto_apply": False})
Expand Down Expand Up @@ -34,3 +41,50 @@ def test_parameters_unchecked(self):
for par, val in zip(self.parameters, (None, 2, 1))]
self.test_parameters()

def test_cannot_binarize(self):
widget = self.widget
error_shown = widget.Error.cannot_binarize.is_shown
self.assertFalse(error_shown())
self.send_signal("Data", self.iris)

# The widget outputs ClassificationTreeLearner.
# If not, below tests may not make sense
learner = self.get_output("Learner")
dlearner = learner.get_learner(learner.CLASSIFICATION)
self.assertTrue(dlearner, ClassificationTreeLearner)

# No error on Iris
max_binarization = dlearner.MAX_BINARIZATION
self.assertFalse(error_shown())

# Error when too many values
domain = Domain([
DiscreteVariable(
values=[str(x) for x in range(max_binarization + 1)])],
DiscreteVariable(values="01"))
self.send_signal("Data", Table(domain, [[0, 0], [1, 1]]))
self.assertTrue(error_shown())
# No more error on Iris
self.send_signal("Data", self.iris)
self.assertFalse(error_shown())

# Checking and unchecking binarization works
widget.controls.binary_trees.click()
self.assertFalse(widget.binary_trees)
widget.unconditional_apply()
self.send_signal("Data", Table(domain, [[0, 0], [1, 1]]))
self.assertFalse(error_shown())
widget.controls.binary_trees.click()
widget.unconditional_apply()
self.assertTrue(error_shown())
widget.controls.binary_trees.click()
widget.unconditional_apply()
self.assertFalse(error_shown())

# If something is wrong with the data, no error appears
domain = Domain([
DiscreteVariable(
values=[str(x) for x in range(max_binarization + 1)])],
DiscreteVariable(values="01"))
self.send_signal("Data", Table(domain))
self.assertFalse(error_shown())

0 comments on commit dba233e

Please sign in to comment.