Skip to content

Commit

Permalink
Merge pull request #4370 from markotoplak/from_table-cache-less
Browse files Browse the repository at this point in the history
[FIX] Table.from_table: fix caching with reused ids
  • Loading branch information
janezd authored Feb 1, 2020
2 parents 098d97d + b175dfa commit 796bd26
Show file tree
Hide file tree
Showing 2 changed files with 160 additions and 7 deletions.
26 changes: 19 additions & 7 deletions Orange/data/table.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import os
import threading
import warnings
import weakref
import zlib
from collections import Iterable, Sequence, Sized
from functools import reduce
Expand Down Expand Up @@ -316,6 +317,12 @@ def from_table(cls, domain, source, row_indices=...):
:rtype: Orange.data.Table
"""

def valid_refs(weakrefs):
for r in weakrefs:
if r() is None:
return False
return True

def get_columns(row_indices, src_cols, n_rows, dtype=np.float64,
is_sparse=False, variables=[]):
if not len(src_cols):
Expand Down Expand Up @@ -356,10 +363,13 @@ def get_columns(row_indices, src_cols, n_rows, dtype=np.float64,
a[:, i] = variables[i].Unknown
elif not isinstance(col, Integral):
if isinstance(col, SharedComputeValue):
if (id(col.compute_shared), id(source)) not in shared_cache:
shared_cache[id(col.compute_shared), id(source)] = \
col.compute_shared(source)
shared = shared_cache[id(col.compute_shared), id(source)]
shared, weakrefs = shared_cache.get((id(col.compute_shared), id(source)),
(None, None))
if shared is None or not valid_refs(weakrefs):
shared, _ = shared_cache[(id(col.compute_shared), id(source))] = \
col.compute_shared(source), \
(weakref.ref(col.compute_shared), weakref.ref(source))

if row_indices is not ...:
a[:, i] = match_density(
col(source, shared_data=shared)[row_indices])
Expand Down Expand Up @@ -389,8 +399,9 @@ def get_columns(row_indices, src_cols, n_rows, dtype=np.float64,
if new_cache:
_thread_local.conversion_cache = {}
else:
cached = _thread_local.conversion_cache.get((id(domain), id(source)))
if cached:
cached, weakrefs = \
_thread_local.conversion_cache.get((id(domain), id(source)), (None, None))
if cached and valid_refs(weakrefs):
return cached
if domain is source.domain:
table = cls.from_table_rows(source, row_indices)
Expand Down Expand Up @@ -443,7 +454,8 @@ def get_columns(row_indices, src_cols, n_rows, dtype=np.float64,
else:
cls._init_ids(self)
self.attributes = getattr(source, 'attributes', {})
_thread_local.conversion_cache[(id(domain), id(source))] = self
_thread_local.conversion_cache[(id(domain), id(source))] = \
self, (weakref.ref(domain), weakref.ref(source))
return self
finally:
if new_cache:
Expand Down
141 changes: 141 additions & 0 deletions Orange/tests/test_table.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
from Orange import data
from Orange.data import (filter, Unknown, Variable, Table, DiscreteVariable,
ContinuousVariable, Domain, StringVariable)
from Orange.data.util import SharedComputeValue
from Orange.tests import test_dirname
from Orange.data.table import _optimize_indices

Expand Down Expand Up @@ -2693,5 +2694,145 @@ def run_from_table():
self.assertLess(duration, 0.5)


class PreprocessComputeValue:

def __init__(self, domain, callback):
self.domain = domain
self.callback = callback

def __call__(self, data_):
self.callback(data_)
data_.transform(self.domain)
return data_.X[:, 0]


class PreprocessShared:

def __init__(self, domain, callback):
self.domain = domain
self.callback = callback

def __call__(self, data_):
self.callback(data_)
data_.transform(self.domain)
return True


class PreprocessSharedComputeValue(SharedComputeValue):

def __init__(self, callback, shared):
super().__init__(compute_shared=shared)
self.callback = callback

# pylint: disable=arguments-differ
def compute(self, data_, shared_data):
self.callback(data_)
return data_.X[:, 0]


def preprocess_domain_single(domain, callback):
""" Preprocess domain with single-source compute values.
"""
return Domain([
ContinuousVariable(name=at.name,
compute_value=PreprocessComputeValue(Domain([at]), callback))
for at in domain.attributes])


def preprocess_domain_shared(domain, callback, callback_shared):
""" Preprocess domain with shared compute values.
"""
shared = PreprocessShared(domain, callback_shared)
return Domain([
ContinuousVariable(name=at.name,
compute_value=PreprocessSharedComputeValue(callback, shared))
for at in domain.attributes])


def preprocess_domain_single_stupid(domain, callback):
""" Preprocess domain with single-source compute values with stupid
implementation: before applying it, instead of transforming just one column
into the input domain, do a needless transform of the whole domain.
"""
return Domain([
ContinuousVariable(name=at.name,
compute_value=PreprocessComputeValue(domain, callback))
for at in domain.attributes])


class EfficientTransformTests(unittest.TestCase):

def setUp(self):
self.iris = Table("iris")

def test_simple(self):
call_cv = Mock()
d1 = preprocess_domain_single(self.iris.domain, call_cv)
self.iris.transform(d1)
self.assertEqual(4, call_cv.call_count)

def test_shared(self):
call_cv = Mock()
call_shared = Mock()
d1 = preprocess_domain_shared(self.iris.domain, call_cv, call_shared)
self.iris.transform(d1)
self.assertEqual(4, call_cv.call_count)
self.assertEqual(1, call_shared.call_count)

def test_simple_simple_shared(self):
call_cv = Mock()
d1 = preprocess_domain_single(self.iris.domain, call_cv)
d2 = preprocess_domain_single(d1, call_cv)
call_shared = Mock()
d3 = preprocess_domain_shared(d2, call_cv, call_shared)
self.iris.transform(d3)
self.assertEqual(1, call_shared.call_count)
self.assertEqual(12, call_cv.call_count)

def test_simple_simple_shared_simple(self):
call_cv = Mock()
d1 = preprocess_domain_single(self.iris.domain, call_cv)
d2 = preprocess_domain_single(d1, call_cv)
call_shared = Mock()
d3 = preprocess_domain_shared(d2, call_cv, call_shared)
d4 = preprocess_domain_single(d3, call_cv)
self.iris.transform(d4)
self.assertEqual(1, call_shared.call_count)
self.assertEqual(16, call_cv.call_count)

def test_simple_simple_shared_simple_shared_simple(self):
call_cv = Mock()
d1 = preprocess_domain_single(self.iris.domain, call_cv)
d2 = preprocess_domain_single(d1, call_cv)
call_shared = Mock()
d3 = preprocess_domain_shared(d2, call_cv, call_shared)
d4 = preprocess_domain_single(d3, call_cv)
d5 = preprocess_domain_shared(d4, call_cv, call_shared)
d6 = preprocess_domain_single(d5, call_cv)
self.iris.transform(d6)
self.assertEqual(2, call_shared.call_count)
self.assertEqual(24, call_cv.call_count)

def test_same_simple_shared_union(self):
call_cv = Mock()
call_shared = Mock()
call_cvs = Mock()
same_simple = preprocess_domain_single(self.iris.domain, call_cv)
s1 = preprocess_domain_shared(same_simple, call_cvs, call_shared)
s2 = preprocess_domain_shared(same_simple, call_cvs, call_shared)
ndom = Domain(s1.attributes + s2.attributes)
self.iris.transform(ndom)
self.assertEqual(2, call_shared.call_count)
self.assertEqual(4, call_cv.call_count)
self.assertEqual(8, call_cvs.call_count)

def test_simple_simple_stupid(self):
call_cv = Mock()
d1 = preprocess_domain_single_stupid(self.iris.domain, call_cv)
d2 = preprocess_domain_single_stupid(d1, call_cv)
self.iris.transform(d2)
self.assertEqual(8, call_cv.call_count)


if __name__ == "__main__":
unittest.main()

0 comments on commit 796bd26

Please sign in to comment.