-
-
Notifications
You must be signed in to change notification settings - Fork 1k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
VizRankDialog: Use extended thread pool to prevent segfaults
When fed large datasets, correlations widget exited with segmentation fault, (probably) due to insufficient stack size for created task.
- Loading branch information
Showing
5 changed files
with
251 additions
and
89 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,132 @@ | ||
from itertools import chain | ||
import unittest | ||
from unittest.mock import Mock | ||
from queue import Queue | ||
|
||
from AnyQt.QtGui import QStandardItem | ||
|
||
from Orange.data import Table | ||
from Orange.widgets.visualize.utils import ( | ||
VizRankDialog, Result, run_vizrank, QueuedScore | ||
) | ||
from Orange.widgets.tests.base import WidgetTest | ||
|
||
|
||
def compute_score(x): | ||
return (x[0] + 1) / (x[1] + 1) | ||
|
||
|
||
class TestRunner(unittest.TestCase): | ||
@classmethod | ||
def setUpClass(cls): | ||
cls.data = Table("iris") | ||
|
||
def test_Result(self): | ||
res = Result(queue=Queue(), scores=[]) | ||
self.assertIsInstance(res.queue, Queue) | ||
self.assertIsInstance(res.scores, list) | ||
|
||
def test_run_vizrank(self): | ||
scores, task = [], Mock() | ||
# run through all states | ||
task.is_interruption_requested.return_value = False | ||
states = [(0, 1), (0, 2), (0, 3), (1, 2), (1, 3), (2, 3)] | ||
res = run_vizrank(compute_score, chain(states), scores, task) | ||
|
||
next_state = self.assertQueueEqual( | ||
res.queue, [0, 0, 0, 3, 2, 5], compute_score, | ||
states, states[1:] + [None]) | ||
self.assertIsNone(next_state) | ||
res_scores = sorted([compute_score(x) for x in states]) | ||
self.assertListEqual(res.scores, res_scores) | ||
self.assertIsNot(scores, res.scores) | ||
self.assertEqual(task.set_partial_result.call_count, 6) | ||
|
||
def test_run_vizrank_interrupt(self): | ||
scores, task = [], Mock() | ||
# interrupt calculation in third iteration | ||
task.is_interruption_requested.side_effect = lambda: \ | ||
True if task.is_interruption_requested.call_count > 2 else False | ||
states = [(0, 1), (0, 2), (0, 3), (1, 2), (1, 3), (2, 3)] | ||
res = run_vizrank(compute_score, chain(states), scores, task) | ||
|
||
next_state = self.assertQueueEqual( | ||
res.queue, [0, 0], compute_score, states[:2], states[1:3]) | ||
self.assertEqual(next_state, (0, 3)) | ||
res_scores = sorted([compute_score(x) for x in states[:2]]) | ||
self.assertListEqual(res.scores, res_scores) | ||
self.assertIsNot(scores, res.scores) | ||
self.assertEqual(task.set_partial_result.call_count, 2) | ||
|
||
# continue calculation through all states | ||
task.is_interruption_requested.side_effect = lambda: False | ||
i = states.index(next_state) | ||
res = run_vizrank(compute_score, chain(states[i:]), res_scores, task) | ||
|
||
next_state = self.assertQueueEqual( | ||
res.queue, [0, 3, 2, 5], compute_score, states[2:], | ||
states[3:] + [None]) | ||
self.assertIsNone(next_state) | ||
res_scores = sorted([compute_score(x) for x in states]) | ||
self.assertListEqual(res.scores, res_scores) | ||
self.assertIsNot(scores, res.scores) | ||
self.assertEqual(task.set_partial_result.call_count, 6) | ||
|
||
def assertQueueEqual(self, queue, positions, f, states, next_states): | ||
self.assertIsInstance(queue, Queue) | ||
for qs in (QueuedScore(position=p, score=f(s), state=s, next_state=ns) | ||
for p, s, ns in zip(positions, states, next_states)): | ||
result = queue.get_nowait() | ||
self.assertEqual(result.position, qs.position) | ||
self.assertEqual(result.state, qs.state) | ||
self.assertEqual(result.next_state, qs.next_state) | ||
self.assertEqual(result.score, qs.score) | ||
next_state = result.next_state | ||
return next_state | ||
|
||
|
||
class TestVizRankDialog(WidgetTest): | ||
def test_on_partial_result(self): | ||
def iterate_states(initial_state): | ||
if initial_state is not None: | ||
return chain(states[states.index(initial_state):]) | ||
return chain(states) | ||
|
||
def invoke_on_partial_result(): | ||
widget.on_partial_result(run_vizrank( | ||
widget.compute_score, | ||
widget.iterate_states(widget.saved_state), | ||
widget.scores, task | ||
)) | ||
|
||
task = Mock() | ||
states = [(0, 1), (0, 2), (0, 3), (1, 2), (1, 3), (2, 3)] | ||
|
||
widget = VizRankDialog(None) | ||
widget.progressBarInit() | ||
widget.compute_score = compute_score | ||
widget.iterate_states = iterate_states | ||
widget.row_for_state = lambda sc, _: [QStandardItem(str(sc))] | ||
|
||
# interrupt calculation in third iteration | ||
task.is_interruption_requested.side_effect = lambda: \ | ||
True if task.is_interruption_requested.call_count > 2 else False | ||
invoke_on_partial_result() | ||
self.assertEqual(widget.rank_model.rowCount(), 2) | ||
for row, score in enumerate( | ||
sorted([compute_score(x) for x in states[:2]])): | ||
self.assertEqual(widget.rank_model.item(row, 0).text(), str(score)) | ||
self.assertEqual(widget.saved_progress, 2) | ||
|
||
# continue calculation through all states | ||
task.is_interruption_requested.side_effect = lambda: False | ||
invoke_on_partial_result() | ||
self.assertEqual(widget.rank_model.rowCount(), 6) | ||
for row, score in enumerate( | ||
sorted([compute_score(x) for x in states])): | ||
self.assertEqual(widget.rank_model.item(row, 0).text(), str(score)) | ||
self.assertEqual(widget.saved_progress, 6) | ||
|
||
|
||
if __name__ == "__main__": | ||
unittest.main() |
Oops, something went wrong.