Skip to content

Commit

Permalink
Merge pull request #6934 from VesnaT/scatterplort_errorbars
Browse files Browse the repository at this point in the history
[ENH] Scatter Plot: Error Bars
  • Loading branch information
lanzagar authored Dec 13, 2024
2 parents 4b1eabe + 92837cf commit ee97bd9
Show file tree
Hide file tree
Showing 7 changed files with 668 additions and 21 deletions.
14 changes: 14 additions & 0 deletions Orange/widgets/visualize/icons/interval-horizontal.svg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
14 changes: 14 additions & 0 deletions Orange/widgets/visualize/icons/interval-vertical.svg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
201 changes: 183 additions & 18 deletions Orange/widgets/visualize/owscatterplot.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import math
from typing import List, Callable
from typing import List, Callable, Optional
from xml.sax.saxutils import escape

import numpy as np
Expand All @@ -9,11 +9,12 @@
from sklearn.metrics import r2_score

from AnyQt.QtCore import Qt, QTimer, QPointF
from AnyQt.QtGui import QColor, QFont
from AnyQt.QtWidgets import QGroupBox
from AnyQt.QtGui import QColor, QFont, QFontMetrics
from AnyQt.QtWidgets import QGroupBox, QSizePolicy, QPushButton

import pyqtgraph as pg

from orangewidget.utils import load_styled_icon
from orangewidget.utils.combobox import ComboBoxSearch

from Orange.data import Table, Domain, DiscreteVariable, Variable
Expand All @@ -29,6 +30,7 @@
from Orange.widgets.utils.widgetpreview import WidgetPreview
from Orange.widgets.visualize.owscatterplotgraph import OWScatterPlotBase, \
ScatterBaseParameterSetter
from Orange.widgets.visualize.utils.error_bars_dialog import ErrorBarsDialog
from Orange.widgets.visualize.utils.vizrank import VizRankDialogAttrPair, \
VizRankMixin
from Orange.widgets.visualize.utils.customizableplot import Updater
Expand Down Expand Up @@ -150,15 +152,20 @@ def __init__(self, scatter_widget, parent):
self.parameter_setter = ParameterSetter(self)
self.reg_line_items = []
self.ellipse_items: List[pg.PlotCurveItem] = []
self.error_bars_items: List[pg.ErrorBarItem] = []
self.view_box.sigResized.connect(self.update_error_bars)
self.view_box.sigRangeChanged.connect(self.update_error_bars)

def clear(self):
super().clear()
self.reg_line_items.clear()
self.ellipse_items.clear()
self.error_bars_items.clear()

def update_coordinates(self):
super().update_coordinates()
self.update_axes()
self.update_error_bars()
# Don't update_regression line here: update_coordinates is always
# followed by update_point_props, which calls update_colors

Expand All @@ -168,6 +175,9 @@ def update_colors(self):
self.update_ellipse()

def jitter_coordinates(self, x, y):
if self.jitter_size == 0:
return x, y

def get_span(attr):
if attr.is_discrete:
# Assuming the maximal jitter size is 10, a span of 4 will
Expand All @@ -179,7 +189,7 @@ def get_span(attr):
return 0 # No jittering
span_x = get_span(self.master.attr_x)
span_y = get_span(self.master.attr_y)
if self.jitter_size == 0 or (span_x == 0 and span_y == 0):
if span_x == 0 and span_y == 0:
return x, y
return self._jitter_data(x, y, span_x, span_y)

Expand Down Expand Up @@ -333,6 +343,42 @@ def _add_ellipse(self, x: np.ndarray, y: np.ndarray, color: QColor) -> np.ndarra
self.plot_widget.addItem(ellipse)
self.ellipse_items.append(ellipse)

def update_jittering(self):
super().update_jittering()
self.update_error_bars()

def update_error_bars(self):
for item in self.error_bars_items:
self.plot_widget.removeItem(item)
self.error_bars_items.clear()
if not self.master.can_draw_regression_line():
return

x, y = self.get_coordinates()
if x is None:
return

top, bottom, left, right = self.master.get_errors_data()
if top is None and bottom is None and left is None and right is None:
return

px, py = self.view_box.viewPixelSize()
pen = pg.mkPen(color=QColor("#505050"))

# x axis
error_bars = pg.ErrorBarItem(x=x, y=y, left=left, right=right,
beam=py * 10, pen=pen)
error_bars.setZValue(-1)
self.plot_widget.addItem(error_bars)
self.error_bars_items.append(error_bars)

# y axis
error_bars = pg.ErrorBarItem(x=x, y=y, top=top, bottom=bottom,
beam=px * 10, pen=pen)
error_bars.setZValue(-1)
self.plot_widget.addItem(error_bars)
self.error_bars_items.append(error_bars)


