Skip to content

Commit

Permalink
Distribution: Fix deepcopy and pickle
Browse files Browse the repository at this point in the history
  • Loading branch information
janezd committed Jul 27, 2021
1 parent f9d78d6 commit 525e599
Show file tree
Hide file tree
Showing 2 changed files with 70 additions and 1 deletion.
18 changes: 18 additions & 0 deletions Orange/statistics/distribution.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,24 @@ def _get_variable(dat, variable, expected_type=None, expected_name=""):


class Distribution(np.ndarray):
def __array_finalize__(self, obj):
# defined in derived classes,
# pylint: disable=attribute-defined-outside-init
"""See http://docs.scipy.org/doc/numpy/user/basics.subclassing.html"""
if obj is None:
return
self.variable = getattr(obj, 'variable', None)
self.unknowns = getattr(obj, 'unknowns', 0)

def __reduce__(self):
state = super().__reduce__()
newstate = state[2] + (self.variable, self.unknowns)
return state[0], state[1], newstate

def __setstate__(self, state):
super().__setstate__(state[:-2])
self.variable, self.unknowns = state[-2:]

def __eq__(self, other):
return (
np.array_equal(self, other) and
Expand Down
53 changes: 52 additions & 1 deletion Orange/tests/test_distribution.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
# Test methods with long descriptive names can omit docstrings
# Test internal methods
# pylint: disable=missing-docstring, protected-access

import copy
import pickle
import unittest
from unittest.mock import Mock
import warnings
Expand Down Expand Up @@ -110,6 +111,32 @@ def test_fallback_with_weights_and_nan(self):
np.asarray(fallback), np.asarray(default))
np.testing.assert_almost_equal(fallback.unknowns, default.unknowns)

def test_pickle(self):
d = data.Table("zoo")
d1 = distribution.Discrete(d, 0)
dc = pickle.loads(pickle.dumps(d1))
# This always worked because `other` wasn't required to have `unknowns`
self.assertEqual(d1, dc)
# This failed before implementing `__reduce__`
self.assertEqual(dc, d1)
self.assertEqual(hash(d1), hash(dc))
# Test that `dc` has the required attributes
self.assertEqual(dc.variable, d1.variable)
self.assertEqual(dc.unknowns, d1.unknowns)

def test_deepcopy(self):
d = data.Table("zoo")
d1 = distribution.Discrete(d, 0)
dc = copy.deepcopy(d1)
# This always worked because `other` wasn't required to have `unknowns`
self.assertEqual(d1, dc)
# This failed before implementing `__deepcopy__`
self.assertEqual(dc, d1)
self.assertEqual(hash(d1), hash(dc))
# Test that `dc` has the required attributes
self.assertEqual(dc.variable, d1.variable)
self.assertEqual(dc.unknowns, d1.unknowns)

def test_equality(self):
d = data.Table("zoo")
d1 = distribution.Discrete(d, 0)
Expand Down Expand Up @@ -285,6 +312,30 @@ def test_construction(self):
self.assertEqual(disc2.unknowns, 0)
assert_dist_equal(disc2, dd)

def test_pickle(self):
d1 = distribution.Continuous(self.iris, 0)
dc = pickle.loads(pickle.dumps(d1))
# This always worked because `other` wasn't required to have `unknowns`
self.assertEqual(d1, dc)
# This failed before implementing `__reduce__`
self.assertEqual(dc, d1)
self.assertEqual(hash(d1), hash(dc))
# Test that `dc` has the required attributes
self.assertEqual(dc.variable, d1.variable)
self.assertEqual(dc.unknowns, d1.unknowns)

def test_deepcopy(self):
d1 = distribution.Continuous(self.iris, 0)
dc = copy.deepcopy(d1)
# This always worked because `other` wasn't required to have `unknowns`
self.assertEqual(d1, dc)
# This failed before implementing `__deepcopy__`
self.assertEqual(dc, d1)
self.assertEqual(hash(d1), hash(dc))
# Test that `dc` has the required attributes
self.assertEqual(dc.variable, d1.variable)
self.assertEqual(dc.unknowns, d1.unknowns)

def test_hash(self):
d = self.iris
petal_length = d.columns.petal_length
Expand Down

0 comments on commit 525e599

Please sign in to comment.