diff --git a/README.rst b/README.rst index fc8c673..54e0c7e 100644 --- a/README.rst +++ b/README.rst @@ -43,7 +43,7 @@ See the docstrings and comments in the codebase for more details. Setup ----- -To do its work, ``defer-imports`` must hook into the Python import system. To do that, include the following call somewhere such that it will be executed before your code: +To do its work, ``defer-imports`` must hook into the Python import system. Include the following call somewhere such that it will be executed before your code: .. code-block:: python diff --git a/src/defer_imports/__init__.py b/src/defer_imports/__init__.py index b8e8603..a6603a6 100644 --- a/src/defer_imports/__init__.py +++ b/src/defer_imports/__init__.py @@ -314,6 +314,13 @@ def _calc___package__(globals: typing.MutableMapping[str, typing.Any]) -> typing """Custom header for defer_imports-instrumented bytecode files. Should be updated with every version release.""" +_is_loaded_using_defer = False +"""Whether the defer_imports import loader is being used to load a module.""" + +_is_loaded_lock = threading.Lock() +"""A lock to guard reading from and writing to _is_loaded_using_defer.""" + + class _DeferredInstrumenter: """AST transformer that instruments imports within "with defer_imports.until_use: ..." blocks so that their results are assigned to custom keys in the global namespace. @@ -806,10 +813,21 @@ def source_to_code(self, data: _SourceData, path: _ModulePath, *, _optimize: int def exec_module(self, module: types.ModuleType) -> None: """Execute the module, but only after getting state from module.__spec__.loader_state if present.""" + global _is_loaded_using_defer # noqa: PLW0603 + if (spec := module.__spec__) and spec.loader_state is not None: self.defer_module_level = spec.loader_state["defer_module_level"] - return super().exec_module(module) + # Signal to defer_imports.until_use that it's not a no-op during this module's execution. + with _is_loaded_lock: + _temp = _is_loaded_using_defer + _is_loaded_using_defer = True + + try: + return super().exec_module(module) + finally: + with _is_loaded_lock: + _is_loaded_using_defer = _temp class _DeferredFileFinder(FileFinder): @@ -1204,19 +1222,22 @@ class DeferredContext: As part of its implementation, this temporarily replaces builtins.__import__. """ - __slots__ = ("_import_ctx_token", "_defer_ctx_token") - - # TODO: Have this turn into a no-op when not being executed with a defer_imports loader. + __slots__ = ("_is_active", "_import_ctx_token", "_defer_ctx_token") def __enter__(self) -> None: - self._defer_ctx_token = _is_deferred.set(True) - self._import_ctx_token = _original_import.set(builtins.__import__) - builtins.__import__ = _deferred___import__ + with _is_loaded_lock: + self._is_active = _is_loaded_using_defer + + if self._is_active: + self._defer_ctx_token = _is_deferred.set(True) + self._import_ctx_token = _original_import.set(builtins.__import__) + builtins.__import__ = _deferred___import__ def __exit__(self, *exc_info: object) -> None: - _original_import.reset(self._import_ctx_token) - _is_deferred.reset(self._defer_ctx_token) - builtins.__import__ = _original_import.get() + if self._is_active: + _original_import.reset(self._import_ctx_token) + _is_deferred.reset(self._defer_ctx_token) + builtins.__import__ = _original_import.get() until_use: typing.Final[DeferredContext] = DeferredContext() diff --git a/tests/test_defer_imports.py b/tests/test_defer_imports.py index 68e7859..d50eccb 100644 --- a/tests/test_defer_imports.py +++ b/tests/test_defer_imports.py @@ -1114,12 +1114,20 @@ def Y2(): def test_import_stdlib(): - """Test that we can import most of the stdlib.""" + """Test that defer_imports.until_use works when wrapping imports for most of the stdlib.""" - import tests.sample_stdlib_imports + # The finder for tests.sample_stdlib_imports is already cached, so we need to temporarily reset that cache. + _temp_cache = dict(sys.path_importer_cache) + sys.path_importer_cache.clear() + + with install_import_hook(uninstall_after=True): + import tests.sample_stdlib_imports assert tests.sample_stdlib_imports + # Revert changes to the path finder cache. + sys.path_importer_cache = _temp_cache + @pytest.mark.skip(reason="Leaking patch problem is currently out of scope.") def test_leaking_patch(tmp_path: Path): # pragma: no cover