class OWScatterPlot(OWDataProjectionWidget, VizRankMixin(ScatterPlotVizRank)):
"""Scatterplot visualization with explorative analysis and intelligent
Expand All @@ -355,6 +401,12 @@ class Outputs(OWDataProjectionWidget.Outputs):
auto_sample = Setting(True)
attr_x = ContextSetting(None)
attr_y = ContextSetting(None)
attr_x_upper = ContextSetting(None)
attr_x_lower = ContextSetting(None)
attr_x_is_abs = Setting(False)
attr_y_upper = ContextSetting(None)
attr_y_lower = ContextSetting(None)
attr_y_is_abs = Setting(False)
tooltip_shows_all = Setting(True)

GRAPH_CLASS = OWScatterPlotGraph
Expand All @@ -376,6 +428,10 @@ def __init__(self):
self.xy_model: DomainModel = None
self.cb_attr_x: ComboBoxSearch = None
self.cb_attr_y: ComboBoxSearch = None
self.button_attr_x: QPushButton = None
self.button_attr_y: QPushButton = None
self.__x_axis_dlg: ErrorBarsDialog = None
self.__y_axis_dlg: ErrorBarsDialog = None
self.sampling: QGroupBox = None
self._xy_invalidated: bool = True

Expand Down Expand Up @@ -425,37 +481,94 @@ def _add_controls_axis(self):
spacing=2 if gui.is_macstyle() else 8)
dmod = DomainModel
self.xy_model = DomainModel(dmod.MIXED, valid_types=dmod.PRIMITIVE)

hor_icon, ver_icon = self.__get_bar_icons()
width = 3 * QFontMetrics(self.font()).horizontalAdvance("m")
hbox = gui.hBox(self.attr_box, spacing=0)
self.cb_attr_x = gui.comboBox(
self.attr_box, self, "attr_x", label="Axis x:",
hbox, self, "attr_x", label="Axis x:",
callback=self.set_attr_from_combo,
model=self.xy_model, **common_options,
)
self.button_attr_x = gui.button(
hbox, self, "", callback=self.__on_x_button_clicked,
autoDefault=False, width=width, enabled=False,
sizePolicy=(QSizePolicy.Fixed, QSizePolicy.Fixed)
)
self.button_attr_x.setIcon(hor_icon)

hbox = gui.hBox(self.attr_box, spacing=0)
self.cb_attr_y = gui.comboBox(
self.attr_box, self, "attr_y", label="Axis y:",
hbox, self, "attr_y", label="Axis y:",
callback=self.set_attr_from_combo,
model=self.xy_model, **common_options,
)
self.button_attr_y = gui.button(
hbox, self, "", callback=self.__on_y_button_clicked,
autoDefault=False, width=width, enabled=False,
sizePolicy=(QSizePolicy.Fixed, QSizePolicy.Fixed)
)
self.button_attr_y.setIcon(ver_icon)

vizrank_box = gui.hBox(self.attr_box)
button = self.vizrank_button("Find Informative Projections")
vizrank_box.layout().addWidget(button)
self.vizrankSelectionChanged.connect(self.set_attr)

self.__x_axis_dlg = ErrorBarsDialog(self)
self.__x_axis_dlg.changed.connect(self.__on_x_dlg_changed)
self.__y_axis_dlg = ErrorBarsDialog(self)
self.__y_axis_dlg.changed.connect(self.__on_y_dlg_changed)

def __on_x_button_clicked(self):
self.__show_bars_dlg(
self.__x_axis_dlg, self.button_attr_x,
self.attr_x_upper, self.attr_x_lower, self.attr_x_is_abs)

def __on_y_button_clicked(self):
self.__show_bars_dlg(
self.__y_axis_dlg, self.button_attr_y,
self.attr_y_upper, self.attr_y_lower, self.attr_y_is_abs)

def __show_bars_dlg(self, dlg, button, upper, lower, is_abs):
pos = button.mapToGlobal(button.rect().bottomLeft())
dlg.show_dlg(self.data.domain,
pos.x(), pos.y(),
upper, lower, is_abs)

def __on_x_dlg_changed(self):
self.attr_x_upper, self.attr_x_lower, self.attr_x_is_abs = \
self.__x_axis_dlg.get_data()
self.graph.update_error_bars()

def __on_y_dlg_changed(self):
self.attr_y_upper, self.attr_y_lower, self.attr_y_is_abs = \
self.__y_axis_dlg.get_data()
self.graph.update_error_bars()

def _add_controls_sampling(self):
self.sampling = gui.auto_commit(
self.controlArea, self, "auto_sample", "Sample", box="Sampling",
callback=self.switch_sampling, commit=lambda: self.add_data(1))
self.sampling.setVisible(False)

@property
def effective_variables(self):
return [self.attr_x, self.attr_y] if self.attr_x and self.attr_y else []
def effective_variables(self) -> list[Variable]:
variables = []
if self.attr_x and self.attr_y:
variables.append(self.attr_x)
if self.attr_x.name != self.attr_y.name:
variables.append(self.attr_y)
for var in (self.attr_x_upper, self.attr_x_lower,
self.attr_y_upper, self.attr_y_lower):
# set is not used to preserve order
if var and var not in variables:
variables.append(var)
return variables

@property
def effective_data(self):
eff_var = self.effective_variables
if eff_var and self.attr_x.name == self.attr_y.name:
eff_var = [self.attr_x]
return self.data.transform(Domain(eff_var))
return self.data.transform(Domain(self.effective_variables))

def init_vizrank(self):
err_msg = ""
Expand Down Expand Up @@ -523,6 +636,14 @@ def check_data(self):
len(self.data.domain.variables) == 0):
self.data = None

def enable_controls(self):
super().enable_controls()
enabled = bool(self.data) and \
self.data.domain.has_continuous_attributes(include_class=True,
include_metas=True)
self.button_attr_x.setEnabled(enabled)
self.button_attr_y.setEnabled(enabled)

def get_embedding(self):
self.valid_data = None
if self.data is None:
Expand All @@ -541,6 +662,31 @@ def get_embedding(self):
msg.missing_coords(self.attr_x.name, self.attr_y.name)
return np.vstack((x_data, y_data)).T

def get_errors_data(self) -> tuple[
Optional[np.ndarray], Optional[np.ndarray],
Optional[np.ndarray], Optional[np.ndarray]
]:
x_data = self.get_column(self.attr_x)
y_data = self.get_column(self.attr_y)
top, bottom, left, right = [None] * 4
if self.attr_x_upper:
right = self.get_column(self.attr_x_upper)
if self.attr_x_is_abs:
right = right - x_data
if self.attr_x_lower:
left = self.get_column(self.attr_x_lower)
if self.attr_x_is_abs:
left = x_data - left
if self.attr_y_upper:
top = self.get_column(self.attr_y_upper)
if self.attr_y_is_abs:
top = top - y_data
if self.attr_y_lower:
bottom = self.get_column(self.attr_y_lower)
if self.attr_y_is_abs:
bottom = y_data - bottom
return top, bottom, left, right

# Tooltip
def _point_tooltip(self, point_id, skip_attrs=()):
point_data = self.data[point_id]
Expand Down Expand Up @@ -580,6 +726,8 @@ def init_attr_values(self):
self.attr_x = self.xy_model[0] if self.xy_model else None
self.attr_y = self.xy_model[1] if len(self.xy_model) >= 2 \
else self.attr_x
self.attr_x_upper, self.attr_x_lower = None, None
self.attr_y_upper, self.attr_y_lower = None, None

def switch_sampling(self):
self.__timer.stop()
Expand All @@ -588,15 +736,15 @@ def switch_sampling(self):
self.__timer.start()

@OWDataProjectionWidget.Inputs.data_subset
def set_subset_data(self, subset_data):
def set_subset_data(self, subset: Optional[Table]):
self.warning()
if isinstance(subset_data, SqlTable):
if subset_data.approx_len() < AUTO_DL_LIMIT:
subset_data = Table(subset_data)
if isinstance(subset, SqlTable):
if subset.approx_len() < AUTO_DL_LIMIT:
subset = Table(subset)
else:
self.warning("Data subset does not support large Sql tables")
subset_data = None
super().set_subset_data(subset_data)
subset = None
super().set_subset_data(subset)

# called when all signals are received, so the graph is updated only once
def handleNewSignals(self):
Expand All @@ -608,12 +756,17 @@ def handleNewSignals(self):
self.attr_x, self.attr_y = self.attribute_selection_list
else:
self.attr_x, self.attr_y = None, None
self.attr_x_upper, self.attr_x_lower = None, None
self.attr_y_upper, self.attr_y_lower = None, None
self._invalidated = self._invalidated or self._xy_invalidated
self._xy_invalidated = False
super().handleNewSignals()
if self._domain_invalidated:
self.graph.update_axes()
self.graph.update_error_bars()
self._domain_invalidated = False
if self.attribute_selection_list:
self.graph.update_error_bars()
can_plot = self.can_draw_regression_line()
self.cb_reg_line.setEnabled(can_plot)
self.graph.controls.show_ellipse.setEnabled(can_plot)
Expand Down Expand Up @@ -706,6 +859,18 @@ def migrate_context(cls, context, version):
if values["attr_x"][1] % 100 == 1 or values["attr_y"][1] % 100 == 1:
raise IncompatibleContext()

__HorizontalBarIcon = None
__VerticalBarIcon = None

@classmethod
def __get_bar_icons(cls):
if cls.__HorizontalBarIcon is None:
cls.__HorizontalBarIcon = load_styled_icon(
"Orange.widgets.visualize", "icons/interval-horizontal.svg")
cls.__VerticalBarIcon = load_styled_icon(
"Orange.widgets.visualize", "icons/interval-vertical.svg")
return cls.__HorizontalBarIcon, cls.__VerticalBarIcon


if __name__ == "__main__": # pragma: no cover
table = Table("iris")
Expand Down
Loading

0 comments on commit ee97bd9

Please sign in to comment.