diff --git a/Orange/preprocess/transformation.py b/Orange/preprocess/transformation.py index 43dcb4bf79a..10c92cbc364 100644 --- a/Orange/preprocess/transformation.py +++ b/Orange/preprocess/transformation.py @@ -16,12 +16,28 @@ def __init__(self, variable): :type variable: int or str or :obj:`~Orange.data.Variable` """ self.variable = variable + self._create_cached_target_domain() + def _create_cached_target_domain(self): + """ If the same domain is used everytime this allows better caching of + domain transformations in from_table""" if self.variable is not None: if self.variable.is_primitive(): - self.need_domain = Domain([self.variable]) + self._target_domain = Domain([self.variable]) else: - self.need_domain = Domain([], metas=[self.variable]) + self._target_domain = Domain([], metas=[self.variable]) + + def __getstate__(self): + # Do not pickle the cached domain; rather recreate it after unpickling + state = self.__dict__.copy() + state.pop("_target_domain") + return state + + def __setstate__(self, state): + # Ensure that cached target domain is created after unpickling. + # This solves the problem of unpickling old pickled models. + self.__dict__.update(state) + self._create_cached_target_domain() def __call__(self, data): """ @@ -31,7 +47,7 @@ def __call__(self, data): inst = isinstance(data, Instance) if inst: data = Table.from_list(data.domain, [data]) - data = data.transform(self.need_domain) + data = data.transform(self._target_domain) if self.variable.is_primitive(): col = data.X else: diff --git a/Orange/tests/test_transformation.py b/Orange/tests/test_transformation.py index 674b7fe1983..9d28de39402 100644 --- a/Orange/tests/test_transformation.py +++ b/Orange/tests/test_transformation.py @@ -1,3 +1,4 @@ +import pickle import unittest import numpy as np @@ -43,6 +44,17 @@ def test_transform_fails(self): trans = Transformation(self.data.domain[2]) self.assertRaises(NotImplementedError, trans, self.data) + def test_pickling_target_domain(self): + data = self.data + trans = self.TransformationMock(data.domain[2]) + self.assertIn("_target_domain", trans.__dict__) + # _target_domain should not be pickled + state = trans.__getstate__() + self.assertNotIn("_target_domain", state) + # _target_domain should be recreated when unpickled + unpickled = pickle.loads(pickle.dumps(trans)) + self.assertIn("_target_domain", unpickled.__dict__) + class IdentityTest(unittest.TestCase): def test_identity(self):