Skip to content

Commit

Permalink
Merge data: Rename variables with duplicated names
Browse files Browse the repository at this point in the history
  • Loading branch information
VesnaT committed Oct 3, 2019
1 parent a0dea94 commit d89d4bf
Show file tree
Hide file tree
Showing 4 changed files with 103 additions and 32 deletions.
35 changes: 35 additions & 0 deletions Orange/data/tests/test_variable.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,23 @@ def test_copy_copies_attributes(self):
# Attributes of original value should not change
self.assertEqual(var.attributes["a"], "b")

def test_rename(self):
var = self.varcls_modified("x")
var2 = var.copy(name="x2")
self.assertIsInstance(var2, type(var))
self.assertIsNot(var, var2)
self.assertEqual(var2.name, "x2")
var.__dict__.pop("_name")
var2.__dict__.pop("_name")
self.assertDictEqual(var.__dict__, var2.__dict__)

def varcls_modified(self, name):
var = self.varcls(name)
var._compute_value = lambda x: x
var.sparse = True
var.attributes = {"a": 1}
return var


class TestVariable(unittest.TestCase):
@classmethod
Expand Down Expand Up @@ -378,6 +395,12 @@ def test_mapper_from_no_values(self):
self.assertRaises(ValueError, mapper, sp.csr_matrix(arr), 0)
self.assertRaises(ValueError, mapper, sp.csc_matrix(arr), 0)

def varcls_modified(self, name):
var = super().varcls_modified(name)
var.values = ["A", "B"]
var.ordered = True
return var


@variabletest(ContinuousVariable)
class TestContinuousVariable(VariableTest):
Expand Down Expand Up @@ -419,6 +442,11 @@ def test_colors(self):
a.colors = ((3, 2, 1), (6, 5, 4), True)
self.assertEqual(a.colors, ((3, 2, 1), (6, 5, 4), True))

def varcls_modified(self, name):
var = super().varcls_modified(name)
var.number_of_decimals = 5
return var


@variabletest(StringVariable)
class TestStringVariable(VariableTest):
Expand Down Expand Up @@ -538,6 +566,13 @@ def test_have_date_have_time_in_construct(self):
self.assertTrue(var.have_date)
self.assertFalse(var.have_time)

def varcls_modified(self, name):
var = super().varcls_modified(name)
var.number_of_decimals = 5
var.have_date = 1
var.have_time = 1
return var


