Skip to content

Commit

Permalink
MDS: Show Kruskal stress
Browse files Browse the repository at this point in the history
  • Loading branch information
janezd committed Feb 11, 2023
1 parent a7d215e commit 0d145ab
Show file tree
Hide file tree
Showing 2 changed files with 44 additions and 3 deletions.
24 changes: 22 additions & 2 deletions Orange/widgets/unsupervised/owmds.py
Original file line number Diff line number Diff line change
Expand Up @@ -200,6 +200,7 @@ def __init__(self):

self.embedding = None # type: Optional[np.ndarray]
self.effective_matrix = None # type: Optional[DistMatrix]
self.stress = None

self.size_model = self.gui.points_models[2]
self.size_model.order = \
Expand Down Expand Up @@ -241,6 +242,8 @@ def _add_controls_optimization(self):
sizePolicy=(QSizePolicy.MinimumExpanding, QSizePolicy.Fixed),
callback=self.__refresh_rate_combo_changed),
1, 1)
self.stress_label = QLabel("Kruskal Stress: -")
grid.addWidget(self.stress_label, 2, 0, 1, 3)

def __refresh_rate_combo_changed(self):
if self.task is not None:
Expand Down Expand Up @@ -392,17 +395,31 @@ def on_partial_result(self, result: Result):
if need_update:
self.graph.update_coordinates()
self.graph.update_density()
self.update_stress()

def on_done(self, result: Result):
assert isinstance(result.embedding, np.ndarray)
assert len(result.embedding) == len(self.effective_matrix)
self.embedding = result.embedding
self.graph.update_coordinates()
self.graph.update_density()
self.update_stress()
self.run_button.setText("Start")
self.step_button.setEnabled(True)
self.commit.deferred()

def update_stress(self):
self.stress = self._compute_stress()
stress_val = "-" if self.stress is None else f"{self.stress:.3f}"
self.stress_label.setText(f"Kruskal Stress: {stress_val}")

def _compute_stress(self):
if self.embedding is None or self.effective_matrix is None:
return None
point_stress = self.get_stress(self.embedding, self.effective_matrix)
return np.sqrt(2 * np.sum(point_stress)
/ (np.sum(self.effective_matrix ** 2) or 1))

def on_exception(self, ex: Exception):
if isinstance(ex, MemoryError):
self.Error.out_of_memory()
Expand Down Expand Up @@ -436,6 +453,7 @@ def jitter_coord(part):
# (Random or PCA), restarting the optimization if necessary.
if self.effective_matrix is None:
self.graph.reset_graph()
self.update_stress()
return

X = self.effective_matrix
Expand All @@ -451,6 +469,8 @@ def jitter_coord(part):
# restart the optimization if it was interrupted.
if self.task is not None:
self._run()
else:
self.update_stress()

def handleNewSignals(self):
self._initialize()
Expand All @@ -473,12 +493,12 @@ def setup_plot(self):

def get_size_data(self):
if self.attr_size == "Stress":
return self.stress(self.embedding, self.effective_matrix)
return self.get_stress(self.embedding, self.effective_matrix)
else:
return super().get_size_data()

@staticmethod
def stress(X, distD):
def get_stress(X, distD):
assert X.shape[0] == distD.shape[0] == distD.shape[1]
D1_c = scipy.spatial.distance.pdist(X, metric="euclidean")
D1 = scipy.spatial.distance.squareform(D1_c, checks=False)
Expand Down
23 changes: 22 additions & 1 deletion Orange/widgets/unsupervised/tests/test_owmds.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
# Test methods with long descriptive names can omit docstrings
# pylint: disable=missing-docstring
# pylint: disable=missing-docstring,protected-access
import os
from itertools import chain
import unittest
Expand Down Expand Up @@ -320,6 +320,27 @@ def test_matrix_columns_default_label(self):
label_text = self.widget.controls.attr_label.currentText()
self.assertEqual(label_text, "labels")

def test_update_stress(self):
w = self.widget
w.effective_matrix = np.array([[0, 4, 1],
[4, 0, 1],
[1, 1, 0]]) # sum of squares is 36
w.embedding = np.array([[0, 0],
[0, 3],
[4, 3]])
# dists [[0, 3, 5], diff [[0, 1, 4], sqr [[0, 1, 16], sum = 52
# [3, 0, 4], [1, 0, 3], [1, 0, 9],
# [5, 4, 0]] [4, 3, 0]] [16, 9, 0]]
w.update_stress()
expected = np.sqrt(52 / 36)
self.assertAlmostEqual(w._compute_stress(), expected)
self.assertIn(f"{expected:.3f}", w.stress_label.text())

w.embedding = None
w.update_stress()
self.assertIsNone(w._compute_stress())
self.assertIn("-", w.stress_label.text())


class TestOWMDSRunner(unittest.TestCase):
@classmethod
Expand Down

0 comments on commit 0d145ab

Please sign in to comment.