diff --git a/Orange/widgets/data/owcorrelations.py b/Orange/widgets/data/owcorrelations.py index 4bf7257b0fe..cb8940a481f 100644 --- a/Orange/widgets/data/owcorrelations.py +++ b/Orange/widgets/data/owcorrelations.py @@ -262,8 +262,8 @@ class Warning(OWWidget.Warning): def __init__(self): super().__init__() - self.data = None - self.cont_data = None + self.data = None # type: Table + self.cont_data = None # type: Table # GUI box = gui.vBox(self.mainArea) @@ -347,9 +347,9 @@ def set_data(self, data): self.Warning.not_enough_inst() else: domain = data.domain - cont_attrs = [a for a in domain.attributes if a.is_continuous] - cont_dom = Domain(cont_attrs, domain.class_vars, domain.metas) - cont_data = Table.from_table(cont_dom, data) + cont_vars = [a for a in domain.class_vars + domain.metas + + domain.attributes if a.is_continuous] + cont_data = Table.from_table(Domain(cont_vars), data) remover = Remove(Remove.RemoveConstant) cont_data = remover(cont_data) if remover.attr_results["removed"]: @@ -365,7 +365,11 @@ def set_data(self, data): def set_feature_model(self): self.feature_model.set_domain(self.cont_data and self.cont_data.domain) - self.feature = None + data = self.data + if self.cont_data and data.domain.has_continuous_class: + self.feature = self.cont_data.domain[data.domain.class_var.name] + else: + self.feature = None def apply(self): self.vizrank.initialize() diff --git a/Orange/widgets/data/tests/test_owcorrelations.py b/Orange/widgets/data/tests/test_owcorrelations.py index ecd4b4b0501..83679b8a5cb 100644 --- a/Orange/widgets/data/tests/test_owcorrelations.py +++ b/Orange/widgets/data/tests/test_owcorrelations.py @@ -28,6 +28,7 @@ def setUpClass(cls): cls.data_cont = Table("iris") cls.data_disc = Table("zoo") cls.data_mixed = Table("heart_disease") + cls.housing = Table("housing") def setUp(self): self.widget = self.create_widget(OWCorrelations) @@ -86,7 +87,7 @@ def test_input_data_one_instance(self): self.assertFalse(self.widget.Warning.not_enough_inst.is_shown()) def test_input_data_with_constant_features(self): - """Check correlation table for dataset with a constant columns""" + """Check correlation table for dataset with constant columns""" np.random.seed(0) # pylint: disable=no-member X = np.random.randint(3, size=(4, 3)).astype(float) @@ -118,6 +119,20 @@ def test_input_data_with_constant_features(self): self.send_signal(self.widget.Inputs.data, None) self.assertFalse(self.widget.Information.removed_cons_feat.is_shown()) + def test_input_data_cont_target(self): + """Check correlation table for dataset with continuous class variable""" + data = self.housing[:5, 11:] + self.send_signal(self.widget.Inputs.data, data) + time.sleep(0.1) + self.process_events() + self.assertEqual(self.widget.vizrank.rank_model.rowCount(), 2) + self.assertEqual(self.widget.controls.feature.count(), 4) + self.assertEqual(self.widget.controls.feature.currentText(), "MEDV") + + data = self.housing[:5, 13:] + self.send_signal(self.widget.Inputs.data, data) + self.assertTrue(self.widget.Warning.not_enough_vars.is_shown()) + def test_output_data(self): """Check dataset on output""" self.send_signal(self.widget.Inputs.data, self.data_cont) @@ -230,8 +245,8 @@ def test_feature_combo(self): self.assertEqual(len(feature_combo.model()), len(cont_attributes) + 1) self.wait_until_stop_blocking() - self.send_signal(self.widget.Inputs.data, Table("housing")) - self.assertEqual(len(feature_combo.model()), 14) + self.send_signal(self.widget.Inputs.data, self.housing) + self.assertEqual(len(feature_combo.model()), 15) def test_select_feature(self): """Test feature selection"""