From 28c1da7c2b559a5686ea216255c88ad08a6ec963 Mon Sep 17 00:00:00 2001 From: Leonardo Gama Date: Sat, 18 Jun 2022 11:01:52 -0300 Subject: [PATCH] Cherry-pick branch 'master' with clean-up changes --- dill/__diff.py | 2 +- dill/_dill.py | 229 +++++++++++++++++++++++++++++++++++++--- dill/_objects.py | 44 +++++++- dill/_shims.py | 40 ++++++- tests/__main__.py | 2 +- tests/test_classdef.py | 9 ++ tests/test_dictviews.py | 10 +- tests/test_functions.py | 16 +++ tests/test_objects.py | 1 - tests/test_pycapsule.py | 45 ++++++++ tests/test_session.py | 6 +- 11 files changed, 370 insertions(+), 34 deletions(-) create mode 100644 tests/test_pycapsule.py diff --git a/dill/__diff.py b/dill/__diff.py index df2589eb..3ff65763 100644 --- a/dill/__diff.py +++ b/dill/__diff.py @@ -235,6 +235,6 @@ def _imp(*args, **kwds): # memorise all already imported modules. This implies that this must be # imported first for any changes to be recorded -for mod in sys.modules.values(): +for mod in list(sys.modules.values()): memorise(mod) release_gone() diff --git a/dill/_dill.py b/dill/_dill.py index 5bd770b3..d0561be8 100644 --- a/dill/_dill.py +++ b/dill/_dill.py @@ -21,6 +21,8 @@ 'UnpicklingError','HANDLE_FMODE','CONTENTS_FMODE','FILE_FMODE', 'PickleError','PickleWarning','PicklingWarning','UnpicklingWarning'] +__module__ = 'dill' + import logging log = logging.getLogger("dill") log.addHandler(logging.StreamHandler()) @@ -35,6 +37,7 @@ def _trace(boolean): import sys diff = None _use_diff = False +OLD38 = (sys.hexversion < 0x3080000) OLD39 = (sys.hexversion < 0x3090000) OLD310 = (sys.hexversion < 0x30a0000) #XXX: get types from .objtypes ? @@ -64,6 +67,7 @@ def _trace(boolean): import gc # import zlib from weakref import ReferenceType, ProxyType, CallableProxyType +from collections import OrderedDict from functools import partial from operator import itemgetter, attrgetter GENERATOR_FAIL = False @@ -215,8 +219,6 @@ def get_file_type(*args, **kwargs): except NameError: ExitType = None singletontypes = [] -from collections import OrderedDict - import inspect ### Shims for different versions of Python and dill @@ -511,7 +513,6 @@ def __init__(self, *args, **kwds): self._strictio = False #_strictio self._fmode = settings['fmode'] if _fmode is None else _fmode self._recurse = settings['recurse'] if _recurse is None else _recurse - from collections import OrderedDict self._postproc = OrderedDict() def dump(self, obj): #NOTE: if settings change, need to update attributes @@ -657,6 +658,15 @@ def _create_typemap(): 'SuperType': SuperType, 'ItemGetterType': ItemGetterType, 'AttrGetterType': AttrGetterType, +}) + +# "Incidental" implementation specific types. Unpickling these types in another +# implementation of Python (PyPy -> CPython) is not gauranteed to work + +# This dictionary should contain all types that appear in Python implementations +# but are not defined in https://docs.python.org/3/library/types.html#standard-interpreter-types +x=OrderedDict() +_incedental_reverse_typemap = { 'FileType': FileType, 'BufferedRandomType': BufferedRandomType, 'BufferedReaderType': BufferedReaderType, @@ -666,18 +676,55 @@ def _create_typemap(): 'PyBufferedReaderType': PyBufferedReaderType, 'PyBufferedWriterType': PyBufferedWriterType, 'PyTextWrapperType': PyTextWrapperType, +} + +_incedental_reverse_typemap.update({ + "DictKeysType": type({}.keys()), + "DictValuesType": type({}.values()), + "DictItemsType": type({}.items()), + + "OdictKeysType": type(x.keys()), + "OdictValuesType": type(x.values()), + "OdictItemsType": type(x.items()), }) + if ExitType: - _reverse_typemap['ExitType'] = ExitType + _incedental_reverse_typemap['ExitType'] = ExitType if InputType: - _reverse_typemap['InputType'] = InputType - _reverse_typemap['OutputType'] = OutputType + _incedental_reverse_typemap['InputType'] = InputType + _incedental_reverse_typemap['OutputType'] = OutputType if not IS_PYPY: - _reverse_typemap['WrapperDescriptorType'] = WrapperDescriptorType - _reverse_typemap['MethodDescriptorType'] = MethodDescriptorType - _reverse_typemap['ClassMethodDescriptorType'] = ClassMethodDescriptorType + _incedental_reverse_typemap['WrapperDescriptorType'] = WrapperDescriptorType + _incedental_reverse_typemap['MethodDescriptorType'] = MethodDescriptorType + _incedental_reverse_typemap['ClassMethodDescriptorType'] = ClassMethodDescriptorType else: - _reverse_typemap['MemberDescriptorType'] = MemberDescriptorType + _incedental_reverse_typemap['MemberDescriptorType'] = MemberDescriptorType + +try: + import symtable + _incedental_reverse_typemap["SymtableStentryType"] = type(symtable.symtable("", "string", "exec")._table) +except: + pass + +if sys.hexversion >= 0x30a00a0: + _incedental_reverse_typemap['LineIteratorType'] = type(compile('3', '', 'eval').co_lines()) + +if sys.hexversion >= 0x30b00b0: + from types import GenericAlias + _incedental_reverse_typemap["GenericAliasIteratorType"] = type(iter(GenericAlias(list, (int,)))) + _incedental_reverse_typemap['PositionsIteratorType'] = type(compile('3', '', 'eval').co_positions()) + +try: + import winreg + _incedental_reverse_typemap["HKEYType"] = winreg.HKEYType +except: + pass + +_reverse_typemap.update(_incedental_reverse_typemap) +_incedental_types = set(_incedental_reverse_typemap.values()) + +del x + _typemap = dict((v, k) for k, v in _reverse_typemap.items()) def _unmarshal(string): @@ -1058,6 +1105,36 @@ def _create_namedtuple(name, fieldnames, modulename, defaults=None): t = collections.namedtuple(name, fieldnames, defaults=defaults, module=modulename) return t +def _create_capsule(pointer, name, context, destructor): + attr_found = False + try: + # based on https://github.com/python/cpython/blob/f4095e53ab708d95e019c909d5928502775ba68f/Objects/capsule.c#L209-L231 + uname = name.decode('utf8') + for i in range(1, uname.count('.')+1): + names = uname.rsplit('.', i) + try: + module = __import__(names[0]) + except: + pass + obj = module + for attr in names[1:]: + obj = getattr(obj, attr) + capsule = obj + attr_found = True + break + except: + pass + + if attr_found: + if _PyCapsule_IsValid(capsule, name): + return capsule + raise UnpicklingError("%s object exists at %s but a PyCapsule object was expected." % (type(capsule), name)) + else: + warnings.warn('Creating a new PyCapsule %s for a C data structure that may not be present in memory. Segmentation faults or other memory errors are possible.' % (name,), UnpicklingWarning) + capsule = _PyCapsule_New(pointer, name, destructor) + _PyCapsule_SetContext(capsule, context) + return capsule + def _getattr(objclass, name, repr_str): # hack to grab the reference directly try: #XXX: works only for __builtin__ ? @@ -1099,13 +1176,35 @@ def _import_module(import_name, safe=False): return None raise +# https://github.com/python/cpython/blob/a8912a0f8d9eba6d502c37d522221f9933e976db/Lib/pickle.py#L322-L333 +def _getattribute(obj, name): + for subpath in name.split('.'): + if subpath == '': + raise AttributeError("Can't get local attribute {!r} on {!r}" + .format(name, obj)) + try: + parent = obj + obj = getattr(obj, subpath) + except AttributeError: + raise AttributeError("Can't get attribute {!r} on {!r}" + .format(name, obj)) + return obj, parent + def _locate_function(obj, pickler=None): - if obj.__module__ in ['__main__', None] or \ - pickler and is_dill(pickler, child=False) and pickler._session and obj.__module__ == pickler._main.__name__: + module_name = getattr(obj, '__module__', None) + if module_name in ['__main__', None] or \ + pickler and is_dill(pickler, child=False) and pickler._session and module_name == pickler._main.__name__: return False - - found = _import_module(obj.__module__ + '.' + obj.__name__, safe=True) - return found is obj + if hasattr(obj, '__qualname__'): + module = _import_module(module_name, safe=True) + try: + found, _ = _getattribute(module, obj.__qualname__) + return found is obj + except: + return False + else: + found = _import_module(module_name + '.' + obj.__name__, safe=True) + return found is obj def _setitems(dest, source): @@ -1451,6 +1550,31 @@ def save_super(pickler, obj): log.info("# Su") return +if IS_PYPY: + @register(MethodType) + def save_instancemethod0(pickler, obj): + code = getattr(obj.__func__, '__code__', None) + if code is not None and type(code) is not CodeType \ + and getattr(obj.__self__, obj.__name__) == obj: + # Some PyPy builtin functions have no module name + log.info("Me2: %s" % obj) + # TODO: verify that this works for all PyPy builtin methods + pickler.save_reduce(getattr, (obj.__self__, obj.__name__), obj=obj) + log.info("# Me2") + return + + log.info("Me1: %s" % obj) + pickler.save_reduce(MethodType, (obj.__func__, obj.__self__), obj=obj) + log.info("# Me1") + return +else: + @register(MethodType) + def save_instancemethod0(pickler, obj): + log.info("Me1: %s" % obj) + pickler.save_reduce(MethodType, (obj.__func__, obj.__self__), obj=obj) + log.info("# Me1") + return + if not IS_PYPY: @register(MemberDescriptorType) @register(GetSetDescriptorType) @@ -1533,7 +1657,7 @@ def save_dictproxy(pickler, obj): pickler.save_reduce(DictProxyType, (mapping,), obj=obj) log.info("# Mp") return -elif not IS_PYPY: +else: @register(DictProxyType) def save_dictproxy(pickler, obj): log.info("Mp: %s" % obj) @@ -1675,6 +1799,8 @@ def save_module(pickler, obj): def save_type(pickler, obj, postproc_list=None): if obj in _typemap: log.info("T1: %s" % obj) + # if obj in _incedental_types: + # warnings.warn('Type %r may only exist on this implementation of Python and cannot be unpickled in other implementations.' % (obj,), PicklingWarning) pickler.save_reduce(_load_type, (_typemap[obj],), obj=obj) log.info("# T1") elif obj.__bases__ == (tuple,) and all([hasattr(obj, attr) for attr in ('_fields','_asdict','_make','_replace')]): @@ -1777,9 +1903,29 @@ def save_classmethod(pickler, obj): @register(FunctionType) def save_function(pickler, obj): if not _locate_function(obj, pickler): + if type(obj.__code__) is not CodeType: + # Some PyPy builtin functions have no module name, and thus are not + # able to be located + module_name = getattr(obj, '__module__', None) + if module_name is None: + module_name = __builtin__.__name__ + module = _import_module(module_name, safe=True) + _pypy_builtin = False + try: + found, _ = _getattribute(module, obj.__qualname__) + if getattr(found, '__func__', None) is obj: + _pypy_builtin = True + except: + pass + + if _pypy_builtin: + log.info("F3: %s" % obj) + pickler.save_reduce(getattr, (found, '__func__'), obj=obj) + log.info("# F3") + return + log.info("F1: %s" % obj) _recurse = getattr(pickler, '_recurse', None) - _byref = getattr(pickler, '_byref', None) _postproc = getattr(pickler, '_postproc', None) _main_modified = getattr(pickler, '_main_modified', None) _original_main = getattr(pickler, '_original_main', __builtin__)#'None' @@ -1868,6 +2014,55 @@ def save_function(pickler, obj): log.info("# F2") return +if HAS_CTYPES and hasattr(ctypes, 'pythonapi'): + _PyCapsule_New = ctypes.pythonapi.PyCapsule_New + _PyCapsule_New.argtypes = (ctypes.c_void_p, ctypes.c_char_p, ctypes.c_void_p) + _PyCapsule_New.restype = ctypes.py_object + _PyCapsule_GetPointer = ctypes.pythonapi.PyCapsule_GetPointer + _PyCapsule_GetPointer.argtypes = (ctypes.py_object, ctypes.c_char_p) + _PyCapsule_GetPointer.restype = ctypes.c_void_p + _PyCapsule_GetDestructor = ctypes.pythonapi.PyCapsule_GetDestructor + _PyCapsule_GetDestructor.argtypes = (ctypes.py_object,) + _PyCapsule_GetDestructor.restype = ctypes.c_void_p + _PyCapsule_GetContext = ctypes.pythonapi.PyCapsule_GetContext + _PyCapsule_GetContext.argtypes = (ctypes.py_object,) + _PyCapsule_GetContext.restype = ctypes.c_void_p + _PyCapsule_GetName = ctypes.pythonapi.PyCapsule_GetName + _PyCapsule_GetName.argtypes = (ctypes.py_object,) + _PyCapsule_GetName.restype = ctypes.c_char_p + _PyCapsule_IsValid = ctypes.pythonapi.PyCapsule_IsValid + _PyCapsule_IsValid.argtypes = (ctypes.py_object, ctypes.c_char_p) + _PyCapsule_IsValid.restype = ctypes.c_bool + _PyCapsule_SetContext = ctypes.pythonapi.PyCapsule_SetContext + _PyCapsule_SetContext.argtypes = (ctypes.py_object, ctypes.c_void_p) + _PyCapsule_SetDestructor = ctypes.pythonapi.PyCapsule_SetDestructor + _PyCapsule_SetDestructor.argtypes = (ctypes.py_object, ctypes.c_void_p) + _PyCapsule_SetName = ctypes.pythonapi.PyCapsule_SetName + _PyCapsule_SetName.argtypes = (ctypes.py_object, ctypes.c_char_p) + _PyCapsule_SetPointer = ctypes.pythonapi.PyCapsule_SetPointer + _PyCapsule_SetPointer.argtypes = (ctypes.py_object, ctypes.c_void_p) + _testcapsule = _PyCapsule_New( + ctypes.cast(_PyCapsule_New, ctypes.c_void_p), + ctypes.create_string_buffer(b'dill._dill._testcapsule'), + None + ) + PyCapsuleType = type(_testcapsule) + @register(PyCapsuleType) + def save_capsule(pickler, obj): + log.info("Cap: %s", obj) + name = _PyCapsule_GetName(obj) + warnings.warn('Pickling a PyCapsule (%s) does not pickle any C data structures and could cause segmentation faults or other memory errors when unpickling.' % (name,), PicklingWarning) + pointer = _PyCapsule_GetPointer(obj, name) + context = _PyCapsule_GetContext(obj) + destructor = _PyCapsule_GetDestructor(obj) + pickler.save_reduce(_create_capsule, (pointer, name, context, destructor), obj=obj) + log.info("# Cap") + _incedental_reverse_typemap['PyCapsuleType'] = PyCapsuleType + _reverse_typemap['PyCapsuleType'] = PyCapsuleType + _incedental_types.add(PyCapsuleType) +else: + _testcapsule = None + # quick sanity checking def pickles(obj,exact=False,safe=False,**kwds): """ diff --git a/dill/_objects.py b/dill/_objects.py index bd8a2b3b..abdcb14c 100644 --- a/dill/_objects.py +++ b/dill/_objects.py @@ -373,9 +373,44 @@ class _Struct(ctypes.Structure): a['FileType'] = open(os.devnull, 'rb', buffering=0) # same 'wb','wb+','rb+' # FIXME: FileType fails >= 3.1 # built-in functions (CH 2) +# Iterators: a['ListIteratorType'] = iter(_list) # empty vs non-empty FIXME: fail < 3.2 +x['SetIteratorType'] = iter(_set) #XXX: empty vs non-empty a['TupleIteratorType']= iter(_tuple) # empty vs non-empty FIXME: fail < 3.2 a['XRangeIteratorType'] = iter(_xrange) # empty vs non-empty FIXME: fail < 3.2 +a["BytesIteratorType"] = iter(b'') +a["BytearrayIteratorType"] = iter(bytearray(b'')) +a["CallableIteratorType"] = iter(iter, None) +a["MemoryIteratorType"] = iter(memoryview(b'')) +a["ListReverseiteratorType"] = reversed([]) +X = a['OrderedDictType'] +a["OdictKeysType"] = X.keys() +a["OdictValuesType"] = X.values() +a["OdictItemsType"] = X.items() +a["OdictIteratorType"] = iter(X.keys()) +del X +x['DictionaryItemIteratorType'] = iter(type.__dict__.items()) +x['DictionaryKeyIteratorType'] = iter(type.__dict__.keys()) +x['DictionaryValueIteratorType'] = iter(type.__dict__.values()) +if sys.hexversion >= 0x30800a0: + a["DictReversekeyiteratorType"] = reversed({}.keys()) + a["DictReversevalueiteratorType"] = reversed({}.values()) + a["DictReverseitemiteratorType"] = reversed({}.items()) + +try: + import symtable + a["SymtableEntryType"] = symtable.symtable("", "string", "exec")._table +except: + pass + +if sys.hexversion >= 0x30a00a0: + a['LineIteratorType'] = compile('3', '', 'eval').co_lines() + +if sys.hexversion >= 0x30b00b0: + from types import GenericAlias + a["GenericAliasIteratorType"] = iter(GenericAlias(list, (int,))) + a['PositionsIteratorType'] = compile('3', '', 'eval').co_positions() + # data types (CH 8) a['PrettyPrinterType'] = pprint.PrettyPrinter() #FIXME: fail >= 3.2 and == 2.5 # numeric and mathematical types (CH 9) @@ -408,11 +443,7 @@ class _Struct(ctypes.Structure): # other (concrete) object types # (also: Capsule / CObject ?) # built-in functions (CH 2) -x['SetIteratorType'] = iter(_set) #XXX: empty vs non-empty # built-in types (CH 5) -x['DictionaryItemIteratorType'] = iter(type.__dict__.items()) -x['DictionaryKeyIteratorType'] = iter(type.__dict__.keys()) -x['DictionaryValueIteratorType'] = iter(type.__dict__.values()) # string services (CH 7) x['StructType'] = struct.Struct('c') x['CallableIteratorType'] = _srepattern.finditer('') @@ -492,6 +523,11 @@ class _Struct(ctypes.Structure): # oddities: removed, etc x['BufferType'] = x['MemoryType'] +from dill._dill import _testcapsule +if _testcapsule is not None: + x['PyCapsuleType'] = _testcapsule +del _testcapsule + # -- cleanup ---------------------------------------------------------------- a.update(d) # registered also succeed if sys.platform[:3] == 'win': diff --git a/dill/_shims.py b/dill/_shims.py index dfb336f2..bfdb385c 100644 --- a/dill/_shims.py +++ b/dill/_shims.py @@ -1,7 +1,7 @@ #!/usr/bin/env python # # Author: Mike McKerns (mmckerns @caltech and @uqfoundation) -# Author: Anirudh Vegesana (avegesan@stanford.edu) +# Author: Anirudh Vegesana (avegesan@cs.stanford.edu) # Copyright (c) 2021-2022 The Uncertainty Quantification Foundation. # License: 3-clause BSD. The full license text is available at: # - https://github.com/uqfoundation/dill/blob/master/LICENSE @@ -150,11 +150,43 @@ def decorator(func): return func return decorator +def register_shim(name, default): + """ + A easier to understand and more compact way of "softly" defining a function. + These two pieces of code are equivalent: + + if _dill.OLD3X: + def _create_class(): + ... + _create_class = register_shim('_create_class', types.new_class) + + if _dill.OLD3X: + @move_to(_dill) + def _create_class(): + ... + _create_class = Getattr(_dill, '_create_class', types.new_class) + + Intuitively, it creates a function or object in the versions of dill/python + that require special reimplementations, and use a core library or default + implementation if that function or object does not exist. + """ + func = globals().get(name) + if func is not None: + _dill.__dict__[name] = func + func.__module__ = _dill.__name__ + + if default is Getattr.NO_DEFAULT: + reduction = (getattr, (_dill, name)) + else: + reduction = (getattr, (_dill, name, default)) + + return Reduce(*reduction, is_callable=callable(default)) + ###################### ## Compatibility Shims are defined below ###################### -_CELL_EMPTY = Getattr(_dill, '_CELL_EMPTY', None) +_CELL_EMPTY = register_shim('_CELL_EMPTY', None) -_setattr = Getattr(_dill, '_setattr', setattr) -_delattr = Getattr(_dill, '_delattr', delattr) +_setattr = register_shim('_setattr', setattr) +_delattr = register_shim('_delattr', delattr) diff --git a/tests/__main__.py b/tests/__main__.py index e82993c6..312e5dae 100644 --- a/tests/__main__.py +++ b/tests/__main__.py @@ -26,5 +26,5 @@ for test in tests: p = sp.Popen([python, test], shell=shell).wait() if not p: - print('.', end='') + print('.', end='', flush=True) print('') diff --git a/tests/test_classdef.py b/tests/test_classdef.py index 6bc78cbc..8edf5daf 100644 --- a/tests/test_classdef.py +++ b/tests/test_classdef.py @@ -211,6 +211,15 @@ def test_slots(): assert dill.pickles(Y.y) assert dill.copy(y).y == value +def test_attr(): + import attr + @attr.s + class A: + a = attr.ib() + + v = A(1) + assert dill.copy(v) == v + def test_metaclass(): class metaclass_with_new(type): def __new__(mcls, name, bases, ns, **kwds): diff --git a/tests/test_dictviews.py b/tests/test_dictviews.py index 3bbc5d62..87d0a9eb 100644 --- a/tests/test_dictviews.py +++ b/tests/test_dictviews.py @@ -1,13 +1,16 @@ #!/usr/bin/env python # # Author: Mike McKerns (mmckerns @caltech and @uqfoundation) -# Copyright (c) 2008-2016 California Institute of Technology. -# Copyright (c) 2016-2021 The Uncertainty Quantification Foundation. +# Author: Anirudh Vegesana (avegesan@cs.stanford.edu) +# Copyright (c) 2021-2022 The Uncertainty Quantification Foundation. # License: 3-clause BSD. The full license text is available at: # - https://github.com/uqfoundation/dill/blob/master/LICENSE import dill -from dill._dill import OLD310, MAPPING_PROXY_TRICK +from dill._dill import OLD310, MAPPING_PROXY_TRICK, DictProxyType + +def test_dictproxy(): + assert dill.copy(DictProxyType({'a': 2})) def test_dictviews(): x = {'a': 1} @@ -31,5 +34,6 @@ def test_dictproxy_trick(): assert dict(seperate_views[1]) == new_x if __name__ == '__main__': + test_dictproxy() test_dictviews() test_dictproxy_trick() diff --git a/tests/test_functions.py b/tests/test_functions.py index 1bf05026..4e8936c4 100644 --- a/tests/test_functions.py +++ b/tests/test_functions.py @@ -49,6 +49,21 @@ def function_with_unassigned_variable(): return (lambda: value) +def test_issue_510(): + # A very bizzare use of functions and methods that pickle doesn't get + # correctly for odd reasons. + class Foo: + def __init__(self): + def f2(self): + return self + self.f2 = f2.__get__(self) + + import dill, pickletools + f = Foo() + f1 = dill.copy(f) + assert f1.f2() is f1 + + def test_functions(): dumped_func_a = dill.dumps(function_a) assert dill.loads(dumped_func_a)(0) == 0 @@ -97,3 +112,4 @@ def test_functions(): if __name__ == '__main__': test_functions() + test_issue_510() diff --git a/tests/test_objects.py b/tests/test_objects.py index 985041be..c83060c3 100644 --- a/tests/test_objects.py +++ b/tests/test_objects.py @@ -57,6 +57,5 @@ def test_objects(): #pickles(member, exact=True) pickles(member, exact=False) - if __name__ == '__main__': test_objects() diff --git a/tests/test_pycapsule.py b/tests/test_pycapsule.py new file mode 100644 index 00000000..6e115ffd --- /dev/null +++ b/tests/test_pycapsule.py @@ -0,0 +1,45 @@ +#!/usr/bin/env python +# +# Author: Mike McKerns (mmckerns @caltech and @uqfoundation) +# Author: Anirudh Vegesana (avegesan@cs.stanford.edu) +# Copyright (c) 2022 The Uncertainty Quantification Foundation. +# License: 3-clause BSD. The full license text is available at: +# - https://github.com/uqfoundation/dill/blob/master/LICENSE +""" +test pickling a PyCapsule object +""" + +import dill +import warnings + +test_pycapsule = None + +if dill._dill._testcapsule is not None: + import ctypes + def test_pycapsule(): + name = ctypes.create_string_buffer(b'dill._testcapsule') + capsule = dill._dill._PyCapsule_New( + ctypes.cast(dill._dill._PyCapsule_New, ctypes.c_void_p), + name, + None + ) + with warnings.catch_warnings(): + warnings.simplefilter("ignore") + dill.copy(capsule) + dill._testcapsule = capsule + with warnings.catch_warnings(): + warnings.simplefilter("ignore") + dill.copy(capsule) + dill._testcapsule = None + try: + with warnings.catch_warnings(): + warnings.simplefilter("ignore", dill.PicklingWarning) + dill.copy(capsule) + except dill.UnpicklingError: + pass + else: + raise AssertionError("Expected a different error") + +if __name__ == '__main__': + if test_pycapsule is not None: + test_pycapsule() diff --git a/tests/test_session.py b/tests/test_session.py index b2f53728..ed78f7f6 100644 --- a/tests/test_session.py +++ b/tests/test_session.py @@ -208,14 +208,14 @@ def test_objects(main, copy_dict, byref): dump = test_file.getvalue() test_file.close() - sys.modules[modname] = ModuleType(modname) # empty + main = sys.modules[modname] = ModuleType(modname) # empty # This should work after fixing https://github.com/uqfoundation/dill/issues/462 test_file = dill._dill.StringIO(dump) - dill.load_session(test_file) + dill.load_session(test_file, main=main) finally: test_file.close() - assert x == 42 + assert main.x == 42 # Dump session for module that is not __main__: