From 53429a241a40228e0f64f9c67caf9900828b70fa Mon Sep 17 00:00:00 2001 From: Marko Toplak Date: Thu, 7 Dec 2023 16:32:33 +0100 Subject: [PATCH] Fix impute.Model for derived domains The compute_value was missing transformation into the variable space it was working upon. --- Orange/preprocess/impute.py | 7 +++---- Orange/tests/test_impute.py | 23 ++++++++++++++++++++++- 2 files changed, 25 insertions(+), 5 deletions(-) diff --git a/Orange/preprocess/impute.py b/Orange/preprocess/impute.py index c67c4a97434..912a5469770 100644 --- a/Orange/preprocess/impute.py +++ b/Orange/preprocess/impute.py @@ -172,7 +172,7 @@ def copy(self): return FixedValueByType(*self.defaults.values()) -class ReplaceUnknownsModel(Reprable): +class ReplaceUnknownsModel(Transformation): """ Replace unknown values with predicted values using a `Orange.base.Model` @@ -185,15 +185,14 @@ class ReplaceUnknownsModel(Reprable): """ def __init__(self, variable, model): assert model.domain.class_var == variable - self.variable = variable + super().__init__(variable) self.model = model def __call__(self, data): if isinstance(data, Orange.data.Instance): data = Orange.data.Table.from_list(data.domain, [data]) domain = data.domain - column = data.get_column(self.variable, copy=True) - + column = data.transform(self._target_domain).get_column(self.variable, copy=True) mask = np.isnan(column) if not np.any(mask): return column diff --git a/Orange/tests/test_impute.py b/Orange/tests/test_impute.py index 9c6fba5b336..1278e4e89de 100644 --- a/Orange/tests/test_impute.py +++ b/Orange/tests/test_impute.py @@ -9,7 +9,7 @@ from Orange import preprocess from Orange.preprocess import impute, SklImpute from Orange import data -from Orange.data import Unknown, Table +from Orange.data import Unknown, Table, Domain from Orange.classification import MajorityLearner, SimpleTreeLearner from Orange.regression import MeanLearner @@ -293,6 +293,27 @@ def test_bad_domain(self): self.assertRaises(ValueError, imputer, data=table, variable=table.domain[0]) + def test_missing_imputed_columns(self): + housing = Table("housing") + + learner = SimpleTreeLearner(min_instances=10, max_depth=10) + method = preprocess.impute.Model(learner) + + ivar = method(housing, housing.domain.attributes[0]) + imputed = housing.transform( + Domain([ivar], + housing.domain.class_var) + ) + removed_imputed = imputed.transform( + Domain([], housing.domain.class_var)) + + r = removed_imputed.transform(imputed.domain) + + no_class = removed_imputed.transform(Domain(removed_imputed.domain.attributes, None)) + impute_model_prediction_for_unknowns = ivar.compute_value.model(no_class[0]) + + np.testing.assert_equal(r.X, impute_model_prediction_for_unknowns) + class TestRandom(unittest.TestCase): def test_replacement(self):