diff --git a/src/deferred/_core.py b/src/deferred/_core.py index 628b931..9d59099 100644 --- a/src/deferred/_core.py +++ b/src/deferred/_core.py @@ -388,15 +388,15 @@ class DeferredImportProxy: def __init__( self, name: str, - global_ns: dict[str, object], - local_ns: dict[str, object], - fromlist: _tp.Optional[tuple[str, ...]], + global_ns: _tp.MutableMapping[str, object], + local_ns: _tp.MutableMapping[str, object], + fromlist: _tp.Sequence[str], level: int = 0, ) -> None: self.defer_proxy_name = name self.defer_proxy_global_ns = global_ns self.defer_proxy_local_ns = local_ns - self.defer_proxy_fromlist: tuple[str, ...] = fromlist if (fromlist is not None) else () + self.defer_proxy_fromlist = fromlist self.defer_proxy_level = level # Only used in cases of non-from-import submodule aliasing a la "import a.b as c". @@ -446,7 +446,7 @@ class DeferredImportKey(str): def __new__(cls, key: str, proxy: DeferredImportProxy, /): return super().__new__(cls, key) - def __init__(self, key: str, proxy: DeferredImportProxy, /): + def __init__(self, key: str, proxy: DeferredImportProxy, /) -> None: self.defer_key_str = str(key) self.defer_key_proxy = proxy self.is_recursing = False @@ -460,6 +460,7 @@ def __eq__(self, value: object, /) -> bool: if self.defer_key_str != value: return False + # The recursion guard allows self-referential imports within __init__.py files. if not self.is_recursing: self.is_recursing = True @@ -483,28 +484,27 @@ def _resolve(self) -> None: module_vars = vars(module) for attr_key, attr_val in vars(proxy).items(): if isinstance(attr_val, DeferredImportProxy) and not hasattr(module, attr_key): - # NOTE: This originally used setattr(), but I found that pypy normalizes the attr name to a str, losing - # the DeferredImportKey properties. + # NOTE: This could have used setattr() if pypy didn't normalize the attr name to a str, so we must + # resort to direct placement in the module's __dict__ to avoid that. module_vars[DeferredImportKey(attr_key, attr_val)] = attr_val # Change the namespaces as well to make sure nested proxies are replaced in the right place. - attr_val.defer_proxy_global_ns = attr_val.defer_proxy_local_ns = vars(module) + attr_val.defer_proxy_global_ns = attr_val.defer_proxy_local_ns = module_vars # Replace the proxy with the resolved module or module attribute in the relevant namespace. - - # First, get the regular string key and the relevant namespace. + # 1. Let the regular string key and the relevant namespace. key = self.defer_key_str namespace = proxy.defer_proxy_local_ns - # Second, remove the deferred key to avoid it sticking around. + # 2. Replace the deferred version of the key to avoid it sticking around. # NOTE: This is necessary to prevent recursive resolution for proxies, since __eq__ will be triggered again. _is_def_tok = is_deferred.set(True) try: - # TODO: Figure out why this works and del namespace[key] doesn't. + # TODO: Figure out why this works, but del namespace[key] doesn't. namespace[key] = namespace.pop(key) finally: is_deferred.reset(_is_def_tok) - # Finally, resolve any requested attribute access. + # 3. Resolve any requested attribute access. if proxy.defer_proxy_fromlist: namespace[key] = getattr(module, proxy.defer_proxy_fromlist[0]) elif proxy.defer_proxy_sub: @@ -513,7 +513,7 @@ def _resolve(self) -> None: namespace[key] = module -def calc___package__(globals: dict[str, _tp.Any]) -> _tp.Optional[str]: +def calc___package__(globals: _tp.MutableMapping[str, _tp.Any]) -> _tp.Optional[str]: """Calculate what __package__ should be. __package__ is not guaranteed to be defined or could be set to None @@ -567,14 +567,16 @@ def resolve_name(name: str, package: str, level: int) -> str: def deferred___import__( # noqa: ANN202 name: str, - globals: dict[str, object], - locals: dict[str, object], - fromlist: _tp.Optional[tuple[str, ...]] = None, + globals: _tp.MutableMapping[str, object], + locals: _tp.MutableMapping[str, object], + fromlist: _tp.Optional[_tp.Sequence[str]] = None, level: int = 0, /, ): """An limited replacement for __import__ that supports deferred imports by returning proxies.""" + fromlist = fromlist or () + # Resolve the names of relative imports. if level > 0: package = calc___package__(locals) diff --git a/src/deferred/_typing.py b/src/deferred/_typing.py index 61966e1..bac82ea 100644 --- a/src/deferred/_typing.py +++ b/src/deferred/_typing.py @@ -14,8 +14,10 @@ "Generator", "Iterable", "ModuleType", + "MutableMapping", "Optional", "ReadableBuffer", + "Sequence", "StrPath", "Union", "final", @@ -86,6 +88,12 @@ def __getattr__(name: str) -> object: # pragma: no cover # noqa: PLR0911, PLR0 globals()["ModuleType"] = ModuleType return ModuleType + if name == "MutableMapping": + from collections.abc import MutableMapping + + globals()["MutableMapping"] = MutableMapping + return MutableMapping + if name == "Optional": from typing import Optional @@ -103,6 +111,12 @@ def __getattr__(name: str) -> object: # pragma: no cover # noqa: PLR0911, PLR0 globals()["ReadableBuffer"] = ReadableBuffer return ReadableBuffer + if name == "Sequence": + from collections.abc import Sequence + + globals()["Sequence"] = Sequence + return Sequence + if name == "StrPath": import os from typing import Union diff --git a/src/deferred/_typing.pyi b/src/deferred/_typing.pyi index a179044..00f98b0 100644 --- a/src/deferred/_typing.pyi +++ b/src/deferred/_typing.pyi @@ -11,14 +11,16 @@ __all__ = ( "Generator", "Iterable", "ModuleType", + "MutableMapping", "Optional", "ReadableBuffer", + "Sequence", "StrPath", "Union", "final", ) -from collections.abc import Generator, Iterable +from collections.abc import Generator, Iterable, MutableMapping, Sequence from types import CodeType, ModuleType from typing import TYPE_CHECKING as TYPING, Any, Final, Optional, TypeVar, Union, final diff --git a/tests/test_deferred.py b/tests/test_deferred.py index d7fca07..12ad678 100644 --- a/tests/test_deferred.py +++ b/tests/test_deferred.py @@ -765,7 +765,10 @@ def Y2(): def test_thread_safety(tmp_path: Path): - """Test if trying to access a lazily loaded import from multiple threads causes race conditions.""" + """Test if trying to access a lazily loaded import from multiple threads causes race conditions. + + Based on a test for importlib.util.LazyLoader in the CPython test suite. + """ source = """\ from deferred import defer_imports_until_use @@ -798,7 +801,7 @@ def access_module_attr() -> Any: threads: list[RaisingThread] = [] - for _ in range(20): + for _ in range(10): thread = RaisingThread(target=access_module_attr) threads.append(thread) thread.start()