Skip to content

Commit

Permalink
Predictions: Output annotated table
Browse files Browse the repository at this point in the history
  • Loading branch information
janezd committed Jan 26, 2024
1 parent f49e020 commit 14c0d52
Show file tree
Hide file tree
Showing 2 changed files with 61 additions and 7 deletions.
22 changes: 15 additions & 7 deletions Orange/widgets/evaluate/owpredictions.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
from Orange.widgets.utils.widgetpreview import WidgetPreview
from Orange.widgets.widget import OWWidget, Msg, Input, Output, MultiInput
from Orange.widgets.utils.itemmodels import TableModel
from Orange.widgets.utils.annotated_data import lazy_annotated_table
from Orange.widgets.utils.sql import check_sql_input
from Orange.widgets.utils.state_summary import format_summary_details
from Orange.widgets.utils.colorpalettes import LimitedDiscretePalette
Expand Down Expand Up @@ -72,7 +73,8 @@ class Inputs:
predictors = MultiInput("Predictors", Model, filter_none=True)

class Outputs:
predictions = Output("Predictions", Orange.data.Table)
predictions = Output("Predictions", Orange.data.Table, default=True)
annotated = Output("Annotated Predictions", Orange.data.Table)
evaluation_results = Output("Evaluation Results", Results)

class Warning(OWWidget.Warning):
Expand Down Expand Up @@ -815,6 +817,7 @@ def _commit_evaluation_results(self):
def _commit_predictions(self):
if not self.data:
self.Outputs.predictions.send(None)
self.Outputs.annotated.send(None)
return

newmetas = []
Expand Down Expand Up @@ -855,12 +858,17 @@ def _commit_predictions(self):
# Reorder rows as they are ordered in view
shown_rows = datamodel.mapFromSourceRows(rows)
rows = rows[numpy.argsort(shown_rows)]
predictions = predictions[rows]
elif datamodel.sortColumn() >= 0 \
or predmodel is not None and predmodel.sortColumn() > 0:
# No selection: output all, but in the shown order
predictions = predictions[datamodel.mapToSourceRows(...)]
self.Outputs.predictions.send(predictions)
selected = predictions[rows]
annotated_data = lazy_annotated_table(predictions, rows)
else:
if datamodel.sortColumn() >= 0 \
or predmodel is not None and predmodel.sortColumn() > 0:
predictions = predictions[datamodel.mapToSourceRows(...)]
selected = predictions
annotated_data = predictions
self.Outputs.predictions.send(selected)
self.Outputs.annotated.send(annotated_data)


def _add_classification_out_columns(self, slot, newmetas, newcolumns, index):
pred = slot.predictor
Expand Down
46 changes: 46 additions & 0 deletions Orange/widgets/evaluate/tests/test_owpredictions.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@
from Orange.evaluation import Results
from Orange.widgets.tests.utils import excepthook_catch, \
possible_duplicate_table, simulate
from Orange.widgets.utils.annotated_data import ANNOTATED_DATA_FEATURE_NAME
from Orange.widgets.utils.colorpalettes import LimitedDiscretePalette


Expand Down Expand Up @@ -515,6 +516,51 @@ def test_select(self):
for index in self.widget.dataview.selectionModel().selectedIndexes()}
self.assertEqual(sel, {(1, col) for col in range(5)})

def test_selection_output(self):
log_reg_iris = LogisticRegressionLearner()(self.iris)
self.send_signal(self.widget.Inputs.predictors, log_reg_iris)
self.send_signal(self.widget.Inputs.data, self.iris)

selmodel = self.widget.dataview.selectionModel()
pred_model = self.widget.predictionsview.model()

selmodel.select(self.widget.dataview.model().index(1, 0), QItemSelectionModel.Select)
selmodel.select(self.widget.dataview.model().index(3, 0), QItemSelectionModel.Select)
output = self.get_output(self.widget.Outputs.predictions)
self.assertEqual(len(output), 2)
self.assertEqual(output[0], self.iris[1])
self.assertEqual(output[1], self.iris[3])
output = self.get_output(self.widget.Outputs.annotated)
self.assertEqual(len(output), len(self.iris))
col = output.get_column(ANNOTATED_DATA_FEATURE_NAME)
self.assertEqual(np.sum(col), 2)
self.assertEqual(col[1], 1)
self.assertEqual(col[3], 1)

pred_model.sort(0)
output = self.get_output(self.widget.Outputs.predictions)
self.assertEqual(len(output), 2)
self.assertEqual(output[0], self.iris[1])
self.assertEqual(output[1], self.iris[3])
output = self.get_output(self.widget.Outputs.annotated)
self.assertEqual(len(output), len(self.iris))
col = output.get_column(ANNOTATED_DATA_FEATURE_NAME)
self.assertEqual(np.sum(col), 2)
self.assertEqual(col[1], 1)
self.assertEqual(col[3], 1)

pred_model.sort(0, Qt.DescendingOrder)
output = self.get_output(self.widget.Outputs.predictions)
self.assertEqual(len(output), 2)
self.assertEqual(output[0], self.iris[3])
self.assertEqual(output[1], self.iris[1])
output = self.get_output(self.widget.Outputs.annotated)
self.assertEqual(len(output), len(self.iris))
col = output.get_column(ANNOTATED_DATA_FEATURE_NAME)
self.assertEqual(np.sum(col), 2)
self.assertEqual(col[1], 1)
self.assertEqual(col[3], 1)

def test_select_data_first(self):
log_reg_iris = LogisticRegressionLearner()(self.iris)
self.send_signal(self.widget.Inputs.data, self.iris)
Expand Down

0 comments on commit 14c0d52

Please sign in to comment.