Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[FIX] VizRankDialog: Use extended thread pool to prevent segfaults #3669

Merged
merged 1 commit into from
Mar 18, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 14 additions & 10 deletions Orange/widgets/data/owcorrelations.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ class KMeansCorrelationHeuristic:
def __init__(self, data):
self.n_attributes = len(data.domain.attributes)
self.data = data
self.states = None
self.clusters = None
self.n_clusters = int(np.sqrt(self.n_attributes))

def get_clusters_of_attributes(self):
Expand All @@ -84,16 +84,15 @@ def get_states(self, initial_state):
:param initial_state: initial state; None if this is the first call
:return: generator of tuples of states
"""
if self.states is not None:
return chain([initial_state], self.states)

clusters = self.get_clusters_of_attributes()
if self.clusters is None:
self.clusters = self.get_clusters_of_attributes()
clusters = self.clusters

# combinations within clusters
self.states = chain.from_iterable(combinations(cluster.instances, 2)
for cluster in clusters)
states0 = chain.from_iterable(combinations(cluster.instances, 2)
for cluster in clusters)
if self.n_clusters == 1:
return self.states
return states0

# combinations among clusters - closest clusters first
centroids = [c.centroid for c in clusters]
Expand All @@ -104,8 +103,13 @@ def get_states(self, initial_state):
states = ((min((c1, c2)), max((c1, c2))) for i in np.argsort(distances)
for c1 in clusters[cluster_combs[i][0]].instances
for c2 in clusters[cluster_combs[i][1]].instances)
self.states = chain(self.states, states)
return self.states
states = chain(states0, states)

if initial_state is not None:
while next(states) != initial_state:
pass
return chain([initial_state], states)
return states


class CorrelationRank(VizRankDialogAttrPair):
Expand Down
12 changes: 12 additions & 0 deletions Orange/widgets/data/tests/test_owcorrelations.py
Original file line number Diff line number Diff line change
Expand Up @@ -308,6 +308,14 @@ def test_row_for_state(self):
self.assertEqual(row[1].data(Qt.DisplayRole), self.attrs[0].name)
self.assertEqual(row[2].data(Qt.DisplayRole), self.attrs[1].name)

def test_iterate_states(self):
self.assertListEqual(list(self.vizrank.iterate_states(None)),
[(1, 0), (2, 0), (2, 1), (3, 0), (3, 1), (3, 2)])
self.assertListEqual(list(self.vizrank.iterate_states((1, 0))),
[(1, 0), (2, 0), (2, 1), (3, 0), (3, 1), (3, 2)])
self.assertListEqual(list(self.vizrank.iterate_states((2, 1))),
[(2, 1), (3, 0), (3, 1), (3, 2)])

def test_iterate_states_by_feature(self):
self.vizrank.sel_feature_index = 2
states = self.vizrank.iterate_states_by_feature()
Expand Down Expand Up @@ -345,3 +353,7 @@ def test_get_states_one_cluster(self):
states = set(heuristic.get_states(None))
self.assertEqual(len(states), 1)
self.assertSetEqual(states, {(0, 1)})


if __name__ == "__main__":
unittest.main()
17 changes: 10 additions & 7 deletions Orange/widgets/visualize/tests/test_owlinearprojection.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
from Orange.widgets.visualize.owlinearprojection import (
OWLinearProjection, LinearProjectionVizRank
)
from Orange.widgets.visualize.utils import Worker
from Orange.widgets.visualize.utils import run_vizrank


class TestOWLinearProjection(WidgetTest, AnchorProjectionWidgetTestMixin,
Expand Down Expand Up @@ -205,16 +205,14 @@ def setUp(self):

def test_discrete_class(self):
self.send_signal(self.widget.Inputs.data, self.data)
worker = Worker(self.vizrank)
self.vizrank.keep_running = True
worker.do_work()
run_vizrank(self.vizrank.compute_score,
self.vizrank.iterate_states(None), [], Mock())

def test_continuous_class(self):
data = Table("housing")[::100]
self.send_signal(self.widget.Inputs.data, data)
worker = Worker(self.vizrank)
self.vizrank.keep_running = True
worker.do_work()
run_vizrank(self.vizrank.compute_score,
self.vizrank.iterate_states(None), [], Mock())

def test_set_attrs(self):
self.send_signal(self.widget.Inputs.data, self.data)
Expand All @@ -230,3 +228,8 @@ def test_set_attrs(self):
self.assertNotEqual(self.widget.model_selected[:], model_selected)
c2 = self.get_output(self.widget.Outputs.components)
self.assertNotEqual(c1.domain.attributes, c2.domain.attributes)


if __name__ == "__main__":
import unittest
unittest.main()
132 changes: 132 additions & 0 deletions Orange/widgets/visualize/tests/test_vizrankdialog.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,132 @@
from itertools import chain
import unittest
from unittest.mock import Mock
from queue import Queue

from AnyQt.QtGui import QStandardItem

from Orange.data import Table
from Orange.widgets.visualize.utils import (
VizRankDialog, Result, run_vizrank, QueuedScore
)
from Orange.widgets.tests.base import WidgetTest


def compute_score(x):
return (x[0] + 1) / (x[1] + 1)


class TestRunner(unittest.TestCase):
@classmethod
def setUpClass(cls):
cls.data = Table("iris")

def test_Result(self):
res = Result(queue=Queue(), scores=[])
self.assertIsInstance(res.queue, Queue)
self.assertIsInstance(res.scores, list)

def test_run_vizrank(self):
scores, task = [], Mock()
# run through all states
task.is_interruption_requested.return_value = False
states = [(0, 1), (0, 2), (0, 3), (1, 2), (1, 3), (2, 3)]
res = run_vizrank(compute_score, chain(states), scores, task)

next_state = self.assertQueueEqual(
res.queue, [0, 0, 0, 3, 2, 5], compute_score,
states, states[1:] + [None])
self.assertIsNone(next_state)
res_scores = sorted([compute_score(x) for x in states])
self.assertListEqual(res.scores, res_scores)
self.assertIsNot(scores, res.scores)
self.assertEqual(task.set_partial_result.call_count, 6)

def test_run_vizrank_interrupt(self):
scores, task = [], Mock()
# interrupt calculation in third iteration
task.is_interruption_requested.side_effect = lambda: \
True if task.is_interruption_requested.call_count > 2 else False
states = [(0, 1), (0, 2), (0, 3), (1, 2), (1, 3), (2, 3)]
res = run_vizrank(compute_score, chain(states), scores, task)

next_state = self.assertQueueEqual(
res.queue, [0, 0], compute_score, states[:2], states[1:3])
self.assertEqual(next_state, (0, 3))
res_scores = sorted([compute_score(x) for x in states[:2]])
self.assertListEqual(res.scores, res_scores)
self.assertIsNot(scores, res.scores)
self.assertEqual(task.set_partial_result.call_count, 2)

# continue calculation through all states
task.is_interruption_requested.side_effect = lambda: False
i = states.index(next_state)
res = run_vizrank(compute_score, chain(states[i:]), res_scores, task)

next_state = self.assertQueueEqual(
res.queue, [0, 3, 2, 5], compute_score, states[2:],
states[3:] + [None])
self.assertIsNone(next_state)
res_scores = sorted([compute_score(x) for x in states])
self.assertListEqual(res.scores, res_scores)
self.assertIsNot(scores, res.scores)
self.assertEqual(task.set_partial_result.call_count, 6)

def assertQueueEqual(self, queue, positions, f, states, next_states):
self.assertIsInstance(queue, Queue)
for qs in (QueuedScore(position=p, score=f(s), state=s, next_state=ns)
for p, s, ns in zip(positions, states, next_states)):
result = queue.get_nowait()
self.assertEqual(result.position, qs.position)
self.assertEqual(result.state, qs.state)
self.assertEqual(result.next_state, qs.next_state)
self.assertEqual(result.score, qs.score)
next_state = result.next_state
return next_state


class TestVizRankDialog(WidgetTest):
def test_on_partial_result(self):
def iterate_states(initial_state):
if initial_state is not None:
return chain(states[states.index(initial_state):])
return chain(states)

def invoke_on_partial_result():
widget.on_partial_result(run_vizrank(
widget.compute_score,
widget.iterate_states(widget.saved_state),
widget.scores, task
))

task = Mock()
states = [(0, 1), (0, 2), (0, 3), (1, 2), (1, 3), (2, 3)]

widget = VizRankDialog(None)
widget.progressBarInit()
widget.compute_score = compute_score
widget.iterate_states = iterate_states
widget.row_for_state = lambda sc, _: [QStandardItem(str(sc))]

# interrupt calculation in third iteration
task.is_interruption_requested.side_effect = lambda: \
True if task.is_interruption_requested.call_count > 2 else False
invoke_on_partial_result()
self.assertEqual(widget.rank_model.rowCount(), 2)
for row, score in enumerate(
sorted([compute_score(x) for x in states[:2]])):
self.assertEqual(widget.rank_model.item(row, 0).text(), str(score))
self.assertEqual(widget.saved_progress, 2)

# continue calculation through all states
task.is_interruption_requested.side_effect = lambda: False
invoke_on_partial_result()
self.assertEqual(widget.rank_model.rowCount(), 6)
for row, score in enumerate(
sorted([compute_score(x) for x in states])):
self.assertEqual(widget.rank_model.item(row, 0).text(), str(score))
self.assertEqual(widget.saved_progress, 6)


if __name__ == "__main__":
unittest.main()
Loading