Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[FIX] Fix deepcopy and pickle for classes derived from np.ndarray #5536

Merged
merged 2 commits into from
Aug 6, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 12 additions & 1 deletion Orange/statistics/contingency.py
Original file line number Diff line number Diff line change
Expand Up @@ -179,9 +179,20 @@ def __reduce__(self):
return (
_create_discrete,
(Discrete, np.copy(self), self.col_variable, self.row_variable,
self.col_unknowns, self.row_unknowns)
self.col_unknowns, self.row_unknowns, self.unknowns)
)

def __array_finalize__(self, obj):
# defined in __new__, pylint: disable=attribute-defined-outside-init
"""See http://docs.scipy.org/doc/numpy/user/basics.subclassing.html"""
if obj is None:
return
self.col_variable = getattr(obj, 'col_variable', None)
self.row_variable = getattr(obj, 'row_variable', None)
self.col_unknowns = getattr(obj, 'col_unknowns', None)
self.row_unknowns = getattr(obj, 'row_unknowns', None)
self.unknowns = getattr(obj, 'unknowns', None)


class Continuous:
def __init__(self, dat, col_variable=None, row_variable=None,
Expand Down
20 changes: 20 additions & 0 deletions Orange/statistics/distribution.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,26 @@ def _get_variable(dat, variable, expected_type=None, expected_name=""):


class Distribution(np.ndarray):
def __array_finalize__(self, obj):
# defined in derived classes,
# pylint: disable=attribute-defined-outside-init
"""See http://docs.scipy.org/doc/numpy/user/basics.subclassing.html"""
if obj is None:
return
self.variable = getattr(obj, 'variable', None)
self.unknowns = getattr(obj, 'unknowns', 0)

def __reduce__(self):
state = super().__reduce__()
newstate = state[2] + (self.variable, self.unknowns)
return state[0], state[1], newstate

def __setstate__(self, state):
# defined in derived classes,
# pylint: disable=attribute-defined-outside-init
super().__setstate__(state[:-2])
self.variable, self.unknowns = state[-2:]

def __eq__(self, other):
return (
np.array_equal(self, other) and
Expand Down
9 changes: 8 additions & 1 deletion Orange/tests/test_contingency.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# Test methods with long descriptive names can omit docstrings
# pylint: disable=missing-docstring

import copy
import unittest
from unittest.mock import Mock

Expand Down Expand Up @@ -71,6 +71,13 @@ def test_discrete_missing(self):
np.testing.assert_almost_equal(cont.row_unknowns, [0, 0])
self.assertEqual(1, cont.unknowns)

def test_deepcopy(self):
cont = contingency.Discrete(self.zoo, 0)
dc = copy.deepcopy(cont)
self.assertEqual(dc, cont)
self.assertEqual(dc.col_variable, cont.col_variable)
self.assertEqual(dc.row_variable, cont.row_variable)

def test_array_with_unknowns(self):
d = data.Table("zoo")
d.Y[2] = float("nan")
Expand Down
53 changes: 52 additions & 1 deletion Orange/tests/test_distribution.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
# Test methods with long descriptive names can omit docstrings
# Test internal methods
# pylint: disable=missing-docstring, protected-access

import copy
import pickle
import unittest
from unittest.mock import Mock
import warnings
Expand Down Expand Up @@ -110,6 +111,32 @@ def test_fallback_with_weights_and_nan(self):
np.asarray(fallback), np.asarray(default))
np.testing.assert_almost_equal(fallback.unknowns, default.unknowns)

def test_pickle(self):
d = data.Table("zoo")
d1 = distribution.Discrete(d, 0)
dc = pickle.loads(pickle.dumps(d1))
# This always worked because `other` wasn't required to have `unknowns`
self.assertEqual(d1, dc)
# This failed before implementing `__reduce__`
self.assertEqual(dc, d1)
self.assertEqual(hash(d1), hash(dc))
# Test that `dc` has the required attributes
self.assertEqual(dc.variable, d1.variable)
self.assertEqual(dc.unknowns, d1.unknowns)

def test_deepcopy(self):
d = data.Table("zoo")
d1 = distribution.Discrete(d, 0)
dc = copy.deepcopy(d1)
# This always worked because `other` wasn't required to have `unknowns`
self.assertEqual(d1, dc)
# This failed before implementing `__deepcopy__`
self.assertEqual(dc, d1)
self.assertEqual(hash(d1), hash(dc))
# Test that `dc` has the required attributes
self.assertEqual(dc.variable, d1.variable)
self.assertEqual(dc.unknowns, d1.unknowns)

def test_equality(self):
d = data.Table("zoo")
d1 = distribution.Discrete(d, 0)
Expand Down Expand Up @@ -285,6 +312,30 @@ def test_construction(self):
self.assertEqual(disc2.unknowns, 0)
assert_dist_equal(disc2, dd)

def test_pickle(self):
d1 = distribution.Continuous(self.iris, 0)
dc = pickle.loads(pickle.dumps(d1))
# This always worked because `other` wasn't required to have `unknowns`
self.assertEqual(d1, dc)
# This failed before implementing `__reduce__`
self.assertEqual(dc, d1)
self.assertEqual(hash(d1), hash(dc))
# Test that `dc` has the required attributes
self.assertEqual(dc.variable, d1.variable)
self.assertEqual(dc.unknowns, d1.unknowns)

def test_deepcopy(self):
d1 = distribution.Continuous(self.iris, 0)
dc = copy.deepcopy(d1)
# This always worked because `other` wasn't required to have `unknowns`
self.assertEqual(d1, dc)
# This failed before implementing `__deepcopy__`
self.assertEqual(dc, d1)
self.assertEqual(hash(d1), hash(dc))
# Test that `dc` has the required attributes
self.assertEqual(dc.variable, d1.variable)
self.assertEqual(dc.unknowns, d1.unknowns)

def test_hash(self):
d = self.iris
petal_length = d.columns.petal_length
Expand Down