Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[FIX] Fix unpickling domains: do not pickle indices (which can cause problems) #6317

Merged
merged 5 commits into from
Jan 27, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
39 changes: 31 additions & 8 deletions Orange/data/domain.py
Original file line number Diff line number Diff line change
Expand Up @@ -164,17 +164,39 @@ def __init__(self, attributes, class_vars=None, metas=None, source=None):
if not all(var.is_primitive() for var in self._variables):
raise TypeError("variables must be primitive")

self._indices = dict(chain.from_iterable(
((var, idx), (var.name, idx), (idx, idx))
for idx, var in enumerate(self._variables)))
self._indices.update(chain.from_iterable(
((var, -1-idx), (var.name, -1-idx), (-1-idx, -1-idx))
for idx, var in enumerate(self.metas)))
self._indices = None

self.anonymous = False

self._hash = None # cache for __hash__()

def _ensure_indices(self):
if self._indices is None:
indices = dict(chain.from_iterable(
((var, idx), (var.name, idx), (idx, idx))
for idx, var in enumerate(self._variables)))
indices.update(chain.from_iterable(
((var, -1-idx), (var.name, -1-idx), (-1-idx, -1-idx))
for idx, var in enumerate(self.metas)))
self._indices = indices

def __setstate__(self, state):
self.__dict__.update(state)
self._variables = self.attributes + self.class_vars
self._indices = None
self._hash = None

def __getstate__(self):
# Do not pickle dictionaries because unpickling dictionaries that
# include objects that redefine __hash__ as keys is sometimes problematic
# (when said objects do not have __dict__ filled yet in but are used as
# keys in a restored dictionary).
state = self.__dict__.copy()
del state["_variables"]
del state["_indices"]
del state["_hash"]
return state

# noinspection PyPep8Naming
@classmethod
def from_numpy(cls, X, Y=None, metas=None):
Expand Down Expand Up @@ -289,7 +311,7 @@ def __getitem__(self, idx):
"""
if isinstance(idx, slice):
return self._variables[idx]

self._ensure_indices()
index = self._indices.get(idx)
if index is None:
var = self._get_equivalent(idx)
Expand All @@ -306,6 +328,7 @@ def __contains__(self, item):
Return `True` if the item (`str`, `int`, :class:`Variable`) is
in the domain.
"""
self._ensure_indices()
return item in self._indices or self._get_equivalent(item) is not None

def __iter__(self):
Expand Down Expand Up @@ -334,7 +357,7 @@ def index(self, var):
Return the index of the given variable or meta attribute, represented
with an instance of :class:`Variable`, `int` or `str`.
"""

self._ensure_indices()
idx = self._indices.get(var)
if idx is not None:
return idx
Expand Down
33 changes: 29 additions & 4 deletions Orange/tests/test_classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
import pkgutil
import unittest

import traceback
import warnings

import numpy as np
Expand Down Expand Up @@ -273,6 +272,9 @@ def test_predict_proba(self):
data = Table("heart_disease")
for learner in all_learners():
with self.subTest(learner.__name__):
# Skip slow tests
if issubclass(learner, _RuleLearner):
continue
if learner in (ThresholdLearner, CalibratedLearner):
model = learner(LogisticRegressionLearner())(data)
else:
Expand Down Expand Up @@ -416,11 +418,9 @@ def test_all_models_work_after_unpickling(self):
if learner in (ThresholdLearner, CalibratedLearner):
continue
# Skip slow tests
if isinstance(learner, _RuleLearner):
if issubclass(learner, _RuleLearner):
continue
with self.subTest(learner.__name__):
if "RandomForest" not in learner.__name__:
continue
learner = learner()
for ds in datasets:
model = learner(ds)
Expand All @@ -435,6 +435,31 @@ def test_all_models_work_after_unpickling(self):
err_msg='%s does not return same values when unpickled %s'
% (learner.__class__.__name__, ds.name))

def test_all_models_work_after_unpickling_pca(self):
datasets = [Table('iris'), Table('titanic')]
for learner in list(all_learners()):
# calibration, threshold learners' __init__ require arguments
if learner in (ThresholdLearner, CalibratedLearner):
continue
# Skip slow tests
if issubclass(learner, _RuleLearner):
continue
with self.subTest(learner.__name__):
learner = learner()
for ds in datasets:
pca_ds = Orange.projection.PCA()(ds)(ds)
model = learner(pca_ds)
s = pickle.dumps(model, 0)
model2 = pickle.loads(s)

np.testing.assert_almost_equal(
Table.from_table(model.domain, ds).X,
Table.from_table(model2.domain, ds).X)
np.testing.assert_almost_equal(
model(ds), model2(ds),
err_msg='%s does not return same values when unpickled %s'
% (learner.__class__.__name__, ds.name))

def test_adequacy_all_learners(self):
for learner in all_learners():
# calibration, threshold learners' __init__ requires arguments
Expand Down
1 change: 1 addition & 0 deletions Orange/tests/test_domain.py
Original file line number Diff line number Diff line change
Expand Up @@ -539,6 +539,7 @@ def test_get_item_similar_vars(self):
metas=[var1, var2]
)
# pylint: disable=protected-access
domain._ensure_indices()
self.assertDictEqual(
{-1: -1, -2: -2, var1: -1, var2: -2, var1.name: -1, var2.name: -2},
domain._indices
Expand Down