From c0f8a6739223b049f28176675dc72f5bdbc49827 Mon Sep 17 00:00:00 2001 From: Vesna Tanko Date: Fri, 22 Nov 2019 13:47:06 +0100 Subject: [PATCH] Normalizer: Retain attributes of attributes --- Orange/preprocess/normalize.py | 20 +++++--------------- Orange/tests/test_normalize.py | 15 +++++++++++++++ 2 files changed, 20 insertions(+), 15 deletions(-) diff --git a/Orange/preprocess/normalize.py b/Orange/preprocess/normalize.py index 00783ab895f..d05504f2258 100644 --- a/Orange/preprocess/normalize.py +++ b/Orange/preprocess/normalize.py @@ -1,6 +1,6 @@ import numpy as np -from Orange.data import ContinuousVariable, Domain +from Orange.data import Domain from Orange.statistics import distribution from Orange.util import Reprable from .preprocess import Normalize @@ -51,12 +51,7 @@ def normalize_by_sd(self, dist, var): compute_val = Norm(var, avg, 1 / sd) else: compute_val = Norm(var, 0, 1 / sd) - - return ContinuousVariable( - var.name, - compute_value=compute_val, - sparse=var.sparse, - ) + return var.copy(compute_value=compute_val) def normalize_by_span(self, dist, var): dma, dmi = (dist.max(), dist.min()) if dist.shape[1] else (np.nan, np.nan) @@ -64,12 +59,7 @@ def normalize_by_span(self, dist, var): if diff < 1e-15: diff = 1 if self.zero_based: - return ContinuousVariable( - var.name, - compute_value=Norm(var, dmi, 1 / diff), - sparse=var.sparse) + compute_val = Norm(var, dmi, 1 / diff) else: - return ContinuousVariable( - var.name, - compute_value=Norm(var, (dma + dmi) / 2, 2 / diff), - sparse=var.sparse) + compute_val = Norm(var, (dma + dmi) / 2, 2 / diff) + return var.copy(compute_value=compute_val) diff --git a/Orange/tests/test_normalize.py b/Orange/tests/test_normalize.py index b372df0c6d6..1b93a2d2311 100644 --- a/Orange/tests/test_normalize.py +++ b/Orange/tests/test_normalize.py @@ -142,3 +142,18 @@ def test_datetime_normalization(self): [0., '2003-07-23', 'a', 'b', -1., '?', 0., 'b', '?', 'b', 0], [0., '1967-03-12', 'a', 'b', 1., 'b', -1.225, 'c', '?', 'c', 1]] self.compare_tables(data_norm, solution) + + def test_retain_vars_attributes(self): + data = Table("iris") + attributes = {"foo": "foo", "baz": 1} + data.domain.attributes[0].attributes = attributes + self.assertDictEqual( + Normalize(norm_type=Normalize.NormalizeBySD)( + data).domain.attributes[0].attributes, attributes) + self.assertDictEqual( + Normalize(norm_type=Normalize.NormalizeBySpan)( + data).domain.attributes[0].attributes, attributes) + + +if __name__ == "__main__": + unittest.main()