Skip to content

Commit

Permalink
MDS: Show Kruskal stress
Browse files Browse the repository at this point in the history
  • Loading branch information
janezd committed Jan 23, 2023
1 parent a25a9d2 commit ff64af7
Show file tree
Hide file tree
Showing 2 changed files with 30 additions and 0 deletions.
18 changes: 18 additions & 0 deletions Orange/widgets/unsupervised/owmds.py
Original file line number Diff line number Diff line change
Expand Up @@ -241,6 +241,8 @@ def _add_controls_optimization(self):
sizePolicy=(QSizePolicy.MinimumExpanding, QSizePolicy.Fixed),
callback=self.__refresh_rate_combo_changed),
1, 1)
self.stress_label = QLabel("Kruskal Stress: -")
grid.addWidget(self.stress_label, 2, 0, 1, 3)

def __refresh_rate_combo_changed(self):
if self.task is not None:
Expand Down Expand Up @@ -392,17 +394,30 @@ def on_partial_result(self, result: Result):
if need_update:
self.graph.update_coordinates()
self.graph.update_density()
self.update_stress()

def on_done(self, result: Result):
assert isinstance(result.embedding, np.ndarray)
assert len(result.embedding) == len(self.effective_matrix)
self.embedding = result.embedding
self.graph.update_coordinates()
self.graph.update_density()
self.update_stress()
self.run_button.setText("Start")
self.step_button.setEnabled(True)
self.commit.deferred()

def update_stress(self):
if self.embedding is None or self.effective_matrix is None:
self.stress_label.setText(f"Kruskal Stress: -")
return

actual = scipy.spatial.distance.pdist(self.embedding)
actual = scipy.spatial.distance.squareform(actual)
stress = np.sqrt(np.sum((actual - self.effective_matrix) ** 2)
/ (np.sum(self.effective_matrix ** 2) or 1))
self.stress_label.setText(f"Kruskal Stress: {stress:.3f}")

def on_exception(self, ex: Exception):
if isinstance(ex, MemoryError):
self.Error.out_of_memory()
Expand Down Expand Up @@ -436,6 +451,7 @@ def jitter_coord(part):
# (Random or PCA), restarting the optimization if necessary.
if self.effective_matrix is None:
self.graph.reset_graph()
self.update_stress()
return

X = self.effective_matrix
Expand All @@ -451,6 +467,8 @@ def jitter_coord(part):
# restart the optimization if it was interrupted.
if self.task is not None:
self._run()
else:
self.update_stress()

def handleNewSignals(self):
self._initialize()
Expand Down
12 changes: 12 additions & 0 deletions Orange/widgets/unsupervised/tests/test_owmds.py
Original file line number Diff line number Diff line change
Expand Up @@ -320,6 +320,18 @@ def test_matrix_columns_default_label(self):
label_text = self.widget.controls.attr_label.currentText()
self.assertEqual(label_text, "labels")

def test_update_stress(self):
w = self.widget
w.effective_matrix = np.array([[0, 4, 1],
[4, 0, 1],
[1, 1, 0]]) # sum of squares is 36
w.embedding = [[0, 0], [0, 3],
[4, 3]]
# dists [[0, 3, 5], diff [[0, 1, 4], sqr [[0, 1, 16], sum = 52
# [3, 0, 4], [1, 0, 3], [1, 0, 9],
# [5, 4, 0]] [4, 3, 0]] [16, 9, 0]]
w.update_stress()
self.assertIn(f"{np.sqrt(52 / 36):.3f}", w.stress_label.text())

class TestOWMDSRunner(unittest.TestCase):
@classmethod
Expand Down

0 comments on commit ff64af7

Please sign in to comment.