Skip to content

Commit

Permalink
Polynomial Regression: Refactoring, modernizing
Browse files Browse the repository at this point in the history
  • Loading branch information
janezd committed Jul 11, 2021
1 parent 198c687 commit 8ba0dc2
Show file tree
Hide file tree
Showing 2 changed files with 68 additions and 84 deletions.
121 changes: 54 additions & 67 deletions orangecontrib/educational/widgets/owpolynomialregression.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ def fit(*args, **kwargs):
return MeanModel(Continuous(np.empty(0)))


class OWUnivariateRegression(OWBaseLearner):
class OWPolynomialRegression(OWBaseLearner):
name = "Polynomial Regression"
description = "Univariate regression with polynomial expansion."
keywords = ["polynomial regression", "regression",
Expand All @@ -50,8 +50,8 @@ class Outputs(OWBaseLearner.Outputs):
replaces = [
"Orange.widgets.regression.owunivariateregression."
"OWUnivariateRegression",
"orangecontrib.prototypes.widgets.owpolynomialregression."
"OWPolynomialRegression"
"orangecontrib.prototypes.widgets.owpolynomialregression.",
"orangecontrib.educational.widgets.owunivariateregression."
]

LEARNER = PolynomialLearner
Expand All @@ -77,21 +77,22 @@ class Outputs(OWBaseLearner.Outputs):
graph_name = 'plot'

class Error(OWBaseLearner.Error):
all_none = Msg("One of the features has no defined values.")
no_cont_variables = Msg("Polynomial Regression requires at least two numeric variables.")
same_dep_indepvar = Msg("Dependent and independent variables must be differnt.")

def add_main_layout(self):

all_none = \
Msg("All rows have undefined data.")
no_cont_variables =\
Msg("Regression requires at least two numeric variables.")
same_dep_indepvar =\
Msg("Dependent and independent variables must be differnt.")

def __init__(self):
super().__init__()
self.data = None
self.learner = None

self.scatterplot_item = None
self.plot_item = None

self.x_label = 'x'
self.y_label = 'y'

def add_main_layout(self):
self.rmse = ""
self.mae = ""
self.regressor_name = self.default_learner_name
Expand All @@ -101,76 +102,53 @@ def add_main_layout(self):
order=DomainModel.MIXED)

