diff --git a/Orange/base.py b/Orange/base.py index 32a4e712b0e..4251234cf04 100644 --- a/Orange/base.py +++ b/Orange/base.py @@ -9,6 +9,7 @@ from Orange.data import Table, Storage, Instance, Value from Orange.data.filter import HasClass +from Orange.data.table import DomainTransformationError from Orange.data.util import one_hot from Orange.misc.wrapper_meta import WrapperMeta from Orange.preprocess import Continuize, RemoveNaNColumns, SklImpute, Normalize @@ -246,6 +247,13 @@ def __call__(self, data, ret=Value): if isinstance(data, Instance): data = Table(data.domain, [data]) if data.domain != self.domain: + if self.original_domain.attributes != data.domain.attributes \ + and data.X.size \ + and not np.isnan(data.X).all(): + data = data.transform(self.original_domain) + if np.isnan(data.X).all(): + raise DomainTransformationError( + "domain transformation produced no defined values") data = data.transform(self.domain) prediction = self.predict_storage(data) elif isinstance(data, (list, tuple)): diff --git a/Orange/data/table.py b/Orange/data/table.py index c8492fbd456..65bcb62a50b 100644 --- a/Orange/data/table.py +++ b/Orange/data/table.py @@ -44,6 +44,10 @@ def get_sample_datasets_dir(): _conversion_cache_lock = RLock() +class DomainTransformationError(Exception): + pass + + class RowInstance(Instance): sparse_x = None sparse_y = None diff --git a/Orange/tests/test_classification.py b/Orange/tests/test_classification.py index de591b85c74..000639ee464 100644 --- a/Orange/tests/test_classification.py +++ b/Orange/tests/test_classification.py @@ -19,6 +19,7 @@ from Orange.classification.rules import _RuleLearner from Orange.data import (ContinuousVariable, DiscreteVariable, Domain, Table, Variable) +from Orange.data.table import DomainTransformationError from Orange.evaluation import CrossValidation from Orange.tests.dummy_learners import DummyLearner, DummyMulticlassLearner from Orange.tests import test_filename @@ -145,6 +146,13 @@ def test_probs_from_value(self): self.assertEqual(y2.shape, y.shape) self.assertEqual(probs.shape, (nrows, 2, 4)) + def test_incompatible_domain(self): + iris = Table("iris") + titanic = Table("titanic") + clf = DummyLearner()(iris) + with self.assertRaises(DomainTransformationError): + clf(titanic) + class ExpandProbabilitiesTest(unittest.TestCase): def prepareTable(self, rows, attr, vars, class_var_domain): diff --git a/Orange/widgets/evaluate/owpredictions.py b/Orange/widgets/evaluate/owpredictions.py index 82ffd579c6d..dd564aa3702 100644 --- a/Orange/widgets/evaluate/owpredictions.py +++ b/Orange/widgets/evaluate/owpredictions.py @@ -23,6 +23,7 @@ from Orange.base import Model from Orange.data import ContinuousVariable, DiscreteVariable, Value +from Orange.data.table import DomainTransformationError from Orange.widgets import gui, settings from Orange.widgets.widget import OWWidget, Msg, Input, Output from Orange.widgets.utils.itemmodels import TableModel @@ -265,7 +266,7 @@ def _call_predictors(self): or numpy.isnan(pred.results[0]).all(): try: results = self.predict(pred.predictor, self.data) - except ValueError as err: + except (ValueError, DomainTransformationError) as err: results = "{}: {}".format(pred.predictor.name, err) self.predictors[inputid] = pred._replace(results=results) diff --git a/Orange/widgets/evaluate/owtestlearners.py b/Orange/widgets/evaluate/owtestlearners.py index 16ec4c57635..76ae5accf63 100644 --- a/Orange/widgets/evaluate/owtestlearners.py +++ b/Orange/widgets/evaluate/owtestlearners.py @@ -32,6 +32,7 @@ from Orange.base import Learner import Orange.classification from Orange.data import Table, DiscreteVariable, ContinuousVariable +from Orange.data.table import DomainTransformationError from Orange.data.filter import HasClass from Orange.data.sql.table import SqlTable, AUTO_DL_LIMIT import Orange.evaluation @@ -209,6 +210,8 @@ class Error(OWWidget.Error): memory_error = Msg("Not enough memory.") no_class_values = Msg("Target variable has no values.") only_one_class_var_value = Msg("Target variable has only one value.") + test_data_incompatible = Msg( + "Test data may be incompatible with train data.") class Warning(OWWidget.Warning): missing_data = \ @@ -221,6 +224,8 @@ class Warning(OWWidget.Warning): class Information(OWWidget.Information): data_sampled = Msg("Train data has been sampled") test_data_sampled = Msg("Test data has been sampled") + test_data_transformed = Msg( + "Test data has been transformed to match the train data.") def __init__(self): super().__init__() @@ -551,14 +556,20 @@ def _update_stats_model(self): name = learner_name(slot.learner) head = QStandardItem(name) head.setData(key, Qt.UserRole) - if isinstance(slot.results, Try.Fail): - head.setToolTip(str(slot.results.exception)) + results = slot.results + if isinstance(results, Try.Fail): + head.setToolTip(str(results.exception)) head.setText("{} (error)".format(name)) head.setForeground(QtGui.QBrush(Qt.red)) - errors.append("{name} failed with error:\n" - "{exc.__class__.__name__}: {exc!s}" - .format(name=name, exc=slot.results.exception)) - + if isinstance(results.exception, DomainTransformationError) \ + and self.resampling == self.TestOnTest: + self.Error.test_data_incompatible() + self.Information.test_data_transformed.clear() + else: + errors.append("{name} failed with error:\n" + "{exc.__class__.__name__}: {exc!s}" + .format(name=name, exc=slot.results.exception) + ) row = [head] if class_var is not None and class_var.is_discrete and \ @@ -744,7 +755,13 @@ def __update(self): self.cancel() self.Warning.test_data_unused.clear() + self.Error.test_data_incompatible.clear() self.Warning.test_data_missing.clear() + self.Information.test_data_transformed( + shown=self.resampling == self.TestOnTest + and self.data is not None + and self.test_data is not None + and self.data.domain.attributes != self.test_data.domain.attributes) self.warning() self.Error.class_inconsistent.clear() self.Error.too_many_folds.clear() diff --git a/Orange/widgets/evaluate/tests/test_owtestlearners.py b/Orange/widgets/evaluate/tests/test_owtestlearners.py index 1a15a82f264..475331a1c7f 100644 --- a/Orange/widgets/evaluate/tests/test_owtestlearners.py +++ b/Orange/widgets/evaluate/tests/test_owtestlearners.py @@ -76,6 +76,22 @@ def test_testOnTest(self): self.widget.resampling = OWTestLearners.TestOnTest self.send_signal(self.widget.Inputs.test_data, data) + def test_testOnTest_incompatible_domain(self): + iris = Table("iris") + self.send_signal(self.widget.Inputs.train_data, iris) + self.send_signal(self.widget.Inputs.learner, LogisticRegressionLearner(), 0) + self.get_output(self.widget.Outputs.evaluations_results, wait=5000) + self.assertFalse(self.widget.Error.test_data_incompatible.is_shown()) + self.widget.resampling = OWTestLearners.TestOnTest + # test data with the same class (otherwise the widget shows a different error) + # and a non-nan X + iris_test = iris.transform(Domain([ContinuousVariable()], + class_vars=iris.domain.class_vars)) + iris_test.X[:, 0] = 1 + self.send_signal(self.widget.Inputs.test_data, iris_test) + self.get_output(self.widget.Outputs.evaluations_results, wait=5000) + self.assertTrue(self.widget.Error.test_data_incompatible.is_shown()) + def test_CrossValidationByFeature(self): data = Table("iris") attrs = data.domain.attributes