Skip to content

Commit

Permalink
Switch to more generic types for various inputs.
Browse files Browse the repository at this point in the history
  • Loading branch information
Sachaa-Thanasius committed Aug 28, 2024
1 parent eea438e commit b7da392
Show file tree
Hide file tree
Showing 4 changed files with 41 additions and 20 deletions.
36 changes: 19 additions & 17 deletions src/deferred/_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -388,15 +388,15 @@ class DeferredImportProxy:
def __init__(
self,
name: str,
global_ns: dict[str, object],
local_ns: dict[str, object],
fromlist: _tp.Optional[tuple[str, ...]],
global_ns: _tp.MutableMapping[str, object],
local_ns: _tp.MutableMapping[str, object],
fromlist: _tp.Sequence[str],
level: int = 0,
) -> None:
self.defer_proxy_name = name
self.defer_proxy_global_ns = global_ns
self.defer_proxy_local_ns = local_ns
self.defer_proxy_fromlist: tuple[str, ...] = fromlist if (fromlist is not None) else ()
self.defer_proxy_fromlist = fromlist
self.defer_proxy_level = level

# Only used in cases of non-from-import submodule aliasing a la "import a.b as c".
Expand Down Expand Up @@ -446,7 +446,7 @@ class DeferredImportKey(str):
def __new__(cls, key: str, proxy: DeferredImportProxy, /):
return super().__new__(cls, key)

def __init__(self, key: str, proxy: DeferredImportProxy, /):
def __init__(self, key: str, proxy: DeferredImportProxy, /) -> None:
self.defer_key_str = str(key)
self.defer_key_proxy = proxy
self.is_recursing = False
Expand All @@ -460,6 +460,7 @@ def __eq__(self, value: object, /) -> bool:
if self.defer_key_str != value:
return False

# The recursion guard allows self-referential imports within __init__.py files.
if not self.is_recursing:
self.is_recursing = True

Expand All @@ -483,28 +484,27 @@ def _resolve(self) -> None:
module_vars = vars(module)
for attr_key, attr_val in vars(proxy).items():
if isinstance(attr_val, DeferredImportProxy) and not hasattr(module, attr_key):
# NOTE: This originally used setattr(), but I found that pypy normalizes the attr name to a str, losing
# the DeferredImportKey properties.
# NOTE: This could have used setattr() if pypy didn't normalize the attr name to a str, so we must
# resort to direct placement in the module's __dict__ to avoid that.
module_vars[DeferredImportKey(attr_key, attr_val)] = attr_val
# Change the namespaces as well to make sure nested proxies are replaced in the right place.
attr_val.defer_proxy_global_ns = attr_val.defer_proxy_local_ns = vars(module)
attr_val.defer_proxy_global_ns = attr_val.defer_proxy_local_ns = module_vars

# Replace the proxy with the resolved module or module attribute in the relevant namespace.

# First, get the regular string key and the relevant namespace.
# 1. Let the regular string key and the relevant namespace.
key = self.defer_key_str
namespace = proxy.defer_proxy_local_ns

# Second, remove the deferred key to avoid it sticking around.
# 2. Replace the deferred version of the key to avoid it sticking around.
# NOTE: This is necessary to prevent recursive resolution for proxies, since __eq__ will be triggered again.
_is_def_tok = is_deferred.set(True)
try:
# TODO: Figure out why this works and del namespace[key] doesn't.
# TODO: Figure out why this works, but del namespace[key] doesn't.
namespace[key] = namespace.pop(key)
finally:
is_deferred.reset(_is_def_tok)

# Finally, resolve any requested attribute access.
# 3. Resolve any requested attribute access.
if proxy.defer_proxy_fromlist:
namespace[key] = getattr(module, proxy.defer_proxy_fromlist[0])
elif proxy.defer_proxy_sub:
Expand All @@ -513,7 +513,7 @@ def _resolve(self) -> None:
namespace[key] = module


def calc___package__(globals: dict[str, _tp.Any]) -> _tp.Optional[str]:
def calc___package__(globals: _tp.MutableMapping[str, _tp.Any]) -> _tp.Optional[str]:
"""Calculate what __package__ should be.
__package__ is not guaranteed to be defined or could be set to None
Expand Down Expand Up @@ -567,14 +567,16 @@ def resolve_name(name: str, package: str, level: int) -> str:

def deferred___import__( # noqa: ANN202
name: str,
globals: dict[str, object],
locals: dict[str, object],
fromlist: _tp.Optional[tuple[str, ...]] = None,
globals: _tp.MutableMapping[str, object],
locals: _tp.MutableMapping[str, object],
fromlist: _tp.Optional[_tp.Sequence[str]] = None,
level: int = 0,
/,
):
"""An limited replacement for __import__ that supports deferred imports by returning proxies."""

fromlist = fromlist or ()

# Resolve the names of relative imports.
if level > 0:
package = calc___package__(locals)
Expand Down
14 changes: 14 additions & 0 deletions src/deferred/_typing.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,10 @@
"Generator",
"Iterable",
"ModuleType",
"MutableMapping",
"Optional",
"ReadableBuffer",
"Sequence",
"StrPath",
"Union",
"final",
Expand Down Expand Up @@ -86,6 +88,12 @@ def __getattr__(name: str) -> object: # pragma: no cover # noqa: PLR0911, PLR0
globals()["ModuleType"] = ModuleType
return ModuleType

if name == "MutableMapping":
from collections.abc import MutableMapping

globals()["MutableMapping"] = MutableMapping
return MutableMapping

if name == "Optional":
from typing import Optional

Expand All @@ -103,6 +111,12 @@ def __getattr__(name: str) -> object: # pragma: no cover # noqa: PLR0911, PLR0
globals()["ReadableBuffer"] = ReadableBuffer
return ReadableBuffer

if name == "Sequence":
from collections.abc import Sequence

globals()["Sequence"] = Sequence
return Sequence

if name == "StrPath":
import os
from typing import Union
Expand Down
4 changes: 3 additions & 1 deletion src/deferred/_typing.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -11,14 +11,16 @@ __all__ = (
"Generator",
"Iterable",
"ModuleType",
"MutableMapping",
"Optional",
"ReadableBuffer",
"Sequence",
"StrPath",
"Union",
"final",
)

from collections.abc import Generator, Iterable
from collections.abc import Generator, Iterable, MutableMapping, Sequence
from types import CodeType, ModuleType
from typing import TYPE_CHECKING as TYPING, Any, Final, Optional, TypeVar, Union, final

Expand Down
7 changes: 5 additions & 2 deletions tests/test_deferred.py
Original file line number Diff line number Diff line change
Expand Up @@ -765,7 +765,10 @@ def Y2():


def test_thread_safety(tmp_path: Path):
"""Test if trying to access a lazily loaded import from multiple threads causes race conditions."""
"""Test if trying to access a lazily loaded import from multiple threads causes race conditions.
Based on a test for importlib.util.LazyLoader in the CPython test suite.
"""

source = """\
from deferred import defer_imports_until_use
Expand Down Expand Up @@ -798,7 +801,7 @@ def access_module_attr() -> Any:

threads: list[RaisingThread] = []

for _ in range(20):
for _ in range(10):
thread = RaisingThread(target=access_module_attr)
threads.append(thread)
thread.start()
Expand Down

0 comments on commit b7da392

Please sign in to comment.