Skip to content

Commit

Permalink
Merge pull request #3323 from janezd/test-score-transformation
Browse files Browse the repository at this point in the history
Test and Score: Warn about transformation, raise error if all is nan
  • Loading branch information
markotoplak authored Nov 19, 2018
2 parents dda3505 + 674ebff commit 8784dbb
Show file tree
Hide file tree
Showing 6 changed files with 61 additions and 7 deletions.
8 changes: 8 additions & 0 deletions Orange/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@

from Orange.data import Table, Storage, Instance, Value
from Orange.data.filter import HasClass
from Orange.data.table import DomainTransformationError
from Orange.data.util import one_hot
from Orange.misc.wrapper_meta import WrapperMeta
from Orange.preprocess import Continuize, RemoveNaNColumns, SklImpute, Normalize
Expand Down Expand Up @@ -246,6 +247,13 @@ def __call__(self, data, ret=Value):
if isinstance(data, Instance):
data = Table(data.domain, [data])
if data.domain != self.domain:
if self.original_domain.attributes != data.domain.attributes \
and data.X.size \
and not np.isnan(data.X).all():
data = data.transform(self.original_domain)
if np.isnan(data.X).all():
raise DomainTransformationError(
"domain transformation produced no defined values")
data = data.transform(self.domain)
prediction = self.predict_storage(data)
elif isinstance(data, (list, tuple)):
Expand Down
4 changes: 4 additions & 0 deletions Orange/data/table.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,10 @@ def get_sample_datasets_dir():
_conversion_cache_lock = RLock()


class DomainTransformationError(Exception):
pass


class RowInstance(Instance):
sparse_x = None
sparse_y = None
Expand Down
8 changes: 8 additions & 0 deletions Orange/tests/test_classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
from Orange.classification.rules import _RuleLearner
from Orange.data import (ContinuousVariable, DiscreteVariable,
Domain, Table, Variable)
from Orange.data.table import DomainTransformationError
from Orange.evaluation import CrossValidation
from Orange.tests.dummy_learners import DummyLearner, DummyMulticlassLearner
from Orange.tests import test_filename
Expand Down Expand Up @@ -145,6 +146,13 @@ def test_probs_from_value(self):
self.assertEqual(y2.shape, y.shape)
self.assertEqual(probs.shape, (nrows, 2, 4))

def test_incompatible_domain(self):
iris = Table("iris")
titanic = Table("titanic")
clf = DummyLearner()(iris)
with self.assertRaises(DomainTransformationError):
clf(titanic)


class ExpandProbabilitiesTest(unittest.TestCase):
def prepareTable(self, rows, attr, vars, class_var_domain):
Expand Down
3 changes: 2 additions & 1 deletion Orange/widgets/evaluate/owpredictions.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@

from Orange.base import Model
from Orange.data import ContinuousVariable, DiscreteVariable, Value
from Orange.data.table import DomainTransformationError
from Orange.widgets import gui, settings
from Orange.widgets.widget import OWWidget, Msg, Input, Output
from Orange.widgets.utils.itemmodels import TableModel
Expand Down Expand Up @@ -265,7 +266,7 @@ def _call_predictors(self):
or numpy.isnan(pred.results[0]).all():
try:
results = self.predict(pred.predictor, self.data)
except ValueError as err:
except (ValueError, DomainTransformationError) as err:
results = "{}: {}".format(pred.predictor.name, err)
self.predictors[inputid] = pred._replace(results=results)

Expand Down
29 changes: 23 additions & 6 deletions Orange/widgets/evaluate/owtestlearners.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
from Orange.base import Learner
import Orange.classification
from Orange.data import Table, DiscreteVariable, ContinuousVariable
from Orange.data.table import DomainTransformationError
from Orange.data.filter import HasClass
from Orange.data.sql.table import SqlTable, AUTO_DL_LIMIT
import Orange.evaluation
Expand Down Expand Up @@ -209,6 +210,8 @@ class Error(OWWidget.Error):
memory_error = Msg("Not enough memory.")
no_class_values = Msg("Target variable has no values.")
only_one_class_var_value = Msg("Target variable has only one value.")
test_data_incompatible = Msg(
"Test data may be incompatible with train data.")

class Warning(OWWidget.Warning):
missing_data = \
Expand All @@ -221,6 +224,8 @@ class Warning(OWWidget.Warning):
class Information(OWWidget.Information):
data_sampled = Msg("Train data has been sampled")
test_data_sampled = Msg("Test data has been sampled")
test_data_transformed = Msg(
"Test data has been transformed to match the train data.")

def __init__(self):
super().__init__()
Expand Down Expand Up @@ -551,14 +556,20 @@ def _update_stats_model(self):
name = learner_name(slot.learner)
head = QStandardItem(name)
head.setData(key, Qt.UserRole)
if isinstance(slot.results, Try.Fail):
head.setToolTip(str(slot.results.exception))
results = slot.results
if isinstance(results, Try.Fail):
head.setToolTip(str(results.exception))
head.setText("{} (error)".format(name))
head.setForeground(QtGui.QBrush(Qt.red))
errors.append("{name} failed with error:\n"
"{exc.__class__.__name__}: {exc!s}"
.format(name=name, exc=slot.results.exception))

if isinstance(results.exception, DomainTransformationError) \
and self.resampling == self.TestOnTest:
self.Error.test_data_incompatible()
self.Information.test_data_transformed.clear()
else:
errors.append("{name} failed with error:\n"
"{exc.__class__.__name__}: {exc!s}"
.format(name=name, exc=slot.results.exception)
)
row = [head]

if class_var is not None and class_var.is_discrete and \
Expand Down Expand Up @@ -744,7 +755,13 @@ def __update(self):
self.cancel()

self.Warning.test_data_unused.clear()
self.Error.test_data_incompatible.clear()
self.Warning.test_data_missing.clear()
self.Information.test_data_transformed(
shown=self.resampling == self.TestOnTest
and self.data is not None
and self.test_data is not None
and self.data.domain.attributes != self.test_data.domain.attributes)
self.warning()
self.Error.class_inconsistent.clear()
self.Error.too_many_folds.clear()
Expand Down
16 changes: 16 additions & 0 deletions Orange/widgets/evaluate/tests/test_owtestlearners.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,22 @@ def test_testOnTest(self):
self.widget.resampling = OWTestLearners.TestOnTest
self.send_signal(self.widget.Inputs.test_data, data)

def test_testOnTest_incompatible_domain(self):
iris = Table("iris")
self.send_signal(self.widget.Inputs.train_data, iris)
self.send_signal(self.widget.Inputs.learner, LogisticRegressionLearner(), 0)
self.get_output(self.widget.Outputs.evaluations_results, wait=5000)
self.assertFalse(self.widget.Error.test_data_incompatible.is_shown())
self.widget.resampling = OWTestLearners.TestOnTest
# test data with the same class (otherwise the widget shows a different error)
# and a non-nan X
iris_test = iris.transform(Domain([ContinuousVariable()],
class_vars=iris.domain.class_vars))
iris_test.X[:, 0] = 1
self.send_signal(self.widget.Inputs.test_data, iris_test)
self.get_output(self.widget.Outputs.evaluations_results, wait=5000)
self.assertTrue(self.widget.Error.test_data_incompatible.is_shown())

def test_CrossValidationByFeature(self):
data = Table("iris")
attrs = data.domain.attributes
Expand Down

0 comments on commit 8784dbb

Please sign in to comment.