diff --git a/Orange/data/table.py b/Orange/data/table.py index 559750804e0..bfabff033a3 100644 --- a/Orange/data/table.py +++ b/Orange/data/table.py @@ -2,6 +2,7 @@ import os import threading import warnings +import weakref import zlib from collections import Iterable, Sequence, Sized from functools import reduce @@ -316,6 +317,12 @@ def from_table(cls, domain, source, row_indices=...): :rtype: Orange.data.Table """ + def valid_refs(weakrefs): + for r in weakrefs: + if r() is None: + return False + return True + def get_columns(row_indices, src_cols, n_rows, dtype=np.float64, is_sparse=False, variables=[]): if not len(src_cols): @@ -356,10 +363,13 @@ def get_columns(row_indices, src_cols, n_rows, dtype=np.float64, a[:, i] = variables[i].Unknown elif not isinstance(col, Integral): if isinstance(col, SharedComputeValue): - if (id(col.compute_shared), id(source)) not in shared_cache: - shared_cache[id(col.compute_shared), id(source)] = \ - col.compute_shared(source) - shared = shared_cache[id(col.compute_shared), id(source)] + shared, weakrefs = shared_cache.get((id(col.compute_shared), id(source)), + (None, None)) + if shared is None or not valid_refs(weakrefs): + shared, _ = shared_cache[(id(col.compute_shared), id(source))] = \ + col.compute_shared(source), \ + (weakref.ref(col.compute_shared), weakref.ref(source)) + if row_indices is not ...: a[:, i] = match_density( col(source, shared_data=shared)[row_indices]) @@ -389,8 +399,9 @@ def get_columns(row_indices, src_cols, n_rows, dtype=np.float64, if new_cache: _thread_local.conversion_cache = {} else: - cached = _thread_local.conversion_cache.get((id(domain), id(source))) - if cached: + cached, weakrefs = \ + _thread_local.conversion_cache.get((id(domain), id(source)), (None, None)) + if cached and valid_refs(weakrefs): return cached if domain is source.domain: table = cls.from_table_rows(source, row_indices) @@ -443,7 +454,8 @@ def get_columns(row_indices, src_cols, n_rows, dtype=np.float64, else: cls._init_ids(self) self.attributes = getattr(source, 'attributes', {}) - _thread_local.conversion_cache[(id(domain), id(source))] = self + _thread_local.conversion_cache[(id(domain), id(source))] = \ + self, (weakref.ref(domain), weakref.ref(source)) return self finally: if new_cache: diff --git a/Orange/tests/test_table.py b/Orange/tests/test_table.py index f826d503e76..7e511f0fd64 100644 --- a/Orange/tests/test_table.py +++ b/Orange/tests/test_table.py @@ -17,6 +17,7 @@ from Orange import data from Orange.data import (filter, Unknown, Variable, Table, DiscreteVariable, ContinuousVariable, Domain, StringVariable) +from Orange.data.util import SharedComputeValue from Orange.tests import test_dirname from Orange.data.table import _optimize_indices @@ -2693,5 +2694,145 @@ def run_from_table(): self.assertLess(duration, 0.5) +class PreprocessComputeValue: + + def __init__(self, domain, callback): + self.domain = domain + self.callback = callback + + def __call__(self, data_): + self.callback(data_) + data_.transform(self.domain) + return data_.X[:, 0] + + +class PreprocessShared: + + def __init__(self, domain, callback): + self.domain = domain + self.callback = callback + + def __call__(self, data_): + self.callback(data_) + data_.transform(self.domain) + return True + + +class PreprocessSharedComputeValue(SharedComputeValue): + + def __init__(self, callback, shared): + super().__init__(compute_shared=shared) + self.callback = callback + + # pylint: disable=arguments-differ + def compute(self, data_, shared_data): + self.callback(data_) + return data_.X[:, 0] + + +def preprocess_domain_single(domain, callback): + """ Preprocess domain with single-source compute values. + """ + return Domain([ + ContinuousVariable(name=at.name, + compute_value=PreprocessComputeValue(Domain([at]), callback)) + for at in domain.attributes]) + + +def preprocess_domain_shared(domain, callback, callback_shared): + """ Preprocess domain with shared compute values. + """ + shared = PreprocessShared(domain, callback_shared) + return Domain([ + ContinuousVariable(name=at.name, + compute_value=PreprocessSharedComputeValue(callback, shared)) + for at in domain.attributes]) + + +def preprocess_domain_single_stupid(domain, callback): + """ Preprocess domain with single-source compute values with stupid + implementation: before applying it, instead of transforming just one column + into the input domain, do a needless transform of the whole domain. + """ + return Domain([ + ContinuousVariable(name=at.name, + compute_value=PreprocessComputeValue(domain, callback)) + for at in domain.attributes]) + + +class EfficientTransformTests(unittest.TestCase): + + def setUp(self): + self.iris = Table("iris") + + def test_simple(self): + call_cv = Mock() + d1 = preprocess_domain_single(self.iris.domain, call_cv) + self.iris.transform(d1) + self.assertEqual(4, call_cv.call_count) + + def test_shared(self): + call_cv = Mock() + call_shared = Mock() + d1 = preprocess_domain_shared(self.iris.domain, call_cv, call_shared) + self.iris.transform(d1) + self.assertEqual(4, call_cv.call_count) + self.assertEqual(1, call_shared.call_count) + + def test_simple_simple_shared(self): + call_cv = Mock() + d1 = preprocess_domain_single(self.iris.domain, call_cv) + d2 = preprocess_domain_single(d1, call_cv) + call_shared = Mock() + d3 = preprocess_domain_shared(d2, call_cv, call_shared) + self.iris.transform(d3) + self.assertEqual(1, call_shared.call_count) + self.assertEqual(12, call_cv.call_count) + + def test_simple_simple_shared_simple(self): + call_cv = Mock() + d1 = preprocess_domain_single(self.iris.domain, call_cv) + d2 = preprocess_domain_single(d1, call_cv) + call_shared = Mock() + d3 = preprocess_domain_shared(d2, call_cv, call_shared) + d4 = preprocess_domain_single(d3, call_cv) + self.iris.transform(d4) + self.assertEqual(1, call_shared.call_count) + self.assertEqual(16, call_cv.call_count) + + def test_simple_simple_shared_simple_shared_simple(self): + call_cv = Mock() + d1 = preprocess_domain_single(self.iris.domain, call_cv) + d2 = preprocess_domain_single(d1, call_cv) + call_shared = Mock() + d3 = preprocess_domain_shared(d2, call_cv, call_shared) + d4 = preprocess_domain_single(d3, call_cv) + d5 = preprocess_domain_shared(d4, call_cv, call_shared) + d6 = preprocess_domain_single(d5, call_cv) + self.iris.transform(d6) + self.assertEqual(2, call_shared.call_count) + self.assertEqual(24, call_cv.call_count) + + def test_same_simple_shared_union(self): + call_cv = Mock() + call_shared = Mock() + call_cvs = Mock() + same_simple = preprocess_domain_single(self.iris.domain, call_cv) + s1 = preprocess_domain_shared(same_simple, call_cvs, call_shared) + s2 = preprocess_domain_shared(same_simple, call_cvs, call_shared) + ndom = Domain(s1.attributes + s2.attributes) + self.iris.transform(ndom) + self.assertEqual(2, call_shared.call_count) + self.assertEqual(4, call_cv.call_count) + self.assertEqual(8, call_cvs.call_count) + + def test_simple_simple_stupid(self): + call_cv = Mock() + d1 = preprocess_domain_single_stupid(self.iris.domain, call_cv) + d2 = preprocess_domain_single_stupid(d1, call_cv) + self.iris.transform(d2) + self.assertEqual(8, call_cv.call_count) + + if __name__ == "__main__": unittest.main()