Skip to content

Commit

Permalink
Fix unpickling pre-3.28.0 Transformation
Browse files Browse the repository at this point in the history
At around 3.28 we introduced domain caching into Tranformation in a
way that prevented older pickles to work. This fix recreates the cache
after unpickling. Furthermore, the cache (before "need_domain", now
"_target_domain") is never pickled anymore.
  • Loading branch information
markotoplak committed Jan 4, 2022
1 parent 300dbba commit 246d534
Show file tree
Hide file tree
Showing 2 changed files with 31 additions and 3 deletions.
22 changes: 19 additions & 3 deletions Orange/preprocess/transformation.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,12 +16,28 @@ def __init__(self, variable):
:type variable: int or str or :obj:`~Orange.data.Variable`
"""
self.variable = variable
self._create_cached_target_domain()

def _create_cached_target_domain(self):
""" If the same domain is used everytime this allows better caching of
domain transformations in from_table"""
if self.variable is not None:
if self.variable.is_primitive():
self.need_domain = Domain([self.variable])
self._target_domain = Domain([self.variable])
else:
self.need_domain = Domain([], metas=[self.variable])
self._target_domain = Domain([], metas=[self.variable])

def __getstate__(self):
# Do not pickle the cached domain; rather recreate it after unpickling
state = self.__dict__.copy()
state.pop("_target_domain")
return state

def __setstate__(self, state):
# Ensure that cached target domain is created after unpickling.
# This solves the problem of unpickling old pickled models.
self.__dict__.update(state)
self._create_cached_target_domain()

def __call__(self, data):
"""
Expand All @@ -31,7 +47,7 @@ def __call__(self, data):
inst = isinstance(data, Instance)
if inst:
data = Table.from_list(data.domain, [data])
data = data.transform(self.need_domain)
data = data.transform(self._target_domain)
if self.variable.is_primitive():
col = data.X
else:
Expand Down
12 changes: 12 additions & 0 deletions Orange/tests/test_transformation.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import pickle
import unittest

import numpy as np
Expand Down Expand Up @@ -43,6 +44,17 @@ def test_transform_fails(self):
trans = Transformation(self.data.domain[2])
self.assertRaises(NotImplementedError, trans, self.data)

def test_pickling_target_domain(self):
data = self.data
trans = self.TransformationMock(data.domain[2])
self.assertIn("_target_domain", trans.__dict__)
# _target_domain should not be pickled
state = trans.__getstate__()
self.assertNotIn("_target_domain", state)
# _target_domain should be recreated when unpickled
unpickled = pickle.loads(pickle.dumps(trans))
self.assertIn("_target_domain", unpickled.__dict__)


class IdentityTest(unittest.TestCase):
def test_identity(self):
Expand Down

0 comments on commit 246d534

Please sign in to comment.