Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[FIX] owpca: Handle the case of 0 total variance in the PCA solution #1897

Merged
merged 2 commits into from
Jan 11, 2017
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
50 changes: 42 additions & 8 deletions Orange/widgets/unsupervised/owpca.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
import numbers

from AnyQt.QtWidgets import QFormLayout, QLineEdit
from AnyQt.QtGui import QColor
from AnyQt.QtCore import Qt, QTimer
Expand Down Expand Up @@ -41,6 +43,11 @@ class OWPCA(widget.OWWidget):

graph_name = "plot.plotItem"

class Warning(widget.OWWidget.Warning):
trivial_components = widget.Msg(
"All components of the PCA are trivial (explain 0 variance). "
"Input data is constant (or near constant).")

def __init__(self):
super().__init__()
self.data = None
Expand Down Expand Up @@ -186,12 +193,15 @@ def fit(self):
pca = self._pca_projector(data)
variance_ratio = pca.explained_variance_ratio_
cumulative = numpy.cumsum(variance_ratio)
self.components_spin.setRange(0, len(cumulative))

self._pca = pca
self._variance_ratio = variance_ratio
self._cumulative = cumulative
self._setup_plot()
if numpy.isfinite(cumulative[-1]):
self.components_spin.setRange(0, len(cumulative))
self._pca = pca
self._variance_ratio = variance_ratio
self._cumulative = cumulative
self._setup_plot()
else:
self.Warning.trivial_components()

self.unconditional_commit()

Expand All @@ -204,6 +214,7 @@ def clear(self):
self.plot_horlabels = []
self.plot_horlines = []
self.plot.clear()
self.Warning.trivial_components.clear()

def get_model(self):
if self.rpca is None:
Expand All @@ -222,6 +233,9 @@ def get_model(self):

def _setup_plot(self):
self.plot.clear()
if self._pca is None:
return

explained_ratio = self._variance_ratio
explained = self._cumulative
p = min(len(self._variance_ratio), self.maxp)
Expand Down Expand Up @@ -279,7 +293,9 @@ def _on_cut_changed(self, line):
self._set_horline_pos()

if self._pca is not None:
self.variance_covered = self._cumulative[components - 1] * 100
var = self._cumulative[components - 1]
if numpy.isfinite(var):
self.variance_covered = int(var * 100)

if current != self._nselected_components():
self._invalidate_selection()
Expand All @@ -295,7 +311,10 @@ def _update_selection_component_spin(self):
cut = len(self._variance_ratio)
else:
cut = self.ncomponents
self.variance_covered = self._cumulative[cut - 1] * 100

var = self._cumulative[cut - 1]
if numpy.isfinite(var):
self.variance_covered = int(var) * 100

if numpy.floor(self._line.value()) + 1 != cut:
self._line.setValue(cut - 1)
Expand Down Expand Up @@ -339,7 +358,8 @@ def _nselected_components(self):
var_max = self._cumulative[max_comp - 1]
if var_max != numpy.floor(self.variance_covered / 100.0):
cut = max_comp
self.variance_covered = var_max * 100
assert numpy.isfinite(var_max)
self.variance_covered = int(var_max * 100)
else:
self.ncomponents = cut = numpy.searchsorted(
self._cumulative, self.variance_covered / 100.0) + 1
Expand Down Expand Up @@ -391,6 +411,20 @@ def send_report(self):
))
self.report_plot()

@classmethod
def migrate_settings(cls, settings, version):
if "variance_covered" in settings:
# Due to the error in gh-1896 the variance_covered was persisted
# as a NaN value, causing a TypeError in the widgets `__init__`.
vc = settings["variance_covered"]
if isinstance(vc, numbers.Real):
if numpy.isfinite(vc):
vc = int(vc)
else:
vc = 100
settings["variance_covered"] = vc


def main():
import gc
from AnyQt.QtWidgets import QApplication
Expand Down
10 changes: 9 additions & 1 deletion Orange/widgets/unsupervised/tests/test_owpca.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,18 @@
class TestOWDistanceMatrix(WidgetTest):

def setUp(self):
self.widget = self.create_widget(OWPCA)
self.widget = self.create_widget(OWPCA) # type: OWPCA

def test_set_variance100(self):
iris = Table("iris")[:5]
self.widget.set_data(iris)
self.widget.variance_covered = 100
self.widget._update_selection_variance_spin()

def test_constant_data(self):
data = Table("iris")[::5]
data.X[:, :] = 1.0
self.send_signal("Data", data)
self.assertTrue(self.widget.Warning.trivial_components.is_shown())
self.assertIsNone(self.get_output("Transformed Data"))
self.assertIsNone(self.get_output("Components"))