diff --git a/pyproject.toml b/pyproject.toml index ca52cf1..a12cd50 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -154,6 +154,7 @@ extend-ignore = [ unfixable = [ "ERA", # Prevent unlikely erroneous deletion. ] +typing-modules = ["deferred._typing"] [tool.ruff.lint.isort] lines-after-imports = 2 @@ -184,7 +185,7 @@ keep-runtime-typing = true ] "tests/stdlib_imports.py" = [ "F401", # Unused imports are fine; we're testing import success. - "ERA001", # Plenty of imports are commented out with explanation in comment. + "ERA001", # Plenty of imports are commented out with explanations next to them. ] "benchmark/**/*.py" = [ "T201", # Printing is fine. diff --git a/src/deferred/_core.py b/src/deferred/_core.py index 4906a61..d972b40 100644 --- a/src/deferred/_core.py +++ b/src/deferred/_core.py @@ -31,11 +31,11 @@ # region -------- Compile-time hook +SourceDataType: _tp.TypeAlias = "_tp.Union[_tp.ReadableBuffer, str, ast.Module, ast.Expression, ast.Interactive]" + BYTECODE_HEADER = f"deferred{__version__}".encode() """Custom header for deferred-instrumented bytecode files. Should be updated with every version release.""" -SourceDataType: _tp.TypeAlias = "_tp.Union[_tp.ReadableBuffer, str, ast.Module, ast.Expression, ast.Interactive]" - class DeferredInstrumenter(ast.NodeTransformer): """AST transformer that "instruments" imports within "with defer_imports_until_use: ..." blocks so that their @@ -468,8 +468,10 @@ def __getattr__(self, name: str, /): if name in self.defer_proxy_fromlist: sub_proxy.defer_proxy_fromlist = (name,) - else: + elif name == self.defer_proxy_name.rpartition(".")[2]: sub_proxy.defer_proxy_sub = name + else: + raise AttributeError(name) return sub_proxy @@ -480,7 +482,7 @@ class DeferredImportKey(str): When referenced, the key will replace itself in the namespace with the resolved import or the right name from it. """ - __slots__ = ("defer_key_str", "defer_key_proxy", "is_recursing", "_lock") + __slots__ = ("defer_key_str", "defer_key_proxy", "is_recursing", "_rlock") def __new__(cls, key: str, proxy: DeferredImportProxy, /): return super().__new__(cls, key) @@ -490,6 +492,8 @@ def __init__(self, key: str, proxy: DeferredImportProxy, /): self.defer_key_proxy = proxy self.is_recursing = False + self._rlock = original_import.get()("threading").RLock() + def __repr__(self) -> str: return f"" @@ -499,12 +503,18 @@ def __eq__(self, value: object, /) -> bool: if self.defer_key_str != value: return False - # The recursion guard allows self-referential imports within __init__.py files. - if not self.is_recursing: - self.is_recursing = True + # NOTE: This RLock prevents a scenario where the proxy begins resolution in one thread, but before it completes + # resolution, a context switch occurs and another thread that tries to resolve the same proxy just gets + # the proxy back before it has been resolved and replaced. This also partially happens because + # is_recursing is a guard that is only intended for one thread, but other threads will see it without the + # RLock. + with self._rlock: + # This recursion guard allows self-referential imports within __init__.py files. + if not self.is_recursing: + self.is_recursing = True - if not is_deferred.get(): - self._resolve() + if not is_deferred.get(): + self._resolve() return True @@ -518,8 +528,6 @@ def _resolve(self) -> None: # Perform the original __import__ and pray. module: _tp.ModuleType = original_import.get()(*proxy.defer_proxy_import_args) - # FIXME: The below print is enough to trigger a race condition. - # print(f"{module=}, {proxy.defer_proxy_name=}, {proxy.defer_proxy_fromlist=}, {proxy.defer_proxy_sub=}") # noqa: ERA001 # Transfer nested proxies over to the resolved module. module_vars = vars(module) @@ -540,7 +548,6 @@ def _resolve(self) -> None: # NOTE: This is necessary to prevent recursive resolution for proxies, since __eq__ will be triggered again. _is_def_tok = is_deferred.set(True) try: - # TODO: Figure out why this works, but del namespace[key] doesn't. namespace[key] = namespace.pop(key) finally: is_deferred.reset(_is_def_tok) diff --git a/tests/test_deferred.py b/tests/test_deferred.py index b274626..d3d7c68 100644 --- a/tests/test_deferred.py +++ b/tests/test_deferred.py @@ -845,14 +845,6 @@ def access_module_attr() -> object: for thread in threads: thread.join() assert callable(thread.result) # pyright: ignore - # FIXME: Error on various versions; sometimes the accessed signature isn't resolved? - # Recreatable by putting "print(module)" after line 525: - # module: _tp.ModuleType = original_import.get()(*proxy.defer_proxy_import_args) # noqa: ERA001 - # - # The error: - # AssertionError: assert False - # + where False = callable() - # + where = .result assert thread.exc is None