From 8ac409b1ed35e9eb93f33ee7d80df6742ad7014c Mon Sep 17 00:00:00 2001 From: janezd Date: Fri, 19 Feb 2021 11:54:27 +0100 Subject: [PATCH] Table.concatenate: Refactor --- Orange/data/table.py | 50 ++++++++++++++++----------------- Orange/data/tests/test_table.py | 12 ++++++++ 2 files changed, 36 insertions(+), 26 deletions(-) diff --git a/Orange/data/table.py b/Orange/data/table.py index 059148a5c96..cf421c8aba8 100644 --- a/Orange/data/table.py +++ b/Orange/data/table.py @@ -898,11 +898,11 @@ def __repr__(self): @classmethod def concatenate(cls, tables, axis=0): """ - Concatenate tables into a new table, either horizontally or vertically. + Concatenate tables into a new table, either vertically or horizontally. - If axis=0 (horizontal concatenate), all tables must have the same domain. + If axis=0 (vertical concatenate), all tables must have the same domain. - If axis=1 (vertical), + If axis=1 (horizontal), - all variable names must be unique. - ids are copied from the first table. - weights are copied from the first table in which they are defined. @@ -915,12 +915,28 @@ def concatenate(cls, tables, axis=0): Returns: table (Table) """ + if axis not in (0, 1): + raise ValueError("invalid axis") + if not tables: + raise ValueError('need at least one table to concatenate') + + if len(tables) == 1: + return tables[0].copy() + if axis == 0: - return cls._concatenate_vertical(tables) - elif axis == 1: - return cls._concatenate_horizontal(tables) + conc = cls._concatenate_vertical(tables) else: - raise ValueError("invalid axis") + conc = cls._concatenate_horizontal(tables) + + # TODO: Add attributes = {} to __init__ + conc.attributes = getattr(conc, "attributes", {}) + for table in reversed(tables): + conc.attributes.update(table.attributes) + + names = [table.name for table in tables if table.name != "untitled"] + if names: + conc.name = names[0] + return conc @classmethod def _concatenate_vertical(cls, tables): @@ -941,10 +957,6 @@ def merge1d(arrs): def collect(attr): return [getattr(arr, attr) for arr in tables] - if not tables: - raise ValueError('need at least one table to concatenate') - if len(tables) == 1: - return tables[0].copy() domain = tables[0].domain if any(table.domain != domain for table in tables): raise ValueError('concatenated tables must have the same domain') @@ -957,22 +969,12 @@ def collect(attr): merge1d(collect("W")) ) conc.ids = np.hstack([t.ids for t in tables]) - names = [table.name for table in tables if table.name != "untitled"] - if names: - conc.name = names[0] - # TODO: Add attributes = {} to __init__ - conc.attributes = getattr(conc, "attributes", {}) - for table in reversed(tables): - conc.attributes.update(table.attributes) return conc @classmethod def _concatenate_horizontal(cls, tables): """ """ - if not tables: - raise ValueError('need at least one table to join') - def all_of(objs, names): return (tuple(getattr(obj, name) for obj in objs) for name in names) @@ -992,11 +994,7 @@ def stack(arrs): parts = all_of(doms, ("attributes", "class_vars", "metas")) domain = Domain(*(tuple(chain(*lst)) for lst in parts)) - table = cls.from_numpy(domain, Xs, Ys, Ms, W, ids=tables[0].ids) - for ta in reversed(table_attrss): - table.attributes.update(ta) - - return table + return cls.from_numpy(domain, Xs, Ys, Ms, W, ids=tables[0].ids) def add_column(self, variable, data, to_metas=None): """ diff --git a/Orange/data/tests/test_table.py b/Orange/data/tests/test_table.py index 0c4bfbb0ac1..836110276b3 100644 --- a/Orange/data/tests/test_table.py +++ b/Orange/data/tests/test_table.py @@ -183,6 +183,18 @@ def test_concatenate_horizontal(self): def test_concatenate_invalid_axis(self): self.assertRaises(ValueError, Table.concatenate, (), axis=2) + def test_concatenate_names(self): + a, b, c, d, e, f, g = map(ContinuousVariable, "abcdefg") + + tab1 = self._new_table((a, ), (c, ), (d, ), 0) + tab2 = self._new_table((e, ), (), (f, g), 1000) + tab3 = self._new_table((b, ), (), (), 1000) + tab2.name = "tab2" + tab3.name = "tab3" + + joined = Table.concatenate((tab1, tab2, tab3), axis=1) + self.assertEqual(joined.name, "tab2") + def test_with_column(self): a, b, c, d, e, f, g = map(ContinuousVariable, "abcdefg") col = np.arange(9, 14)