box = gui.vBox(self.controlArea, "Predictor")
self.comboBoxAttributesX = gui.comboBox(
gui.comboBox(
box, self, value='x_var', model=self.var_model, callback=self.apply)
self.expansion_spin = gui.spin(
gui.spin(
box, self, "polynomialexpansion", label="Polynomial degree: ",
minv=0, maxv=10, alignment=Qt.AlignmentFlag.AlignRight,
callback=self.apply)
gui.checkBox(
box, self, "fit_intercept",
label="Fit intercept", callback=self.apply, stateWhenDisabled=True,
tooltip="Add an intercept term;\n"
"This option is always checked if the model is set on input."
)
"This is always checked if the model is defined on input.")

box = gui.vBox(self.controlArea, "Target")
self.comboBoxAttributesY = gui.comboBox(
gui.comboBox(
box, self, value="y_var", model=self.var_model, callback=self.apply)

self.error_bars_checkbox = gui.checkBox(
gui.checkBox(
widget=box, master=self, value='error_bars_enabled',
label="Show error bars", callback=self.apply)

gui.rubber(self.controlArea)

# info box
info_box = gui.vBox(self.controlArea, "Info")
self.regressor_label = gui.label(
gui.label(
widget=info_box, master=self,
label="Regressor: %(regressor_name).30s")
gui.label(widget=info_box, master=self,
label="Mean absolute error: %(mae).6s")
gui.label(widget=info_box, master=self,
label="Root mean square error: %(rmse).6s")

gui.label(
widget=info_box, master=self,
label="Mean absolute error: %(mae).6s")
gui.label(
widget=info_box, master=self,
label="Root mean square error: %(rmse).6s")

# main area GUI
self.plotview = pg.PlotWidget(background="w")
self.plot = self.plotview.getPlotItem()

axis_color = self.palette().color(QPalette.Text)
axis_pen = QPen(axis_color)

axis_pen = QPen(self.palette().color(QPalette.Text))
tickfont = QFont(self.font())
tickfont.setPixelSize(max(int(tickfont.pixelSize() * 2 // 3), 11))

axis = self.plot.getAxis("bottom")
axis.setLabel(self.x_label)
axis.setPen(axis_pen)
axis.setTickFont(tickfont)

axis = self.plot.getAxis("left")
axis.setLabel(self.y_label)
axis.setPen(axis_pen)
axis.setTickFont(tickfont)

for axis in ("bottom", "left"):
axis = self.plot.getAxis(axis)
axis.setPen(axis_pen)
axis.setTickFont(tickfont)
self.plot.setRange(xRange=(0.0, 1.0), yRange=(0.0, 1.0),
disableAutoRange=True)

self.mainArea.layout().addWidget(self.plotview)

def send_report(self):
if self.data is None:
return
caption = report.render_items_vert((
("Polynomial Expansion", self.polynomialexpansion),
("Fit intercept",
self._has_intercept and ["No", "Yes"][self.fit_intercept])
))
self.report_plot()
if caption:
self.report_caption(caption)
def add_bottom_buttons(self):
pass

def clear(self):
self.data = None
Expand All @@ -190,7 +168,6 @@ def clear_plot(self):
self.scatterplot_item = None

self.remove_error_items()

self.plotview.clear()

@check_sql_input
Expand Down Expand Up @@ -220,19 +197,21 @@ def set_data(self, data):
@Inputs.learner
def set_learner(self, learner):
self.learner = learner
self.controls.fit_intercept.setDisabled(learner is not None)
self.regressor_name = (learner.name if learner is not None
else self.default_learner_name)
if learner is None:
self.controls.fit_intercept.setDisabled(False)
self.regressor_name = self.default_learner_name
else:
self.controls.fit_intercept.setDisabled(True)
self.regressor_name = learner.name

def handleNewSignals(self):
self.apply()

def plot_scatter_points(self, x_data, y_data):
if self.scatterplot_item:
self.plotview.removeItem(self.scatterplot_item)
self.n_points = len(x_data)
self.scatterplot_item = pg.ScatterPlotItem(
x=x_data, y=y_data, data=np.arange(self.n_points),
x=x_data, y=y_data,
symbol="o", size=10, pen=pg.mkPen(0.2), brush=pg.mkBrush(0.7),
antialias=True)
self.scatterplot_item.opts["useCache"] = False
Expand Down Expand Up @@ -303,7 +282,7 @@ def ss(x):
for i in range(not self._has_intercept,
1 + self.polynomialexpansion)]
else:
return ["1"] * self._has_intercept + \
return ["intercept"] * self._has_intercept + \
[name] * (self.polynomialexpansion >= 1) + \
[f"{name}^{i}" for i in range(2, 1 + self.polynomialexpansion)]

Expand Down Expand Up @@ -344,8 +323,7 @@ def apply(self):
self.data)

# all lines has nan
if sum(math.isnan(line[0]) or math.isnan(line.get_class())
for line in data_table) == len(data_table):
if np.all(np.isnan(data_table.X.flatten()) | np.isnan(data_table.Y)):
self.Error.all_none()
self.clear_plot()
return
Expand Down Expand Up @@ -430,7 +408,7 @@ def send_data(self):
x = data_table.X[valid_mask]
x = polyfeatures.fit_transform(x)

out_array = np.concatenate((x, data_table.Y[np.newaxis].T[valid_mask]), axis=1)
out_array = np.hstack((x, data_table.Y[np.newaxis].T[valid_mask]))

out_domain = Domain(
[ContinuousVariable(name)
Expand All @@ -441,8 +419,17 @@ def send_data(self):

self.Outputs.data.send(None)

def add_bottom_buttons(self):
pass
def send_report(self):
if self.data is None:
return
caption = report.render_items_vert((
("Polynomial Expansion", self.polynomialexpansion),
("Fit intercept",
self._has_intercept and ["No", "Yes"][self.fit_intercept])
))
self.report_plot()
if caption:
self.report_caption(caption)

@classmethod
def migrate_settings(cls, settings, version):
Expand All @@ -455,4 +442,4 @@ def migrate_settings(cls, settings, version):
if __name__ == "__main__":
learner = RidgeRegressionLearner(alpha=1.0)
iris = Table('iris')
WidgetPreview(OWUnivariateRegression).run(set_data=iris)#, set_learner=learner)
WidgetPreview(OWPolynomialRegression).run(set_data=iris) #, set_learner=learner)
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from Orange.data import Table, Domain, ContinuousVariable
from Orange.widgets.tests.base import WidgetTest
from orangecontrib.educational.widgets.owpolynomialregression \
import OWUnivariateRegression
import OWPolynomialRegression
from Orange.regression import (LinearRegressionLearner,
RandomForestRegressionLearner)
from Orange.regression.tree import TreeLearner as TreeRegressionLearner
Expand All @@ -13,7 +13,7 @@
class TestOWPolynomialRegression(WidgetTest):

def setUp(self):
self.widget = self.create_widget(OWUnivariateRegression) # type: OWUnivariateRegression
self.widget = self.create_widget(OWPolynomialRegression) # type: OWPolynomialRegression
self.data = Table.from_file("iris")
self.data_housing = Table.from_file("housing")

Expand Down Expand Up @@ -60,11 +60,9 @@ def test_add_main_layout(self):
self.assertEqual(w.learner, None)
self.assertEqual(w.scatterplot_item, None)
self.assertEqual(w.plot_item, None)
self.assertEqual(w.x_label, 'x')
self.assertEqual(w.y_label, 'y')

self.assertEqual(
w.regressor_label.text(), "Regressor: Linear Regression")
w.controls.regressor_name.text(), "Regressor: Linear Regression")

def test_send_report(self):
# check if nothing happens when polynomialexpansion is None
Expand Down Expand Up @@ -108,24 +106,22 @@ def test_set_learner(self):
self.assertEqual(self.widget.learner, lin)

self.assertEqual(
w.regressor_label.text(), "Regressor: Linear Regression")
w.controls.regressor_name.text(), "Regressor: Linear Regression")

tree = TreeRegressionLearner
tree.name = "Tree Learner"

self.widget.set_learner(tree)
self.assertEqual(self.widget.learner, tree)
self.assertEqual(
w.regressor_label.text(), "Regressor: Tree Learner")

w.controls.regressor_name.text(), "Regressor: Tree Learner")

def test_plot_scatter_points(self):
x_data = [1, 2, 3]
y_data = [2, 3, 4]

self.widget.plot_scatter_points(x_data, y_data)

self.assertEqual(self.widget.n_points, len(x_data))
self.assertNotEqual(self.widget.scatterplot_item, None)

# check case when scatter plot allready exist
Expand All @@ -134,7 +130,6 @@ def test_plot_scatter_points(self):

self.widget.plot_scatter_points(x_data, y_data)

self.assertEqual(self.widget.n_points, len(x_data))
self.assertNotEqual(self.widget.scatterplot_item, None)

def test_plot_regression_line(self):
Expand All @@ -156,7 +151,8 @@ def test_plot_regression_line(self):
def test_plot_error_bars(self):
w = self.widget

w.error_bars_checkbox.click()
check = w.controls.error_bars_enabled
check.click()

x_data = [1, 2, 3]
y_data = [2, 3, 4]
Expand All @@ -165,12 +161,12 @@ def test_plot_error_bars(self):
self.widget.plot_error_bars(x_data, y_data, y_data_fake)
self.assertEqual(len(w.error_plot_items), len(x_data))

w.error_bars_checkbox.click()
check.click()

self.widget.plot_error_bars(x_data, y_data, y_data_fake)
self.assertEqual(len(w.error_plot_items), 0)

w.error_bars_checkbox.click()
check.click()

self.send_signal(w.Inputs.data, self.data)
self.assertEqual(len(w.error_plot_items), len(self.data))
Expand Down Expand Up @@ -215,19 +211,20 @@ def test_data_output(self):

self.assertIsNone(self.get_output(w.Outputs.data))
self.widget.set_data(self.data)
self.widget.expansion_spin.setValue(1)
spin = self.widget.controls.polynomialexpansion
spin.setValue(1)
self.widget.send_data()
self.assertEqual(len(self.get_output(w.Outputs.data).domain.attributes), 2)

self.widget.expansion_spin.setValue(2)
spin.setValue(2)
self.widget.send_data()
self.assertEqual(len(self.get_output(w.Outputs.data).domain.attributes), 3)

self.widget.expansion_spin.setValue(3)
spin.setValue(3)
self.widget.send_data()
self.assertEqual(len(self.get_output(w.Outputs.data).domain.attributes), 4)

self.widget.expansion_spin.setValue(4)
spin.setValue(4)
self.widget.send_data()
self.assertEqual(len(self.get_output(w.Outputs.data).domain.attributes), 5)

Expand Down

0 comments on commit 8ba0dc2

Please sign in to comment.