Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[ENH] Allow concurrent transformation of tables into new domains #4363

Merged
merged 2 commits into from
Jan 24, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
151 changes: 77 additions & 74 deletions Orange/data/table.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,13 @@
import operator
import os
import threading
import warnings
import zlib
from collections import Iterable, Sequence, Sized
from functools import reduce
from itertools import chain
from numbers import Real, Integral
from threading import Lock, RLock
from threading import Lock

import bottleneck as bn
import numpy as np
Expand Down Expand Up @@ -40,11 +41,16 @@ def get_sample_datasets_dir():
dataset_dirs = ['', get_sample_datasets_dir()]


"""Domain conversion cache used in Table.from_table. It is global so that
chaining of domain conversions also works with caching even with descendants
of Table."""
_conversion_cache = None
_conversion_cache_lock = RLock()
class _ThreadLocal(threading.local):
def __init__(self):
super().__init__()
# Domain conversion cache used in Table.from_table. It is defined
# here instead of as a class variable of a Table so that caching also works
# with descendants of Table.
self.conversion_cache = None


_thread_local = _ThreadLocal()


class DomainTransformationError(Exception):
Expand Down Expand Up @@ -310,8 +316,6 @@ def from_table(cls, domain, source, row_indices=...):
:rtype: Orange.data.Table
"""

global _conversion_cache

def get_columns(row_indices, src_cols, n_rows, dtype=np.float64,
is_sparse=False, variables=[]):
if not len(src_cols):
Expand Down Expand Up @@ -346,7 +350,7 @@ def get_columns(row_indices, src_cols, n_rows, dtype=np.float64,
a = np.empty((n_rows, len(src_cols)), dtype=dtype)
match_density = assure_column_dense

shared_cache = _conversion_cache
shared_cache = _thread_local.conversion_cache
for i, col in enumerate(src_cols):
if col is None:
a[:, i] = variables[i].Unknown
Expand Down Expand Up @@ -380,71 +384,70 @@ def get_columns(row_indices, src_cols, n_rows, dtype=np.float64,

return a

with _conversion_cache_lock:
new_cache = _conversion_cache is None
try:
if new_cache:
_conversion_cache = {}
else:
cached = _conversion_cache.get((id(domain), id(source)))
if cached:
return cached
if domain is source.domain:
table = cls.from_table_rows(source, row_indices)
# assure resulting domain is the instance passed on input
table.domain = domain
# since sparse flags are not considered when checking for
# domain equality, fix manually.
table = assure_domain_conversion_sparsity(table, source)
return table

if isinstance(row_indices, slice):
start, stop, stride = row_indices.indices(source.X.shape[0])
n_rows = (stop - start) // stride
if n_rows < 0:
n_rows = 0
elif row_indices is ...:
n_rows = len(source)
else:
n_rows = len(row_indices)

self = cls()
self.domain = domain
conversion = domain.get_conversion(source.domain)
self.X = get_columns(row_indices, conversion.attributes, n_rows,
is_sparse=conversion.sparse_X,
variables=domain.attributes)
if self.X.ndim == 1:
self.X = self.X.reshape(-1, len(self.domain.attributes))

self.Y = get_columns(row_indices, conversion.class_vars, n_rows,
is_sparse=conversion.sparse_Y,
variables=domain.class_vars)

dtype = np.float64
if any(isinstance(var, StringVariable) for var in domain.metas):
dtype = np.object
self.metas = get_columns(row_indices, conversion.metas,
n_rows, dtype,
is_sparse=conversion.sparse_metas,
variables=domain.metas)
if self.metas.ndim == 1:
self.metas = self.metas.reshape(-1, len(self.domain.metas))
if source.has_weights():
self.W = source.W[row_indices]
else:
self.W = np.empty((n_rows, 0))
self.name = getattr(source, 'name', '')
if hasattr(source, 'ids'):
self.ids = source.ids[row_indices]
else:
cls._init_ids(self)
self.attributes = getattr(source, 'attributes', {})
_conversion_cache[(id(domain), id(source))] = self
return self
finally:
if new_cache:
_conversion_cache = None
new_cache = _thread_local.conversion_cache is None
try:
if new_cache:
_thread_local.conversion_cache = {}
else:
cached = _thread_local.conversion_cache.get((id(domain), id(source)))
if cached:
return cached
if domain is source.domain:
table = cls.from_table_rows(source, row_indices)
# assure resulting domain is the instance passed on input
table.domain = domain
# since sparse flags are not considered when checking for
# domain equality, fix manually.
table = assure_domain_conversion_sparsity(table, source)
return table

if isinstance(row_indices, slice):
start, stop, stride = row_indices.indices(source.X.shape[0])
n_rows = (stop - start) // stride
if n_rows < 0:
n_rows = 0
elif row_indices is ...:
n_rows = len(source)
else:
n_rows = len(row_indices)

self = cls()
self.domain = domain
conversion = domain.get_conversion(source.domain)
self.X = get_columns(row_indices, conversion.attributes, n_rows,
is_sparse=conversion.sparse_X,
variables=domain.attributes)
if self.X.ndim == 1:
self.X = self.X.reshape(-1, len(self.domain.attributes))

self.Y = get_columns(row_indices, conversion.class_vars, n_rows,
is_sparse=conversion.sparse_Y,
variables=domain.class_vars)

dtype = np.float64
if any(isinstance(var, StringVariable) for var in domain.metas):
dtype = np.object
self.metas = get_columns(row_indices, conversion.metas,
n_rows, dtype,
is_sparse=conversion.sparse_metas,
variables=domain.metas)
if self.metas.ndim == 1:
self.metas = self.metas.reshape(-1, len(self.domain.metas))
if source.has_weights():
self.W = source.W[row_indices]
else:
self.W = np.empty((n_rows, 0))
self.name = getattr(source, 'name', '')
if hasattr(source, 'ids'):
self.ids = source.ids[row_indices]
else:
cls._init_ids(self)
self.attributes = getattr(source, 'attributes', {})
_thread_local.conversion_cache[(id(domain), id(source))] = self
return self
finally:
if new_cache:
_thread_local.conversion_cache = None

def transform(self, domain):
"""
Expand Down
32 changes: 32 additions & 0 deletions Orange/tests/test_table.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@
from unittest.mock import Mock, MagicMock, patch
from itertools import chain
from math import isnan
from threading import Thread
from time import sleep, time

import numpy as np
import scipy.sparse as sp
Expand Down Expand Up @@ -2661,5 +2663,35 @@ def test_from_table_sparse_metas_with_strings(self):
d = self.iris.transform(domain)
self.assertFalse(sp.issparse(d.metas))


class ConcurrencyTests(unittest.TestCase):

def test_from_table_non_blocking(self):
iris = Table("iris")

def slow_compute_value(d):
sleep(0.1)
return d.X[:, 0]

ndom = Domain([ContinuousVariable("a", compute_value=slow_compute_value)])

def run_from_table():
Table.from_table(ndom, iris)

start = time()

threads = []
for _ in range(10):
thread = Thread(target=run_from_table)
thread.start()
threads.append(thread)
for t in threads:
t.join()

# if from_table was blocking these threads would need at least 0.1*10s
duration = time() - start
self.assertLess(duration, 0.5)


if __name__ == "__main__":
unittest.main()