From d89d4bfae87e79c48f1d2a52f5a38327350ca10a Mon Sep 17 00:00:00 2001 From: Vesna Tanko Date: Thu, 3 Oct 2019 10:28:48 +0200 Subject: [PATCH] Merge data: Rename variables with duplicated names --- Orange/data/tests/test_variable.py | 35 +++++++++++++++ Orange/data/variable.py | 39 +++++++++------- Orange/widgets/data/owmergedata.py | 45 ++++++++++++------- Orange/widgets/data/tests/test_owmergedata.py | 16 +++++++ 4 files changed, 103 insertions(+), 32 deletions(-) diff --git a/Orange/data/tests/test_variable.py b/Orange/data/tests/test_variable.py index 5aba4928287..e66aee48534 100644 --- a/Orange/data/tests/test_variable.py +++ b/Orange/data/tests/test_variable.py @@ -60,6 +60,23 @@ def test_copy_copies_attributes(self): # Attributes of original value should not change self.assertEqual(var.attributes["a"], "b") + def test_rename(self): + var = self.varcls_modified("x") + var2 = var.copy(name="x2") + self.assertIsInstance(var2, type(var)) + self.assertIsNot(var, var2) + self.assertEqual(var2.name, "x2") + var.__dict__.pop("_name") + var2.__dict__.pop("_name") + self.assertDictEqual(var.__dict__, var2.__dict__) + + def varcls_modified(self, name): + var = self.varcls(name) + var._compute_value = lambda x: x + var.sparse = True + var.attributes = {"a": 1} + return var + class TestVariable(unittest.TestCase): @classmethod @@ -378,6 +395,12 @@ def test_mapper_from_no_values(self): self.assertRaises(ValueError, mapper, sp.csr_matrix(arr), 0) self.assertRaises(ValueError, mapper, sp.csc_matrix(arr), 0) + def varcls_modified(self, name): + var = super().varcls_modified(name) + var.values = ["A", "B"] + var.ordered = True + return var + @variabletest(ContinuousVariable) class TestContinuousVariable(VariableTest): @@ -419,6 +442,11 @@ def test_colors(self): a.colors = ((3, 2, 1), (6, 5, 4), True) self.assertEqual(a.colors, ((3, 2, 1), (6, 5, 4), True)) + def varcls_modified(self, name): + var = super().varcls_modified(name) + var.number_of_decimals = 5 + return var + @variabletest(StringVariable) class TestStringVariable(VariableTest): @@ -538,6 +566,13 @@ def test_have_date_have_time_in_construct(self): self.assertTrue(var.have_date) self.assertFalse(var.have_time) + def varcls_modified(self, name): + var = super().varcls_modified(name) + var.number_of_decimals = 5 + var.have_date = 1 + var.have_time = 1 + return var + PickleContinuousVariable = create_pickling_tests( "PickleContinuousVariable", diff --git a/Orange/data/variable.py b/Orange/data/variable.py index 0fcbc0869eb..4e4bf1b49a5 100644 --- a/Orange/data/variable.py +++ b/Orange/data/variable.py @@ -1,4 +1,3 @@ -import collections import re import warnings from collections import Iterable @@ -215,7 +214,8 @@ def __contains__(self, other): def __hash__(self): if self.variable.is_discrete: - # It is not possible to hash the id and the domain value to the same number as required by __eq__. + # It is not possible to hash the id and the domain value to the + # same number as required by __eq__. # hash(1) == hash(Value(DiscreteVariable("var", ["red", "green", "blue"]), 1)) == hash("green") # User should hash directly ids or domain values instead. raise TypeError("unhashable type - cannot hash values of discrete variables!") @@ -451,8 +451,10 @@ def __reduce__(self): # Use make to unpickle variables. return make_variable, (self.__class__, self._compute_value, self.name), self.__dict__ - def copy(self, compute_value): - var = type(self)(self.name, compute_value=compute_value, sparse=self.sparse) + def copy(self, compute_value=None, name=None, **kwargs): + var = type(self)(name=name or self.name, + compute_value=compute_value or self.compute_value, + sparse=self.sparse, **kwargs) var.attributes = dict(self.attributes) return var @@ -507,6 +509,10 @@ def number_of_decimals(self): def format_str(self): return self._format_str + @format_str.setter + def format_str(self, value): + self._format_str = value + @property def colors(self): if self._colors is None: @@ -559,9 +565,12 @@ def repr_val(self, val): str_val = repr_val - def copy(self, compute_value=None): - var = type(self)(self.name, self.number_of_decimals, compute_value, sparse=self.sparse) - var.attributes = dict(self.attributes) + def copy(self, compute_value=None, name=None, **kwargs): + var = super().copy(compute_value=compute_value, name=name, + number_of_decimals=self.number_of_decimals, + **kwargs) + var.adjust_decimals = self.adjust_decimals + var.format_str = self._format_str return var @@ -795,11 +804,9 @@ def ordered_values(values): except ValueError: return sorted(values) - def copy(self, compute_value=None): - var = DiscreteVariable(self.name, self.values, self.ordered, - compute_value, sparse=self.sparse) - var.attributes = dict(self.attributes) - return var + def copy(self, compute_value=None, name=None, **_): + return super().copy(compute_value=compute_value, name=name, + values=self.values, ordered=self.ordered) class StringVariable(Variable): @@ -913,11 +920,9 @@ def __init__(self, *args, have_date=0, have_time=0, **kwargs): self.have_date = have_date self.have_time = have_time - def copy(self, compute_value=None): - copy = super().copy(compute_value=compute_value) - copy.have_date = self.have_date - copy.have_time = self.have_time - return copy + def copy(self, compute_value=None, name=None, **_): + return super().copy(compute_value=compute_value, name=name, + have_date=self.have_date, have_time=self.have_time) @staticmethod def _tzre_sub(s, _subtz=re.compile(r'([+-])(\d\d):(\d\d)$').sub): diff --git a/Orange/widgets/data/owmergedata.py b/Orange/widgets/data/owmergedata.py index 32aa4b04d33..6a99c72194a 100644 --- a/Orange/widgets/data/owmergedata.py +++ b/Orange/widgets/data/owmergedata.py @@ -9,7 +9,7 @@ import Orange from Orange.data import StringVariable, ContinuousVariable, Variable -from Orange.data.util import hstack +from Orange.data.util import hstack, get_unique_names_duplicates from Orange.widgets import widget, gui from Orange.widgets.settings import Setting from Orange.widgets.utils.itemmodels import DomainModel @@ -217,7 +217,8 @@ class Outputs: resizing_enabled = False class Warning(widget.OWWidget.Warning): - duplicate_names = Msg("Duplicate variable names in output.") + renamed_vars = Msg("Some variables have been renamed " + "to avoid duplicates.\n{}") class Error(widget.OWWidget.Error): matching_numeric_with_nonnum = Msg( @@ -379,19 +380,9 @@ def dataInfoText(data): f"{len(data.domain) + len(data.domain.metas)} variables" def commit(self): - self.Error.clear() - self.Warning.duplicate_names.clear() - if not self.data or not self.extra_data: - merged_data = None - else: - merged_data = self.merge() - if merged_data: - merged_domain = merged_data.domain - var_names = [var.name for var in chain(merged_domain.variables, - merged_domain.metas)] - if len(set(var_names)) != len(var_names): - self.Warning.duplicate_names() - self.Outputs.data.send(merged_data) + self.clear_messages() + merged = None if not self.data or not self.extra_data else self.merge() + self.Outputs.data.send(merged) def send_report(self): # pylint: disable=invalid-sequence-index @@ -544,6 +535,7 @@ def _join_table_by_indices(self, reduced_extra, lefti, righti, rightu): domain = Orange.data.Domain( *(getattr(self.data.domain, x) + getattr(reduced_extra.domain, x) for x in ("attributes", "class_vars", "metas"))) + domain = self._domain_rename_duplicates(domain) X = self._join_array_by_indices(self.data.X, reduced_extra.X, lefti, righti) Y = self._join_array_by_indices( np.c_[self.data.Y], np.c_[reduced_extra.Y], lefti, righti) @@ -566,6 +558,29 @@ def _join_table_by_indices(self, reduced_extra, lefti, righti, rightu): return table + def _domain_rename_duplicates(self, domain): + """Check for duplicate variable names in domain. If any, rename + the variables, by replacing them with new ones (names are + appended a number). """ + attrs, cvars, metas = [], [], [] + n_attrs, n_cvars, n_metas = (len(domain.attributes), + len(domain.class_vars), len(domain.metas)) + lists = [attrs] * n_attrs + [cvars] * n_cvars + [metas] * n_metas + + variables = domain.variables + domain.metas + proposed_names = [m.name for m in variables] + unique_names = get_unique_names_duplicates(proposed_names) + duplicates = set() + for p_name, u_name, var, c in zip(proposed_names, unique_names, + variables, lists): + if p_name != u_name: + duplicates.add(p_name) + var = var.copy(name=u_name) + c.append(var) + if duplicates: + self.Warning.renamed_vars(", ".join(duplicates)) + return Orange.data.Domain(attrs, cvars, metas) + @staticmethod def _join_array_by_indices(left, right, lefti, righti, string_cols=None): """Join (horizontally) two arrays, taking pairs of rows given in indices diff --git a/Orange/widgets/data/tests/test_owmergedata.py b/Orange/widgets/data/tests/test_owmergedata.py index 5243f14bd11..be3873242f4 100644 --- a/Orange/widgets/data/tests/test_owmergedata.py +++ b/Orange/widgets/data/tests/test_owmergedata.py @@ -937,6 +937,22 @@ def test_invalide_pairs(self): self.assertFalse(widget.Error.matching_index_with_sth.is_shown()) self.assertTrue(widget.Error.matching_numeric_with_nonnum.is_shown()) + def test_duplicate_names(self): + domain = Domain([ContinuousVariable("C1")], + metas=[DiscreteVariable("Feature", values=["A", "B"])]) + data = Table(domain, np.array([[1.], [0.]]), + metas=np.array([[1.], [0.]])) + domain = Domain([ContinuousVariable("C1")], + metas=[StringVariable("Feature")]) + extra_data = Table(domain, np.array([[1.], [0.]]), + metas=np.array([["A"], ["B"]])) + self.send_signal(self.widget.Inputs.data, data) + self.send_signal(self.widget.Inputs.extra_data, extra_data) + self.assertTrue(self.widget.Warning.renamed_vars.is_shown()) + merged_data = self.get_output(self.widget.Outputs.data) + self.assertListEqual([m.name for m in merged_data.domain.metas], + ["Feature (1)", "Feature (2)"]) + if __name__ == "__main__": unittest.main()