Skip to content

Commit

Permalink
[torch/package_importer] add compatibility name mapping (pytorch#134376)
Browse files Browse the repository at this point in the history
Summary:
This enables patching extern modules to provide compatibility with serialized code depending on different versions of those extern modules.

The main motivation is to enable Numpy upgrade. In the recent release many alias to builtin types were deprecated and removed [1]. This breaks loading pickled modules that reference the removed aliases. While the proper solution is to re-generate pickled modules, it's not always feasible.

This proposes a way to define mapping with a new type, for a module member. It is only set if it's not present in the loaded module, thus removes the need to check for exact versions.

https://numpy.org/doc/stable/release/1.20.0-notes.html#using-the-aliases-of-builtin-types-like-np-int-is-deprecated

Differential Revision: D61556888

Pull Request resolved: pytorch#134376
Approved by: https://github.com/SherlockNoMad
  • Loading branch information
igorsugak authored and pytorchmergebot committed Aug 25, 2024
1 parent 8160618 commit 7940f24
Showing 1 changed file with 17 additions and 0 deletions.
17 changes: 17 additions & 0 deletions torch/package/package_importer.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,18 @@
]


# Compatibility name mapping to facilitate upgrade of external modules.
# The primary motivation is to enable Numpy upgrade that many modules
# depend on. The latest release of Numpy removed `numpy.str` and
# `numpy.bool` breaking unpickling for many modules.
EXTERN_IMPORT_COMPAT_NAME_MAPPING: Dict[str, Dict[str, Any]] = {
"numpy": {
"str": str,
"bool": bool,
},
}


class PackageImporter(Importer):
"""Importers allow you to load code written to packages by :class:`PackageExporter`.
Code is loaded in a hermetic way, using files from the package
Expand Down Expand Up @@ -410,6 +422,11 @@ def _load_module(self, name: str, parent: str):
cur = cur.children[atom]
if isinstance(cur, _ExternNode):
module = self.modules[name] = importlib.import_module(name)

if compat_mapping := EXTERN_IMPORT_COMPAT_NAME_MAPPING.get(name):
for old_name, new_name in compat_mapping.items():
module.__dict__.setdefault(old_name, new_name)

return module
return self._make_module(name, cur.source_file, isinstance(cur, _PackageNode), parent) # type: ignore[attr-defined]

Expand Down

0 comments on commit 7940f24

Please sign in to comment.