PickleContinuousVariable = create_pickling_tests(
"PickleContinuousVariable",
Expand Down
39 changes: 22 additions & 17 deletions Orange/data/variable.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
import collections
import re
import warnings
from collections import Iterable
Expand Down Expand Up @@ -215,7 +214,8 @@ def __contains__(self, other):

def __hash__(self):
if self.variable.is_discrete:
# It is not possible to hash the id and the domain value to the same number as required by __eq__.
# It is not possible to hash the id and the domain value to the
# same number as required by __eq__.
# hash(1) == hash(Value(DiscreteVariable("var", ["red", "green", "blue"]), 1)) == hash("green")
# User should hash directly ids or domain values instead.
raise TypeError("unhashable type - cannot hash values of discrete variables!")
Expand Down Expand Up @@ -451,8 +451,10 @@ def __reduce__(self):
# Use make to unpickle variables.
return make_variable, (self.__class__, self._compute_value, self.name), self.__dict__

def copy(self, compute_value):
var = type(self)(self.name, compute_value=compute_value, sparse=self.sparse)
def copy(self, compute_value=None, name=None, **kwargs):
var = type(self)(name=name or self.name,
compute_value=compute_value or self.compute_value,
sparse=self.sparse, **kwargs)
var.attributes = dict(self.attributes)
return var

Expand Down Expand Up @@ -507,6 +509,10 @@ def number_of_decimals(self):
def format_str(self):
return self._format_str

@format_str.setter
def format_str(self, value):
self._format_str = value

@property
def colors(self):
if self._colors is None:
Expand Down Expand Up @@ -559,9 +565,12 @@ def repr_val(self, val):

str_val = repr_val

def copy(self, compute_value=None):
var = type(self)(self.name, self.number_of_decimals, compute_value, sparse=self.sparse)
var.attributes = dict(self.attributes)
def copy(self, compute_value=None, name=None, **kwargs):
var = super().copy(compute_value=compute_value, name=name,
number_of_decimals=self.number_of_decimals,
**kwargs)
var.adjust_decimals = self.adjust_decimals
var.format_str = self._format_str
return var


Expand Down Expand Up @@ -795,11 +804,9 @@ def ordered_values(values):
except ValueError:
return sorted(values)

def copy(self, compute_value=None):
var = DiscreteVariable(self.name, self.values, self.ordered,
compute_value, sparse=self.sparse)
var.attributes = dict(self.attributes)
return var
def copy(self, compute_value=None, name=None, **_):
return super().copy(compute_value=compute_value, name=name,
values=self.values, ordered=self.ordered)


class StringVariable(Variable):
Expand Down Expand Up @@ -913,11 +920,9 @@ def __init__(self, *args, have_date=0, have_time=0, **kwargs):
self.have_date = have_date
self.have_time = have_time

def copy(self, compute_value=None):
copy = super().copy(compute_value=compute_value)
copy.have_date = self.have_date
copy.have_time = self.have_time
return copy
def copy(self, compute_value=None, name=None, **_):
return super().copy(compute_value=compute_value, name=name,
have_date=self.have_date, have_time=self.have_time)

@staticmethod
def _tzre_sub(s, _subtz=re.compile(r'([+-])(\d\d):(\d\d)$').sub):
Expand Down
45 changes: 30 additions & 15 deletions Orange/widgets/data/owmergedata.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@

import Orange
from Orange.data import StringVariable, ContinuousVariable, Variable
from Orange.data.util import hstack
from Orange.data.util import hstack, get_unique_names_duplicates
from Orange.widgets import widget, gui
from Orange.widgets.settings import Setting
from Orange.widgets.utils.itemmodels import DomainModel
Expand Down Expand Up @@ -217,7 +217,8 @@ class Outputs:
resizing_enabled = False

class Warning(widget.OWWidget.Warning):
duplicate_names = Msg("Duplicate variable names in output.")
renamed_vars = Msg("Some variables have been renamed "
"to avoid duplicates.\n{}")

class Error(widget.OWWidget.Error):
matching_numeric_with_nonnum = Msg(
Expand Down Expand Up @@ -379,19 +380,9 @@ def dataInfoText(data):
f"{len(data.domain) + len(data.domain.metas)} variables"

def commit(self):
self.Error.clear()
self.Warning.duplicate_names.clear()
if not self.data or not self.extra_data:
merged_data = None
else:
merged_data = self.merge()
if merged_data:
merged_domain = merged_data.domain
var_names = [var.name for var in chain(merged_domain.variables,
merged_domain.metas)]
if len(set(var_names)) != len(var_names):
self.Warning.duplicate_names()
self.Outputs.data.send(merged_data)
self.clear_messages()
merged = None if not self.data or not self.extra_data else self.merge()
self.Outputs.data.send(merged)

def send_report(self):
# pylint: disable=invalid-sequence-index
Expand Down Expand Up @@ -544,6 +535,7 @@ def _join_table_by_indices(self, reduced_extra, lefti, righti, rightu):
domain = Orange.data.Domain(
*(getattr(self.data.domain, x) + getattr(reduced_extra.domain, x)
for x in ("attributes", "class_vars", "metas")))
domain = self._domain_rename_duplicates(domain)
X = self._join_array_by_indices(self.data.X, reduced_extra.X, lefti, righti)
Y = self._join_array_by_indices(
np.c_[self.data.Y], np.c_[reduced_extra.Y], lefti, righti)
Expand All @@ -566,6 +558,29 @@ def _join_table_by_indices(self, reduced_extra, lefti, righti, rightu):

return table

def _domain_rename_duplicates(self, domain):
"""Check for duplicate variable names in domain. If any, rename
the variables, by replacing them with new ones (names are
appended a number). """
attrs, cvars, metas = [], [], []
n_attrs, n_cvars, n_metas = (len(domain.attributes),
len(domain.class_vars), len(domain.metas))
lists = [attrs] * n_attrs + [cvars] * n_cvars + [metas] * n_metas

variables = domain.variables + domain.metas
proposed_names = [m.name for m in variables]
unique_names = get_unique_names_duplicates(proposed_names)
duplicates = set()
for p_name, u_name, var, c in zip(proposed_names, unique_names,
variables, lists):
if p_name != u_name:
duplicates.add(p_name)
var = var.copy(name=u_name)
c.append(var)
if duplicates:
self.Warning.renamed_vars(", ".join(duplicates))
return Orange.data.Domain(attrs, cvars, metas)

@staticmethod
def _join_array_by_indices(left, right, lefti, righti, string_cols=None):
"""Join (horizontally) two arrays, taking pairs of rows given in indices
Expand Down
16 changes: 16 additions & 0 deletions Orange/widgets/data/tests/test_owmergedata.py
Original file line number Diff line number Diff line change
Expand Up @@ -937,6 +937,22 @@ def test_invalide_pairs(self):
self.assertFalse(widget.Error.matching_index_with_sth.is_shown())
self.assertTrue(widget.Error.matching_numeric_with_nonnum.is_shown())

def test_duplicate_names(self):
domain = Domain([ContinuousVariable("C1")],
metas=[DiscreteVariable("Feature", values=["A", "B"])])
data = Table(domain, np.array([[1.], [0.]]),
metas=np.array([[1.], [0.]]))
domain = Domain([ContinuousVariable("C1")],
metas=[StringVariable("Feature")])
extra_data = Table(domain, np.array([[1.], [0.]]),
metas=np.array([["A"], ["B"]]))
self.send_signal(self.widget.Inputs.data, data)
self.send_signal(self.widget.Inputs.extra_data, extra_data)
self.assertTrue(self.widget.Warning.renamed_vars.is_shown())
merged_data = self.get_output(self.widget.Outputs.data)
self.assertListEqual([m.name for m in merged_data.domain.metas],
["Feature (1)", "Feature (2)"])


if __name__ == "__main__":
unittest.main()

0 comments on commit d89d4bf

Please sign in to comment.